Browse Source

[feat-4496][server] Add to! {} is used to mark the custom parameters to be output as-is in sql (#4497)

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496
pull/3/MERGE
liuxuedongcn 3 years ago committed by GitHub
parent
commit
43586da376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 181
      dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java

181
dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java

@ -18,7 +18,9 @@ package org.apache.dolphinscheduler.server.worker.task.sql;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.lang.StringUtils;
import org.apache.dolphinscheduler.alert.utils.MailUtils;
import org.apache.dolphinscheduler.common.Constants;
import org.apache.dolphinscheduler.common.enums.*;
@ -41,6 +43,7 @@ import org.apache.dolphinscheduler.server.utils.ParamUtils;
import org.apache.dolphinscheduler.server.utils.UDFUtils;
import org.apache.dolphinscheduler.server.worker.task.AbstractTask;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.slf4j.Logger;
import java.sql.*;
@ -51,17 +54,18 @@ import java.util.stream.Collectors;
import static org.apache.dolphinscheduler.common.Constants.*;
import static org.apache.dolphinscheduler.common.enums.DbType.HIVE;
/**
* sql task
*/
public class SqlTask extends AbstractTask {
/**
* sql parameters
* sql parameters
*/
private SqlParameters sqlParameters;
/**
* alert dao
* alert dao
*/
private AlertDao alertDao;
/**
@ -148,10 +152,11 @@ public class SqlTask extends AbstractTask {
/**
* ready to execute SQL and parameter entity Map
*
* @return SqlBinds
*/
private SqlBinds getSqlAndSqlParamsMap(String sql) {
Map<Integer,Property> sqlParamsMap = new HashMap<>();
Map<Integer, Property> sqlParamsMap = new HashMap<>();
StringBuilder sqlBuilder = new StringBuilder();
// find process instance by task id
@ -164,25 +169,27 @@ public class SqlTask extends AbstractTask {
taskExecutionContext.getScheduleTime());
// spell SQL according to the final user-defined variable
if(paramsMap == null){
if (paramsMap == null) {
sqlBuilder.append(sql);
return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
}
if (StringUtils.isNotEmpty(sqlParameters.getTitle())){
if (StringUtils.isNotEmpty(sqlParameters.getTitle())) {
String title = ParameterUtils.convertParameterPlaceholders(sqlParameters.getTitle(),
ParamUtils.convert(paramsMap));
logger.info("SQL title : {}",title);
logger.info("SQL title : {}", title);
sqlParameters.setTitle(title);
}
//new
//replace variable TIME with $[YYYYmmddd...] in sql when history run job and batch complement job
sql = ParameterUtils.replaceScheduleTime(sql, taskExecutionContext.getScheduleTime());
// special characters need to be escaped, ${} needs to be escaped
String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*";
setSqlParamsMap(sql, rgex, sqlParamsMap, paramsMap);
//Replace the original value in sql !{...} ,Does not participate in precompilation
String rgexo = "['\"]*\\!\\{(.*?)\\}['\"]*";
sql = replaceOriginalValue(sql, rgexo, paramsMap);
// replace the ${} of the SQL statement with the Placeholder
String formatSql = sql.replaceAll(rgex, "?");
sqlBuilder.append(formatSql);
@ -192,6 +199,20 @@ public class SqlTask extends AbstractTask {
return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
}
public String replaceOriginalValue(String content, String rgex, Map<String, Property> sqlParamsMap) {
Pattern pattern = Pattern.compile(rgex);
while (true) {
Matcher m = pattern.matcher(content);
if (!m.find()) {
break;
}
String paramName = m.group(1);
String paramValue = sqlParamsMap.get(paramName).getValue();
content = m.replaceFirst(paramValue);
}
return content;
}
@Override
public AbstractParameters getParameters() {
return this.sqlParameters;
@ -199,15 +220,16 @@ public class SqlTask extends AbstractTask {
/**
* execute function and sql
* @param mainSqlBinds main sql binds
* @param preStatementsBinds pre statements binds
* @param postStatementsBinds post statements binds
* @param createFuncs create functions
*
* @param mainSqlBinds main sql binds
* @param preStatementsBinds pre statements binds
* @param postStatementsBinds post statements binds
* @param createFuncs create functions
*/
public void executeFuncAndSql(SqlBinds mainSqlBinds,
List<SqlBinds> preStatementsBinds,
List<SqlBinds> postStatementsBinds,
List<String> createFuncs){
List<SqlBinds> preStatementsBinds,
List<SqlBinds> postStatementsBinds,
List<String> createFuncs) {
Connection connection = null;
PreparedStatement stmt = null;
ResultSet resultSet = null;
@ -218,11 +240,11 @@ public class SqlTask extends AbstractTask {
connection = createConnection();
// create temp function
if (CollectionUtils.isNotEmpty(createFuncs)) {
createTempFunction(connection,createFuncs);
createTempFunction(connection, createFuncs);
}
// pre sql
preSql(connection,preStatementsBinds);
preSql(connection, preStatementsBinds);
stmt = prepareStatementAndBind(connection, mainSqlBinds);
// decide whether to executeQuery or executeUpdate based on sqlType
@ -236,13 +258,13 @@ public class SqlTask extends AbstractTask {
stmt.executeUpdate();
}
postSql(connection,postStatementsBinds);
postSql(connection, postStatementsBinds);
} catch (Exception e) {
logger.error("execute sql error",e);
logger.error("execute sql error", e);
throw new RuntimeException("execute sql error");
} finally {
close(resultSet,stmt,connection);
close(resultSet, stmt, connection);
}
}
@ -252,7 +274,7 @@ public class SqlTask extends AbstractTask {
* @param resultSet resultSet
* @throws Exception Exception
*/
private void resultProcess(ResultSet resultSet) throws Exception{
private void resultProcess(ResultSet resultSet) throws Exception {
ArrayNode resultJSONArray = JSONUtils.createArrayNode();
ResultSetMetaData md = resultSet.getMetaData();
int num = md.getColumnCount();
@ -271,22 +293,22 @@ public class SqlTask extends AbstractTask {
logger.debug("execute sql : {}", result);
sendAttachment(StringUtils.isNotEmpty(sqlParameters.getTitle()) ?
sqlParameters.getTitle(): taskExecutionContext.getTaskName() + " query result sets",
sqlParameters.getTitle() : taskExecutionContext.getTaskName() + " query result sets",
JSONUtils.toJsonString(resultJSONArray));
}
/**
* pre sql
* pre sql
*
* @param connection connection
* @param preStatementsBinds preStatementsBinds
*/
private void preSql(Connection connection,
List<SqlBinds> preStatementsBinds) throws Exception{
for (SqlBinds sqlBind: preStatementsBinds) {
try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){
List<SqlBinds> preStatementsBinds) throws Exception {
for (SqlBinds sqlBind : preStatementsBinds) {
try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) {
int result = pstmt.executeUpdate();
logger.info("pre statement execute result: {}, for sql: {}",result,sqlBind.getSql());
logger.info("pre statement execute result: {}, for sql: {}", result, sqlBind.getSql());
}
}
@ -297,26 +319,25 @@ public class SqlTask extends AbstractTask {
*
* @param connection connection
* @param postStatementsBinds postStatementsBinds
* @throws Exception
*/
private void postSql(Connection connection,
List<SqlBinds> postStatementsBinds) throws Exception{
for (SqlBinds sqlBind: postStatementsBinds) {
try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){
List<SqlBinds> postStatementsBinds) throws Exception {
for (SqlBinds sqlBind : postStatementsBinds) {
try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) {
int result = pstmt.executeUpdate();
logger.info("post statement execute result: {},for sql: {}",result,sqlBind.getSql());
logger.info("post statement execute result: {},for sql: {}", result, sqlBind.getSql());
}
}
}
/**
* create temp function
*
* @param connection connection
* @param createFuncs createFuncs
* @throws Exception
*/
private void createTempFunction(Connection connection,
List<String> createFuncs) throws Exception{
List<String> createFuncs) throws Exception {
try (Statement funcStmt = connection.createStatement()) {
for (String createFunc : createFuncs) {
logger.info("hive create function sql: {}", createFunc);
@ -324,14 +345,14 @@ public class SqlTask extends AbstractTask {
}
}
}
/**
* create connection
*
* @return connection
* @throws Exception Exception
*/
private Connection createConnection() throws Exception{
private Connection createConnection() throws Exception {
// if hive , load connection params if exists
Connection connection = null;
if (HIVE == DbType.valueOf(sqlParameters.getType())) {
@ -345,7 +366,7 @@ public class SqlTask extends AbstractTask {
connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(),
paramProp);
}else{
} else {
connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(),
baseDataSource.getUser(),
baseDataSource.getPassword());
@ -354,7 +375,7 @@ public class SqlTask extends AbstractTask {
}
/**
* close jdbc resource
* close jdbc resource
*
* @param resultSet resultSet
* @param pstmt pstmt
@ -362,36 +383,37 @@ public class SqlTask extends AbstractTask {
*/
private void close(ResultSet resultSet,
PreparedStatement pstmt,
Connection connection){
if (resultSet != null){
Connection connection) {
if (resultSet != null) {
try {
resultSet.close();
} catch (SQLException e) {
logger.error("close result set error : {}",e.getMessage(),e);
logger.error("close result set error : {}", e.getMessage(), e);
}
}
if (pstmt != null){
if (pstmt != null) {
try {
pstmt.close();
} catch (SQLException e) {
logger.error("close prepared statement error : {}",e.getMessage(),e);
logger.error("close prepared statement error : {}", e.getMessage(), e);
}
}
if (connection != null){
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
logger.error("close connection error : {}",e.getMessage(),e);
logger.error("close connection error : {}", e.getMessage(), e);
}
}
}
/**
* preparedStatement bind
*
* @param connection connection
* @param sqlBinds sqlBinds
* @param sqlBinds sqlBinds
* @return PreparedStatement
* @throws Exception Exception
*/
@ -400,11 +422,11 @@ public class SqlTask extends AbstractTask {
boolean timeoutFlag = TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.FAILED ||
TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.WARNFAILED;
PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql());
if(timeoutFlag){
if (timeoutFlag) {
stmt.setQueryTimeout(taskExecutionContext.getTaskTimeout());
}
Map<Integer, Property> params = sqlBinds.getParamsMap();
if(params != null) {
if (params != null) {
for (Map.Entry<Integer, Property> entry : params.entrySet()) {
Property prop = entry.getValue();
ParameterUtils.setInParameter(entry.getKey(), stmt, prop.getType(), prop.getValue());
@ -416,23 +438,24 @@ public class SqlTask extends AbstractTask {
/**
* send mail as an attachment
* @param title title
* @param content content
*
* @param title title
* @param content content
*/
public void sendAttachment(String title,String content){
public void sendAttachment(String title, String content) {
List<User> users = alertDao.queryUserByAlertGroupId(taskExecutionContext.getSqlTaskExecutionContext().getWarningGroupId());
// receiving group list
List<String> receiversList = new ArrayList<>();
for(User user:users){
for (User user : users) {
receiversList.add(user.getEmail().trim());
}
// custom receiver
String receivers = sqlParameters.getReceivers();
if (StringUtils.isNotEmpty(receivers)){
if (StringUtils.isNotEmpty(receivers)) {
String[] splits = receivers.split(COMMA);
for (String receiver : splits){
for (String receiver : splits) {
receiversList.add(receiver.trim());
}
}
@ -441,60 +464,62 @@ public class SqlTask extends AbstractTask {
List<String> receiversCcList = new ArrayList<>();
// Custom Copier
String receiversCc = sqlParameters.getReceiversCc();
if (StringUtils.isNotEmpty(receiversCc)){
if (StringUtils.isNotEmpty(receiversCc)) {
String[] splits = receiversCc.split(COMMA);
for (String receiverCc : splits){
for (String receiverCc : splits) {
receiversCcList.add(receiverCc.trim());
}
}
String showTypeName = sqlParameters.getShowType().replace(COMMA,"").trim();
if(EnumUtils.isValidEnum(ShowType.class,showTypeName)){
String showTypeName = sqlParameters.getShowType().replace(COMMA, "").trim();
if (EnumUtils.isValidEnum(ShowType.class, showTypeName)) {
Map<String, Object> mailResult = MailUtils.sendMails(receiversList,
receiversCcList, title, content, ShowType.valueOf(showTypeName).getDescp());
if(!(boolean) mailResult.get(STATUS)){
if (!(boolean) mailResult.get(STATUS)) {
throw new RuntimeException("send mail failed!");
}
}else{
logger.error("showType: {} is not valid " ,showTypeName);
throw new RuntimeException(String.format("showType: %s is not valid ",showTypeName));
} else {
logger.error("showType: {} is not valid ", showTypeName);
throw new RuntimeException(String.format("showType: %s is not valid ", showTypeName));
}
}
/**
* regular expressions match the contents between two specified strings
* @param content content
* @param rgex rgex
* @param sqlParamsMap sql params map
* @param paramsPropsMap params props map
*
* @param content content
* @param rgex rgex
* @param sqlParamsMap sql params map
* @param paramsPropsMap params props map
*/
public void setSqlParamsMap(String content, String rgex, Map<Integer,Property> sqlParamsMap, Map<String,Property> paramsPropsMap){
public void setSqlParamsMap(String content, String rgex, Map<Integer, Property> sqlParamsMap, Map<String, Property> paramsPropsMap) {
Pattern pattern = Pattern.compile(rgex);
Matcher m = pattern.matcher(content);
int index = 1;
while (m.find()) {
String paramName = m.group(1);
Property prop = paramsPropsMap.get(paramName);
Property prop = paramsPropsMap.get(paramName);
sqlParamsMap.put(index,prop);
index ++;
sqlParamsMap.put(index, prop);
index++;
}
}
/**
* print replace sql
* @param content content
* @param formatSql format sql
* @param rgex rgex
* @param sqlParamsMap sql params map
*
* @param content content
* @param formatSql format sql
* @param rgex rgex
* @param sqlParamsMap sql params map
*/
public void printReplacedSql(String content, String formatSql,String rgex, Map<Integer,Property> sqlParamsMap){
public void printReplacedSql(String content, String formatSql, String rgex, Map<Integer, Property> sqlParamsMap) {
//parameter print style
logger.info("after replace sql , preparing : {}" , formatSql);
logger.info("after replace sql , preparing : {}", formatSql);
StringBuilder logPrint = new StringBuilder("replaced sql , parameters:");
for(int i=1;i<=sqlParamsMap.size();i++){
logPrint.append(sqlParamsMap.get(i).getValue()+"("+sqlParamsMap.get(i).getType()+")");
for (int i = 1; i <= sqlParamsMap.size(); i++) {
logPrint.append(sqlParamsMap.get(i).getValue() + "(" + sqlParamsMap.get(i).getType() + ")");
}
logger.info("Sql Params are {}", logPrint);
}

Loading…
Cancel
Save