From ad683c3c428876916d9550111fc9bbb77afd9e9f Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Sun, 18 Sep 2022 16:28:18 +0800 Subject: [PATCH] [feat][python] Support MLflow task in python api (#11962) --- .../docs/source/tasks/index.rst | 1 + .../docs/source/tasks/mlflow.rst | 42 +++ .../examples/yaml_define/mlflow.yaml | 78 ++++++ .../src/pydolphinscheduler/constants.py | 1 + .../examples/task_mlflow_example.py | 104 +++++++ .../src/pydolphinscheduler/tasks/__init__.py | 10 + .../src/pydolphinscheduler/tasks/mlflow.py | 265 ++++++++++++++++++ .../tests/tasks/test_mlflow.py | 211 ++++++++++++++ 8 files changed, 712 insertions(+) create mode 100644 dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst create mode 100644 dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml create mode 100644 dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py create mode 100644 dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py create mode 100644 dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py diff --git a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst index c0cea593d8..3f83f92675 100644 --- a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst +++ b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst @@ -42,6 +42,7 @@ In this section sub_process sagemaker + mlflow openmldb pytorch dvc diff --git a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst new file mode 100644 index 0000000000..b83903c26f --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst @@ -0,0 +1,42 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +MLflow +========= + + +A MLflow task type's example and dive into information of **PyDolphinScheduler**. + +Example +------- + +.. literalinclude:: ../../../src/pydolphinscheduler/examples/task_mlflow_example.py + :start-after: [start workflow_declare] + :end-before: [end workflow_declare] + +Dive Into +--------- + +.. automodule:: pydolphinscheduler.tasks.mlflow + + +YAML file example +----------------- + +.. literalinclude:: ../../../examples/yaml_define/mlflow.yaml + :start-after: # under the License. + :language: yaml diff --git a/dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml b/dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml new file mode 100644 index 0000000000..232442a186 --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Define variable `mlflow_tracking_uri` +mlflow_tracking_uri: &mlflow_tracking_uri "http://127.0.0.1:5000" + +# Define the workflow +workflow: + name: "MLflow" + +# Define the tasks under the workflow +tasks: + - name: train_xgboost_native + task_type: MLFlowProjectsCustom + repository: https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native + mlflow_tracking_uri: *mlflow_tracking_uri + parameters: -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 + experiment_name: xgboost + + + - name: deploy_mlflow + deps: [train_xgboost_native] + task_type: MLflowModels + model_uri: models:/xgboost_native/Production + mlflow_tracking_uri: *mlflow_tracking_uri + deploy_mode: MLFLOW + port: 7001 + + - name: train_automl + task_type: MLFlowProjectsAutoML + mlflow_tracking_uri: *mlflow_tracking_uri + parameters: time_budget=30;estimator_list=['lgbm'] + experiment_name: automl_iris + model_name: iris_A + automl_tool: flaml + data_path: /data/examples/iris + + - name: deploy_docker + task_type: MLflowModels + deps: [train_automl] + model_uri: models:/iris_A/Production + mlflow_tracking_uri: *mlflow_tracking_uri + deploy_mode: DOCKER + port: 7002 + + - name: train_basic_algorithm + task_type: MLFlowProjectsBasicAlgorithm + mlflow_tracking_uri: *mlflow_tracking_uri + parameters: n_estimators=200;learning_rate=0.2 + experiment_name: basic_algorithm_iris + model_name: iris_B + algorithm: lightgbm + data_path: /data/examples/iris + search_params: max_depth=[5, 10];n_estimators=[100, 200] + + + - name: deploy_docker_compose + task_type: MLflowModels + deps: [train_basic_algorithm] + model_uri: models:/iris_B/Production + mlflow_tracking_uri: *mlflow_tracking_uri + deploy_mode: DOCKER COMPOSE + port: 7003 diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py index b4a89bb585..fd640c512f 100644 --- a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py +++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py @@ -58,6 +58,7 @@ class TaskType(str): SPARK = "SPARK" MR = "MR" SAGEMAKER = "SAGEMAKER" + MLFLOW = "MLFLOW" OPENMLDB = "OPENMLDB" PYTORCH = "PYTORCH" DVC = "DVC" diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py new file mode 100644 index 0000000000..328688e646 --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# [start workflow_declare] +"""A example workflow for task mlflow.""" + +from pydolphinscheduler.core.process_definition import ProcessDefinition +from pydolphinscheduler.tasks.mlflow import ( + MLflowDeployType, + MLflowModels, + MLFlowProjectsAutoML, + MLFlowProjectsBasicAlgorithm, + MLFlowProjectsCustom, +) + +mlflow_tracking_uri = "http://127.0.0.1:5000" + +with ProcessDefinition( + name="task_mlflow_example", + tenant="tenant_exists", +) as pd: + + # run custom mlflow project to train model + train_custom = MLFlowProjectsCustom( + name="train_xgboost_native", + repository="https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native", + mlflow_tracking_uri=mlflow_tracking_uri, + parameters="-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9", + experiment_name="xgboost", + ) + + # Using MLFLOW to deploy model from custom mlflow project + deploy_mlflow = MLflowModels( + name="deploy_mlflow", + model_uri="models:/xgboost_native/Production", + mlflow_tracking_uri=mlflow_tracking_uri, + deploy_mode=MLflowDeployType.MLFLOW, + port=7001, + ) + + train_custom >> deploy_mlflow + + # run automl to train model + train_automl = MLFlowProjectsAutoML( + name="train_automl", + mlflow_tracking_uri=mlflow_tracking_uri, + parameters="time_budget=30;estimator_list=['lgbm']", + experiment_name="automl_iris", + model_name="iris_A", + automl_tool="flaml", + data_path="/data/examples/iris", + ) + + # Using DOCKER to deploy model from train_automl + deploy_docker = MLflowModels( + name="deploy_docker", + model_uri="models:/iris_A/Production", + mlflow_tracking_uri=mlflow_tracking_uri, + deploy_mode=MLflowDeployType.DOCKER, + port=7002, + ) + + train_automl >> deploy_docker + + # run lightgbm to train model + train_basic_algorithm = MLFlowProjectsBasicAlgorithm( + name="train_basic_algorithm", + mlflow_tracking_uri=mlflow_tracking_uri, + parameters="n_estimators=200;learning_rate=0.2", + experiment_name="basic_algorithm_iris", + model_name="iris_B", + algorithm="lightgbm", + data_path="/data/examples/iris", + search_params="max_depth=[5, 10];n_estimators=[100, 200]", + ) + + # Using DOCKER COMPOSE to deploy model from train_basic_algorithm + deploy_docker_compose = MLflowModels( + name="deploy_docker_compose", + model_uri="models:/iris_B/Production", + mlflow_tracking_uri=mlflow_tracking_uri, + deploy_mode=MLflowDeployType.DOCKER_COMPOSE, + port=7003, + ) + + train_basic_algorithm >> deploy_docker_compose + + pd.submit() + +# [end workflow_declare] diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py index cefc0024ca..972b1b76dd 100644 --- a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py +++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py @@ -24,6 +24,12 @@ from pydolphinscheduler.tasks.dvc import DVCDownload, DVCInit, DVCUpload from pydolphinscheduler.tasks.flink import Flink from pydolphinscheduler.tasks.http import Http from pydolphinscheduler.tasks.map_reduce import MR +from pydolphinscheduler.tasks.mlflow import ( + MLflowModels, + MLFlowProjectsAutoML, + MLFlowProjectsBasicAlgorithm, + MLFlowProjectsCustom, +) from pydolphinscheduler.tasks.openmldb import OpenMLDB from pydolphinscheduler.tasks.procedure import Procedure from pydolphinscheduler.tasks.python import Python @@ -47,6 +53,10 @@ __all__ = [ "Http", "MR", "OpenMLDB", + "MLFlowProjectsBasicAlgorithm", + "MLFlowProjectsCustom", + "MLFlowProjectsAutoML", + "MLflowModels", "Procedure", "Python", "Pytorch", diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py new file mode 100644 index 0000000000..44e6634822 --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Task mlflow.""" +from copy import deepcopy +from typing import Dict, Optional + +from pydolphinscheduler.constants import TaskType +from pydolphinscheduler.core.task import Task + + +class MLflowTaskType(str): + """MLflow task type.""" + + MLFLOW_PROJECTS = "MLflow Projects" + MLFLOW_MODELS = "MLflow Models" + + +class MLflowJobType(str): + """MLflow job type.""" + + AUTOML = "AutoML" + BASIC_ALGORITHM = "BasicAlgorithm" + CUSTOM_PROJECT = "CustomProject" + + +class MLflowDeployType(str): + """MLflow deploy type.""" + + MLFLOW = "MLFLOW" + DOCKER = "DOCKER" + DOCKER_COMPOSE = "DOCKER COMPOSE" + + +DEFAULT_MLFLOW_TRACKING_URI = "http://127.0.0.1:5000" +DEFAULT_VERSION = "master" + + +class BaseMLflow(Task): + """Base MLflow task.""" + + mlflow_task_type = None + + _task_custom_attr = { + "mlflow_tracking_uri", + "mlflow_task_type", + } + + _child_task_mlflow_attr = set() + + def __init__(self, name: str, mlflow_tracking_uri: str, *args, **kwargs): + super().__init__(name, TaskType.MLFLOW, *args, **kwargs) + self.mlflow_tracking_uri = mlflow_tracking_uri + + @property + def task_params(self) -> Dict: + """Return task params.""" + self._task_custom_attr = deepcopy(self._task_custom_attr) + self._task_custom_attr.update(self._child_task_mlflow_attr) + return super().task_params + + +class MLflowModels(BaseMLflow): + """Task MLflow models object, declare behavior for MLflow models task to dolphinscheduler. + + Deploy machine learning models in diverse serving environments. + + :param name: task name + :param model_uri: Model-URI of MLflow , support models://suffix format and runs:/ format. + See https://mlflow.org/docs/latest/tracking.html#artifact-stores + :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000 + :param deploy_mode: MLflow deploy mode, support MLFLOW, DOCKER, DOCKER COMPOSE, default is DOCKER + :param port: deploy port, default is 7000 + :param cpu_limit: cpu limit, default is 1.0 + :param memory_limit: memory limit, default is 500M + """ + + mlflow_task_type = MLflowTaskType.MLFLOW_MODELS + + _child_task_mlflow_attr = { + "deploy_type", + "deploy_model_key", + "deploy_port", + "cpu_limit", + "memory_limit", + } + + def __init__( + self, + name: str, + model_uri: str, + mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI, + deploy_mode: Optional[str] = MLflowDeployType.DOCKER, + port: Optional[int] = 7000, + cpu_limit: Optional[float] = 1.0, + memory_limit: Optional[str] = "500M", + *args, + **kwargs + ): + """Init mlflow models task.""" + super().__init__(name, mlflow_tracking_uri, *args, **kwargs) + self.deploy_type = deploy_mode.upper() + self.deploy_model_key = model_uri + self.deploy_port = port + self.cpu_limit = cpu_limit + self.memory_limit = memory_limit + + +class MLFlowProjectsCustom(BaseMLflow): + """Task MLflow projects object, declare behavior for MLflow Custom projects task to dolphinscheduler. + + :param name: task name + :param repository: Repository url of MLflow Project, Support git address and directory on worker. + If it's in a subdirectory, We add # to support this (same as mlflow run) , + for example https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native. + :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000 + :param experiment_name: MLflow experiment name, default is empty + :param parameters: MLflow project parameters, default is empty + :param version: MLflow project version, default is master + + """ + + mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS + mlflow_job_type = MLflowJobType.CUSTOM_PROJECT + + _child_task_mlflow_attr = { + "mlflow_job_type", + "experiment_name", + "params", + "mlflow_project_repository", + "mlflow_project_version", + } + + def __init__( + self, + name: str, + repository: str, + mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI, + experiment_name: Optional[str] = "", + parameters: Optional[str] = "", + version: Optional[str] = "master", + *args, + **kwargs + ): + """Init mlflow projects task.""" + super().__init__(name, mlflow_tracking_uri, *args, **kwargs) + self.mlflow_project_repository = repository + self.experiment_name = experiment_name + self.params = parameters + self.mlflow_project_version = version + + +class MLFlowProjectsAutoML(BaseMLflow): + """Task MLflow projects object, declare behavior for AutoML task to dolphinscheduler. + + :param name: task name + :param data_path: data path of MLflow Project, Support git address and directory on worker. + :param automl_tool: The AutoML tool used, currently supports autosklearn and flaml. + :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000 + :param experiment_name: MLflow experiment name, default is empty + :param model_name: MLflow model name, default is empty + :param parameters: MLflow project parameters, default is empty + + """ + + mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS + mlflow_job_type = MLflowJobType.AUTOML + + _child_task_mlflow_attr = { + "mlflow_job_type", + "experiment_name", + "model_name", + "register_model", + "data_path", + "params", + "automl_tool", + } + + def __init__( + self, + name: str, + data_path: str, + automl_tool: Optional[str] = "flaml", + mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI, + experiment_name: Optional[str] = "", + model_name: Optional[str] = "", + parameters: Optional[str] = "", + *args, + **kwargs + ): + """Init mlflow projects task.""" + super().__init__(name, mlflow_tracking_uri, *args, **kwargs) + self.data_path = data_path + self.experiment_name = experiment_name + self.model_name = model_name + self.params = parameters + self.automl_tool = automl_tool.lower() + self.register_model = bool(model_name) + + +class MLFlowProjectsBasicAlgorithm(BaseMLflow): + """Task MLflow projects object, declare behavior for BasicAlgorithm task to dolphinscheduler. + + :param name: task name + :param data_path: data path of MLflow Project, Support git address and directory on worker. + :param algorithm: The selected algorithm currently supports LR, SVM, LightGBM and XGboost + based on scikit-learn form. + :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000 + :param experiment_name: MLflow experiment name, default is empty + :param model_name: MLflow model name, default is empty + :param parameters: MLflow project parameters, default is empty + :param search_params: Whether to search the parameters, default is empty + + """ + + mlflow_job_type = MLflowJobType.BASIC_ALGORITHM + mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS + + _child_task_mlflow_attr = { + "mlflow_job_type", + "experiment_name", + "model_name", + "register_model", + "data_path", + "params", + "algorithm", + "search_params", + } + + def __init__( + self, + name: str, + data_path: str, + algorithm: Optional[str] = "lightgbm", + mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI, + experiment_name: Optional[str] = "", + model_name: Optional[str] = "", + parameters: Optional[str] = "", + search_params: Optional[str] = "", + *args, + **kwargs + ): + """Init mlflow projects task.""" + super().__init__(name, mlflow_tracking_uri, *args, **kwargs) + self.data_path = data_path + self.experiment_name = experiment_name + self.model_name = model_name + self.params = parameters + self.algorithm = algorithm.lower() + self.search_params = search_params + self.register_model = bool(model_name) diff --git a/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py b/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py new file mode 100644 index 0000000000..2159b6c77e --- /dev/null +++ b/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test Task MLflow.""" +from copy import deepcopy +from unittest.mock import patch + +from pydolphinscheduler.tasks.mlflow import ( + MLflowDeployType, + MLflowJobType, + MLflowModels, + MLFlowProjectsAutoML, + MLFlowProjectsBasicAlgorithm, + MLFlowProjectsCustom, + MLflowTaskType, +) + +CODE = 123 +VERSION = 1 +MLFLOW_TRACKING_URI = "http://127.0.0.1:5000" + +EXPECT = { + "code": CODE, + "version": VERSION, + "description": None, + "delayTime": 0, + "taskType": "MLFLOW", + "taskParams": { + "resourceList": [], + "localParams": [], + "dependence": {}, + "conditionResult": {"successNode": [""], "failedNode": [""]}, + "waitStartTimeout": {}, + }, + "flag": "YES", + "taskPriority": "MEDIUM", + "workerGroup": "default", + "environmentCode": None, + "failRetryTimes": 0, + "failRetryInterval": 1, + "timeoutFlag": "CLOSE", + "timeoutNotifyStrategy": None, + "timeout": 0, +} + + +def test_mlflow_models_get_define(): + """Test task mlflow models function get_define.""" + name = "mlflow_models" + model_uri = "models:/xgboost_native/Production" + port = 7001 + cpu_limit = 2.0 + memory_limit = "600M" + + expect = deepcopy(EXPECT) + expect["name"] = name + task_params = expect["taskParams"] + task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI + task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_MODELS + task_params["deployType"] = MLflowDeployType.DOCKER_COMPOSE + task_params["deployModelKey"] = model_uri + task_params["deployPort"] = port + task_params["cpuLimit"] = cpu_limit + task_params["memoryLimit"] = memory_limit + + with patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(CODE, VERSION), + ): + task = MLflowModels( + name=name, + model_uri=model_uri, + mlflow_tracking_uri=MLFLOW_TRACKING_URI, + deploy_mode=MLflowDeployType.DOCKER_COMPOSE, + port=port, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + assert task.get_define() == expect + + +def test_mlflow_project_custom_get_define(): + """Test task mlflow project custom function get_define.""" + name = ("train_xgboost_native",) + repository = "https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native" + mlflow_tracking_uri = MLFLOW_TRACKING_URI + parameters = "-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9" + experiment_name = "xgboost" + + expect = deepcopy(EXPECT) + expect["name"] = name + task_params = expect["taskParams"] + + task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI + task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS + task_params["mlflowJobType"] = MLflowJobType.CUSTOM_PROJECT + task_params["experimentName"] = experiment_name + task_params["params"] = parameters + task_params["mlflowProjectRepository"] = repository + task_params["mlflowProjectVersion"] = "dev" + + with patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(CODE, VERSION), + ): + task = MLFlowProjectsCustom( + name=name, + repository=repository, + mlflow_tracking_uri=mlflow_tracking_uri, + parameters=parameters, + experiment_name=experiment_name, + version="dev", + ) + assert task.get_define() == expect + + +def test_mlflow_project_automl_get_define(): + """Test task mlflow project automl function get_define.""" + name = ("train_automl",) + mlflow_tracking_uri = MLFLOW_TRACKING_URI + parameters = "time_budget=30;estimator_list=['lgbm']" + experiment_name = "automl_iris" + model_name = "iris_A" + automl_tool = "flaml" + data_path = "/data/examples/iris" + + expect = deepcopy(EXPECT) + expect["name"] = name + task_params = expect["taskParams"] + + task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI + task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS + task_params["mlflowJobType"] = MLflowJobType.AUTOML + task_params["experimentName"] = experiment_name + task_params["modelName"] = model_name + task_params["registerModel"] = bool(model_name) + task_params["dataPath"] = data_path + task_params["params"] = parameters + task_params["automlTool"] = automl_tool + + with patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(CODE, VERSION), + ): + task = MLFlowProjectsAutoML( + name=name, + mlflow_tracking_uri=mlflow_tracking_uri, + parameters=parameters, + experiment_name=experiment_name, + model_name=model_name, + automl_tool=automl_tool, + data_path=data_path, + ) + assert task.get_define() == expect + + +def test_mlflow_project_basic_algorithm_get_define(): + """Test task mlflow project BasicAlgorithm function get_define.""" + name = "train_basic_algorithm" + mlflow_tracking_uri = MLFLOW_TRACKING_URI + parameters = "n_estimators=200;learning_rate=0.2" + experiment_name = "basic_algorithm_iris" + model_name = "iris_B" + algorithm = "lightgbm" + data_path = "/data/examples/iris" + search_params = "max_depth=[5, 10];n_estimators=[100, 200]" + + expect = deepcopy(EXPECT) + expect["name"] = name + task_params = expect["taskParams"] + + task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI + task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS + task_params["mlflowJobType"] = MLflowJobType.BASIC_ALGORITHM + task_params["experimentName"] = experiment_name + task_params["modelName"] = model_name + task_params["registerModel"] = bool(model_name) + task_params["dataPath"] = data_path + task_params["params"] = parameters + task_params["algorithm"] = algorithm + task_params["searchParams"] = search_params + + with patch( + "pydolphinscheduler.core.task.Task.gen_code_and_version", + return_value=(CODE, VERSION), + ): + task = MLFlowProjectsBasicAlgorithm( + name=name, + mlflow_tracking_uri=mlflow_tracking_uri, + parameters=parameters, + experiment_name=experiment_name, + model_name=model_name, + algorithm=algorithm, + data_path=data_path, + search_params=search_params, + ) + assert task.get_define() == expect