@ -18,7 +18,9 @@ package org.apache.dolphinscheduler.server.worker.task.sql;
import com.fasterxml.jackson.databind.node.ArrayNode ;
import com.fasterxml.jackson.databind.node.ObjectNode ;
import org.apache.commons.lang.StringUtils ;
import org.apache.dolphinscheduler.alert.utils.MailUtils ;
import org.apache.dolphinscheduler.common.Constants ;
import org.apache.dolphinscheduler.common.enums.* ;
@ -41,6 +43,7 @@ import org.apache.dolphinscheduler.server.utils.ParamUtils;
import org.apache.dolphinscheduler.server.utils.UDFUtils ;
import org.apache.dolphinscheduler.server.worker.task.AbstractTask ;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext ;
import org.slf4j.Logger ;
import java.sql.* ;
@ -51,6 +54,7 @@ import java.util.stream.Collectors;
import static org.apache.dolphinscheduler.common.Constants.* ;
import static org.apache.dolphinscheduler.common.enums.DbType.HIVE ;
/ * *
* sql task
* /
@ -148,10 +152,11 @@ public class SqlTask extends AbstractTask {
/ * *
* ready to execute SQL and parameter entity Map
*
* @return SqlBinds
* /
private SqlBinds getSqlAndSqlParamsMap ( String sql ) {
Map < Integer , Property > sqlParamsMap = new HashMap < > ( ) ;
Map < Integer , Property > sqlParamsMap = new HashMap < > ( ) ;
StringBuilder sqlBuilder = new StringBuilder ( ) ;
// find process instance by task id
@ -164,15 +169,15 @@ public class SqlTask extends AbstractTask {
taskExecutionContext . getScheduleTime ( ) ) ;
// spell SQL according to the final user-defined variable
if ( paramsMap = = null ) {
if ( paramsMap = = null ) {
sqlBuilder . append ( sql ) ;
return new SqlBinds ( sqlBuilder . toString ( ) , sqlParamsMap ) ;
}
if ( StringUtils . isNotEmpty ( sqlParameters . getTitle ( ) ) ) {
if ( StringUtils . isNotEmpty ( sqlParameters . getTitle ( ) ) ) {
String title = ParameterUtils . convertParameterPlaceholders ( sqlParameters . getTitle ( ) ,
ParamUtils . convert ( paramsMap ) ) ;
logger . info ( "SQL title : {}" , title ) ;
logger . info ( "SQL title : {}" , title ) ;
sqlParameters . setTitle ( title ) ;
}
@ -182,7 +187,9 @@ public class SqlTask extends AbstractTask {
// special characters need to be escaped, ${} needs to be escaped
String rgex = "['\"]*\\$\\{(.*?)\\}['\"]*" ;
setSqlParamsMap ( sql , rgex , sqlParamsMap , paramsMap ) ;
//Replace the original value in sql !{...} ,Does not participate in precompilation
String rgexo = "['\"]*\\!\\{(.*?)\\}['\"]*" ;
sql = replaceOriginalValue ( sql , rgexo , paramsMap ) ;
// replace the ${} of the SQL statement with the Placeholder
String formatSql = sql . replaceAll ( rgex , "?" ) ;
sqlBuilder . append ( formatSql ) ;
@ -192,6 +199,20 @@ public class SqlTask extends AbstractTask {
return new SqlBinds ( sqlBuilder . toString ( ) , sqlParamsMap ) ;
}
public String replaceOriginalValue ( String content , String rgex , Map < String , Property > sqlParamsMap ) {
Pattern pattern = Pattern . compile ( rgex ) ;
while ( true ) {
Matcher m = pattern . matcher ( content ) ;
if ( ! m . find ( ) ) {
break ;
}
String paramName = m . group ( 1 ) ;
String paramValue = sqlParamsMap . get ( paramName ) . getValue ( ) ;
content = m . replaceFirst ( paramValue ) ;
}
return content ;
}
@Override
public AbstractParameters getParameters ( ) {
return this . sqlParameters ;
@ -199,6 +220,7 @@ public class SqlTask extends AbstractTask {
/ * *
* execute function and sql
*
* @param mainSqlBinds main sql binds
* @param preStatementsBinds pre statements binds
* @param postStatementsBinds post statements binds
@ -207,7 +229,7 @@ public class SqlTask extends AbstractTask {
public void executeFuncAndSql ( SqlBinds mainSqlBinds ,
List < SqlBinds > preStatementsBinds ,
List < SqlBinds > postStatementsBinds ,
List < String > createFuncs ) {
List < String > createFuncs ) {
Connection connection = null ;
PreparedStatement stmt = null ;
ResultSet resultSet = null ;
@ -218,11 +240,11 @@ public class SqlTask extends AbstractTask {
connection = createConnection ( ) ;
// create temp function
if ( CollectionUtils . isNotEmpty ( createFuncs ) ) {
createTempFunction ( connection , createFuncs ) ;
createTempFunction ( connection , createFuncs ) ;
}
// pre sql
preSql ( connection , preStatementsBinds ) ;
preSql ( connection , preStatementsBinds ) ;
stmt = prepareStatementAndBind ( connection , mainSqlBinds ) ;
// decide whether to executeQuery or executeUpdate based on sqlType
@ -236,13 +258,13 @@ public class SqlTask extends AbstractTask {
stmt . executeUpdate ( ) ;
}
postSql ( connection , postStatementsBinds ) ;
postSql ( connection , postStatementsBinds ) ;
} catch ( Exception e ) {
logger . error ( "execute sql error" , e ) ;
logger . error ( "execute sql error" , e ) ;
throw new RuntimeException ( "execute sql error" ) ;
} finally {
close ( resultSet , stmt , connection ) ;
close ( resultSet , stmt , connection ) ;
}
}
@ -252,7 +274,7 @@ public class SqlTask extends AbstractTask {
* @param resultSet resultSet
* @throws Exception Exception
* /
private void resultProcess ( ResultSet resultSet ) throws Exception {
private void resultProcess ( ResultSet resultSet ) throws Exception {
ArrayNode resultJSONArray = JSONUtils . createArrayNode ( ) ;
ResultSetMetaData md = resultSet . getMetaData ( ) ;
int num = md . getColumnCount ( ) ;
@ -271,7 +293,7 @@ public class SqlTask extends AbstractTask {
logger . debug ( "execute sql : {}" , result ) ;
sendAttachment ( StringUtils . isNotEmpty ( sqlParameters . getTitle ( ) ) ?
sqlParameters . getTitle ( ) : taskExecutionContext . getTaskName ( ) + " query result sets" ,
sqlParameters . getTitle ( ) : taskExecutionContext . getTaskName ( ) + " query result sets" ,
JSONUtils . toJsonString ( resultJSONArray ) ) ;
}
@ -282,11 +304,11 @@ public class SqlTask extends AbstractTask {
* @param preStatementsBinds preStatementsBinds
* /
private void preSql ( Connection connection ,
List < SqlBinds > preStatementsBinds ) throws Exception {
for ( SqlBinds sqlBind : preStatementsBinds ) {
try ( PreparedStatement pstmt = prepareStatementAndBind ( connection , sqlBind ) ) {
List < SqlBinds > preStatementsBinds ) throws Exception {
for ( SqlBinds sqlBind : preStatementsBinds ) {
try ( PreparedStatement pstmt = prepareStatementAndBind ( connection , sqlBind ) ) {
int result = pstmt . executeUpdate ( ) ;
logger . info ( "pre statement execute result: {}, for sql: {}" , result , sqlBind . getSql ( ) ) ;
logger . info ( "pre statement execute result: {}, for sql: {}" , result , sqlBind . getSql ( ) ) ;
}
}
@ -297,26 +319,25 @@ public class SqlTask extends AbstractTask {
*
* @param connection connection
* @param postStatementsBinds postStatementsBinds
* @throws Exception
* /
private void postSql ( Connection connection ,
List < SqlBinds > postStatementsBinds ) throws Exception {
for ( SqlBinds sqlBind : postStatementsBinds ) {
try ( PreparedStatement pstmt = prepareStatementAndBind ( connection , sqlBind ) ) {
List < SqlBinds > postStatementsBinds ) throws Exception {
for ( SqlBinds sqlBind : postStatementsBinds ) {
try ( PreparedStatement pstmt = prepareStatementAndBind ( connection , sqlBind ) ) {
int result = pstmt . executeUpdate ( ) ;
logger . info ( "post statement execute result: {},for sql: {}" , result , sqlBind . getSql ( ) ) ;
logger . info ( "post statement execute result: {},for sql: {}" , result , sqlBind . getSql ( ) ) ;
}
}
}
/ * *
* create temp function
*
* @param connection connection
* @param createFuncs createFuncs
* @throws Exception
* /
private void createTempFunction ( Connection connection ,
List < String > createFuncs ) throws Exception {
List < String > createFuncs ) throws Exception {
try ( Statement funcStmt = connection . createStatement ( ) ) {
for ( String createFunc : createFuncs ) {
logger . info ( "hive create function sql: {}" , createFunc ) ;
@ -331,7 +352,7 @@ public class SqlTask extends AbstractTask {
* @return connection
* @throws Exception Exception
* /
private Connection createConnection ( ) throws Exception {
private Connection createConnection ( ) throws Exception {
// if hive , load connection params if exists
Connection connection = null ;
if ( HIVE = = DbType . valueOf ( sqlParameters . getType ( ) ) ) {
@ -345,7 +366,7 @@ public class SqlTask extends AbstractTask {
connection = DriverManager . getConnection ( baseDataSource . getJdbcUrl ( ) ,
paramProp ) ;
} else {
} else {
connection = DriverManager . getConnection ( baseDataSource . getJdbcUrl ( ) ,
baseDataSource . getUser ( ) ,
baseDataSource . getPassword ( ) ) ;
@ -362,34 +383,35 @@ public class SqlTask extends AbstractTask {
* /
private void close ( ResultSet resultSet ,
PreparedStatement pstmt ,
Connection connection ) {
if ( resultSet ! = null ) {
Connection connection ) {
if ( resultSet ! = null ) {
try {
resultSet . close ( ) ;
} catch ( SQLException e ) {
logger . error ( "close result set error : {}" , e . getMessage ( ) , e ) ;
logger . error ( "close result set error : {}" , e . getMessage ( ) , e ) ;
}
}
if ( pstmt ! = null ) {
if ( pstmt ! = null ) {
try {
pstmt . close ( ) ;
} catch ( SQLException e ) {
logger . error ( "close prepared statement error : {}" , e . getMessage ( ) , e ) ;
logger . error ( "close prepared statement error : {}" , e . getMessage ( ) , e ) ;
}
}
if ( connection ! = null ) {
if ( connection ! = null ) {
try {
connection . close ( ) ;
} catch ( SQLException e ) {
logger . error ( "close connection error : {}" , e . getMessage ( ) , e ) ;
logger . error ( "close connection error : {}" , e . getMessage ( ) , e ) ;
}
}
}
/ * *
* preparedStatement bind
*
* @param connection connection
* @param sqlBinds sqlBinds
* @return PreparedStatement
@ -400,11 +422,11 @@ public class SqlTask extends AbstractTask {
boolean timeoutFlag = TaskTimeoutStrategy . of ( taskExecutionContext . getTaskTimeoutStrategy ( ) ) = = TaskTimeoutStrategy . FAILED | |
TaskTimeoutStrategy . of ( taskExecutionContext . getTaskTimeoutStrategy ( ) ) = = TaskTimeoutStrategy . WARNFAILED ;
PreparedStatement stmt = connection . prepareStatement ( sqlBinds . getSql ( ) ) ;
if ( timeoutFlag ) {
if ( timeoutFlag ) {
stmt . setQueryTimeout ( taskExecutionContext . getTaskTimeout ( ) ) ;
}
Map < Integer , Property > params = sqlBinds . getParamsMap ( ) ;
if ( params ! = null ) {
if ( params ! = null ) {
for ( Map . Entry < Integer , Property > entry : params . entrySet ( ) ) {
Property prop = entry . getValue ( ) ;
ParameterUtils . setInParameter ( entry . getKey ( ) , stmt , prop . getType ( ) , prop . getValue ( ) ) ;
@ -416,23 +438,24 @@ public class SqlTask extends AbstractTask {
/ * *
* send mail as an attachment
*
* @param title title
* @param content content
* /
public void sendAttachment ( String title , String content ) {
public void sendAttachment ( String title , String content ) {
List < User > users = alertDao . queryUserByAlertGroupId ( taskExecutionContext . getSqlTaskExecutionContext ( ) . getWarningGroupId ( ) ) ;
// receiving group list
List < String > receiversList = new ArrayList < > ( ) ;
for ( User user : users ) {
for ( User user : users ) {
receiversList . add ( user . getEmail ( ) . trim ( ) ) ;
}
// custom receiver
String receivers = sqlParameters . getReceivers ( ) ;
if ( StringUtils . isNotEmpty ( receivers ) ) {
if ( StringUtils . isNotEmpty ( receivers ) ) {
String [ ] splits = receivers . split ( COMMA ) ;
for ( String receiver : splits ) {
for ( String receiver : splits ) {
receiversList . add ( receiver . trim ( ) ) ;
}
}
@ -441,34 +464,35 @@ public class SqlTask extends AbstractTask {
List < String > receiversCcList = new ArrayList < > ( ) ;
// Custom Copier
String receiversCc = sqlParameters . getReceiversCc ( ) ;
if ( StringUtils . isNotEmpty ( receiversCc ) ) {
if ( StringUtils . isNotEmpty ( receiversCc ) ) {
String [ ] splits = receiversCc . split ( COMMA ) ;
for ( String receiverCc : splits ) {
for ( String receiverCc : splits ) {
receiversCcList . add ( receiverCc . trim ( ) ) ;
}
}
String showTypeName = sqlParameters . getShowType ( ) . replace ( COMMA , "" ) . trim ( ) ;
if ( EnumUtils . isValidEnum ( ShowType . class , showTypeName ) ) {
String showTypeName = sqlParameters . getShowType ( ) . replace ( COMMA , "" ) . trim ( ) ;
if ( EnumUtils . isValidEnum ( ShowType . class , showTypeName ) ) {
Map < String , Object > mailResult = MailUtils . sendMails ( receiversList ,
receiversCcList , title , content , ShowType . valueOf ( showTypeName ) . getDescp ( ) ) ;
if ( ! ( boolean ) mailResult . get ( STATUS ) ) {
if ( ! ( boolean ) mailResult . get ( STATUS ) ) {
throw new RuntimeException ( "send mail failed!" ) ;
}
} else {
logger . error ( "showType: {} is not valid " , showTypeName ) ;
throw new RuntimeException ( String . format ( "showType: %s is not valid " , showTypeName ) ) ;
} else {
logger . error ( "showType: {} is not valid " , showTypeName ) ;
throw new RuntimeException ( String . format ( "showType: %s is not valid " , showTypeName ) ) ;
}
}
/ * *
* regular expressions match the contents between two specified strings
*
* @param content content
* @param rgex rgex
* @param sqlParamsMap sql params map
* @param paramsPropsMap params props map
* /
public void setSqlParamsMap ( String content , String rgex , Map < Integer , Property > sqlParamsMap , Map < String , Property > paramsPropsMap ) {
public void setSqlParamsMap ( String content , String rgex , Map < Integer , Property > sqlParamsMap , Map < String , Property > paramsPropsMap ) {
Pattern pattern = Pattern . compile ( rgex ) ;
Matcher m = pattern . matcher ( content ) ;
int index = 1 ;
@ -477,24 +501,25 @@ public class SqlTask extends AbstractTask {
String paramName = m . group ( 1 ) ;
Property prop = paramsPropsMap . get ( paramName ) ;
sqlParamsMap . put ( index , prop ) ;
index + + ;
sqlParamsMap . put ( index , prop ) ;
index + + ;
}
}
/ * *
* print replace sql
*
* @param content content
* @param formatSql format sql
* @param rgex rgex
* @param sqlParamsMap sql params map
* /
public void printReplacedSql ( String content , String formatSql , String rgex , Map < Integer , Property > sqlParamsMap ) {
public void printReplacedSql ( String content , String formatSql , String rgex , Map < Integer , Property > sqlParamsMap ) {
//parameter print style
logger . info ( "after replace sql , preparing : {}" , formatSql ) ;
logger . info ( "after replace sql , preparing : {}" , formatSql ) ;
StringBuilder logPrint = new StringBuilder ( "replaced sql , parameters:" ) ;
for ( int i = 1 ; i < = sqlParamsMap . size ( ) ; i + + ) {
logPrint . append ( sqlParamsMap . get ( i ) . getValue ( ) + "(" + sqlParamsMap . get ( i ) . getType ( ) + ")" ) ;
for ( int i = 1 ; i < = sqlParamsMap . size ( ) ; i + + ) {
logPrint . append ( sqlParamsMap . get ( i ) . getValue ( ) + "(" + sqlParamsMap . get ( i ) . getType ( ) + ")" ) ;
}
logger . info ( "Sql Params are {}" , logPrint ) ;
}