From 41e8836c91f41035e4be806a11640f44438e4dc0 Mon Sep 17 00:00:00 2001 From: Jiajie Zhong Date: Tue, 23 Nov 2021 16:58:00 +0800 Subject: [PATCH] [python] Add task sql (#6968) * [python] Add task sql * Add java gateway function doc --- .../src/pydolphinscheduler/constants.py | 1 + .../src/pydolphinscheduler/tasks/sql.py | 128 +++++++++++++++++ .../tests/tasks/test_sql.py | 131 ++++++++++++++++++ .../server/PythonGatewayServer.java | 31 +++++ 4 files changed, 291 insertions(+) create mode 100644 dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/sql.py create mode 100644 dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_sql.py diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py index c27f9bd958..ca0f368e0a 100644 --- a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py +++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py @@ -70,6 +70,7 @@ class TaskType(str): SHELL = "SHELL" HTTP = "HTTP" PYTHON = "PYTHON" + SQL = "SQL" class DefaultTaskCodeNum(str): diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/sql.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/sql.py new file mode 100644 index 0000000000..62da964d58 --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/sql.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Task sql.""" + +import re +from typing import Dict, Optional + +from pydolphinscheduler.constants import TaskType +from pydolphinscheduler.core.task import Task, TaskParams +from pydolphinscheduler.java_gateway import launch_gateway + + +class SqlType: + """SQL type, for now it just contain `SELECT` and `NO_SELECT`.""" + + SELECT = 0 + NOT_SELECT = 1 + + +class SqlTaskParams(TaskParams): + """Parameter only for Sql task type.""" + + def __init__( + self, + type: str, + datasource: str, + sql: str, + sql_type: Optional[int] = SqlType.NOT_SELECT, + display_rows: Optional[int] = 10, + pre_statements: Optional[str] = None, + post_statements: Optional[str] = None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.type = type + self.datasource = datasource + self.sql = sql + self.sql_type = sql_type + self.display_rows = display_rows + self.pre_statements = pre_statements or [] + self.post_statements = post_statements or [] + + +class Sql(Task): + """Task SQL object, declare behavior for SQL task to dolphinscheduler. + + It should run sql job in multiply sql lik engine, such as: + - ClickHouse + - DB2 + - HIVE + - MySQL + - Oracle + - Postgresql + - Presto + - SQLServer + You provider datasource_name contain connection information, it decisions which + database type and database instance would run this sql. + """ + + def __init__( + self, + name: str, + datasource_name: str, + sql: str, + pre_sql: Optional[str] = None, + post_sql: Optional[str] = None, + display_rows: Optional[int] = 10, + *args, + **kwargs + ): + self._sql = sql + self._datasource_name = datasource_name + self._datasource = {} + task_params = SqlTaskParams( + type=self.get_datasource_type(), + datasource=self.get_datasource_id(), + sql=sql, + sql_type=self.sql_type, + display_rows=display_rows, + pre_statements=pre_sql, + post_statements=post_sql, + ) + super().__init__(name, TaskType.SQL, task_params, *args, **kwargs) + + def get_datasource_type(self) -> str: + """Get datasource type from java gateway, a wrapper for :func:`get_datasource_info`.""" + return self.get_datasource_info(self._datasource_name).get("type") + + def get_datasource_id(self) -> str: + """Get datasource id from java gateway, a wrapper for :func:`get_datasource_info`.""" + return self.get_datasource_info(self._datasource_name).get("id") + + def get_datasource_info(self, name) -> Dict: + """Get datasource info from java gateway, contains datasource id, type, name.""" + if self._datasource: + return self._datasource + else: + gateway = launch_gateway() + self._datasource = gateway.entry_point.getDatasourceInfo(name) + return self._datasource + + @property + def sql_type(self) -> int: + """Judgement sql type, use regexp to check which type of the sql is.""" + pattern_select_str = ( + "^(?!(.* |)insert |(.* |)delete |(.* |)drop |(.* |)update |(.* |)alter ).*" + ) + pattern_select = re.compile(pattern_select_str, re.IGNORECASE) + if pattern_select.match(self._sql) is None: + return SqlType.NOT_SELECT + else: + return SqlType.SELECT diff --git a/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_sql.py b/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_sql.py new file mode 100644 index 0000000000..499b46b4bb --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_sql.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test Task Sql.""" + + +from unittest.mock import patch + +import pytest + +from pydolphinscheduler.tasks.sql import Sql, SqlType + + +@patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(123, 1), +) +@patch( + "pydolphinscheduler.tasks.sql.Sql.get_datasource_info", + return_value=({"id": 1, "type": "mock_type"}), +) +def test_get_datasource_detail(mock_datasource, mock_code_version): + """Test :func:`get_datasource_type` and :func:`get_datasource_id` can return expect value.""" + name = "test_get_sql_type" + datasource_name = "test_datasource" + sql = "select 1" + task = Sql(name, datasource_name, sql) + assert 1 == task.get_datasource_id() + assert "mock_type" == task.get_datasource_type() + + +@pytest.mark.parametrize( + "sql, sql_type", + [ + ("select 1", SqlType.SELECT), + (" select 1", SqlType.SELECT), + (" select 1 ", SqlType.SELECT), + (" select 'insert' ", SqlType.SELECT), + (" select 'insert ' ", SqlType.SELECT), + ("with tmp as (select 1) select * from tmp ", SqlType.SELECT), + ("insert into table_name(col1, col2) value (val1, val2)", SqlType.NOT_SELECT), + ( + "insert into table_name(select, col2) value ('select', val2)", + SqlType.NOT_SELECT, + ), + ("update table_name SET col1=val1 where col1=val2", SqlType.NOT_SELECT), + ("update table_name SET col1='select' where col1=val2", SqlType.NOT_SELECT), + ("delete from table_name where id < 10", SqlType.NOT_SELECT), + ("delete from table_name where id < 10", SqlType.NOT_SELECT), + ("alter table table_name add column col1 int", SqlType.NOT_SELECT), + ], +) +@patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(123, 1), +) +@patch( + "pydolphinscheduler.tasks.sql.Sql.get_datasource_info", + return_value=({"id": 1, "type": "mock_type"}), +) +def test_get_sql_type(mock_datasource, mock_code_version, sql, sql_type): + """Test property sql_type could return correct type.""" + name = "test_get_sql_type" + datasource_name = "test_datasource" + task = Sql(name, datasource_name, sql) + assert ( + sql_type == task.sql_type + ), f"Sql {sql} expect sql type is {sql_type} but got {task.sql_type}" + + +@patch( + "pydolphinscheduler.tasks.sql.Sql.get_datasource_info", + return_value=({"id": 1, "type": "MYSQL"}), +) +def test_sql_to_dict(mock_datasource): + """Test task sql function to_dict.""" + code = 123 + version = 1 + name = "test_sql_dict" + command = "select 1" + datasource_name = "test_datasource" + expect = { + "code": code, + "name": name, + "version": 1, + "description": None, + "delayTime": 0, + "taskType": "SQL", + "taskParams": { + "type": "MYSQL", + "datasource": 1, + "sql": command, + "sqlType": SqlType.SELECT, + "displayRows": 10, + "preStatements": [], + "postStatements": [], + "localParams": [], + "resourceList": [], + "dependence": {}, + "conditionResult": {"successNode": [""], "failedNode": [""]}, + "waitStartTimeout": {}, + }, + "flag": "YES", + "taskPriority": "MEDIUM", + "workerGroup": "default", + "failRetryTimes": 0, + "failRetryInterval": 1, + "timeoutFlag": "CLOSE", + "timeoutNotifyStrategy": None, + "timeout": 0, + } + with patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(code, version), + ): + task = Sql(name, datasource_name, command) + assert task.to_dict() == expect diff --git a/dolphinscheduler-python/src/main/java/org/apache/dolphinscheduler/server/PythonGatewayServer.java b/dolphinscheduler-python/src/main/java/org/apache/dolphinscheduler/server/PythonGatewayServer.java index 9ae966fe4f..77c23449de 100644 --- a/dolphinscheduler-python/src/main/java/org/apache/dolphinscheduler/server/PythonGatewayServer.java +++ b/dolphinscheduler-python/src/main/java/org/apache/dolphinscheduler/server/PythonGatewayServer.java @@ -37,6 +37,7 @@ import org.apache.dolphinscheduler.common.enums.TaskDependType; import org.apache.dolphinscheduler.common.enums.UserType; import org.apache.dolphinscheduler.common.enums.WarningType; import org.apache.dolphinscheduler.common.utils.CodeGenerateUtils; +import org.apache.dolphinscheduler.dao.entity.DataSource; import org.apache.dolphinscheduler.dao.entity.ProcessDefinition; import org.apache.dolphinscheduler.dao.entity.Project; import org.apache.dolphinscheduler.dao.entity.Queue; @@ -44,6 +45,7 @@ import org.apache.dolphinscheduler.dao.entity.Schedule; import org.apache.dolphinscheduler.dao.entity.TaskDefinition; import org.apache.dolphinscheduler.dao.entity.Tenant; import org.apache.dolphinscheduler.dao.entity.User; +import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper; import org.apache.dolphinscheduler.dao.mapper.ProcessDefinitionMapper; import org.apache.dolphinscheduler.dao.mapper.ProjectMapper; import org.apache.dolphinscheduler.dao.mapper.ScheduleMapper; @@ -124,6 +126,9 @@ public class PythonGatewayServer extends SpringBootServletInitializer { @Autowired private ScheduleMapper scheduleMapper; + @Autowired + private DataSourceMapper dataSourceMapper; + // TODO replace this user to build in admin user if we make sure build in one could not be change private final User dummyAdminUser = new User() { { @@ -360,6 +365,32 @@ public class PythonGatewayServer extends SpringBootServletInitializer { } } + /** + * Get datasource by given datasource name. It return map contain datasource id, type, name. + * Useful in Python API create sql task which need datasource information. + * + * @param datasourceName user who create or update schedule + */ + public Map getDatasourceInfo(String datasourceName) { + Map result = new HashMap<>(); + List dataSourceList = dataSourceMapper.queryDataSourceByName(datasourceName); + if (dataSourceList.size() > 1) { + String msg = String.format("Get more than one datasource by name %s", datasourceName); + logger.error(msg); + throw new IllegalArgumentException(msg); + } else if (dataSourceList.size() == 0) { + String msg = String.format("Can not find any datasource by name %s", datasourceName); + logger.error(msg); + throw new IllegalArgumentException(msg); + } else { + DataSource dataSource = dataSourceList.get(0); + result.put("id", dataSource.getId()); + result.put("type", dataSource.getType().name()); + result.put("name", dataSource.getName()); + } + return result; + } + @PostConstruct public void run() { GatewayServer server = new GatewayServer(this);