Jiajie Zhong
3 years ago
committed by
GitHub
4 changed files with 291 additions and 0 deletions
@ -0,0 +1,128 @@ |
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one |
||||||
|
# or more contributor license agreements. See the NOTICE file |
||||||
|
# distributed with this work for additional information |
||||||
|
# regarding copyright ownership. The ASF licenses this file |
||||||
|
# to you under the Apache License, Version 2.0 (the |
||||||
|
# "License"); you may not use this file except in compliance |
||||||
|
# with the License. You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, |
||||||
|
# software distributed under the License is distributed on an |
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||||
|
# KIND, either express or implied. See the License for the |
||||||
|
# specific language governing permissions and limitations |
||||||
|
# under the License. |
||||||
|
|
||||||
|
"""Task sql.""" |
||||||
|
|
||||||
|
import re |
||||||
|
from typing import Dict, Optional |
||||||
|
|
||||||
|
from pydolphinscheduler.constants import TaskType |
||||||
|
from pydolphinscheduler.core.task import Task, TaskParams |
||||||
|
from pydolphinscheduler.java_gateway import launch_gateway |
||||||
|
|
||||||
|
|
||||||
|
class SqlType: |
||||||
|
"""SQL type, for now it just contain `SELECT` and `NO_SELECT`.""" |
||||||
|
|
||||||
|
SELECT = 0 |
||||||
|
NOT_SELECT = 1 |
||||||
|
|
||||||
|
|
||||||
|
class SqlTaskParams(TaskParams): |
||||||
|
"""Parameter only for Sql task type.""" |
||||||
|
|
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
type: str, |
||||||
|
datasource: str, |
||||||
|
sql: str, |
||||||
|
sql_type: Optional[int] = SqlType.NOT_SELECT, |
||||||
|
display_rows: Optional[int] = 10, |
||||||
|
pre_statements: Optional[str] = None, |
||||||
|
post_statements: Optional[str] = None, |
||||||
|
*args, |
||||||
|
**kwargs |
||||||
|
): |
||||||
|
super().__init__(*args, **kwargs) |
||||||
|
self.type = type |
||||||
|
self.datasource = datasource |
||||||
|
self.sql = sql |
||||||
|
self.sql_type = sql_type |
||||||
|
self.display_rows = display_rows |
||||||
|
self.pre_statements = pre_statements or [] |
||||||
|
self.post_statements = post_statements or [] |
||||||
|
|
||||||
|
|
||||||
|
class Sql(Task): |
||||||
|
"""Task SQL object, declare behavior for SQL task to dolphinscheduler. |
||||||
|
|
||||||
|
It should run sql job in multiply sql lik engine, such as: |
||||||
|
- ClickHouse |
||||||
|
- DB2 |
||||||
|
- HIVE |
||||||
|
- MySQL |
||||||
|
- Oracle |
||||||
|
- Postgresql |
||||||
|
- Presto |
||||||
|
- SQLServer |
||||||
|
You provider datasource_name contain connection information, it decisions which |
||||||
|
database type and database instance would run this sql. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
name: str, |
||||||
|
datasource_name: str, |
||||||
|
sql: str, |
||||||
|
pre_sql: Optional[str] = None, |
||||||
|
post_sql: Optional[str] = None, |
||||||
|
display_rows: Optional[int] = 10, |
||||||
|
*args, |
||||||
|
**kwargs |
||||||
|
): |
||||||
|
self._sql = sql |
||||||
|
self._datasource_name = datasource_name |
||||||
|
self._datasource = {} |
||||||
|
task_params = SqlTaskParams( |
||||||
|
type=self.get_datasource_type(), |
||||||
|
datasource=self.get_datasource_id(), |
||||||
|
sql=sql, |
||||||
|
sql_type=self.sql_type, |
||||||
|
display_rows=display_rows, |
||||||
|
pre_statements=pre_sql, |
||||||
|
post_statements=post_sql, |
||||||
|
) |
||||||
|
super().__init__(name, TaskType.SQL, task_params, *args, **kwargs) |
||||||
|
|
||||||
|
def get_datasource_type(self) -> str: |
||||||
|
"""Get datasource type from java gateway, a wrapper for :func:`get_datasource_info`.""" |
||||||
|
return self.get_datasource_info(self._datasource_name).get("type") |
||||||
|
|
||||||
|
def get_datasource_id(self) -> str: |
||||||
|
"""Get datasource id from java gateway, a wrapper for :func:`get_datasource_info`.""" |
||||||
|
return self.get_datasource_info(self._datasource_name).get("id") |
||||||
|
|
||||||
|
def get_datasource_info(self, name) -> Dict: |
||||||
|
"""Get datasource info from java gateway, contains datasource id, type, name.""" |
||||||
|
if self._datasource: |
||||||
|
return self._datasource |
||||||
|
else: |
||||||
|
gateway = launch_gateway() |
||||||
|
self._datasource = gateway.entry_point.getDatasourceInfo(name) |
||||||
|
return self._datasource |
||||||
|
|
||||||
|
@property |
||||||
|
def sql_type(self) -> int: |
||||||
|
"""Judgement sql type, use regexp to check which type of the sql is.""" |
||||||
|
pattern_select_str = ( |
||||||
|
"^(?!(.* |)insert |(.* |)delete |(.* |)drop |(.* |)update |(.* |)alter ).*" |
||||||
|
) |
||||||
|
pattern_select = re.compile(pattern_select_str, re.IGNORECASE) |
||||||
|
if pattern_select.match(self._sql) is None: |
||||||
|
return SqlType.NOT_SELECT |
||||||
|
else: |
||||||
|
return SqlType.SELECT |
@ -0,0 +1,131 @@ |
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one |
||||||
|
# or more contributor license agreements. See the NOTICE file |
||||||
|
# distributed with this work for additional information |
||||||
|
# regarding copyright ownership. The ASF licenses this file |
||||||
|
# to you under the Apache License, Version 2.0 (the |
||||||
|
# "License"); you may not use this file except in compliance |
||||||
|
# with the License. You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, |
||||||
|
# software distributed under the License is distributed on an |
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||||
|
# KIND, either express or implied. See the License for the |
||||||
|
# specific language governing permissions and limitations |
||||||
|
# under the License. |
||||||
|
|
||||||
|
"""Test Task Sql.""" |
||||||
|
|
||||||
|
|
||||||
|
from unittest.mock import patch |
||||||
|
|
||||||
|
import pytest |
||||||
|
|
||||||
|
from pydolphinscheduler.tasks.sql import Sql, SqlType |
||||||
|
|
||||||
|
|
||||||
|
@patch( |
||||||
|
"pydolphinscheduler.core.task.Task.gen_code_and_version", |
||||||
|
return_value=(123, 1), |
||||||
|
) |
||||||
|
@patch( |
||||||
|
"pydolphinscheduler.tasks.sql.Sql.get_datasource_info", |
||||||
|
return_value=({"id": 1, "type": "mock_type"}), |
||||||
|
) |
||||||
|
def test_get_datasource_detail(mock_datasource, mock_code_version): |
||||||
|
"""Test :func:`get_datasource_type` and :func:`get_datasource_id` can return expect value.""" |
||||||
|
name = "test_get_sql_type" |
||||||
|
datasource_name = "test_datasource" |
||||||
|
sql = "select 1" |
||||||
|
task = Sql(name, datasource_name, sql) |
||||||
|
assert 1 == task.get_datasource_id() |
||||||
|
assert "mock_type" == task.get_datasource_type() |
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize( |
||||||
|
"sql, sql_type", |
||||||
|
[ |
||||||
|
("select 1", SqlType.SELECT), |
||||||
|
(" select 1", SqlType.SELECT), |
||||||
|
(" select 1 ", SqlType.SELECT), |
||||||
|
(" select 'insert' ", SqlType.SELECT), |
||||||
|
(" select 'insert ' ", SqlType.SELECT), |
||||||
|
("with tmp as (select 1) select * from tmp ", SqlType.SELECT), |
||||||
|
("insert into table_name(col1, col2) value (val1, val2)", SqlType.NOT_SELECT), |
||||||
|
( |
||||||
|
"insert into table_name(select, col2) value ('select', val2)", |
||||||
|
SqlType.NOT_SELECT, |
||||||
|
), |
||||||
|
("update table_name SET col1=val1 where col1=val2", SqlType.NOT_SELECT), |
||||||
|
("update table_name SET col1='select' where col1=val2", SqlType.NOT_SELECT), |
||||||
|
("delete from table_name where id < 10", SqlType.NOT_SELECT), |
||||||
|
("delete from table_name where id < 10", SqlType.NOT_SELECT), |
||||||
|
("alter table table_name add column col1 int", SqlType.NOT_SELECT), |
||||||
|
], |
||||||
|
) |
||||||
|
@patch( |
||||||
|
"pydolphinscheduler.core.task.Task.gen_code_and_version", |
||||||
|
return_value=(123, 1), |
||||||
|
) |
||||||
|
@patch( |
||||||
|
"pydolphinscheduler.tasks.sql.Sql.get_datasource_info", |
||||||
|
return_value=({"id": 1, "type": "mock_type"}), |
||||||
|
) |
||||||
|
def test_get_sql_type(mock_datasource, mock_code_version, sql, sql_type): |
||||||
|
"""Test property sql_type could return correct type.""" |
||||||
|
name = "test_get_sql_type" |
||||||
|
datasource_name = "test_datasource" |
||||||
|
task = Sql(name, datasource_name, sql) |
||||||
|
assert ( |
||||||
|
sql_type == task.sql_type |
||||||
|
), f"Sql {sql} expect sql type is {sql_type} but got {task.sql_type}" |
||||||
|
|
||||||
|
|
||||||
|
@patch( |
||||||
|
"pydolphinscheduler.tasks.sql.Sql.get_datasource_info", |
||||||
|
return_value=({"id": 1, "type": "MYSQL"}), |
||||||
|
) |
||||||
|
def test_sql_to_dict(mock_datasource): |
||||||
|
"""Test task sql function to_dict.""" |
||||||
|
code = 123 |
||||||
|
version = 1 |
||||||
|
name = "test_sql_dict" |
||||||
|
command = "select 1" |
||||||
|
datasource_name = "test_datasource" |
||||||
|
expect = { |
||||||
|
"code": code, |
||||||
|
"name": name, |
||||||
|
"version": 1, |
||||||
|
"description": None, |
||||||
|
"delayTime": 0, |
||||||
|
"taskType": "SQL", |
||||||
|
"taskParams": { |
||||||
|
"type": "MYSQL", |
||||||
|
"datasource": 1, |
||||||
|
"sql": command, |
||||||
|
"sqlType": SqlType.SELECT, |
||||||
|
"displayRows": 10, |
||||||
|
"preStatements": [], |
||||||
|
"postStatements": [], |
||||||
|
"localParams": [], |
||||||
|
"resourceList": [], |
||||||
|
"dependence": {}, |
||||||
|
"conditionResult": {"successNode": [""], "failedNode": [""]}, |
||||||
|
"waitStartTimeout": {}, |
||||||
|
}, |
||||||
|
"flag": "YES", |
||||||
|
"taskPriority": "MEDIUM", |
||||||
|
"workerGroup": "default", |
||||||
|
"failRetryTimes": 0, |
||||||
|
"failRetryInterval": 1, |
||||||
|
"timeoutFlag": "CLOSE", |
||||||
|
"timeoutNotifyStrategy": None, |
||||||
|
"timeout": 0, |
||||||
|
} |
||||||
|
with patch( |
||||||
|
"pydolphinscheduler.core.task.Task.gen_code_and_version", |
||||||
|
return_value=(code, version), |
||||||
|
): |
||||||
|
task = Sql(name, datasource_name, command) |
||||||
|
assert task.to_dict() == expect |
Loading…
Reference in new issue