* [Feature][Task Plugin] Add mlflow task plugin (#9725) * [Feature][Task Plugin] Add mlflow task plugin UI (#9725) * [Feature][Task Plugin] fix license header (#9725) * [Feature][Task Plugin] fix license header (#9725) * [Feature][Task Plugin] revert unnecessary * [Feature][Task Plugin] add auto ml to mlflow task plugin * [DOC] add mlflow document * [DOC] fix mlflow docs imgs * [DOC] fix dead link localhost:5000 * [DOC] fix dead link localhost:5000 * [DOC] remove dead link localhost:5000 * Update docs/docs/en/guide/task/mlflow.md * Update docs/docs/zh/guide/task/mlflow.md * [DOC] format ui code * [DOC] remove dead link localhost:5000 * [Feature][Task Plugin] revert unnecessary * fix some nits * Move the personal repository to the public repository. * Run the command directly instead of saving the file * fix paramsMap initialize * revert unnecessary Co-authored-by: Jiajie Zhong <zhongjiajie955@gmail.com>3.1.0-release
@ -0,0 +1,117 @@ |
|||||||
|
# MLflow Node |
||||||
|
|
||||||
|
## Overview |
||||||
|
|
||||||
|
[MLflow](https://mlflow.org) is an excellent open source platform to manage the ML lifecycle, including experimentation, |
||||||
|
reproducibility, deployment, and a central model registry. |
||||||
|
|
||||||
|
Mlflow task is used to perform mlflow project tasks, which include basic algorithmic and autoML capabilities ( |
||||||
|
User-defined MLFlow project task execution will be supported in the near future) |
||||||
|
|
||||||
|
## Create Task |
||||||
|
|
||||||
|
- Click `Project -> Management-Project -> Name-Workflow Definition`, and click the "Create Workflow" button to enter the |
||||||
|
DAG editing page. |
||||||
|
- Drag from the toolbar <img src="/img/tasks/icons/mlflow.png" width="15"/> task node to canvas. |
||||||
|
|
||||||
|
## Task Parameter |
||||||
|
|
||||||
|
- DolphinScheduler common parameters |
||||||
|
- **Node name**: The node name in a workflow definition is unique. |
||||||
|
- **Run flag**: Identifies whether this node schedules normally, if it does not need to execute, select |
||||||
|
the `prohibition execution`. |
||||||
|
- **Descriptive information**: Describe the function of the node. |
||||||
|
- **Task priority**: When the number of worker threads is insufficient, execute in the order of priority from high |
||||||
|
to low, and tasks with the same priority will execute in a first-in first-out order. |
||||||
|
- **Worker grouping**: Assign tasks to the machines of the worker group to execute. If `Default` is selected, |
||||||
|
randomly select a worker machine for execution. |
||||||
|
- **Environment Name**: Configure the environment name in which run the script. |
||||||
|
- **Times of failed retry attempts**: The number of times the task failed to resubmit. |
||||||
|
- **Failed retry interval**: The time interval (unit minute) for resubmitting the task after a failed task. |
||||||
|
- **Delayed execution time**: The time (unit minute) that a task delays in execution. |
||||||
|
- **Timeout alarm**: Check the timeout alarm and timeout failure. When the task runs exceed the "timeout", an alarm |
||||||
|
email will send and the task execution will fail. |
||||||
|
- **Custom parameter**: It is a local user-defined parameter for mlflow, and will replace the content |
||||||
|
with `${variable}` in the script. |
||||||
|
- **Predecessor task**: Selecting a predecessor task for the current task, will set the selected predecessor task as |
||||||
|
upstream of the current task. |
||||||
|
|
||||||
|
- MLflow task specific parameters |
||||||
|
- **mlflow server tracking uri** :MLflow server uri, default http://localhost:5000. |
||||||
|
- **experiment name** :The experiment in which the task is running, if none, is created. |
||||||
|
- **register model** :Register the model or not. If register is selected, the following parameters are expanded. |
||||||
|
- **model name** : The registered model name is added to the original model version and registered as |
||||||
|
Production. |
||||||
|
- **job type** : The type of task to run, currently including the underlying algorithm and AutoML. (User-defined |
||||||
|
MLFlow project task execution will be supported in the near future) |
||||||
|
- BasicAlgorithm specific parameters |
||||||
|
- **algorithm** :The selected algorithm currently supports `LR`, `SVM`, `LightGBM` and `XGboost` based |
||||||
|
on [scikit-learn](https://scikit-learn.org/) form. |
||||||
|
- **Parameter search space** : Parameter search space when running the corresponding algorithm, which can be |
||||||
|
empty. For example, the parameter `max_depth=[5, 10];n_estimators=[100, 200]` for lightgbm 。The convention |
||||||
|
will be passed with '; 'shards each parameter, using the name before the equal sign as the parameter name, |
||||||
|
and using the name after the equal sign to get the corresponding parameter value through `python eval()`. |
||||||
|
- AutoML specific parameters |
||||||
|
- **AutoML tool** : The AutoML tool used, currently |
||||||
|
supports [autosklearn](https://github.com/automl/auto-sklearn) |
||||||
|
and [flaml](https://github.com/microsoft/FLAML) |
||||||
|
- Parameters common to BasicAlgorithm and AutoML |
||||||
|
- **data path** : The absolute path of the file or folder. Ends with .csv for file or contain train.csv and |
||||||
|
test.csv for folder(In the suggested way, users should build their own test sets for model evaluation)。 |
||||||
|
- **parameters** : Parameter when initializing the algorithm/AutoML model, which can be empty. For example |
||||||
|
parameters `"time_budget=30;estimator_list=['lgbm']"` for flaml 。The convention will be passed with '; 'shards |
||||||
|
each parameter, using the name before the equal sign as the parameter name, and using the name after the equal |
||||||
|
sign to get the corresponding parameter value through `python eval()`. |
||||||
|
- BasicAlgorithm |
||||||
|
- [lr](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression) |
||||||
|
- [SVM](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html?highlight=svc#sklearn.svm.SVC) |
||||||
|
- [lightgbm](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm.LGBMClassifier) |
||||||
|
- [xgboost](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBClassifier) |
||||||
|
- AutoML |
||||||
|
- [flaml](https://microsoft.github.io/FLAML/docs/reference/automl#automl-objects) |
||||||
|
- [autosklearn](https://automl.github.io/auto-sklearn/master/api.html) |
||||||
|
|
||||||
|
## Task Example |
||||||
|
|
||||||
|
### Preparation |
||||||
|
|
||||||
|
#### Conda env |
||||||
|
|
||||||
|
You need to enter the admin account to configure a conda environment variable(Please |
||||||
|
install [anaconda](https://docs.continuum.io/anaconda/install/) |
||||||
|
or [miniconda](https://docs.conda.io/en/latest/miniconda.html#installing ) in advance ) |
||||||
|
|
||||||
|
![mlflow-conda-env](/img/tasks/demo/mlflow-conda-env.png) |
||||||
|
|
||||||
|
Note During the configuration task, select the conda environment created above. Otherwise, the program cannot find the |
||||||
|
Conda environment. |
||||||
|
|
||||||
|
![mlflow-set-conda-env](/img/tasks/demo/mlflow-set-conda-env.png) |
||||||
|
|
||||||
|
#### Start the mlflow service |
||||||
|
|
||||||
|
Make sure you have installed MLflow, using 'PIP Install MLFlow'. |
||||||
|
|
||||||
|
Create a folder where you want to save your experiments and models and start mlFlow service. |
||||||
|
|
||||||
|
```sh |
||||||
|
mkdir mlflow |
||||||
|
cd mlflow |
||||||
|
mlflow server -h 0.0.0.0 -p 5000 --serve-artifacts --backend-store-uri sqlite:///mlflow.db |
||||||
|
``` |
||||||
|
|
||||||
|
After running, an MLflow service is started |
||||||
|
|
||||||
|
### Run BasicAlgorithm task |
||||||
|
|
||||||
|
The following example shows how to create an MLflow BasicAlgorithm task. |
||||||
|
|
||||||
|
![mlflow-basic-algorithm](/img/tasks/demo/mlflow-basic-algorithm.png) |
||||||
|
|
||||||
|
After this, you can visit the MLFlow service (`http://localhost:5000`) page to view the experiments and models. |
||||||
|
|
||||||
|
![mlflow-server](/img/tasks/demo/mlflow-server.png) |
||||||
|
|
||||||
|
### Run AutoML task |
||||||
|
|
||||||
|
![mlflow-automl](/img/tasks/demo/mlflow-automl.png) |
@ -0,0 +1,97 @@ |
|||||||
|
# MLflow节点 |
||||||
|
|
||||||
|
## 综述 |
||||||
|
|
||||||
|
[MLflow](https://mlflow.org) 是一个MLops领域一个优秀的开源项目, 用于管理机器学习的生命周期,包括实验、可再现性、部署和中心模型注册。 |
||||||
|
|
||||||
|
MLflow 任务用于执行 MLflow Project 任务,其中包含了阈值的基础算法能力与AutoML能力(将在不久将来支持用户自定义的mlflow project任务执行)。 |
||||||
|
|
||||||
|
## 创建任务 |
||||||
|
|
||||||
|
- 点击项目管理-项目名称-工作流定义,点击“创建工作流”按钮,进入 DAG 编辑页面; |
||||||
|
- 拖动工具栏的 <img src="/img/tasks/icons/mlflow.png" width="15"/> 任务节点到画板中。 |
||||||
|
|
||||||
|
## 任务参数 |
||||||
|
|
||||||
|
- DS通用参数 |
||||||
|
- **节点名称** :设置任务的名称。一个工作流定义中的节点名称是唯一的。 |
||||||
|
- **运行标志** :标识这个节点是否能正常调度,如果不需要执行,可以打开禁止执行开关。 |
||||||
|
- **描述** :描述该节点的功能。 |
||||||
|
- **任务优先级** :worker 线程数不足时,根据优先级从高到低依次执行,优先级一样时根据先进先出原则执行。 |
||||||
|
- **Worker 分组** :任务分配给 worker 组的机器执行,选择 Default,会随机选择一台 worker 机执行。 |
||||||
|
- **环境名称** :配置运行脚本的环境。 |
||||||
|
- **失败重试次数** :任务失败重新提交的次数。 |
||||||
|
- **失败重试间隔** :任务失败重新提交任务的时间间隔,以分钟为单位。 |
||||||
|
- **延迟执行时间** :任务延迟执行的时间,以分钟为单位。 |
||||||
|
- **超时告警** :勾选超时告警、超时失败,当任务超过"超时时长"后,会发送告警邮件并且任务执行失败。 |
||||||
|
- **自定义参数** :是 mlflow 局部的用户自定义参数,会替换脚本中以 ${变量} 的内容 |
||||||
|
- **前置任务** :选择当前任务的前置任务,会将被选择的前置任务设置为当前任务的上游。 |
||||||
|
|
||||||
|
- MLflow任务特定参数 |
||||||
|
- **mlflow server tracking uri** :MLflow server 的连接, 默认 http://localhost:5000。 |
||||||
|
- **实验名称** :任务运行时所在的实验,若无则创建。 |
||||||
|
- **注册模型** :是否注册模型,若选择注册,则会展开以下参数。 |
||||||
|
- **注册的模型名称** : 注册的模型名称,会在原来的基础上加上一个模型版本,并注册为Production。 |
||||||
|
- **任务类型** : 运行的任务类型,目前包括基础算法与AutoML, 后续将会支持用户自定义的ML Project。 |
||||||
|
- 基础算法下的特有参数 |
||||||
|
- **算法** :选择的算法,目前基于 [scikit-learn](https://scikit-learn.org/) 形式支持 `lr`, `svm`, `lightgbm`, `xgboost`. |
||||||
|
- **参数搜索空间** : 运行对应算法的参数搜索空间, 可为空。如针对lightgbm 的 `max_depth=[5, 10];n_estimators=[100, 200]` |
||||||
|
则会进行对应搜索。约定传入后会以`;`切分各个参数,等号前的名字作为参数名,等号后的名字将以python eval执行得到对应的参数值 |
||||||
|
- AutoML下的参数下的特有参数 |
||||||
|
- **AutoML工具** : 使用的AutoML工具,目前支持 [autosklearn](https://github.com/automl/auto-sklearn) |
||||||
|
, [flaml](https://github.com/microsoft/FLAML) |
||||||
|
- BasicAlgorithm 和 AutoML共有参数 |
||||||
|
- **数据路径** : 文件/文件夹的绝对路径, 若文件需以.csv结尾(自动切分训练集与测试集), 文件夹需包含train.csv和test.csv(建议方式,用户应自行构建测试集用于模型评估)。 |
||||||
|
- **参数** : 初始化模型/AutoML训练器时的参数,可为空, 如针对 flaml 设置`"time_budget=30;estimator_list=['lgbm']"`。约定传入后会以`;` |
||||||
|
切分各个参数,等号前的名字作为参数名,等号后的名字将以python eval执行得到对应的参数值。详细的参数列表如下: |
||||||
|
- BasicAlgorithm |
||||||
|
- [lr](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression) |
||||||
|
- [SVM](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html?highlight=svc#sklearn.svm.SVC) |
||||||
|
- [lightgbm](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html#lightgbm.LGBMClassifier) |
||||||
|
- [xgboost](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBClassifier) |
||||||
|
- AutoML |
||||||
|
- [flaml](https://microsoft.github.io/FLAML/docs/reference/automl#automl-objects) |
||||||
|
- [autosklearn](https://automl.github.io/auto-sklearn/master/api.html) |
||||||
|
|
||||||
|
## 任务样例 |
||||||
|
|
||||||
|
### 前置准备 |
||||||
|
|
||||||
|
#### conda 环境配置 |
||||||
|
|
||||||
|
你需要进入admin账户配置一个conda环境变量(请提前[安装anaconda](https://docs.continuum.io/anaconda/install/) |
||||||
|
或者[安装miniconda](https://docs.conda.io/en/latest/miniconda.html#installing) ) |
||||||
|
|
||||||
|
![mlflow-conda-env](/img/tasks/demo/mlflow-conda-env.png) |
||||||
|
|
||||||
|
后续注意配置任务时,环境选择上面创建的conda环境,否则程序会找不到conda环境 |
||||||
|
|
||||||
|
![mlflow-set-conda-env](/img/tasks/demo/mlflow-set-conda-env.png) |
||||||
|
|
||||||
|
#### mlflow service 启动 |
||||||
|
|
||||||
|
确保你已经安装mlflow,可以使用`pip install mlflow`进行安装 |
||||||
|
|
||||||
|
在你想保存实验和模型的地方建立一个文件夹,然后启动 mlflow service |
||||||
|
|
||||||
|
``` |
||||||
|
mkdir mlflow |
||||||
|
cd mlflow |
||||||
|
mlflow server -h 0.0.0.0 -p 5000 --serve-artifacts --backend-store-uri sqlite:///mlflow.db |
||||||
|
``` |
||||||
|
|
||||||
|
运行后会启动一个mlflow服务 |
||||||
|
|
||||||
|
### 执行 基础算法 任务 |
||||||
|
|
||||||
|
以下实例展示了如何创建 mlflow 基础算法任务 |
||||||
|
|
||||||
|
![mlflow-basic-algorithm](/img/tasks/demo/mlflow-basic-algorithm.png) |
||||||
|
|
||||||
|
执行完后可以通过访问 mlflow service (`http://localhost:5000`) 页面查看实验与模型 |
||||||
|
|
||||||
|
![mlflow-server](/img/tasks/demo/mlflow-server.png) |
||||||
|
|
||||||
|
### 执行 AutoML 任务 |
||||||
|
|
||||||
|
![mlflow-automl](/img/tasks/demo/mlflow-automl.png) |
After Width: | Height: | Size: 85 KiB |
After Width: | Height: | Size: 137 KiB |
After Width: | Height: | Size: 252 KiB |
After Width: | Height: | Size: 195 KiB |
After Width: | Height: | Size: 137 KiB |
After Width: | Height: | Size: 111 KiB |
@ -0,0 +1,46 @@ |
|||||||
|
<?xml version="1.0" encoding="UTF-8"?> |
||||||
|
<!-- |
||||||
|
~ Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
~ contributor license agreements. See the NOTICE file distributed with |
||||||
|
~ this work for additional information regarding copyright ownership. |
||||||
|
~ The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
~ (the "License"); you may not use this file except in compliance with |
||||||
|
~ the License. You may obtain a copy of the License at |
||||||
|
~ |
||||||
|
~ http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
~ |
||||||
|
~ Unless required by applicable law or agreed to in writing, software |
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
~ See the License for the specific language governing permissions and |
||||||
|
~ limitations under the License. |
||||||
|
--> |
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" |
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |
||||||
|
<parent> |
||||||
|
<artifactId>dolphinscheduler-task-plugin</artifactId> |
||||||
|
<groupId>org.apache.dolphinscheduler</groupId> |
||||||
|
<version>dev-SNAPSHOT</version> |
||||||
|
</parent> |
||||||
|
<modelVersion>4.0.0</modelVersion> |
||||||
|
|
||||||
|
<artifactId>dolphinscheduler-task-mlflow</artifactId> |
||||||
|
<packaging>jar</packaging> |
||||||
|
|
||||||
|
<dependencies> |
||||||
|
<dependency> |
||||||
|
<groupId>org.apache.dolphinscheduler</groupId> |
||||||
|
<artifactId>dolphinscheduler-spi</artifactId> |
||||||
|
<scope>provided</scope> |
||||||
|
</dependency> |
||||||
|
<dependency> |
||||||
|
<groupId>org.apache.dolphinscheduler</groupId> |
||||||
|
<artifactId>dolphinscheduler-task-api</artifactId> |
||||||
|
</dependency> |
||||||
|
<dependency> |
||||||
|
<groupId>org.apache.commons</groupId> |
||||||
|
<artifactId>commons-collections4</artifactId> |
||||||
|
</dependency> |
||||||
|
</dependencies> |
||||||
|
</project> |
@ -0,0 +1,42 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.apache.dolphinscheduler.plugin.task.mlflow; |
||||||
|
|
||||||
|
public class MlflowConstants { |
||||||
|
private MlflowConstants() { |
||||||
|
throw new IllegalStateException("Utility class"); |
||||||
|
} |
||||||
|
|
||||||
|
public static final String JOB_TYPE_AUTOML = "AutoML"; |
||||||
|
|
||||||
|
public static final String JOB_TYPE_BASIC_ALGORITHM = "BasicAlgorithm"; |
||||||
|
|
||||||
|
public static final String PRESET_REPOSITORY = "https://github.com/apache/dolphinscheduler-mlflow"; |
||||||
|
|
||||||
|
public static final String PRESET_REPOSITORY_VERSION = "main"; |
||||||
|
|
||||||
|
public static final String PRESET_AUTOML_PROJECT = PRESET_REPOSITORY + "#Project-AutoML"; |
||||||
|
|
||||||
|
public static final String PRESET_BASIC_ALGORITHM_PROJECT = PRESET_REPOSITORY + "#Project-BasicAlgorithm"; |
||||||
|
|
||||||
|
public static final String RUN_PROJECT_BASIC_ALGORITHM_SCRIPT = "run_mlflow_basic_algorithm_project.sh"; |
||||||
|
|
||||||
|
public static final String RUN_PROJECT_AUTOML_SCRIPT = "run_mlflow_automl_project.sh"; |
||||||
|
|
||||||
|
|
||||||
|
} |
@ -0,0 +1,191 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.apache.dolphinscheduler.plugin.task.mlflow; |
||||||
|
|
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters; |
||||||
|
|
||||||
|
import java.util.HashMap; |
||||||
|
|
||||||
|
public class MlflowParameters extends AbstractParameters { |
||||||
|
|
||||||
|
/** |
||||||
|
* common parameters |
||||||
|
*/ |
||||||
|
|
||||||
|
private String params = ""; |
||||||
|
|
||||||
|
private String mlflowJobType = "BasicAlgorithm"; |
||||||
|
|
||||||
|
/** |
||||||
|
* AutoML parameters |
||||||
|
*/ |
||||||
|
private String automlTool = "FLAML"; |
||||||
|
|
||||||
|
|
||||||
|
/** |
||||||
|
* basic algorithm parameters |
||||||
|
*/ |
||||||
|
|
||||||
|
private String algorithm = "lightgbm"; |
||||||
|
|
||||||
|
private String searchParams = ""; |
||||||
|
|
||||||
|
private String dataPath; |
||||||
|
|
||||||
|
/** |
||||||
|
* mlflow parameters |
||||||
|
*/ |
||||||
|
|
||||||
|
private String experimentName; |
||||||
|
|
||||||
|
private String modelName = ""; |
||||||
|
|
||||||
|
private String mlflowTrackingUri = "http://127.0.0.1:5000"; |
||||||
|
|
||||||
|
|
||||||
|
public void setAlgorithm(String algorithm) { |
||||||
|
this.algorithm = algorithm; |
||||||
|
} |
||||||
|
|
||||||
|
public String getAlgorithm() { |
||||||
|
return algorithm; |
||||||
|
} |
||||||
|
|
||||||
|
public void setParams(String params) { |
||||||
|
this.params = params; |
||||||
|
} |
||||||
|
|
||||||
|
public String getParams() { |
||||||
|
return params; |
||||||
|
} |
||||||
|
|
||||||
|
public void setSearchParams(String searchParams) { |
||||||
|
this.searchParams = searchParams; |
||||||
|
} |
||||||
|
|
||||||
|
public String getSearchParams() { |
||||||
|
return searchParams; |
||||||
|
} |
||||||
|
|
||||||
|
public void setDataPaths(String dataPath) { |
||||||
|
this.dataPath = dataPath; |
||||||
|
} |
||||||
|
|
||||||
|
public String getDataPath() { |
||||||
|
return dataPath; |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
public void setExperimentNames(String experimentName) { |
||||||
|
this.experimentName = experimentName; |
||||||
|
} |
||||||
|
|
||||||
|
public String getExperimentName() { |
||||||
|
return experimentName; |
||||||
|
} |
||||||
|
|
||||||
|
public void setModelNames(String modelName) { |
||||||
|
this.modelName = modelName; |
||||||
|
} |
||||||
|
|
||||||
|
public String getModelName() { |
||||||
|
return modelName; |
||||||
|
} |
||||||
|
|
||||||
|
public void setMlflowTrackingUris(String mlflowTrackingUri) { |
||||||
|
this.mlflowTrackingUri = mlflowTrackingUri; |
||||||
|
} |
||||||
|
|
||||||
|
public String getMlflowTrackingUri() { |
||||||
|
return mlflowTrackingUri; |
||||||
|
} |
||||||
|
|
||||||
|
public void setMlflowJobType(String mlflowJobType) { |
||||||
|
this.mlflowJobType = mlflowJobType; |
||||||
|
} |
||||||
|
|
||||||
|
public String getMlflowJobType() { |
||||||
|
return mlflowJobType; |
||||||
|
} |
||||||
|
|
||||||
|
public void setAutomlTool(String automlTool) { |
||||||
|
this.automlTool = automlTool; |
||||||
|
} |
||||||
|
|
||||||
|
public String getAutomlTool() { |
||||||
|
return automlTool; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public boolean checkParameters() { |
||||||
|
|
||||||
|
Boolean checkResult = experimentName != null && mlflowTrackingUri != null; |
||||||
|
if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_BASIC_ALGORITHM)) { |
||||||
|
checkResult &= dataPath != null; |
||||||
|
} else if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_AUTOML)) { |
||||||
|
checkResult &= dataPath != null; |
||||||
|
checkResult &= automlTool != null; |
||||||
|
} else { |
||||||
|
} |
||||||
|
return checkResult; |
||||||
|
} |
||||||
|
|
||||||
|
public HashMap<String, String> getParamsMap() { |
||||||
|
|
||||||
|
HashMap<String, String> paramsMap = new HashMap<String, String>(); |
||||||
|
paramsMap.put("params", params); |
||||||
|
paramsMap.put("data_path", dataPath); |
||||||
|
paramsMap.put("experiment_name", experimentName); |
||||||
|
paramsMap.put("model_name", modelName); |
||||||
|
paramsMap.put("MLFLOW_TRACKING_URI", mlflowTrackingUri); |
||||||
|
if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_BASIC_ALGORITHM)) { |
||||||
|
addParamsMapForBasicAlgorithm(paramsMap); |
||||||
|
} else if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_AUTOML)) { |
||||||
|
getParamsMapForAutoML(paramsMap); |
||||||
|
} else { |
||||||
|
} |
||||||
|
return paramsMap; |
||||||
|
} |
||||||
|
|
||||||
|
private void addParamsMapForBasicAlgorithm(HashMap<String, String> paramsMap) { |
||||||
|
paramsMap.put("algorithm", algorithm); |
||||||
|
paramsMap.put("search_params", searchParams); |
||||||
|
paramsMap.put("repo", MlflowConstants.PRESET_BASIC_ALGORITHM_PROJECT); |
||||||
|
paramsMap.put("repo_version", MlflowConstants.PRESET_REPOSITORY_VERSION); |
||||||
|
} |
||||||
|
|
||||||
|
private void getParamsMapForAutoML(HashMap<String, String> paramsMap) { |
||||||
|
paramsMap.put("automl_tool", automlTool); |
||||||
|
paramsMap.put("repo", MlflowConstants.PRESET_AUTOML_PROJECT); |
||||||
|
paramsMap.put("repo_version", MlflowConstants.PRESET_REPOSITORY_VERSION); |
||||||
|
} |
||||||
|
|
||||||
|
public String getScriptPath() { |
||||||
|
String projectScript; |
||||||
|
if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_BASIC_ALGORITHM)) { |
||||||
|
projectScript = MlflowConstants.RUN_PROJECT_BASIC_ALGORITHM_SCRIPT; |
||||||
|
} else if (mlflowJobType.equals(MlflowConstants.JOB_TYPE_AUTOML)) { |
||||||
|
projectScript = MlflowConstants.RUN_PROJECT_AUTOML_SCRIPT; |
||||||
|
} else { |
||||||
|
throw new IllegalArgumentException(); |
||||||
|
} |
||||||
|
String scriptPath = MlflowTask.class.getClassLoader().getResource(projectScript).getPath(); |
||||||
|
return scriptPath; |
||||||
|
} |
||||||
|
|
||||||
|
}; |
@ -0,0 +1,139 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.apache.dolphinscheduler.plugin.task.mlflow; |
||||||
|
|
||||||
|
import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE; |
||||||
|
|
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.AbstractTaskExecutor; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.ShellCommandExecutor; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.model.TaskResponse; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.parser.ParameterUtils; |
||||||
|
import org.apache.dolphinscheduler.spi.utils.JSONUtils; |
||||||
|
|
||||||
|
import java.nio.file.Files; |
||||||
|
import java.nio.file.Path; |
||||||
|
|
||||||
|
import java.io.*; |
||||||
|
import java.nio.file.Paths; |
||||||
|
|
||||||
|
/** |
||||||
|
* shell task |
||||||
|
*/ |
||||||
|
public class MlflowTask extends AbstractTaskExecutor { |
||||||
|
|
||||||
|
/** |
||||||
|
* shell parameters |
||||||
|
*/ |
||||||
|
private MlflowParameters mlflowParameters; |
||||||
|
|
||||||
|
/** |
||||||
|
* shell command executor |
||||||
|
*/ |
||||||
|
private ShellCommandExecutor shellCommandExecutor; |
||||||
|
|
||||||
|
/** |
||||||
|
* taskExecutionContext |
||||||
|
*/ |
||||||
|
private TaskExecutionContext taskExecutionContext; |
||||||
|
|
||||||
|
/** |
||||||
|
* constructor |
||||||
|
* |
||||||
|
* @param taskExecutionContext taskExecutionContext |
||||||
|
*/ |
||||||
|
public MlflowTask(TaskExecutionContext taskExecutionContext) { |
||||||
|
super(taskExecutionContext); |
||||||
|
|
||||||
|
this.taskExecutionContext = taskExecutionContext; |
||||||
|
this.shellCommandExecutor = new ShellCommandExecutor(this::logHandle, |
||||||
|
taskExecutionContext, |
||||||
|
logger); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void init() { |
||||||
|
logger.info("shell task params {}", taskExecutionContext.getTaskParams()); |
||||||
|
|
||||||
|
mlflowParameters = JSONUtils.parseObject(taskExecutionContext.getTaskParams(), MlflowParameters.class); |
||||||
|
|
||||||
|
if (!mlflowParameters.checkParameters()) { |
||||||
|
throw new RuntimeException("shell task params is not valid"); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void handle() throws Exception { |
||||||
|
try { |
||||||
|
// construct process
|
||||||
|
String command = buildCommand(); |
||||||
|
TaskResponse commandExecuteResult = shellCommandExecutor.run(command); |
||||||
|
setExitStatusCode(commandExecuteResult.getExitStatusCode()); |
||||||
|
setAppIds(commandExecuteResult.getAppIds()); |
||||||
|
setProcessId(commandExecuteResult.getProcessId()); |
||||||
|
mlflowParameters.dealOutParam(shellCommandExecutor.getVarPool()); |
||||||
|
} catch (Exception e) { |
||||||
|
logger.error("shell task error", e); |
||||||
|
setExitStatusCode(EXIT_CODE_FAILURE); |
||||||
|
throw e; |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void cancelApplication(boolean cancelApplication) throws Exception { |
||||||
|
// cancel process
|
||||||
|
shellCommandExecutor.cancelApplication(); |
||||||
|
} |
||||||
|
|
||||||
|
/** |
||||||
|
* create command |
||||||
|
* |
||||||
|
* @return file name |
||||||
|
* @throws Exception exception |
||||||
|
*/ |
||||||
|
private String buildCommand() throws Exception { |
||||||
|
|
||||||
|
/** |
||||||
|
* load script template from resource folder |
||||||
|
*/ |
||||||
|
String script = loadRunScript(mlflowParameters.getScriptPath()); |
||||||
|
script = parseScript(script); |
||||||
|
|
||||||
|
logger.info("raw script : \n{}", script); |
||||||
|
logger.info("task execute path : {}", taskExecutionContext.getExecutePath()); |
||||||
|
|
||||||
|
return script; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public AbstractParameters getParameters() { |
||||||
|
return mlflowParameters; |
||||||
|
} |
||||||
|
|
||||||
|
private String parseScript(String script) { |
||||||
|
return ParameterUtils.convertParameterPlaceholders(script, mlflowParameters.getParamsMap()); |
||||||
|
} |
||||||
|
|
||||||
|
public static String loadRunScript(String scriptPath) throws IOException { |
||||||
|
Path path = Paths.get(scriptPath); |
||||||
|
byte[] data = Files.readAllBytes(path); |
||||||
|
String result = new String(data); |
||||||
|
return result; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,50 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.apache.dolphinscheduler.plugin.task.mlflow; |
||||||
|
|
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskChannel; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.parameters.ParametersNode; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.parameters.resource.ResourceParametersHelper; |
||||||
|
import org.apache.dolphinscheduler.spi.utils.JSONUtils; |
||||||
|
|
||||||
|
|
||||||
|
public class MlflowTaskChannel implements TaskChannel { |
||||||
|
|
||||||
|
@Override |
||||||
|
public void cancelApplication(boolean status) { |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public MlflowTask createTask(TaskExecutionContext taskRequest) { |
||||||
|
return new MlflowTask(taskRequest); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public AbstractParameters parseParameters(ParametersNode parametersNode) { |
||||||
|
return JSONUtils.parseObject(parametersNode.getTaskParams(), MlflowParameters.class); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public ResourceParametersHelper getResources(String parameters) { |
||||||
|
return null; |
||||||
|
} |
||||||
|
|
||||||
|
} |
@ -0,0 +1,57 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.apache.dolphinscheduler.plugin.task.mlflow; |
||||||
|
|
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskChannel; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskChannelFactory; |
||||||
|
import org.apache.dolphinscheduler.spi.params.base.ParamsOptions; |
||||||
|
import org.apache.dolphinscheduler.spi.params.base.PluginParams; |
||||||
|
import org.apache.dolphinscheduler.spi.params.base.Validate; |
||||||
|
import org.apache.dolphinscheduler.spi.params.input.InputParam; |
||||||
|
import org.apache.dolphinscheduler.spi.params.radio.RadioParam; |
||||||
|
|
||||||
|
import java.util.ArrayList; |
||||||
|
import java.util.List; |
||||||
|
|
||||||
|
import com.google.auto.service.AutoService; |
||||||
|
|
||||||
|
@AutoService(TaskChannelFactory.class) |
||||||
|
public class MlflowTaskChannelFactory implements TaskChannelFactory { |
||||||
|
@Override |
||||||
|
public TaskChannel create() { |
||||||
|
return new MlflowTaskChannel(); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public String getName() { |
||||||
|
return "MLFLOW"; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public List<PluginParams> getParams() { |
||||||
|
List<PluginParams> paramsList = new ArrayList<>(); |
||||||
|
|
||||||
|
InputParam nodeName = InputParam.newBuilder("name", "$t('Node name')").addValidate(Validate.newBuilder().setRequired(true).build()).build(); |
||||||
|
|
||||||
|
RadioParam runFlag = RadioParam.newBuilder("runFlag", "RUN_FLAG").addParamsOptions(new ParamsOptions("NORMAL", "NORMAL", false)).addParamsOptions(new ParamsOptions("FORBIDDEN", "FORBIDDEN", false)).build(); |
||||||
|
|
||||||
|
paramsList.add(nodeName); |
||||||
|
paramsList.add(runFlag); |
||||||
|
return paramsList; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,25 @@ |
|||||||
|
#!/bin/bash |
||||||
|
# |
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
# contributor license agreements. See the NOTICE file distributed with |
||||||
|
# this work for additional information regarding copyright ownership. |
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
# (the "License"); you may not use this file except in compliance with |
||||||
|
# the License. You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
# |
||||||
|
|
||||||
|
data_path=${data_path} |
||||||
|
export MLFLOW_TRACKING_URI=${MLFLOW_TRACKING_URI} |
||||||
|
echo $data_path |
||||||
|
repo=${repo} |
||||||
|
mlflow run $repo -P tool=${automl_tool} -P data_path=$data_path -P params="${params}" -P model_name="${model_name}" --experiment-name="${experiment_name}" --version="${repo_version}" |
||||||
|
|
||||||
|
echo "training finish" |
@ -0,0 +1,25 @@ |
|||||||
|
#!/bin/bash |
||||||
|
# |
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
# contributor license agreements. See the NOTICE file distributed with |
||||||
|
# this work for additional information regarding copyright ownership. |
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
# (the "License"); you may not use this file except in compliance with |
||||||
|
# the License. You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
# |
||||||
|
|
||||||
|
data_path=${data_path} |
||||||
|
export MLFLOW_TRACKING_URI=${MLFLOW_TRACKING_URI} |
||||||
|
echo $data_path |
||||||
|
repo=${repo} |
||||||
|
mlflow run $repo -P algorithm=${algorithm} -P data_path=$data_path -P params="${params}" -P search_params="${search_params}" -P model_name="${model_name}" --experiment-name="${experiment_name}" --version="${repo_version}" |
||||||
|
|
||||||
|
echo "training finish" |
@ -0,0 +1,129 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
package org.apache.dolphinler.plugin.task.mlflow; |
||||||
|
|
||||||
|
import java.util.Date; |
||||||
|
import java.util.UUID; |
||||||
|
|
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowParameters; |
||||||
|
import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowTask; |
||||||
|
import org.junit.Assert; |
||||||
|
import org.junit.Before; |
||||||
|
import org.junit.Test; |
||||||
|
import org.mockito.Mockito; |
||||||
|
import org.powermock.api.mockito.PowerMockito; |
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
import org.apache.dolphinscheduler.spi.utils.PropertyUtils; |
||||||
|
|
||||||
|
import org.powermock.core.classloader.annotations.PrepareForTest; |
||||||
|
import org.powermock.core.classloader.annotations.PowerMockIgnore; |
||||||
|
import org.junit.runner.RunWith; |
||||||
|
import org.powermock.modules.junit4.PowerMockRunner; |
||||||
|
import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; |
||||||
|
import org.apache.dolphinscheduler.spi.utils.JSONUtils; |
||||||
|
|
||||||
|
|
||||||
|
@RunWith(PowerMockRunner.class) |
||||||
|
@PrepareForTest({ |
||||||
|
JSONUtils.class, |
||||||
|
PropertyUtils.class, |
||||||
|
}) |
||||||
|
@PowerMockIgnore({"javax.*"}) |
||||||
|
@SuppressStaticInitializationFor("org.apache.dolphinscheduler.spi.utils.PropertyUtils") |
||||||
|
public class MlflowTaskTest { |
||||||
|
private static final Logger logger = LoggerFactory.getLogger(MlflowTask.class); |
||||||
|
|
||||||
|
@Before |
||||||
|
public void before() throws Exception { |
||||||
|
PowerMockito.mockStatic(PropertyUtils.class); |
||||||
|
} |
||||||
|
|
||||||
|
public TaskExecutionContext createContext(MlflowParameters mlflowParameters){ |
||||||
|
String parameters = JSONUtils.toJsonString(mlflowParameters); |
||||||
|
TaskExecutionContext taskExecutionContext = Mockito.mock(TaskExecutionContext.class); |
||||||
|
Mockito.when(taskExecutionContext.getTaskParams()).thenReturn(parameters); |
||||||
|
Mockito.when(taskExecutionContext.getTaskLogName()).thenReturn("MLflowTest"); |
||||||
|
Mockito.when(taskExecutionContext.getExecutePath()).thenReturn("/tmp/dolphinscheduler_test"); |
||||||
|
Mockito.when(taskExecutionContext.getTaskAppId()).thenReturn(UUID.randomUUID().toString()); |
||||||
|
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn("root"); |
||||||
|
Mockito.when(taskExecutionContext.getStartTime()).thenReturn(new Date()); |
||||||
|
Mockito.when(taskExecutionContext.getTaskTimeout()).thenReturn(10000); |
||||||
|
Mockito.when(taskExecutionContext.getLogPath()).thenReturn("/tmp/dolphinscheduler_test/log"); |
||||||
|
Mockito.when(taskExecutionContext.getEnvironmentConfig()).thenReturn("export PATH=$HOME/anaconda3/bin:$PATH"); |
||||||
|
|
||||||
|
String userName = System.getenv().get("USER"); |
||||||
|
Mockito.when(taskExecutionContext.getTenantCode()).thenReturn(userName); |
||||||
|
|
||||||
|
TaskExecutionContextCacheManager.cacheTaskExecutionContext(taskExecutionContext); |
||||||
|
return taskExecutionContext; |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testInitBasicAlgorithmTask() |
||||||
|
throws Exception { |
||||||
|
try { |
||||||
|
MlflowParameters mlflowParameters = createBasicAlgorithmParameters(); |
||||||
|
TaskExecutionContext taskExecutionContext = createContext(mlflowParameters); |
||||||
|
MlflowTask mlflowTask = new MlflowTask(taskExecutionContext); |
||||||
|
mlflowTask.init(); |
||||||
|
mlflowTask.getParameters().setVarPool(taskExecutionContext.getVarPool()); |
||||||
|
} catch (Exception e) { |
||||||
|
Assert.fail(e.getMessage()); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testInitAutoMLTask() |
||||||
|
throws Exception { |
||||||
|
try { |
||||||
|
MlflowParameters mlflowParameters = createAutoMLParameters(); |
||||||
|
TaskExecutionContext taskExecutionContext = createContext(mlflowParameters); |
||||||
|
MlflowTask mlflowTask = new MlflowTask(taskExecutionContext); |
||||||
|
mlflowTask.init(); |
||||||
|
mlflowTask.getParameters().setVarPool(taskExecutionContext.getVarPool()); |
||||||
|
} catch (Exception e) { |
||||||
|
Assert.fail(e.getMessage()); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private MlflowParameters createBasicAlgorithmParameters() { |
||||||
|
MlflowParameters mlflowParameters = new MlflowParameters(); |
||||||
|
mlflowParameters.setMlflowJobType("BasicAlgorithm"); |
||||||
|
mlflowParameters.setAlgorithm("xgboost"); |
||||||
|
mlflowParameters.setDataPaths("xxxxxxxxxx"); |
||||||
|
mlflowParameters.setExperimentNames("asbbb"); |
||||||
|
mlflowParameters.setMlflowTrackingUris("http://127.0.0.1:5000"); |
||||||
|
return mlflowParameters; |
||||||
|
} |
||||||
|
|
||||||
|
private MlflowParameters createAutoMLParameters() { |
||||||
|
MlflowParameters mlflowParameters = new MlflowParameters(); |
||||||
|
mlflowParameters.setMlflowJobType("AutoML"); |
||||||
|
mlflowParameters.setAutomlTool("autosklearn"); |
||||||
|
mlflowParameters.setParams("time_left_for_this_task=30"); |
||||||
|
mlflowParameters.setDataPaths("xxxxxxxxxxx"); |
||||||
|
mlflowParameters.setExperimentNames("asbbb"); |
||||||
|
mlflowParameters.setModelNames("asbbb"); |
||||||
|
mlflowParameters.setMlflowTrackingUris("http://127.0.0.1:5000"); |
||||||
|
return mlflowParameters; |
||||||
|
} |
||||||
|
|
||||||
|
} |
After Width: | Height: | Size: 30 KiB |
After Width: | Height: | Size: 111 KiB |
@ -0,0 +1,184 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
import { useI18n } from 'vue-i18n' |
||||||
|
import { useCustomParams } from '.' |
||||||
|
import type { IJsonItem } from '../types' |
||||||
|
import { computed } from 'vue' |
||||||
|
|
||||||
|
export const MLFLOW_JOB_TYPE = [ |
||||||
|
{ |
||||||
|
label: 'BasicAlgorithm', |
||||||
|
value: 'BasicAlgorithm' |
||||||
|
}, |
||||||
|
{ |
||||||
|
label: 'AutoML', |
||||||
|
value: 'AutoML' |
||||||
|
} |
||||||
|
] |
||||||
|
export const ALGORITHM = [ |
||||||
|
{ |
||||||
|
label: 'svm', |
||||||
|
value: 'svm' |
||||||
|
}, |
||||||
|
{ |
||||||
|
label: 'lr', |
||||||
|
value: 'lr' |
||||||
|
}, |
||||||
|
{ |
||||||
|
label: 'lightgbm', |
||||||
|
value: 'lightgbm' |
||||||
|
}, |
||||||
|
{ |
||||||
|
label: 'xgboost', |
||||||
|
value: 'xgboost' |
||||||
|
} |
||||||
|
] |
||||||
|
export const AutoMLTOOL = [ |
||||||
|
{ |
||||||
|
label: 'autosklearn', |
||||||
|
value: 'autosklearn' |
||||||
|
}, |
||||||
|
{ |
||||||
|
label: 'flaml', |
||||||
|
value: 'flaml' |
||||||
|
} |
||||||
|
] |
||||||
|
|
||||||
|
export function useMlflow(model: { [field: string]: any }): IJsonItem[] { |
||||||
|
const { t } = useI18n() |
||||||
|
const registerModelSpan = computed(() => (model.registerModel ? 12 : 24)) |
||||||
|
const modelNameSpan = computed(() => (model.registerModel ? 12 : 0)) |
||||||
|
const algorithmSpan = computed(() => |
||||||
|
model.mlflowJobType === 'BasicAlgorithm' ? 12 : 0 |
||||||
|
) |
||||||
|
const automlToolSpan = computed(() => |
||||||
|
model.mlflowJobType === 'AutoML' ? 12 : 0 |
||||||
|
) |
||||||
|
const searchParamsSpan = computed(() => |
||||||
|
model.mlflowJobType === 'BasicAlgorithm' ? 24 : 0 |
||||||
|
) |
||||||
|
|
||||||
|
return [ |
||||||
|
{ |
||||||
|
type: 'input', |
||||||
|
field: 'mlflowTrackingUri', |
||||||
|
name: t('project.node.mlflow_mlflowTrackingUri'), |
||||||
|
span: 12, |
||||||
|
props: { |
||||||
|
placeholder: t('project.node.mlflow_mlflowTrackingUri_tips') |
||||||
|
}, |
||||||
|
validate: { |
||||||
|
trigger: ['input', 'blur'], |
||||||
|
required: false, |
||||||
|
validator(validate: any, value: string) { |
||||||
|
if (!value) { |
||||||
|
return new Error( |
||||||
|
t('project.node.mlflow_mlflowTrackingUri_error_tips') |
||||||
|
) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'input', |
||||||
|
field: 'experimentName', |
||||||
|
name: t('project.node.mlflow_experimentName'), |
||||||
|
span: 12, |
||||||
|
props: { |
||||||
|
placeholder: t('project.node.mlflow_experimentName_tips') |
||||||
|
}, |
||||||
|
validate: { |
||||||
|
trigger: ['input', 'blur'], |
||||||
|
required: false |
||||||
|
} |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'switch', |
||||||
|
field: 'registerModel', |
||||||
|
name: t('project.node.mlflow_registerModel'), |
||||||
|
span: registerModelSpan |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'input', |
||||||
|
field: 'modelName', |
||||||
|
name: t('project.node.mlflow_modelName'), |
||||||
|
span: modelNameSpan, |
||||||
|
props: { |
||||||
|
placeholder: t('project.node.mlflow_modelName_tips') |
||||||
|
}, |
||||||
|
validate: { |
||||||
|
trigger: ['input', 'blur'], |
||||||
|
required: false |
||||||
|
} |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'select', |
||||||
|
field: 'mlflowJobType', |
||||||
|
name: t('project.node.mlflow_jobType'), |
||||||
|
span: 12, |
||||||
|
options: MLFLOW_JOB_TYPE |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'select', |
||||||
|
field: 'algorithm', |
||||||
|
name: t('project.node.mlflow_algorithm'), |
||||||
|
span: algorithmSpan, |
||||||
|
options: ALGORITHM |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'select', |
||||||
|
field: 'automlTool', |
||||||
|
name: t('project.node.mlflow_automlTool'), |
||||||
|
span: automlToolSpan, |
||||||
|
options: AutoMLTOOL |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'input', |
||||||
|
field: 'dataPath', |
||||||
|
name: t('project.node.mlflow_dataPath'), |
||||||
|
props: { |
||||||
|
placeholder: t('project.node.mlflow_dataPath_tips') |
||||||
|
} |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'input', |
||||||
|
field: 'params', |
||||||
|
name: t('project.node.mlflow_params'), |
||||||
|
props: { |
||||||
|
placeholder: t('project.node.mlflow_params_tips') |
||||||
|
}, |
||||||
|
validate: { |
||||||
|
trigger: ['input', 'blur'], |
||||||
|
required: false |
||||||
|
} |
||||||
|
}, |
||||||
|
{ |
||||||
|
type: 'input', |
||||||
|
field: 'searchParams', |
||||||
|
name: t('project.node.mlflow_searchParams'), |
||||||
|
props: { |
||||||
|
placeholder: t('project.node.mlflow_searchParams_tips') |
||||||
|
}, |
||||||
|
span: searchParamsSpan, |
||||||
|
validate: { |
||||||
|
trigger: ['input', 'blur'], |
||||||
|
required: false |
||||||
|
} |
||||||
|
}, |
||||||
|
...useCustomParams({ model, field: 'localParams', isSimple: false }) |
||||||
|
] |
||||||
|
} |
@ -0,0 +1,84 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
|
||||||
|
import { reactive } from 'vue' |
||||||
|
import * as Fields from '../fields/index' |
||||||
|
import type { IJsonItem, INodeData, ITaskData } from '../types' |
||||||
|
|
||||||
|
export function useMlflow({ |
||||||
|
projectCode, |
||||||
|
from = 0, |
||||||
|
readonly, |
||||||
|
data |
||||||
|
}: { |
||||||
|
projectCode: number |
||||||
|
from?: number |
||||||
|
readonly?: boolean |
||||||
|
data?: ITaskData |
||||||
|
}) { |
||||||
|
const model = reactive({ |
||||||
|
name: '', |
||||||
|
taskType: 'MLFLOW', |
||||||
|
flag: 'YES', |
||||||
|
description: '', |
||||||
|
timeoutFlag: false, |
||||||
|
localParams: [], |
||||||
|
environmentCode: null, |
||||||
|
failRetryInterval: 1, |
||||||
|
failRetryTimes: 0, |
||||||
|
workerGroup: 'default', |
||||||
|
mlflowTrackingUri: 'http://127.0.0.1:5000', |
||||||
|
algorithm: 'svm', |
||||||
|
mlflowJobType: 'AutoML', |
||||||
|
automlTool: 'flaml', |
||||||
|
delayTime: 0, |
||||||
|
timeout: 30 |
||||||
|
} as INodeData) |
||||||
|
|
||||||
|
let extra: IJsonItem[] = [] |
||||||
|
if (from === 1) { |
||||||
|
extra = [ |
||||||
|
Fields.useTaskType(model, readonly), |
||||||
|
Fields.useProcessName({ |
||||||
|
model, |
||||||
|
projectCode, |
||||||
|
isCreate: !data?.id, |
||||||
|
from, |
||||||
|
processName: data?.processName |
||||||
|
}) |
||||||
|
] |
||||||
|
} |
||||||
|
|
||||||
|
return { |
||||||
|
json: [ |
||||||
|
Fields.useName(from), |
||||||
|
...extra, |
||||||
|
Fields.useRunFlag(), |
||||||
|
Fields.useDescription(), |
||||||
|
Fields.useTaskPriority(), |
||||||
|
Fields.useWorkerGroup(), |
||||||
|
Fields.useEnvironmentName(model, !model.id), |
||||||
|
...Fields.useTaskGroup(model, projectCode), |
||||||
|
...Fields.useFailed(), |
||||||
|
Fields.useDelayTime(model), |
||||||
|
...Fields.useTimeoutAlarm(model), |
||||||
|
...Fields.useMlflow(model), |
||||||
|
Fields.usePreTasks() |
||||||
|
] as IJsonItem[], |
||||||
|
model |
||||||
|
} |
||||||
|
} |