Browse Source

[feat-4496][server] Add to! {} is used to mark the custom parameters to be output as-is in sql (#4497)

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496

* feat([server]): Add to! {} is used to mark the custom parameters to be output as-is in sql

Before pre-compiling sql, replace the custom parameters marked with !{}
to prevent the parameters in the hive plus partition path from being
replaced with single quotes

Closes This closes #4496
pull/3/MERGE
liuxuedongcn 4 years ago committed by GitHub
parent
commit
43586da376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 177
      dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java

177
dolphinscheduler-server/src/main/java/org/apache/dolphinscheduler/server/worker/task/sql/SqlTask.java

@ -18,7 +18,9 @@ package org.apache.dolphinscheduler.server.worker.task.sql;
import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
import org.apache.dolphinscheduler.alert.utils.MailUtils; import org.apache.dolphinscheduler.alert.utils.MailUtils;
import org.apache.dolphinscheduler.common.Constants; import org.apache.dolphinscheduler.common.Constants;
import org.apache.dolphinscheduler.common.enums.*; import org.apache.dolphinscheduler.common.enums.*;
@ -41,6 +43,7 @@ import org.apache.dolphinscheduler.server.utils.ParamUtils;
import org.apache.dolphinscheduler.server.utils.UDFUtils; import org.apache.dolphinscheduler.server.utils.UDFUtils;
import org.apache.dolphinscheduler.server.worker.task.AbstractTask; import org.apache.dolphinscheduler.server.worker.task.AbstractTask;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext; import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.slf4j.Logger; import org.slf4j.Logger;
import java.sql.*; import java.sql.*;
@ -51,17 +54,18 @@ import java.util.stream.Collectors;
import static org.apache.dolphinscheduler.common.Constants.*; import static org.apache.dolphinscheduler.common.Constants.*;
import static org.apache.dolphinscheduler.common.enums.DbType.HIVE; import static org.apache.dolphinscheduler.common.enums.DbType.HIVE;
/** /**
* sql task * sql task
*/ */
public class SqlTask extends AbstractTask { public class SqlTask extends AbstractTask {
/** /**
* sql parameters * sql parameters
*/ */
private SqlParameters sqlParameters; private SqlParameters sqlParameters;
/** /**
* alert dao * alert dao
*/ */
private AlertDao alertDao; private AlertDao alertDao;
/** /**
@ -148,10 +152,11 @@ public class SqlTask extends AbstractTask {
/** /**
* ready to execute SQL and parameter entity Map * ready to execute SQL and parameter entity Map
*
* @return SqlBinds * @return SqlBinds
*/ */
private SqlBinds getSqlAndSqlParamsMap(String sql) { private SqlBinds getSqlAndSqlParamsMap(String sql) {
Map<Integer,Property> sqlParamsMap = new HashMap<>(); Map<Integer, Property> sqlParamsMap = new HashMap<>();
StringBuilder sqlBuilder = new StringBuilder(); StringBuilder sqlBuilder = new StringBuilder();
// find process instance by task id // find process instance by task id
@ -164,15 +169,15 @@ public class SqlTask extends AbstractTask {
taskExecutionContext.getScheduleTime()); taskExecutionContext.getScheduleTime());
// spell SQL according to the final user-defined variable // spell SQL according to the final user-defined variable
if(paramsMap == null){ if (paramsMap == null) {
sqlBuilder.append(sql); sqlBuilder.append(sql);
return new SqlBinds(sqlBuilder.toString(), sqlParamsMap); return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
} }
if (StringUtils.isNotEmpty(sqlParameters.getTitle())){ if (StringUtils.isNotEmpty(sqlParameters.getTitle())) {
String title = ParameterUtils.convertParameterPlaceholders(sqlParameters.getTitle(), String title = ParameterUtils.convertParameterPlaceholders(sqlParameters.getTitle(),
ParamUtils.convert(paramsMap)); ParamUtils.convert(paramsMap));
logger.info("SQL title : {}",title); logger.info("SQL title : {}", title);
sqlParameters.setTitle(title); sqlParameters.setTitle(title);
} }
@ -182,7 +187,9 @@ public class SqlTask extends AbstractTask {
// special characters need to be escaped, ${} needs to be escaped // special characters need to be escaped, ${} needs to be escaped
String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*"; String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*";
setSqlParamsMap(sql, rgex, sqlParamsMap, paramsMap); setSqlParamsMap(sql, rgex, sqlParamsMap, paramsMap);
//Replace the original value in sql !{...} ,Does not participate in precompilation
String rgexo = "['\"]*\\!\\{(.*?)\\}['\"]*";
sql = replaceOriginalValue(sql, rgexo, paramsMap);
// replace the ${} of the SQL statement with the Placeholder // replace the ${} of the SQL statement with the Placeholder
String formatSql = sql.replaceAll(rgex, "?"); String formatSql = sql.replaceAll(rgex, "?");
sqlBuilder.append(formatSql); sqlBuilder.append(formatSql);
@ -192,6 +199,20 @@ public class SqlTask extends AbstractTask {
return new SqlBinds(sqlBuilder.toString(), sqlParamsMap); return new SqlBinds(sqlBuilder.toString(), sqlParamsMap);
} }
public String replaceOriginalValue(String content, String rgex, Map<String, Property> sqlParamsMap) {
Pattern pattern = Pattern.compile(rgex);
while (true) {
Matcher m = pattern.matcher(content);
if (!m.find()) {
break;
}
String paramName = m.group(1);
String paramValue = sqlParamsMap.get(paramName).getValue();
content = m.replaceFirst(paramValue);
}
return content;
}
@Override @Override
public AbstractParameters getParameters() { public AbstractParameters getParameters() {
return this.sqlParameters; return this.sqlParameters;
@ -199,15 +220,16 @@ public class SqlTask extends AbstractTask {
/** /**
* execute function and sql * execute function and sql
* @param mainSqlBinds main sql binds *
* @param preStatementsBinds pre statements binds * @param mainSqlBinds main sql binds
* @param postStatementsBinds post statements binds * @param preStatementsBinds pre statements binds
* @param createFuncs create functions * @param postStatementsBinds post statements binds
* @param createFuncs create functions
*/ */
public void executeFuncAndSql(SqlBinds mainSqlBinds, public void executeFuncAndSql(SqlBinds mainSqlBinds,
List<SqlBinds> preStatementsBinds, List<SqlBinds> preStatementsBinds,
List<SqlBinds> postStatementsBinds, List<SqlBinds> postStatementsBinds,
List<String> createFuncs){ List<String> createFuncs) {
Connection connection = null; Connection connection = null;
PreparedStatement stmt = null; PreparedStatement stmt = null;
ResultSet resultSet = null; ResultSet resultSet = null;
@ -218,11 +240,11 @@ public class SqlTask extends AbstractTask {
connection = createConnection(); connection = createConnection();
// create temp function // create temp function
if (CollectionUtils.isNotEmpty(createFuncs)) { if (CollectionUtils.isNotEmpty(createFuncs)) {
createTempFunction(connection,createFuncs); createTempFunction(connection, createFuncs);
} }
// pre sql // pre sql
preSql(connection,preStatementsBinds); preSql(connection, preStatementsBinds);
stmt = prepareStatementAndBind(connection, mainSqlBinds); stmt = prepareStatementAndBind(connection, mainSqlBinds);
// decide whether to executeQuery or executeUpdate based on sqlType // decide whether to executeQuery or executeUpdate based on sqlType
@ -236,13 +258,13 @@ public class SqlTask extends AbstractTask {
stmt.executeUpdate(); stmt.executeUpdate();
} }
postSql(connection,postStatementsBinds); postSql(connection, postStatementsBinds);
} catch (Exception e) { } catch (Exception e) {
logger.error("execute sql error",e); logger.error("execute sql error", e);
throw new RuntimeException("execute sql error"); throw new RuntimeException("execute sql error");
} finally { } finally {
close(resultSet,stmt,connection); close(resultSet, stmt, connection);
} }
} }
@ -252,7 +274,7 @@ public class SqlTask extends AbstractTask {
* @param resultSet resultSet * @param resultSet resultSet
* @throws Exception Exception * @throws Exception Exception
*/ */
private void resultProcess(ResultSet resultSet) throws Exception{ private void resultProcess(ResultSet resultSet) throws Exception {
ArrayNode resultJSONArray = JSONUtils.createArrayNode(); ArrayNode resultJSONArray = JSONUtils.createArrayNode();
ResultSetMetaData md = resultSet.getMetaData(); ResultSetMetaData md = resultSet.getMetaData();
int num = md.getColumnCount(); int num = md.getColumnCount();
@ -271,22 +293,22 @@ public class SqlTask extends AbstractTask {
logger.debug("execute sql : {}", result); logger.debug("execute sql : {}", result);
sendAttachment(StringUtils.isNotEmpty(sqlParameters.getTitle()) ? sendAttachment(StringUtils.isNotEmpty(sqlParameters.getTitle()) ?
sqlParameters.getTitle(): taskExecutionContext.getTaskName() + " query result sets", sqlParameters.getTitle() : taskExecutionContext.getTaskName() + " query result sets",
JSONUtils.toJsonString(resultJSONArray)); JSONUtils.toJsonString(resultJSONArray));
} }
/** /**
* pre sql * pre sql
* *
* @param connection connection * @param connection connection
* @param preStatementsBinds preStatementsBinds * @param preStatementsBinds preStatementsBinds
*/ */
private void preSql(Connection connection, private void preSql(Connection connection,
List<SqlBinds> preStatementsBinds) throws Exception{ List<SqlBinds> preStatementsBinds) throws Exception {
for (SqlBinds sqlBind: preStatementsBinds) { for (SqlBinds sqlBind : preStatementsBinds) {
try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){ try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) {
int result = pstmt.executeUpdate(); int result = pstmt.executeUpdate();
logger.info("pre statement execute result: {}, for sql: {}",result,sqlBind.getSql()); logger.info("pre statement execute result: {}, for sql: {}", result, sqlBind.getSql());
} }
} }
@ -297,26 +319,25 @@ public class SqlTask extends AbstractTask {
* *
* @param connection connection * @param connection connection
* @param postStatementsBinds postStatementsBinds * @param postStatementsBinds postStatementsBinds
* @throws Exception
*/ */
private void postSql(Connection connection, private void postSql(Connection connection,
List<SqlBinds> postStatementsBinds) throws Exception{ List<SqlBinds> postStatementsBinds) throws Exception {
for (SqlBinds sqlBind: postStatementsBinds) { for (SqlBinds sqlBind : postStatementsBinds) {
try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)){ try (PreparedStatement pstmt = prepareStatementAndBind(connection, sqlBind)) {
int result = pstmt.executeUpdate(); int result = pstmt.executeUpdate();
logger.info("post statement execute result: {},for sql: {}",result,sqlBind.getSql()); logger.info("post statement execute result: {},for sql: {}", result, sqlBind.getSql());
} }
} }
} }
/** /**
* create temp function * create temp function
* *
* @param connection connection * @param connection connection
* @param createFuncs createFuncs * @param createFuncs createFuncs
* @throws Exception
*/ */
private void createTempFunction(Connection connection, private void createTempFunction(Connection connection,
List<String> createFuncs) throws Exception{ List<String> createFuncs) throws Exception {
try (Statement funcStmt = connection.createStatement()) { try (Statement funcStmt = connection.createStatement()) {
for (String createFunc : createFuncs) { for (String createFunc : createFuncs) {
logger.info("hive create function sql: {}", createFunc); logger.info("hive create function sql: {}", createFunc);
@ -331,7 +352,7 @@ public class SqlTask extends AbstractTask {
* @return connection * @return connection
* @throws Exception Exception * @throws Exception Exception
*/ */
private Connection createConnection() throws Exception{ private Connection createConnection() throws Exception {
// if hive , load connection params if exists // if hive , load connection params if exists
Connection connection = null; Connection connection = null;
if (HIVE == DbType.valueOf(sqlParameters.getType())) { if (HIVE == DbType.valueOf(sqlParameters.getType())) {
@ -345,7 +366,7 @@ public class SqlTask extends AbstractTask {
connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(), connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(),
paramProp); paramProp);
}else{ } else {
connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(), connection = DriverManager.getConnection(baseDataSource.getJdbcUrl(),
baseDataSource.getUser(), baseDataSource.getUser(),
baseDataSource.getPassword()); baseDataSource.getPassword());
@ -354,7 +375,7 @@ public class SqlTask extends AbstractTask {
} }
/** /**
* close jdbc resource * close jdbc resource
* *
* @param resultSet resultSet * @param resultSet resultSet
* @param pstmt pstmt * @param pstmt pstmt
@ -362,36 +383,37 @@ public class SqlTask extends AbstractTask {
*/ */
private void close(ResultSet resultSet, private void close(ResultSet resultSet,
PreparedStatement pstmt, PreparedStatement pstmt,
Connection connection){ Connection connection) {
if (resultSet != null){ if (resultSet != null) {
try { try {
resultSet.close(); resultSet.close();
} catch (SQLException e) { } catch (SQLException e) {
logger.error("close result set error : {}",e.getMessage(),e); logger.error("close result set error : {}", e.getMessage(), e);
} }
} }
if (pstmt != null){ if (pstmt != null) {
try { try {
pstmt.close(); pstmt.close();
} catch (SQLException e) { } catch (SQLException e) {
logger.error("close prepared statement error : {}",e.getMessage(),e); logger.error("close prepared statement error : {}", e.getMessage(), e);
} }
} }
if (connection != null){ if (connection != null) {
try { try {
connection.close(); connection.close();
} catch (SQLException e) { } catch (SQLException e) {
logger.error("close connection error : {}",e.getMessage(),e); logger.error("close connection error : {}", e.getMessage(), e);
} }
} }
} }
/** /**
* preparedStatement bind * preparedStatement bind
*
* @param connection connection * @param connection connection
* @param sqlBinds sqlBinds * @param sqlBinds sqlBinds
* @return PreparedStatement * @return PreparedStatement
* @throws Exception Exception * @throws Exception Exception
*/ */
@ -400,11 +422,11 @@ public class SqlTask extends AbstractTask {
boolean timeoutFlag = TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.FAILED || boolean timeoutFlag = TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.FAILED ||
TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.WARNFAILED; TaskTimeoutStrategy.of(taskExecutionContext.getTaskTimeoutStrategy()) == TaskTimeoutStrategy.WARNFAILED;
PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql()); PreparedStatement stmt = connection.prepareStatement(sqlBinds.getSql());
if(timeoutFlag){ if (timeoutFlag) {
stmt.setQueryTimeout(taskExecutionContext.getTaskTimeout()); stmt.setQueryTimeout(taskExecutionContext.getTaskTimeout());
} }
Map<Integer, Property> params = sqlBinds.getParamsMap(); Map<Integer, Property> params = sqlBinds.getParamsMap();
if(params != null) { if (params != null) {
for (Map.Entry<Integer, Property> entry : params.entrySet()) { for (Map.Entry<Integer, Property> entry : params.entrySet()) {
Property prop = entry.getValue(); Property prop = entry.getValue();
ParameterUtils.setInParameter(entry.getKey(), stmt, prop.getType(), prop.getValue()); ParameterUtils.setInParameter(entry.getKey(), stmt, prop.getType(), prop.getValue());
@ -416,23 +438,24 @@ public class SqlTask extends AbstractTask {
/** /**
* send mail as an attachment * send mail as an attachment
* @param title title *
* @param content content * @param title title
* @param content content
*/ */
public void sendAttachment(String title,String content){ public void sendAttachment(String title, String content) {
List<User> users = alertDao.queryUserByAlertGroupId(taskExecutionContext.getSqlTaskExecutionContext().getWarningGroupId()); List<User> users = alertDao.queryUserByAlertGroupId(taskExecutionContext.getSqlTaskExecutionContext().getWarningGroupId());
// receiving group list // receiving group list
List<String> receiversList = new ArrayList<>(); List<String> receiversList = new ArrayList<>();
for(User user:users){ for (User user : users) {
receiversList.add(user.getEmail().trim()); receiversList.add(user.getEmail().trim());
} }
// custom receiver // custom receiver
String receivers = sqlParameters.getReceivers(); String receivers = sqlParameters.getReceivers();
if (StringUtils.isNotEmpty(receivers)){ if (StringUtils.isNotEmpty(receivers)) {
String[] splits = receivers.split(COMMA); String[] splits = receivers.split(COMMA);
for (String receiver : splits){ for (String receiver : splits) {
receiversList.add(receiver.trim()); receiversList.add(receiver.trim());
} }
} }
@ -441,60 +464,62 @@ public class SqlTask extends AbstractTask {
List<String> receiversCcList = new ArrayList<>(); List<String> receiversCcList = new ArrayList<>();
// Custom Copier // Custom Copier
String receiversCc = sqlParameters.getReceiversCc(); String receiversCc = sqlParameters.getReceiversCc();
if (StringUtils.isNotEmpty(receiversCc)){ if (StringUtils.isNotEmpty(receiversCc)) {
String[] splits = receiversCc.split(COMMA); String[] splits = receiversCc.split(COMMA);
for (String receiverCc : splits){ for (String receiverCc : splits) {
receiversCcList.add(receiverCc.trim()); receiversCcList.add(receiverCc.trim());
} }
} }
String showTypeName = sqlParameters.getShowType().replace(COMMA,"").trim(); String showTypeName = sqlParameters.getShowType().replace(COMMA, "").trim();
if(EnumUtils.isValidEnum(ShowType.class,showTypeName)){ if (EnumUtils.isValidEnum(ShowType.class, showTypeName)) {
Map<String, Object> mailResult = MailUtils.sendMails(receiversList, Map<String, Object> mailResult = MailUtils.sendMails(receiversList,
receiversCcList, title, content, ShowType.valueOf(showTypeName).getDescp()); receiversCcList, title, content, ShowType.valueOf(showTypeName).getDescp());
if(!(boolean) mailResult.get(STATUS)){ if (!(boolean) mailResult.get(STATUS)) {
throw new RuntimeException("send mail failed!"); throw new RuntimeException("send mail failed!");
} }
}else{ } else {
logger.error("showType: {} is not valid " ,showTypeName); logger.error("showType: {} is not valid ", showTypeName);
throw new RuntimeException(String.format("showType: %s is not valid ",showTypeName)); throw new RuntimeException(String.format("showType: %s is not valid ", showTypeName));
} }
} }
/** /**
* regular expressions match the contents between two specified strings * regular expressions match the contents between two specified strings
* @param content content *
* @param rgex rgex * @param content content
* @param sqlParamsMap sql params map * @param rgex rgex
* @param paramsPropsMap params props map * @param sqlParamsMap sql params map
* @param paramsPropsMap params props map
*/ */
public void setSqlParamsMap(String content, String rgex, Map<Integer,Property> sqlParamsMap, Map<String,Property> paramsPropsMap){ public void setSqlParamsMap(String content, String rgex, Map<Integer, Property> sqlParamsMap, Map<String, Property> paramsPropsMap) {
Pattern pattern = Pattern.compile(rgex); Pattern pattern = Pattern.compile(rgex);
Matcher m = pattern.matcher(content); Matcher m = pattern.matcher(content);
int index = 1; int index = 1;
while (m.find()) { while (m.find()) {
String paramName = m.group(1); String paramName = m.group(1);
Property prop = paramsPropsMap.get(paramName); Property prop = paramsPropsMap.get(paramName);
sqlParamsMap.put(index,prop); sqlParamsMap.put(index, prop);
index ++; index++;
} }
} }
/** /**
* print replace sql * print replace sql
* @param content content *
* @param formatSql format sql * @param content content
* @param rgex rgex * @param formatSql format sql
* @param sqlParamsMap sql params map * @param rgex rgex
* @param sqlParamsMap sql params map
*/ */
public void printReplacedSql(String content, String formatSql,String rgex, Map<Integer,Property> sqlParamsMap){ public void printReplacedSql(String content, String formatSql, String rgex, Map<Integer, Property> sqlParamsMap) {
//parameter print style //parameter print style
logger.info("after replace sql , preparing : {}" , formatSql); logger.info("after replace sql , preparing : {}", formatSql);
StringBuilder logPrint = new StringBuilder("replaced sql , parameters:"); StringBuilder logPrint = new StringBuilder("replaced sql , parameters:");
for(int i=1;i<=sqlParamsMap.size();i++){ for (int i = 1; i <= sqlParamsMap.size(); i++) {
logPrint.append(sqlParamsMap.get(i).getValue()+"("+sqlParamsMap.get(i).getType()+")"); logPrint.append(sqlParamsMap.get(i).getValue() + "(" + sqlParamsMap.get(i).getType() + ")");
} }
logger.info("Sql Params are {}", logPrint); logger.info("Sql Params are {}", logPrint);
} }

Loading…
Cancel
Save