diff --git a/dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java b/dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java index f28f5804b0..523d2e0848 100644 --- a/dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java +++ b/dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java @@ -18,7 +18,9 @@ package org.apache.dolphinscheduler.server.worker.task.sql; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; + import org.apache.commons.lang.StringUtils; + import org.apache.dolphinscheduler.alert.utils.MailUtils; import org.apache.dolphinscheduler.common.Constants; import org.apache.dolphinscheduler.common.enums.*; @@ -41,6 +43,7 @@ import org.apache.dolphinscheduler.server.utils.ParamUtils; import org.apache.dolphinscheduler.server.utils.UDFUtils; import org.apache.dolphinscheduler.server.worker.task.AbstractTask; import org.apache.dolphinscheduler.service.bean.SpringApplicationContext; + import org.slf4j.Logger; import java.sql.*; @@ -51,17 +54,18 @@ import java.util.stream.Collectors; import static org.apache.dolphinscheduler.common.Constants.*; import static org.apache.dolphinscheduler.common.enums.DbType.HIVE; + /** * sql task */ public class SqlTask extends AbstractTask { /** - * sql parameters + * sql parameters */ private SqlParameters sqlParameters; /** - * alert dao + * alert dao */ private AlertDao alertDao; /** @@ -148,10 +152,11 @@ public class SqlTask extends AbstractTask { /** * ready to execute SQL and parameter entity Map + * * @return SqlBinds */ private SqlBinds getSqlAndSqlParamsMap(String sql) { - Map sqlParamsMap = new HashMap<>(); + Map sqlParamsMap = new HashMap<>(); StringBuilder sqlBuilder = new StringBuilder(); // find process instance by task id @@ -164,25 +169,27 @@ public class SqlTask extends AbstractTask { taskExecutionContext.getScheduleTime()); // spell SQL according to the final user-defined variable - if(paramsMap == null){ + if (paramsMap == null) { sqlBuilder.append(sql); return new SqlBinds(sqlBuilder.toString(), sqlParamsMap); } - if (StringUtils.isNotEmpty(sqlParameters.getTitle())){ + if (StringUtils.isNotEmpty(sqlParameters.getTitle())) { String title = ParameterUtils.convertParameterPlaceholders(sqlParameters.getTitle(), ParamUtils.convert(paramsMap)); - logger.info("SQL title : {}",title); + logger.info("SQL title : {}", title); sqlParameters.setTitle(title); } - + //new //replace variable TIME with $[YYYYmmddd...] in sql when history run job and batch complement job sql = ParameterUtils.replaceScheduleTime(sql, taskExecutionContext.getScheduleTime()); // special characters need to be escaped, ${} needs to be escaped String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*"; setSqlParamsMap(sql, rgex, sqlParamsMap, paramsMap); - + //Replace the original value in sql !{...} ,Does not participate in precompilation + String rgexo = "['\"]*\\!\\{(.*?)\\}['\"]*"; + sql = replaceOriginalValue(sql, rgexo, paramsMap); // replace the ${} of the SQL statement with the Placeholder String formatSql = sql.replaceAll(rgex, "?"); sqlBuilder.append(formatSql); @@ -192,6 +199,20 @@ public class SqlTask extends AbstractTask { return new SqlBinds(sqlBuilder.toString(), sqlParamsMap); } + public String replaceOriginalValue(String content, String rgex, Map sqlParamsMap) { + Pattern pattern = Pattern.compile(rgex); + while (true) { + Matcher m = pattern.matcher(content); + if (!m.find()) { + break; + } + String paramName = m.group(1); + String paramValue = sqlParamsMap.get(paramName).getValue(); + content = m.replaceFirst(paramValue); + } + return content; + } + @Override public AbstractParameters getParameters() { return this.sqlParameters; @@ -199,15 +220,16 @@ public class SqlTask extends AbstractTask { /** * execute function and sql - * @param mainSqlBinds main sql binds - * @param preStatementsBinds pre statements binds - * @param postStatementsBinds post statements binds - * @param createFuncs create functions + * + * @param mainSqlBinds main sql binds + * @param preStatementsBinds pre statements binds + * @param postStatementsBinds post statements binds + * @param createFuncs create functions */ public void executeFuncAndSql(SqlBinds mainSqlBinds, - List preStatementsBinds, - List postStatementsBinds, - List createFuncs){ + List preStatementsBinds, + List postStatementsBinds, + List createFuncs) { Connection connection = null; PreparedStatement stmt = null; ResultSet resultSet = null; @@ -218,11 +240,11 @@ public class SqlTask extends AbstractTask { connection = createConnection(); // create temp function if (CollectionUtils.isNotEmpty(createFuncs)) { - createTempFunction(connection,createFuncs); + createTempFunction(connection, createFuncs); } // pre sql - preSql(connection,preStatementsBinds); + preSql(connection, preStatementsBinds); stmt = prepareStatementAndBind(connection, mainSqlBinds); // decide whether to executeQuery or executeUpdate based on sqlType @@ -236,13 +258,13 @@ public class SqlTask extends AbstractTask { stmt.executeUpdate(); } - postSql(connection,postStatementsBinds); + postSql(connection, postStatementsBinds); } catch (Exception e) { - logger.error("execute sql error",e); + logger.error("execute sql error", e); throw new RuntimeException("execute sql error"); } finally { - close(resultSet,stmt,connection); + close(resultSet, stmt, connection); } } @@ -252,7 +274,7 @@ public class SqlTask extends AbstractTask { * @param resultSet resultSet * @throws Exception Exception */ - private void resultProcess(ResultSet resultSet) throws Exception{ + private void resultProcess(ResultSet resultSet) throws Exception { ArrayNode resultJSONArray = JSONUtils.createArrayNode(); ResultSetMetaData md = resultSet.getMetaData(); int num = md.getColumnCount(); @@ -271,22 +293,22 @@ public class SqlTask extends AbstractTask { logger.debug("execute sql : {}", result); sendAttachment(StringUtils.isNotEmpty(sqlParameters.getTitle()) ? - sqlParameters.getTitle(): taskExecutionContext.getTaskName() + " query result sets", + sqlParameters.getTitle() : taskExecutionContext.getTaskName() + " query result sets", JSONUtils.toJsonString(resultJSONArray)); } /** - * pre sql + * pre sql * * @param connection connection * @param preStatementsBinds preStatementsBinds */ private void preSql(Connection connection, - List preStatementsBinds) throws Exception{ - for (SqlBinds sqlBind: preStatementsBinds) { - try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){ + List preStatementsBinds) throws Exception { + for (SqlBinds sqlBind : preStatementsBinds) { + try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) { int result = pstmt.executeUpdate(); - logger.info("pre statement execute result: {}, for sql: {}",result,sqlBind.getSql()); + logger.info("pre statement execute result: {}, for sql: {}", result, sqlBind.getSql()); } } @@ -297,26 +319,25 @@ public class SqlTask extends AbstractTask { * * @param connection connection * @param postStatementsBinds postStatementsBinds - * @throws Exception */ private void postSql(Connection connection, - List postStatementsBinds) throws Exception{ - for (SqlBinds sqlBind: postStatementsBinds) { - try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){ + List postStatementsBinds) throws Exception { + for (SqlBinds sqlBind : postStatementsBinds) { + try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) { int result = pstmt.executeUpdate(); - logger.info("post statement execute result: {},for sql: {}",result,sqlBind.getSql()); + logger.info("post statement execute result: {},for sql: {}", result, sqlBind.getSql()); } } } + /** * create temp function * * @param connection connection * @param createFuncs createFuncs - * @throws Exception */ private void createTempFunction(Connection connection, - List createFuncs) throws Exception{ + List createFuncs) throws Exception { try (Statement funcStmt = connection.createStatement()) { for (String createFunc : createFuncs) { logger.info("hive create function sql: {}", createFunc); @@ -324,14 +345,14 @@ public class SqlTask extends AbstractTask { } } } - + /** * create connection * * @return connection * @throws Exception Exception */ - private Connection createConnection() throws Exception{ + private Connection createConnection() throws Exception { // if hive , load connection params if exists Connection connection = null; if (HIVE == DbType.valueOf(sqlParameters.getType())) { @@ -345,7 +366,7 @@ public class SqlTask extends AbstractTask { connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(), paramProp); - }else{ + } else { connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(), baseDataSource.getUser(), baseDataSource.getPassword()); @@ -354,7 +375,7 @@ public class SqlTask extends AbstractTask { } /** - * close jdbc resource + * close jdbc resource * * @param resultSet resultSet * @param pstmt pstmt @@ -362,36 +383,37 @@ public class SqlTask extends AbstractTask { */ private void close(ResultSet resultSet, PreparedStatement pstmt, - Connection connection){ - if (resultSet != null){ + Connection connection) { + if (resultSet != null) { try { resultSet.close(); } catch (SQLException e) { - logger.error("close result set error : {}",e.getMessage(),e); + logger.error("close result set error : {}", e.getMessage(), e); } } - if (pstmt != null){ + if (pstmt != null) { try { pstmt.close(); } catch (SQLException e) { - logger.error("close prepared statement error : {}",e.getMessage(),e); + logger.error("close prepared statement error : {}", e.getMessage(), e); } } - if (connection != null){ + if (connection != null) { try { connection.close(); } catch (SQLException e) { - logger.error("close connection error : {}",e.getMessage(),e); + logger.error("close connection error : {}", e.getMessage(), e); } } } /** * preparedStatement bind + * * @param connection connection - * @param sqlBinds sqlBinds + * @param sqlBinds sqlBinds * @return PreparedStatement * @throws Exception Exception */ @@ -400,11 +422,11 @@ public class SqlTask extends AbstractTask { boolean timeoutFlag = TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.FAILED || TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.WARNFAILED; PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql()); - if(timeoutFlag){ + if (timeoutFlag) { stmt.setQueryTimeout(taskExecutionContext.getTaskTimeout()); } Map params = sqlBinds.getParamsMap(); - if(params != null) { + if (params != null) { for (Map.Entry entry : params.entrySet()) { Property prop = entry.getValue(); ParameterUtils.setInParameter(entry.getKey(), stmt, prop.getType(), prop.getValue()); @@ -416,23 +438,24 @@ public class SqlTask extends AbstractTask { /** * send mail as an attachment - * @param title title - * @param content content + * + * @param title title + * @param content content */ - public void sendAttachment(String title,String content){ + public void sendAttachment(String title, String content) { List users = alertDao.queryUserByAlertGroupId(taskExecutionContext.getSqlTaskExecutionContext().getWarningGroupId()); // receiving group list List receiversList = new ArrayList<>(); - for(User user:users){ + for (User user : users) { receiversList.add(user.getEmail().trim()); } // custom receiver String receivers = sqlParameters.getReceivers(); - if (StringUtils.isNotEmpty(receivers)){ + if (StringUtils.isNotEmpty(receivers)) { String[] splits = receivers.split(COMMA); - for (String receiver : splits){ + for (String receiver : splits) { receiversList.add(receiver.trim()); } } @@ -441,60 +464,62 @@ public class SqlTask extends AbstractTask { List receiversCcList = new ArrayList<>(); // Custom Copier String receiversCc = sqlParameters.getReceiversCc(); - if (StringUtils.isNotEmpty(receiversCc)){ + if (StringUtils.isNotEmpty(receiversCc)) { String[] splits = receiversCc.split(COMMA); - for (String receiverCc : splits){ + for (String receiverCc : splits) { receiversCcList.add(receiverCc.trim()); } } - String showTypeName = sqlParameters.getShowType().replace(COMMA,"").trim(); - if(EnumUtils.isValidEnum(ShowType.class,showTypeName)){ + String showTypeName = sqlParameters.getShowType().replace(COMMA, "").trim(); + if (EnumUtils.isValidEnum(ShowType.class, showTypeName)) { Map mailResult = MailUtils.sendMails(receiversList, receiversCcList, title, content, ShowType.valueOf(showTypeName).getDescp()); - if(!(boolean) mailResult.get(STATUS)){ + if (!(boolean) mailResult.get(STATUS)) { throw new RuntimeException("send mail failed!"); } - }else{ - logger.error("showType: {} is not valid " ,showTypeName); - throw new RuntimeException(String.format("showType: %s is not valid ",showTypeName)); + } else { + logger.error("showType: {} is not valid ", showTypeName); + throw new RuntimeException(String.format("showType: %s is not valid ", showTypeName)); } } /** * regular expressions match the contents between two specified strings - * @param content content - * @param rgex rgex - * @param sqlParamsMap sql params map - * @param paramsPropsMap params props map + * + * @param content content + * @param rgex rgex + * @param sqlParamsMap sql params map + * @param paramsPropsMap params props map */ - public void setSqlParamsMap(String content, String rgex, Map sqlParamsMap, Map paramsPropsMap){ + public void setSqlParamsMap(String content, String rgex, Map sqlParamsMap, Map paramsPropsMap) { Pattern pattern = Pattern.compile(rgex); Matcher m = pattern.matcher(content); int index = 1; while (m.find()) { String paramName = m.group(1); - Property prop = paramsPropsMap.get(paramName); + Property prop = paramsPropsMap.get(paramName); - sqlParamsMap.put(index,prop); - index ++; + sqlParamsMap.put(index, prop); + index++; } } /** * print replace sql - * @param content content - * @param formatSql format sql - * @param rgex rgex - * @param sqlParamsMap sql params map + * + * @param content content + * @param formatSql format sql + * @param rgex rgex + * @param sqlParamsMap sql params map */ - public void printReplacedSql(String content, String formatSql,String rgex, Map sqlParamsMap){ + public void printReplacedSql(String content, String formatSql, String rgex, Map sqlParamsMap) { //parameter print style - logger.info("after replace sql , preparing : {}" , formatSql); + logger.info("after replace sql , preparing : {}", formatSql); StringBuilder logPrint = new StringBuilder("replaced sql , parameters:"); - for(int i=1;i<=sqlParamsMap.size();i++){ - logPrint.append(sqlParamsMap.get(i).getValue()+"("+sqlParamsMap.get(i).getType()+")"); + for (int i = 1; i <= sqlParamsMap.size(); i++) { + logPrint.append(sqlParamsMap.get(i).getValue() + "(" + sqlParamsMap.get(i).getType() + ")"); } logger.info("Sql Params are {}", logPrint); }