Jiajie Zhong
3 years ago
committed by
GitHub
4 changed files with 291 additions and 0 deletions
@ -0,0 +1,128 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""Task sql.""" |
||||
|
||||
import re |
||||
from typing import Dict, Optional |
||||
|
||||
from pydolphinscheduler.constants import TaskType |
||||
from pydolphinscheduler.core.task import Task, TaskParams |
||||
from pydolphinscheduler.java_gateway import launch_gateway |
||||
|
||||
|
||||
class SqlType: |
||||
"""SQL type, for now it just contain `SELECT` and `NO_SELECT`.""" |
||||
|
||||
SELECT = 0 |
||||
NOT_SELECT = 1 |
||||
|
||||
|
||||
class SqlTaskParams(TaskParams): |
||||
"""Parameter only for Sql task type.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
type: str, |
||||
datasource: str, |
||||
sql: str, |
||||
sql_type: Optional[int] = SqlType.NOT_SELECT, |
||||
display_rows: Optional[int] = 10, |
||||
pre_statements: Optional[str] = None, |
||||
post_statements: Optional[str] = None, |
||||
*args, |
||||
**kwargs |
||||
): |
||||
super().__init__(*args, **kwargs) |
||||
self.type = type |
||||
self.datasource = datasource |
||||
self.sql = sql |
||||
self.sql_type = sql_type |
||||
self.display_rows = display_rows |
||||
self.pre_statements = pre_statements or [] |
||||
self.post_statements = post_statements or [] |
||||
|
||||
|
||||
class Sql(Task): |
||||
"""Task SQL object, declare behavior for SQL task to dolphinscheduler. |
||||
|
||||
It should run sql job in multiply sql lik engine, such as: |
||||
- ClickHouse |
||||
- DB2 |
||||
- HIVE |
||||
- MySQL |
||||
- Oracle |
||||
- Postgresql |
||||
- Presto |
||||
- SQLServer |
||||
You provider datasource_name contain connection information, it decisions which |
||||
database type and database instance would run this sql. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
name: str, |
||||
datasource_name: str, |
||||
sql: str, |
||||
pre_sql: Optional[str] = None, |
||||
post_sql: Optional[str] = None, |
||||
display_rows: Optional[int] = 10, |
||||
*args, |
||||
**kwargs |
||||
): |
||||
self._sql = sql |
||||
self._datasource_name = datasource_name |
||||
self._datasource = {} |
||||
task_params = SqlTaskParams( |
||||
type=self.get_datasource_type(), |
||||
datasource=self.get_datasource_id(), |
||||
sql=sql, |
||||
sql_type=self.sql_type, |
||||
display_rows=display_rows, |
||||
pre_statements=pre_sql, |
||||
post_statements=post_sql, |
||||
) |
||||
super().__init__(name, TaskType.SQL, task_params, *args, **kwargs) |
||||
|
||||
def get_datasource_type(self) -> str: |
||||
"""Get datasource type from java gateway, a wrapper for :func:`get_datasource_info`.""" |
||||
return self.get_datasource_info(self._datasource_name).get("type") |
||||
|
||||
def get_datasource_id(self) -> str: |
||||
"""Get datasource id from java gateway, a wrapper for :func:`get_datasource_info`.""" |
||||
return self.get_datasource_info(self._datasource_name).get("id") |
||||
|
||||
def get_datasource_info(self, name) -> Dict: |
||||
"""Get datasource info from java gateway, contains datasource id, type, name.""" |
||||
if self._datasource: |
||||
return self._datasource |
||||
else: |
||||
gateway = launch_gateway() |
||||
self._datasource = gateway.entry_point.getDatasourceInfo(name) |
||||
return self._datasource |
||||
|
||||
@property |
||||
def sql_type(self) -> int: |
||||
"""Judgement sql type, use regexp to check which type of the sql is.""" |
||||
pattern_select_str = ( |
||||
"^(?!(.* |)insert |(.* |)delete |(.* |)drop |(.* |)update |(.* |)alter ).*" |
||||
) |
||||
pattern_select = re.compile(pattern_select_str, re.IGNORECASE) |
||||
if pattern_select.match(self._sql) is None: |
||||
return SqlType.NOT_SELECT |
||||
else: |
||||
return SqlType.SELECT |
@ -0,0 +1,131 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one |
||||
# or more contributor license agreements. See the NOTICE file |
||||
# distributed with this work for additional information |
||||
# regarding copyright ownership. The ASF licenses this file |
||||
# to you under the Apache License, Version 2.0 (the |
||||
# "License"); you may not use this file except in compliance |
||||
# with the License. You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, |
||||
# software distributed under the License is distributed on an |
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
||||
# KIND, either express or implied. See the License for the |
||||
# specific language governing permissions and limitations |
||||
# under the License. |
||||
|
||||
"""Test Task Sql.""" |
||||
|
||||
|
||||
from unittest.mock import patch |
||||
|
||||
import pytest |
||||
|
||||
from pydolphinscheduler.tasks.sql import Sql, SqlType |
||||
|
||||
|
||||
@patch( |
||||
"pydolphinscheduler.core.task.Task.gen_code_and_version", |
||||
return_value=(123, 1), |
||||
) |
||||
@patch( |
||||
"pydolphinscheduler.tasks.sql.Sql.get_datasource_info", |
||||
return_value=({"id": 1, "type": "mock_type"}), |
||||
) |
||||
def test_get_datasource_detail(mock_datasource, mock_code_version): |
||||
"""Test :func:`get_datasource_type` and :func:`get_datasource_id` can return expect value.""" |
||||
name = "test_get_sql_type" |
||||
datasource_name = "test_datasource" |
||||
sql = "select 1" |
||||
task = Sql(name, datasource_name, sql) |
||||
assert 1 == task.get_datasource_id() |
||||
assert "mock_type" == task.get_datasource_type() |
||||
|
||||
|
||||
@pytest.mark.parametrize( |
||||
"sql, sql_type", |
||||
[ |
||||
("select 1", SqlType.SELECT), |
||||
(" select 1", SqlType.SELECT), |
||||
(" select 1 ", SqlType.SELECT), |
||||
(" select 'insert' ", SqlType.SELECT), |
||||
(" select 'insert ' ", SqlType.SELECT), |
||||
("with tmp as (select 1) select * from tmp ", SqlType.SELECT), |
||||
("insert into table_name(col1, col2) value (val1, val2)", SqlType.NOT_SELECT), |
||||
( |
||||
"insert into table_name(select, col2) value ('select', val2)", |
||||
SqlType.NOT_SELECT, |
||||
), |
||||
("update table_name SET col1=val1 where col1=val2", SqlType.NOT_SELECT), |
||||
("update table_name SET col1='select' where col1=val2", SqlType.NOT_SELECT), |
||||
("delete from table_name where id < 10", SqlType.NOT_SELECT), |
||||
("delete from table_name where id < 10", SqlType.NOT_SELECT), |
||||
("alter table table_name add column col1 int", SqlType.NOT_SELECT), |
||||
], |
||||
) |
||||
@patch( |
||||
"pydolphinscheduler.core.task.Task.gen_code_and_version", |
||||
return_value=(123, 1), |
||||
) |
||||
@patch( |
||||
"pydolphinscheduler.tasks.sql.Sql.get_datasource_info", |
||||
return_value=({"id": 1, "type": "mock_type"}), |
||||
) |
||||
def test_get_sql_type(mock_datasource, mock_code_version, sql, sql_type): |
||||
"""Test property sql_type could return correct type.""" |
||||
name = "test_get_sql_type" |
||||
datasource_name = "test_datasource" |
||||
task = Sql(name, datasource_name, sql) |
||||
assert ( |
||||
sql_type == task.sql_type |
||||
), f"Sql {sql} expect sql type is {sql_type} but got {task.sql_type}" |
||||
|
||||
|
||||
@patch( |
||||
"pydolphinscheduler.tasks.sql.Sql.get_datasource_info", |
||||
return_value=({"id": 1, "type": "MYSQL"}), |
||||
) |
||||
def test_sql_to_dict(mock_datasource): |
||||
"""Test task sql function to_dict.""" |
||||
code = 123 |
||||
version = 1 |
||||
name = "test_sql_dict" |
||||
command = "select 1" |
||||
datasource_name = "test_datasource" |
||||
expect = { |
||||
"code": code, |
||||
"name": name, |
||||
"version": 1, |
||||
"description": None, |
||||
"delayTime": 0, |
||||
"taskType": "SQL", |
||||
"taskParams": { |
||||
"type": "MYSQL", |
||||
"datasource": 1, |
||||
"sql": command, |
||||
"sqlType": SqlType.SELECT, |
||||
"displayRows": 10, |
||||
"preStatements": [], |
||||
"postStatements": [], |
||||
"localParams": [], |
||||
"resourceList": [], |
||||
"dependence": {}, |
||||
"conditionResult": {"successNode": [""], "failedNode": [""]}, |
||||
"waitStartTimeout": {}, |
||||
}, |
||||
"flag": "YES", |
||||
"taskPriority": "MEDIUM", |
||||
"workerGroup": "default", |
||||
"failRetryTimes": 0, |
||||
"failRetryInterval": 1, |
||||
"timeoutFlag": "CLOSE", |
||||
"timeoutNotifyStrategy": None, |
||||
"timeout": 0, |
||||
} |
||||
with patch( |
||||
"pydolphinscheduler.core.task.Task.gen_code_and_version", |
||||
return_value=(code, version), |
||||
): |
||||
task = Sql(name, datasource_name, command) |
||||
assert task.to_dict() == expect |
Loading…
Reference in new issue