Browse Source

[feat][python] Support MLflow task in python api (#11962)

(cherry picked from commit ad683c3c42)
3.1.0-release
JieguangZhou 2 years ago committed by Jiajie Zhong
parent
commit
cb063732d7
  1. 1
      dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst
  2. 42
      dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst
  3. 78
      dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml
  4. 1
      dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py
  5. 104
      dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py
  6. 10
      dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py
  7. 265
      dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py
  8. 211
      dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py

1
dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst

@ -42,5 +42,6 @@ In this section
sub_process
sagemaker
mlflow
openmldb
pytorch

42
dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst

@ -0,0 +1,42 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
MLflow
=========
A MLflow task type's example and dive into information of **PyDolphinScheduler**.
Example
-------
.. literalinclude:: ../../../src/pydolphinscheduler/examples/task_mlflow_example.py
:start-after: [start workflow_declare]
:end-before: [end workflow_declare]
Dive Into
---------
.. automodule:: pydolphinscheduler.tasks.mlflow
YAML file example
-----------------
.. literalinclude:: ../../../examples/yaml_define/mlflow.yaml
:start-after: # under the License.
:language: yaml

78
dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml

@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Define variable `mlflow_tracking_uri`
mlflow_tracking_uri: &mlflow_tracking_uri "http://127.0.0.1:5000"
# Define the workflow
workflow:
name: "MLflow"
# Define the tasks under the workflow
tasks:
- name: train_xgboost_native
task_type: MLFlowProjectsCustom
repository: https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native
mlflow_tracking_uri: *mlflow_tracking_uri
parameters: -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9
experiment_name: xgboost
- name: deploy_mlflow
deps: [train_xgboost_native]
task_type: MLflowModels
model_uri: models:/xgboost_native/Production
mlflow_tracking_uri: *mlflow_tracking_uri
deploy_mode: MLFLOW
port: 7001
- name: train_automl
task_type: MLFlowProjectsAutoML
mlflow_tracking_uri: *mlflow_tracking_uri
parameters: time_budget=30;estimator_list=['lgbm']
experiment_name: automl_iris
model_name: iris_A
automl_tool: flaml
data_path: /data/examples/iris
- name: deploy_docker
task_type: MLflowModels
deps: [train_automl]
model_uri: models:/iris_A/Production
mlflow_tracking_uri: *mlflow_tracking_uri
deploy_mode: DOCKER
port: 7002
- name: train_basic_algorithm
task_type: MLFlowProjectsBasicAlgorithm
mlflow_tracking_uri: *mlflow_tracking_uri
parameters: n_estimators=200;learning_rate=0.2
experiment_name: basic_algorithm_iris
model_name: iris_B
algorithm: lightgbm
data_path: /data/examples/iris
search_params: max_depth=[5, 10];n_estimators=[100, 200]
- name: deploy_docker_compose
task_type: MLflowModels
deps: [train_basic_algorithm]
model_uri: models:/iris_B/Production
mlflow_tracking_uri: *mlflow_tracking_uri
deploy_mode: DOCKER COMPOSE
port: 7003

1
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py

@ -58,6 +58,7 @@ class TaskType(str):
SPARK = "SPARK"
MR = "MR"
SAGEMAKER = "SAGEMAKER"
MLFLOW = "MLFLOW"
OPENMLDB = "OPENMLDB"
PYTORCH = "PYTORCH"

104
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py

@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# [start workflow_declare]
"""A example workflow for task mlflow."""
from pydolphinscheduler.core.process_definition import ProcessDefinition
from pydolphinscheduler.tasks.mlflow import (
MLflowDeployType,
MLflowModels,
MLFlowProjectsAutoML,
MLFlowProjectsBasicAlgorithm,
MLFlowProjectsCustom,
)
mlflow_tracking_uri = "http://127.0.0.1:5000"
with ProcessDefinition(
name="task_mlflow_example",
tenant="tenant_exists",
) as pd:
# run custom mlflow project to train model
train_custom = MLFlowProjectsCustom(
name="train_xgboost_native",
repository="https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native",
mlflow_tracking_uri=mlflow_tracking_uri,
parameters="-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9",
experiment_name="xgboost",
)
# Using MLFLOW to deploy model from custom mlflow project
deploy_mlflow = MLflowModels(
name="deploy_mlflow",
model_uri="models:/xgboost_native/Production",
mlflow_tracking_uri=mlflow_tracking_uri,
deploy_mode=MLflowDeployType.MLFLOW,
port=7001,
)
train_custom >> deploy_mlflow
# run automl to train model
train_automl = MLFlowProjectsAutoML(
name="train_automl",
mlflow_tracking_uri=mlflow_tracking_uri,
parameters="time_budget=30;estimator_list=['lgbm']",
experiment_name="automl_iris",
model_name="iris_A",
automl_tool="flaml",
data_path="/data/examples/iris",
)
# Using DOCKER to deploy model from train_automl
deploy_docker = MLflowModels(
name="deploy_docker",
model_uri="models:/iris_A/Production",
mlflow_tracking_uri=mlflow_tracking_uri,
deploy_mode=MLflowDeployType.DOCKER,
port=7002,
)
train_automl >> deploy_docker
# run lightgbm to train model
train_basic_algorithm = MLFlowProjectsBasicAlgorithm(
name="train_basic_algorithm",
mlflow_tracking_uri=mlflow_tracking_uri,
parameters="n_estimators=200;learning_rate=0.2",
experiment_name="basic_algorithm_iris",
model_name="iris_B",
algorithm="lightgbm",
data_path="/data/examples/iris",
search_params="max_depth=[5, 10];n_estimators=[100, 200]",
)
# Using DOCKER COMPOSE to deploy model from train_basic_algorithm
deploy_docker_compose = MLflowModels(
name="deploy_docker_compose",
model_uri="models:/iris_B/Production",
mlflow_tracking_uri=mlflow_tracking_uri,
deploy_mode=MLflowDeployType.DOCKER_COMPOSE,
port=7003,
)
train_basic_algorithm >> deploy_docker_compose
pd.submit()
# [end workflow_declare]

10
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py

@ -23,6 +23,12 @@ from pydolphinscheduler.tasks.dependent import Dependent
from pydolphinscheduler.tasks.flink import Flink
from pydolphinscheduler.tasks.http import Http
from pydolphinscheduler.tasks.map_reduce import MR
from pydolphinscheduler.tasks.mlflow import (
MLflowModels,
MLFlowProjectsAutoML,
MLFlowProjectsBasicAlgorithm,
MLFlowProjectsCustom,
)
from pydolphinscheduler.tasks.openmldb import OpenMLDB
from pydolphinscheduler.tasks.procedure import Procedure
from pydolphinscheduler.tasks.python import Python
@ -43,6 +49,10 @@ __all__ = [
"Http",
"MR",
"OpenMLDB",
"MLFlowProjectsBasicAlgorithm",
"MLFlowProjectsCustom",
"MLFlowProjectsAutoML",
"MLflowModels",
"Procedure",
"Python",
"Pytorch",

265
dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py

@ -0,0 +1,265 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Task mlflow."""
from copy import deepcopy
from typing import Dict, Optional
from pydolphinscheduler.constants import TaskType
from pydolphinscheduler.core.task import Task
class MLflowTaskType(str):
"""MLflow task type."""
MLFLOW_PROJECTS = "MLflow Projects"
MLFLOW_MODELS = "MLflow Models"
class MLflowJobType(str):
"""MLflow job type."""
AUTOML = "AutoML"
BASIC_ALGORITHM = "BasicAlgorithm"
CUSTOM_PROJECT = "CustomProject"
class MLflowDeployType(str):
"""MLflow deploy type."""
MLFLOW = "MLFLOW"
DOCKER = "DOCKER"
DOCKER_COMPOSE = "DOCKER COMPOSE"
DEFAULT_MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"
DEFAULT_VERSION = "master"
class BaseMLflow(Task):
"""Base MLflow task."""
mlflow_task_type = None
_task_custom_attr = {
"mlflow_tracking_uri",
"mlflow_task_type",
}
_child_task_mlflow_attr = set()
def __init__(self, name: str, mlflow_tracking_uri: str, *args, **kwargs):
super().__init__(name, TaskType.MLFLOW, *args, **kwargs)
self.mlflow_tracking_uri = mlflow_tracking_uri
@property
def task_params(self) -> Dict:
"""Return task params."""
self._task_custom_attr = deepcopy(self._task_custom_attr)
self._task_custom_attr.update(self._child_task_mlflow_attr)
return super().task_params
class MLflowModels(BaseMLflow):
"""Task MLflow models object, declare behavior for MLflow models task to dolphinscheduler.
Deploy machine learning models in diverse serving environments.
:param name: task name
:param model_uri: Model-URI of MLflow , support models:/<model_name>/suffix format and runs:/ format.
See https://mlflow.org/docs/latest/tracking.html#artifact-stores
:param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
:param deploy_mode: MLflow deploy mode, support MLFLOW, DOCKER, DOCKER COMPOSE, default is DOCKER
:param port: deploy port, default is 7000
:param cpu_limit: cpu limit, default is 1.0
:param memory_limit: memory limit, default is 500M
"""
mlflow_task_type = MLflowTaskType.MLFLOW_MODELS
_child_task_mlflow_attr = {
"deploy_type",
"deploy_model_key",
"deploy_port",
"cpu_limit",
"memory_limit",
}
def __init__(
self,
name: str,
model_uri: str,
mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
deploy_mode: Optional[str] = MLflowDeployType.DOCKER,
port: Optional[int] = 7000,
cpu_limit: Optional[float] = 1.0,
memory_limit: Optional[str] = "500M",
*args,
**kwargs
):
"""Init mlflow models task."""
super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
self.deploy_type = deploy_mode.upper()
self.deploy_model_key = model_uri
self.deploy_port = port
self.cpu_limit = cpu_limit
self.memory_limit = memory_limit
class MLFlowProjectsCustom(BaseMLflow):
"""Task MLflow projects object, declare behavior for MLflow Custom projects task to dolphinscheduler.
:param name: task name
:param repository: Repository url of MLflow Project, Support git address and directory on worker.
If it's in a subdirectory, We add # to support this (same as mlflow run) ,
for example https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native.
:param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
:param experiment_name: MLflow experiment name, default is empty
:param parameters: MLflow project parameters, default is empty
:param version: MLflow project version, default is master
"""
mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS
mlflow_job_type = MLflowJobType.CUSTOM_PROJECT
_child_task_mlflow_attr = {
"mlflow_job_type",
"experiment_name",
"params",
"mlflow_project_repository",
"mlflow_project_version",
}
def __init__(
self,
name: str,
repository: str,
mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
experiment_name: Optional[str] = "",
parameters: Optional[str] = "",
version: Optional[str] = "master",
*args,
**kwargs
):
"""Init mlflow projects task."""
super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
self.mlflow_project_repository = repository
self.experiment_name = experiment_name
self.params = parameters
self.mlflow_project_version = version
class MLFlowProjectsAutoML(BaseMLflow):
"""Task MLflow projects object, declare behavior for AutoML task to dolphinscheduler.
:param name: task name
:param data_path: data path of MLflow Project, Support git address and directory on worker.
:param automl_tool: The AutoML tool used, currently supports autosklearn and flaml.
:param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
:param experiment_name: MLflow experiment name, default is empty
:param model_name: MLflow model name, default is empty
:param parameters: MLflow project parameters, default is empty
"""
mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS
mlflow_job_type = MLflowJobType.AUTOML
_child_task_mlflow_attr = {
"mlflow_job_type",
"experiment_name",
"model_name",
"register_model",
"data_path",
"params",
"automl_tool",
}
def __init__(
self,
name: str,
data_path: str,
automl_tool: Optional[str] = "flaml",
mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
experiment_name: Optional[str] = "",
model_name: Optional[str] = "",
parameters: Optional[str] = "",
*args,
**kwargs
):
"""Init mlflow projects task."""
super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
self.data_path = data_path
self.experiment_name = experiment_name
self.model_name = model_name
self.params = parameters
self.automl_tool = automl_tool.lower()
self.register_model = bool(model_name)
class MLFlowProjectsBasicAlgorithm(BaseMLflow):
"""Task MLflow projects object, declare behavior for BasicAlgorithm task to dolphinscheduler.
:param name: task name
:param data_path: data path of MLflow Project, Support git address and directory on worker.
:param algorithm: The selected algorithm currently supports LR, SVM, LightGBM and XGboost
based on scikit-learn form.
:param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
:param experiment_name: MLflow experiment name, default is empty
:param model_name: MLflow model name, default is empty
:param parameters: MLflow project parameters, default is empty
:param search_params: Whether to search the parameters, default is empty
"""
mlflow_job_type = MLflowJobType.BASIC_ALGORITHM
mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS
_child_task_mlflow_attr = {
"mlflow_job_type",
"experiment_name",
"model_name",
"register_model",
"data_path",
"params",
"algorithm",
"search_params",
}
def __init__(
self,
name: str,
data_path: str,
algorithm: Optional[str] = "lightgbm",
mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
experiment_name: Optional[str] = "",
model_name: Optional[str] = "",
parameters: Optional[str] = "",
search_params: Optional[str] = "",
*args,
**kwargs
):
"""Init mlflow projects task."""
super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
self.data_path = data_path
self.experiment_name = experiment_name
self.model_name = model_name
self.params = parameters
self.algorithm = algorithm.lower()
self.search_params = search_params
self.register_model = bool(model_name)

211
dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py

@ -0,0 +1,211 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test Task MLflow."""
from copy import deepcopy
from unittest.mock import patch
from pydolphinscheduler.tasks.mlflow import (
MLflowDeployType,
MLflowJobType,
MLflowModels,
MLFlowProjectsAutoML,
MLFlowProjectsBasicAlgorithm,
MLFlowProjectsCustom,
MLflowTaskType,
)
CODE = 123
VERSION = 1
MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"
EXPECT = {
"code": CODE,
"version": VERSION,
"description": None,
"delayTime": 0,
"taskType": "MLFLOW",
"taskParams": {
"resourceList": [],
"localParams": [],
"dependence": {},
"conditionResult": {"successNode": [""], "failedNode": [""]},
"waitStartTimeout": {},
},
"flag": "YES",
"taskPriority": "MEDIUM",
"workerGroup": "default",
"environmentCode": None,
"failRetryTimes": 0,
"failRetryInterval": 1,
"timeoutFlag": "CLOSE",
"timeoutNotifyStrategy": None,
"timeout": 0,
}
def test_mlflow_models_get_define():
"""Test task mlflow models function get_define."""
name = "mlflow_models"
model_uri = "models:/xgboost_native/Production"
port = 7001
cpu_limit = 2.0
memory_limit = "600M"
expect = deepcopy(EXPECT)
expect["name"] = name
task_params = expect["taskParams"]
task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_MODELS
task_params["deployType"] = MLflowDeployType.DOCKER_COMPOSE
task_params["deployModelKey"] = model_uri
task_params["deployPort"] = port
task_params["cpuLimit"] = cpu_limit
task_params["memoryLimit"] = memory_limit
with patch(
"pydolphinscheduler.core.task.Task.gen_code_and_version",
return_value=(CODE, VERSION),
):
task = MLflowModels(
name=name,
model_uri=model_uri,
mlflow_tracking_uri=MLFLOW_TRACKING_URI,
deploy_mode=MLflowDeployType.DOCKER_COMPOSE,
port=port,
cpu_limit=cpu_limit,
memory_limit=memory_limit,
)
assert task.get_define() == expect
def test_mlflow_project_custom_get_define():
"""Test task mlflow project custom function get_define."""
name = ("train_xgboost_native",)
repository = "https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native"
mlflow_tracking_uri = MLFLOW_TRACKING_URI
parameters = "-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9"
experiment_name = "xgboost"
expect = deepcopy(EXPECT)
expect["name"] = name
task_params = expect["taskParams"]
task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS
task_params["mlflowJobType"] = MLflowJobType.CUSTOM_PROJECT
task_params["experimentName"] = experiment_name
task_params["params"] = parameters
task_params["mlflowProjectRepository"] = repository
task_params["mlflowProjectVersion"] = "dev"
with patch(
"pydolphinscheduler.core.task.Task.gen_code_and_version",
return_value=(CODE, VERSION),
):
task = MLFlowProjectsCustom(
name=name,
repository=repository,
mlflow_tracking_uri=mlflow_tracking_uri,
parameters=parameters,
experiment_name=experiment_name,
version="dev",
)
assert task.get_define() == expect
def test_mlflow_project_automl_get_define():
"""Test task mlflow project automl function get_define."""
name = ("train_automl",)
mlflow_tracking_uri = MLFLOW_TRACKING_URI
parameters = "time_budget=30;estimator_list=['lgbm']"
experiment_name = "automl_iris"
model_name = "iris_A"
automl_tool = "flaml"
data_path = "/data/examples/iris"
expect = deepcopy(EXPECT)
expect["name"] = name
task_params = expect["taskParams"]
task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS
task_params["mlflowJobType"] = MLflowJobType.AUTOML
task_params["experimentName"] = experiment_name
task_params["modelName"] = model_name
task_params["registerModel"] = bool(model_name)
task_params["dataPath"] = data_path
task_params["params"] = parameters
task_params["automlTool"] = automl_tool
with patch(
"pydolphinscheduler.core.task.Task.gen_code_and_version",
return_value=(CODE, VERSION),
):
task = MLFlowProjectsAutoML(
name=name,
mlflow_tracking_uri=mlflow_tracking_uri,
parameters=parameters,
experiment_name=experiment_name,
model_name=model_name,
automl_tool=automl_tool,
data_path=data_path,
)
assert task.get_define() == expect
def test_mlflow_project_basic_algorithm_get_define():
"""Test task mlflow project BasicAlgorithm function get_define."""
name = "train_basic_algorithm"
mlflow_tracking_uri = MLFLOW_TRACKING_URI
parameters = "n_estimators=200;learning_rate=0.2"
experiment_name = "basic_algorithm_iris"
model_name = "iris_B"
algorithm = "lightgbm"
data_path = "/data/examples/iris"
search_params = "max_depth=[5, 10];n_estimators=[100, 200]"
expect = deepcopy(EXPECT)
expect["name"] = name
task_params = expect["taskParams"]
task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS
task_params["mlflowJobType"] = MLflowJobType.BASIC_ALGORITHM
task_params["experimentName"] = experiment_name
task_params["modelName"] = model_name
task_params["registerModel"] = bool(model_name)
task_params["dataPath"] = data_path
task_params["params"] = parameters
task_params["algorithm"] = algorithm
task_params["searchParams"] = search_params
with patch(
"pydolphinscheduler.core.task.Task.gen_code_and_version",
return_value=(CODE, VERSION),
):
task = MLFlowProjectsBasicAlgorithm(
name=name,
mlflow_tracking_uri=mlflow_tracking_uri,
parameters=parameters,
experiment_name=experiment_name,
model_name=model_name,
algorithm=algorithm,
data_path=data_path,
search_params=search_params,
)
assert task.get_define() == expect
Loading…
Cancel
Save