From 4c2f77ee9cbd599edfb38e4bf82755f74e96e7c6 Mon Sep 17 00:00:00 2001 From: ououtt <734164350@qq.com> Date: Wed, 5 Jan 2022 17:55:08 +0800 Subject: [PATCH] [DS-7016][feat] Auto create workflow while import sql script with specific hint (#7214) * [DS-7016][feat] Auto create workflow while import sql script with specific hint * datasource : datasource name * name: task name * upstream: pre task names * [DS-7016][feat] Auto create workflow while import sql script with specific hint * remove excess blank lines * [DS-7016][feat] Auto create workflow while import sql script with specific hint * datasource : datasource name * name: task name * upstream: pre task names * [DS-7016][feat] Auto create workflow while import sql script with specific hint * datasource : datasource name * name: task name * upstream: pre task names * Code optimization * for Zip Bomb Attack Co-authored-by: eye --- .../ProcessDefinitionController.java | 7 +- .../api/service/ProcessDefinitionService.java | 12 + .../impl/ProcessDefinitionServiceImpl.java | 211 ++++++++++++++++++ .../service/ProcessDefinitionServiceTest.java | 46 ++++ .../dao/mapper/DataSourceMapper.java | 9 +- .../dao/mapper/DataSourceMapper.xml | 13 ++ 6 files changed, 296 insertions(+), 2 deletions(-) diff --git a/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/controller/ProcessDefinitionController.java b/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/controller/ProcessDefinitionController.java index d9516d4856..5dba775317 100644 --- a/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/controller/ProcessDefinitionController.java +++ b/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/controller/ProcessDefinitionController.java @@ -696,7 +696,12 @@ public class ProcessDefinitionController extends BaseController { public Result importProcessDefinition(@ApiIgnore @RequestAttribute(value = Constants.SESSION_USER) User loginUser, @ApiParam(name = "projectCode", value = "PROJECT_CODE", required = true) @PathVariable long projectCode, @RequestParam("file") MultipartFile file) { - Map result = processDefinitionService.importProcessDefinition(loginUser, projectCode, file); + Map result; + if ("application/zip".equals(file.getContentType())) { + result = processDefinitionService.importSqlProcessDefinition(loginUser, projectCode, file); + } else { + result = processDefinitionService.importProcessDefinition(loginUser, projectCode, file); + } return returnDataList(result); } diff --git a/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionService.java b/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionService.java index 64a0fbe989..99a6d95365 100644 --- a/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionService.java +++ b/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionService.java @@ -242,6 +242,18 @@ public interface ProcessDefinitionService { long projectCode, MultipartFile file); + /** + * import sql process definition + * + * @param loginUser login user + * @param projectCode project code + * @param file sql file, zip + * @return import process + */ + Map importSqlProcessDefinition(User loginUser, + long projectCode, + MultipartFile file); + /** * check the process task relation json * diff --git a/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/impl/ProcessDefinitionServiceImpl.java b/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/impl/ProcessDefinitionServiceImpl.java index eb6aa5a959..20d109a2b2 100644 --- a/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/impl/ProcessDefinitionServiceImpl.java +++ b/dolphinscheduler-api/src/main/java/org/apache/dolphinscheduler/api/service/impl/ProcessDefinitionServiceImpl.java @@ -18,6 +18,7 @@ package org.apache.dolphinscheduler.api.service.impl; import static org.apache.dolphinscheduler.common.Constants.CMD_PARAM_SUB_PROCESS_DEFINE_CODE; +import static org.apache.dolphinscheduler.common.Constants.DEFAULT_WORKER_GROUP; import org.apache.dolphinscheduler.api.dto.DagDataSchedule; import org.apache.dolphinscheduler.api.dto.ScheduleParam; @@ -34,22 +35,29 @@ import org.apache.dolphinscheduler.api.utils.FileUtils; import org.apache.dolphinscheduler.api.utils.PageInfo; import org.apache.dolphinscheduler.api.utils.Result; import org.apache.dolphinscheduler.common.Constants; +import org.apache.dolphinscheduler.common.enums.ConditionType; import org.apache.dolphinscheduler.common.enums.FailureStrategy; +import org.apache.dolphinscheduler.common.enums.Flag; import org.apache.dolphinscheduler.common.enums.Priority; import org.apache.dolphinscheduler.common.enums.ProcessExecutionTypeEnum; import org.apache.dolphinscheduler.common.enums.ReleaseState; +import org.apache.dolphinscheduler.common.enums.TaskTimeoutStrategy; import org.apache.dolphinscheduler.common.enums.TaskType; +import org.apache.dolphinscheduler.common.enums.TimeoutFlag; import org.apache.dolphinscheduler.common.enums.UserType; import org.apache.dolphinscheduler.common.enums.WarningType; import org.apache.dolphinscheduler.common.graph.DAG; import org.apache.dolphinscheduler.common.model.TaskNode; import org.apache.dolphinscheduler.common.model.TaskNodeRelation; +import org.apache.dolphinscheduler.common.task.sql.SqlParameters; +import org.apache.dolphinscheduler.common.task.sql.SqlType; import org.apache.dolphinscheduler.common.thread.Stopper; import org.apache.dolphinscheduler.common.utils.CodeGenerateUtils; import org.apache.dolphinscheduler.common.utils.CodeGenerateUtils.CodeGenerateException; import org.apache.dolphinscheduler.common.utils.DateUtils; import org.apache.dolphinscheduler.common.utils.JSONUtils; import org.apache.dolphinscheduler.dao.entity.DagData; +import org.apache.dolphinscheduler.dao.entity.DataSource; import org.apache.dolphinscheduler.dao.entity.ProcessDefinition; import org.apache.dolphinscheduler.dao.entity.ProcessDefinitionLog; import org.apache.dolphinscheduler.dao.entity.ProcessInstance; @@ -62,6 +70,7 @@ import org.apache.dolphinscheduler.dao.entity.TaskDefinitionLog; import org.apache.dolphinscheduler.dao.entity.TaskInstance; import org.apache.dolphinscheduler.dao.entity.Tenant; import org.apache.dolphinscheduler.dao.entity.User; +import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper; import org.apache.dolphinscheduler.dao.mapper.ProcessDefinitionLogMapper; import org.apache.dolphinscheduler.dao.mapper.ProcessDefinitionMapper; import org.apache.dolphinscheduler.dao.mapper.ProcessTaskRelationLogMapper; @@ -79,11 +88,14 @@ import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import java.io.BufferedOutputStream; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Iterator; @@ -93,6 +105,8 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; @@ -167,6 +181,9 @@ public class ProcessDefinitionServiceImpl extends BaseServiceImpl implements Pro @Autowired private TenantMapper tenantMapper; + @Autowired + private DataSourceMapper dataSourceMapper; + /** * create process definition * @@ -867,6 +884,200 @@ public class ProcessDefinitionServiceImpl extends BaseServiceImpl implements Pro return result; } + @Override + @Transactional(rollbackFor = RuntimeException.class) + public Map importSqlProcessDefinition(User loginUser, long projectCode, MultipartFile file) { + Map result = new HashMap<>(); + String processDefinitionName = file.getOriginalFilename() == null ? file.getName() : file.getOriginalFilename(); + int index = processDefinitionName.lastIndexOf("."); + if (index > 0) { + processDefinitionName = processDefinitionName.substring(0, index); + } + processDefinitionName = processDefinitionName + "_import_" + DateUtils.getCurrentTimeStamp(); + + ProcessDefinition processDefinition; + List taskDefinitionList = new ArrayList<>(); + List processTaskRelationList = new ArrayList<>(); + + // for Zip Bomb Attack + int THRESHOLD_ENTRIES = 10000; + int THRESHOLD_SIZE = 1000000000; // 1 GB + double THRESHOLD_RATIO = 10; + int totalEntryArchive = 0; + int totalSizeEntry = 0; + // In most cases, there will be only one data source + Map dataSourceCache = new HashMap<>(1); + Map taskNameToCode = new HashMap<>(16); + Map> taskNameToUpstream = new HashMap<>(16); + try (ZipInputStream zIn = new ZipInputStream(file.getInputStream()); + BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(zIn))) { + // build process definition + processDefinition = new ProcessDefinition(projectCode, + processDefinitionName, + CodeGenerateUtils.getInstance().genCode(), + "", + "[]", null, + 0, loginUser.getId(), loginUser.getTenantId()); + + ZipEntry entry; + while ((entry = zIn.getNextEntry()) != null) { + totalEntryArchive ++; + int totalSizeArchive = 0; + if (!entry.isDirectory()) { + StringBuilder sql = new StringBuilder(); + String taskName = null; + String datasourceName = null; + List upstreams = Collections.emptyList(); + String line; + while ((line = bufferedReader.readLine()) != null) { + int nBytes = line.getBytes(StandardCharsets.UTF_8).length; + totalSizeEntry += nBytes; + totalSizeArchive += nBytes; + long compressionRatio = totalSizeEntry / entry.getCompressedSize(); + if(compressionRatio > THRESHOLD_RATIO) { + throw new IllegalStateException("ratio between compressed and uncompressed data is highly suspicious, looks like a Zip Bomb Attack"); + } + int commentIndex = line.indexOf("-- "); + if (commentIndex >= 0) { + int colonIndex = line.indexOf(":", commentIndex); + if (colonIndex > 0) { + String key = line.substring(commentIndex + 3, colonIndex).trim().toLowerCase(); + String value = line.substring(colonIndex + 1).trim(); + switch (key) { + case "name": + taskName = value; + line = line.substring(0, commentIndex); + break; + case "upstream": + upstreams = Arrays.stream(value.split(",")).map(String::trim) + .filter(s -> !"".equals(s)).collect(Collectors.toList()); + line = line.substring(0, commentIndex); + break; + case "datasource": + datasourceName = value; + line = line.substring(0, commentIndex); + break; + default: + break; + } + } + } + if (!"".equals(line)) { + sql.append(line).append("\n"); + } + } + // import/sql1.sql -> sql1 + if (taskName == null) { + taskName = entry.getName(); + index = taskName.indexOf("/"); + if (index > 0) { + taskName = taskName.substring(index + 1); + } + index = taskName.lastIndexOf("."); + if (index > 0) { + taskName = taskName.substring(0, index); + } + } + DataSource dataSource = dataSourceCache.get(datasourceName); + if (dataSource == null) { + dataSource = queryDatasourceByNameAndUser(datasourceName, loginUser); + } + if (dataSource == null) { + putMsg(result, Status.DATASOURCE_NAME_ILLEGAL); + return result; + } + dataSourceCache.put(datasourceName, dataSource); + + TaskDefinitionLog taskDefinition = buildNormalSqlTaskDefinition(taskName, dataSource, sql.substring(0, sql.length() - 1)); + + taskDefinitionList.add(taskDefinition); + taskNameToCode.put(taskDefinition.getName(), taskDefinition.getCode()); + taskNameToUpstream.put(taskDefinition.getName(), upstreams); + } + + if(totalSizeArchive > THRESHOLD_SIZE) { + throw new IllegalStateException("the uncompressed data size is too much for the application resource capacity"); + } + + if(totalEntryArchive > THRESHOLD_ENTRIES) { + throw new IllegalStateException("too much entries in this archive, can lead to inodes exhaustion of the system"); + } + } + } catch (Exception e) { + logger.error(e.getMessage(), e); + putMsg(result, Status.IMPORT_PROCESS_DEFINE_ERROR); + return result; + } + + // build task relation + for (Map.Entry entry : taskNameToCode.entrySet()) { + List upstreams = taskNameToUpstream.get(entry.getKey()); + if (CollectionUtils.isEmpty(upstreams) + || (upstreams.size() == 1 && upstreams.contains("root") && !taskNameToCode.containsKey("root"))) { + ProcessTaskRelationLog processTaskRelation = buildNormalTaskRelation(0, entry.getValue()); + processTaskRelationList.add(processTaskRelation); + continue; + } + for (String upstream : upstreams) { + ProcessTaskRelationLog processTaskRelation = buildNormalTaskRelation(taskNameToCode.get(upstream), entry.getValue()); + processTaskRelationList.add(processTaskRelation); + } + } + + return createDagDefine(loginUser, processTaskRelationList, processDefinition, taskDefinitionList); + } + + private ProcessTaskRelationLog buildNormalTaskRelation(long preTaskCode, long postTaskCode) { + ProcessTaskRelationLog processTaskRelation = new ProcessTaskRelationLog(); + processTaskRelation.setPreTaskCode(preTaskCode); + processTaskRelation.setPreTaskVersion(0); + processTaskRelation.setPostTaskCode(postTaskCode); + processTaskRelation.setPostTaskVersion(0); + processTaskRelation.setConditionType(ConditionType.NONE); + processTaskRelation.setName(""); + return processTaskRelation; + } + + private DataSource queryDatasourceByNameAndUser(String datasourceName, User loginUser) { + if (isAdmin(loginUser)) { + List dataSources = dataSourceMapper.queryDataSourceByName(datasourceName); + if (CollectionUtils.isNotEmpty(dataSources)) { + return dataSources.get(0); + } + } else { + return dataSourceMapper.queryDataSourceByNameAndUserId(loginUser.getId(), datasourceName); + } + return null; + } + + private TaskDefinitionLog buildNormalSqlTaskDefinition(String taskName, DataSource dataSource, String sql) throws CodeGenerateException { + TaskDefinitionLog taskDefinition = new TaskDefinitionLog(); + taskDefinition.setName(taskName); + taskDefinition.setFlag(Flag.YES); + SqlParameters sqlParameters = new SqlParameters(); + sqlParameters.setType(dataSource.getType().name()); + sqlParameters.setDatasource(dataSource.getId()); + sqlParameters.setSql(sql.substring(0, sql.length() - 1)); + // it may be a query type, but it can only be determined by parsing SQL + sqlParameters.setSqlType(SqlType.NON_QUERY.ordinal()); + sqlParameters.setLocalParams(Collections.emptyList()); + taskDefinition.setTaskParams(JSONUtils.toJsonString(sqlParameters)); + taskDefinition.setCode(CodeGenerateUtils.getInstance().genCode()); + taskDefinition.setTaskType(TaskType.SQL.getDesc()); + taskDefinition.setFailRetryTimes(0); + taskDefinition.setFailRetryInterval(0); + taskDefinition.setTimeoutFlag(TimeoutFlag.CLOSE); + taskDefinition.setWorkerGroup(DEFAULT_WORKER_GROUP); + taskDefinition.setTaskPriority(Priority.MEDIUM); + taskDefinition.setEnvironmentCode(-1); + taskDefinition.setTimeout(0); + taskDefinition.setDelayTime(0); + taskDefinition.setTimeoutNotifyStrategy(TaskTimeoutStrategy.WARN); + taskDefinition.setVersion(0); + taskDefinition.setResourceIds(""); + return taskDefinition; + } + /** * check and import */ diff --git a/dolphinscheduler-api/src/test/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionServiceTest.java b/dolphinscheduler-api/src/test/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionServiceTest.java index 607341dfe1..8bdcc1db08 100644 --- a/dolphinscheduler-api/src/test/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionServiceTest.java +++ b/dolphinscheduler-api/src/test/java/org/apache/dolphinscheduler/api/service/ProcessDefinitionServiceTest.java @@ -32,6 +32,7 @@ import org.apache.dolphinscheduler.common.enums.UserType; import org.apache.dolphinscheduler.common.enums.WarningType; import org.apache.dolphinscheduler.common.graph.DAG; import org.apache.dolphinscheduler.dao.entity.DagData; +import org.apache.dolphinscheduler.dao.entity.DataSource; import org.apache.dolphinscheduler.dao.entity.ProcessDefinition; import org.apache.dolphinscheduler.dao.entity.ProcessTaskRelation; import org.apache.dolphinscheduler.dao.entity.Project; @@ -46,9 +47,12 @@ import org.apache.dolphinscheduler.dao.mapper.ScheduleMapper; import org.apache.dolphinscheduler.dao.mapper.TaskInstanceMapper; import org.apache.dolphinscheduler.dao.mapper.TenantMapper; import org.apache.dolphinscheduler.service.process.ProcessService; +import org.apache.dolphinscheduler.spi.enums.DbType; import org.apache.commons.lang.StringUtils; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; import java.text.MessageFormat; import java.util.ArrayList; import java.util.Arrays; @@ -58,6 +62,8 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; import javax.servlet.http.HttpServletResponse; @@ -69,6 +75,7 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.mock.web.MockMultipartFile; import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; @@ -656,6 +663,45 @@ public class ProcessDefinitionServiceTest { Assert.assertNotNull(processDefinitionService.exportProcessDagData(processDefinition)); } + @Test + public void testImportSqlProcessDefinition() throws Exception { + int userId = 10; + User loginUser = Mockito.mock(User.class); + Mockito.when(loginUser.getId()).thenReturn(userId); + Mockito.when(loginUser.getTenantId()).thenReturn(2); + Mockito.when(loginUser.getUserType()).thenReturn(UserType.GENERAL_USER); + + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + ZipOutputStream outputStream = new ZipOutputStream(byteArrayOutputStream); + outputStream.putNextEntry(new ZipEntry("import_sql/")); + + outputStream.putNextEntry(new ZipEntry("import_sql/a.sql")); + outputStream.write("-- upstream: start_auto_dag\n-- datasource: mysql_1\nselect 1;".getBytes(StandardCharsets.UTF_8)); + + outputStream.putNextEntry(new ZipEntry("import_sql/b.sql")); + outputStream.write("-- name: start_auto_dag\n-- datasource: mysql_1\nselect 1;".getBytes(StandardCharsets.UTF_8)); + + outputStream.close(); + + MockMultipartFile mockMultipartFile = new MockMultipartFile("import_sql.zip", byteArrayOutputStream.toByteArray()); + + DataSource dataSource = Mockito.mock(DataSource.class); + Mockito.when(dataSource.getId()).thenReturn(1); + Mockito.when(dataSource.getType()).thenReturn(DbType.MYSQL); + + Mockito.when(dataSourceMapper.queryDataSourceByNameAndUserId(userId, "mysql_1")).thenReturn(dataSource); + + long projectCode = 1001; + Mockito.when(processService.saveTaskDefine(Mockito.same(loginUser), Mockito.eq(projectCode), Mockito.notNull())).thenReturn(2); + Mockito.when(processService.saveProcessDefine(Mockito.same(loginUser), Mockito.notNull(), Mockito.notNull())).thenReturn(1); + Mockito.when(processService.saveTaskRelation(Mockito.same(loginUser), Mockito.eq(projectCode), Mockito.anyLong(), + Mockito.eq(1), Mockito.notNull(), Mockito.notNull())).thenReturn(0); + + Map result = processDefinitionService.importSqlProcessDefinition(loginUser, projectCode, mockMultipartFile); + + Assert.assertEquals(result.get(Constants.STATUS), Status.SUCCESS); + } + /** * get mock processDefinition * diff --git a/dolphinscheduler-dao/src/main/java/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.java b/dolphinscheduler-dao/src/main/java/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.java index 0c3238a5c5..bfb0640386 100644 --- a/dolphinscheduler-dao/src/main/java/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.java +++ b/dolphinscheduler-dao/src/main/java/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.java @@ -87,5 +87,12 @@ public interface DataSourceMapper extends BaseMapper { */ List listAuthorizedDataSource(@Param("userId") int userId,@Param("dataSourceIds")T[] dataSourceIds); - + /** + * query datasource by name and user id + * + * @param userId userId + * @param name datasource name + * @return If the name does not exist or the user does not have permission, it will return null + */ + DataSource queryDataSourceByNameAndUserId(@Param("userId") int userId, @Param("name") String name); } diff --git a/dolphinscheduler-dao/src/main/resources/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.xml b/dolphinscheduler-dao/src/main/resources/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.xml index acf310ecfc..2241608cbe 100644 --- a/dolphinscheduler-dao/src/main/resources/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.xml +++ b/dolphinscheduler-dao/src/main/resources/org/apache/dolphinscheduler/dao/mapper/DataSourceMapper.xml @@ -98,4 +98,17 @@ + + +