@ -18,6 +18,7 @@
package org.apache.dolphinscheduler.api.service.impl ;
import static org.apache.dolphinscheduler.common.Constants.CMD_PARAM_SUB_PROCESS_DEFINE_CODE ;
import static org.apache.dolphinscheduler.common.Constants.DEFAULT_WORKER_GROUP ;
import org.apache.dolphinscheduler.api.dto.DagDataSchedule ;
import org.apache.dolphinscheduler.api.dto.ScheduleParam ;
@ -34,22 +35,29 @@ import org.apache.dolphinscheduler.api.utils.FileUtils;
import org.apache.dolphinscheduler.api.utils.PageInfo ;
import org.apache.dolphinscheduler.api.utils.Result ;
import org.apache.dolphinscheduler.common.Constants ;
import org.apache.dolphinscheduler.common.enums.ConditionType ;
import org.apache.dolphinscheduler.common.enums.FailureStrategy ;
import org.apache.dolphinscheduler.common.enums.Flag ;
import org.apache.dolphinscheduler.common.enums.Priority ;
import org.apache.dolphinscheduler.common.enums.ProcessExecutionTypeEnum ;
import org.apache.dolphinscheduler.common.enums.ReleaseState ;
import org.apache.dolphinscheduler.common.enums.TaskTimeoutStrategy ;
import org.apache.dolphinscheduler.common.enums.TaskType ;
import org.apache.dolphinscheduler.common.enums.TimeoutFlag ;
import org.apache.dolphinscheduler.common.enums.UserType ;
import org.apache.dolphinscheduler.common.enums.WarningType ;
import org.apache.dolphinscheduler.common.graph.DAG ;
import org.apache.dolphinscheduler.common.model.TaskNode ;
import org.apache.dolphinscheduler.common.model.TaskNodeRelation ;
import org.apache.dolphinscheduler.common.task.sql.SqlParameters ;
import org.apache.dolphinscheduler.common.task.sql.SqlType ;
import org.apache.dolphinscheduler.common.thread.Stopper ;
import org.apache.dolphinscheduler.common.utils.CodeGenerateUtils ;
import org.apache.dolphinscheduler.common.utils.CodeGenerateUtils.CodeGenerateException ;
import org.apache.dolphinscheduler.common.utils.DateUtils ;
import org.apache.dolphinscheduler.common.utils.JSONUtils ;
import org.apache.dolphinscheduler.dao.entity.DagData ;
import org.apache.dolphinscheduler.dao.entity.DataSource ;
import org.apache.dolphinscheduler.dao.entity.ProcessDefinition ;
import org.apache.dolphinscheduler.dao.entity.ProcessDefinitionLog ;
import org.apache.dolphinscheduler.dao.entity.ProcessInstance ;
@ -62,6 +70,7 @@ import org.apache.dolphinscheduler.dao.entity.TaskDefinitionLog;
import org.apache.dolphinscheduler.dao.entity.TaskInstance ;
import org.apache.dolphinscheduler.dao.entity.Tenant ;
import org.apache.dolphinscheduler.dao.entity.User ;
import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper ;
import org.apache.dolphinscheduler.dao.mapper.ProcessDefinitionLogMapper ;
import org.apache.dolphinscheduler.dao.mapper.ProcessDefinitionMapper ;
import org.apache.dolphinscheduler.dao.mapper.ProcessTaskRelationLogMapper ;
@ -79,11 +88,14 @@ import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils ;
import java.io.BufferedOutputStream ;
import java.io.BufferedReader ;
import java.io.IOException ;
import java.io.InputStreamReader ;
import java.nio.charset.StandardCharsets ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.Date ;
import java.util.HashMap ;
import java.util.Iterator ;
@ -93,6 +105,8 @@ import java.util.Objects;
import java.util.Set ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.stream.Collectors ;
import java.util.zip.ZipEntry ;
import java.util.zip.ZipInputStream ;
import javax.servlet.ServletOutputStream ;
import javax.servlet.http.HttpServletResponse ;
@ -167,6 +181,9 @@ public class ProcessDefinitionServiceImpl extends BaseServiceImpl implements Pro
@Autowired
private TenantMapper tenantMapper ;
@Autowired
private DataSourceMapper dataSourceMapper ;
/ * *
* create process definition
*
@ -867,6 +884,200 @@ public class ProcessDefinitionServiceImpl extends BaseServiceImpl implements Pro
return result ;
}
@Override
@Transactional ( rollbackFor = RuntimeException . class )
public Map < String , Object > importSqlProcessDefinition ( User loginUser , long projectCode , MultipartFile file ) {
Map < String , Object > result = new HashMap < > ( ) ;
String processDefinitionName = file . getOriginalFilename ( ) = = null ? file . getName ( ) : file . getOriginalFilename ( ) ;
int index = processDefinitionName . lastIndexOf ( "." ) ;
if ( index > 0 ) {
processDefinitionName = processDefinitionName . substring ( 0 , index ) ;
}
processDefinitionName = processDefinitionName + "_import_" + DateUtils . getCurrentTimeStamp ( ) ;
ProcessDefinition processDefinition ;
List < TaskDefinitionLog > taskDefinitionList = new ArrayList < > ( ) ;
List < ProcessTaskRelationLog > processTaskRelationList = new ArrayList < > ( ) ;
// for Zip Bomb Attack
int THRESHOLD_ENTRIES = 10000 ;
int THRESHOLD_SIZE = 1000000000 ; // 1 GB
double THRESHOLD_RATIO = 10 ;
int totalEntryArchive = 0 ;
int totalSizeEntry = 0 ;
// In most cases, there will be only one data source
Map < String , DataSource > dataSourceCache = new HashMap < > ( 1 ) ;
Map < String , Long > taskNameToCode = new HashMap < > ( 16 ) ;
Map < String , List < String > > taskNameToUpstream = new HashMap < > ( 16 ) ;
try ( ZipInputStream zIn = new ZipInputStream ( file . getInputStream ( ) ) ;
BufferedReader bufferedReader = new BufferedReader ( new InputStreamReader ( zIn ) ) ) {
// build process definition
processDefinition = new ProcessDefinition ( projectCode ,
processDefinitionName ,
CodeGenerateUtils . getInstance ( ) . genCode ( ) ,
"" ,
"[]" , null ,
0 , loginUser . getId ( ) , loginUser . getTenantId ( ) ) ;
ZipEntry entry ;
while ( ( entry = zIn . getNextEntry ( ) ) ! = null ) {
totalEntryArchive + + ;
int totalSizeArchive = 0 ;
if ( ! entry . isDirectory ( ) ) {
StringBuilder sql = new StringBuilder ( ) ;
String taskName = null ;
String datasourceName = null ;
List < String > upstreams = Collections . emptyList ( ) ;
String line ;
while ( ( line = bufferedReader . readLine ( ) ) ! = null ) {
int nBytes = line . getBytes ( StandardCharsets . UTF_8 ) . length ;
totalSizeEntry + = nBytes ;
totalSizeArchive + = nBytes ;
long compressionRatio = totalSizeEntry / entry . getCompressedSize ( ) ;
if ( compressionRatio > THRESHOLD_RATIO ) {
throw new IllegalStateException ( "ratio between compressed and uncompressed data is highly suspicious, looks like a Zip Bomb Attack" ) ;
}
int commentIndex = line . indexOf ( "-- " ) ;
if ( commentIndex > = 0 ) {
int colonIndex = line . indexOf ( ":" , commentIndex ) ;
if ( colonIndex > 0 ) {
String key = line . substring ( commentIndex + 3 , colonIndex ) . trim ( ) . toLowerCase ( ) ;
String value = line . substring ( colonIndex + 1 ) . trim ( ) ;
switch ( key ) {
case "name" :
taskName = value ;
line = line . substring ( 0 , commentIndex ) ;
break ;
case "upstream" :
upstreams = Arrays . stream ( value . split ( "," ) ) . map ( String : : trim )
. filter ( s - > ! "" . equals ( s ) ) . collect ( Collectors . toList ( ) ) ;
line = line . substring ( 0 , commentIndex ) ;
break ;
case "datasource" :
datasourceName = value ;
line = line . substring ( 0 , commentIndex ) ;
break ;
default :
break ;
}
}
}
if ( ! "" . equals ( line ) ) {
sql . append ( line ) . append ( "\n" ) ;
}
}
// import/sql1.sql -> sql1
if ( taskName = = null ) {
taskName = entry . getName ( ) ;
index = taskName . indexOf ( "/" ) ;
if ( index > 0 ) {
taskName = taskName . substring ( index + 1 ) ;
}
index = taskName . lastIndexOf ( "." ) ;
if ( index > 0 ) {
taskName = taskName . substring ( 0 , index ) ;
}
}
DataSource dataSource = dataSourceCache . get ( datasourceName ) ;
if ( dataSource = = null ) {
dataSource = queryDatasourceByNameAndUser ( datasourceName , loginUser ) ;
}
if ( dataSource = = null ) {
putMsg ( result , Status . DATASOURCE_NAME_ILLEGAL ) ;
return result ;
}
dataSourceCache . put ( datasourceName , dataSource ) ;
TaskDefinitionLog taskDefinition = buildNormalSqlTaskDefinition ( taskName , dataSource , sql . substring ( 0 , sql . length ( ) - 1 ) ) ;
taskDefinitionList . add ( taskDefinition ) ;
taskNameToCode . put ( taskDefinition . getName ( ) , taskDefinition . getCode ( ) ) ;
taskNameToUpstream . put ( taskDefinition . getName ( ) , upstreams ) ;
}
if ( totalSizeArchive > THRESHOLD_SIZE ) {
throw new IllegalStateException ( "the uncompressed data size is too much for the application resource capacity" ) ;
}
if ( totalEntryArchive > THRESHOLD_ENTRIES ) {
throw new IllegalStateException ( "too much entries in this archive, can lead to inodes exhaustion of the system" ) ;
}
}
} catch ( Exception e ) {
logger . error ( e . getMessage ( ) , e ) ;
putMsg ( result , Status . IMPORT_PROCESS_DEFINE_ERROR ) ;
return result ;
}
// build task relation
for ( Map . Entry < String , Long > entry : taskNameToCode . entrySet ( ) ) {
List < String > upstreams = taskNameToUpstream . get ( entry . getKey ( ) ) ;
if ( CollectionUtils . isEmpty ( upstreams )
| | ( upstreams . size ( ) = = 1 & & upstreams . contains ( "root" ) & & ! taskNameToCode . containsKey ( "root" ) ) ) {
ProcessTaskRelationLog processTaskRelation = buildNormalTaskRelation ( 0 , entry . getValue ( ) ) ;
processTaskRelationList . add ( processTaskRelation ) ;
continue ;
}
for ( String upstream : upstreams ) {
ProcessTaskRelationLog processTaskRelation = buildNormalTaskRelation ( taskNameToCode . get ( upstream ) , entry . getValue ( ) ) ;
processTaskRelationList . add ( processTaskRelation ) ;
}
}
return createDagDefine ( loginUser , processTaskRelationList , processDefinition , taskDefinitionList ) ;
}
private ProcessTaskRelationLog buildNormalTaskRelation ( long preTaskCode , long postTaskCode ) {
ProcessTaskRelationLog processTaskRelation = new ProcessTaskRelationLog ( ) ;
processTaskRelation . setPreTaskCode ( preTaskCode ) ;
processTaskRelation . setPreTaskVersion ( 0 ) ;
processTaskRelation . setPostTaskCode ( postTaskCode ) ;
processTaskRelation . setPostTaskVersion ( 0 ) ;
processTaskRelation . setConditionType ( ConditionType . NONE ) ;
processTaskRelation . setName ( "" ) ;
return processTaskRelation ;
}
private DataSource queryDatasourceByNameAndUser ( String datasourceName , User loginUser ) {
if ( isAdmin ( loginUser ) ) {
List < DataSource > dataSources = dataSourceMapper . queryDataSourceByName ( datasourceName ) ;
if ( CollectionUtils . isNotEmpty ( dataSources ) ) {
return dataSources . get ( 0 ) ;
}
} else {
return dataSourceMapper . queryDataSourceByNameAndUserId ( loginUser . getId ( ) , datasourceName ) ;
}
return null ;
}
private TaskDefinitionLog buildNormalSqlTaskDefinition ( String taskName , DataSource dataSource , String sql ) throws CodeGenerateException {
TaskDefinitionLog taskDefinition = new TaskDefinitionLog ( ) ;
taskDefinition . setName ( taskName ) ;
taskDefinition . setFlag ( Flag . YES ) ;
SqlParameters sqlParameters = new SqlParameters ( ) ;
sqlParameters . setType ( dataSource . getType ( ) . name ( ) ) ;
sqlParameters . setDatasource ( dataSource . getId ( ) ) ;
sqlParameters . setSql ( sql . substring ( 0 , sql . length ( ) - 1 ) ) ;
// it may be a query type, but it can only be determined by parsing SQL
sqlParameters . setSqlType ( SqlType . NON_QUERY . ordinal ( ) ) ;
sqlParameters . setLocalParams ( Collections . emptyList ( ) ) ;
taskDefinition . setTaskParams ( JSONUtils . toJsonString ( sqlParameters ) ) ;
taskDefinition . setCode ( CodeGenerateUtils . getInstance ( ) . genCode ( ) ) ;
taskDefinition . setTaskType ( TaskType . SQL . getDesc ( ) ) ;
taskDefinition . setFailRetryTimes ( 0 ) ;
taskDefinition . setFailRetryInterval ( 0 ) ;
taskDefinition . setTimeoutFlag ( TimeoutFlag . CLOSE ) ;
taskDefinition . setWorkerGroup ( DEFAULT_WORKER_GROUP ) ;
taskDefinition . setTaskPriority ( Priority . MEDIUM ) ;
taskDefinition . setEnvironmentCode ( - 1 ) ;
taskDefinition . setTimeout ( 0 ) ;
taskDefinition . setDelayTime ( 0 ) ;
taskDefinition . setTimeoutNotifyStrategy ( TaskTimeoutStrategy . WARN ) ;
taskDefinition . setVersion ( 0 ) ;
taskDefinition . setResourceIds ( "" ) ;
return taskDefinition ;
}
/ * *
* check and import
* /