diff --git a/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowConstants.java b/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowConstants.java index 8712081ad7..c2701fb543 100644 --- a/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowConstants.java +++ b/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowConstants.java @@ -30,11 +30,13 @@ public class MlflowConstants { public static final String PRESET_REPOSITORY = "https://github.com/apache/dolphinscheduler-mlflow"; + public static final String PRESET_PATH = "dolphinscheduler-mlflow"; + public static final String PRESET_REPOSITORY_VERSION = "main"; - public static final String PRESET_AUTOML_PROJECT = PRESET_REPOSITORY + "#Project-AutoML"; + public static final String PRESET_AUTOML_PROJECT = PRESET_PATH + "#Project-AutoML"; - public static final String PRESET_BASIC_ALGORITHM_PROJECT = PRESET_REPOSITORY + "#Project-BasicAlgorithm"; + public static final String PRESET_BASIC_ALGORITHM_PROJECT = PRESET_PATH + "#Project-BasicAlgorithm"; public static final String MLFLOW_TASK_TYPE_PROJECTS = "MLflow Projects"; @@ -62,27 +64,25 @@ public class MlflowConstants { public static final String SET_REPOSITORY = "repo=%s"; - public static final String MLFLOW_RUN_BASIC_ALGORITHM = "mlflow run $repo " + - "-P algorithm=%s " + - "-P data_path=$data_path " + - "-P params=\"%s\" " + - "-P search_params=\"%s\" " + - "-P model_name=\"%s\" " + - "--experiment-name=\"%s\" " + - "--version=main "; - - public static final String MLFLOW_RUN_AUTOML_PROJECT = "mlflow run $repo " + - "-P tool=%s " + - "-P data_path=$data_path " + - "-P params=\"%s\" " + - "-P model_name=\"%s\" " + - "--experiment-name=\"%s\" " + - "--version=main "; - - public static final String MLFLOW_RUN_CUSTOM_PROJECT = "mlflow run $repo " + - "%s " + - "--experiment-name=\"%s\" " + - "--version=\"%s\" "; + public static final String MLFLOW_RUN_BASIC_ALGORITHM = "mlflow run $repo " + + "-P algorithm=%s " + + "-P data_path=$data_path " + + "-P params=\"%s\" " + + "-P search_params=\"%s\" " + + "-P model_name=\"%s\" " + + "--experiment-name=\"%s\""; + + public static final String MLFLOW_RUN_AUTOML_PROJECT = "mlflow run $repo " + + "-P tool=%s " + + "-P data_path=$data_path " + + "-P params=\"%s\" " + + "-P model_name=\"%s\" " + + "--experiment-name=\"%s\""; + + public static final String MLFLOW_RUN_CUSTOM_PROJECT = "mlflow run $repo " + + "%s " + + "--experiment-name=\"%s\" " + + "--version=\"%s\" "; public static final String MLFLOW_MODELS_SERVE = "mlflow models serve -m %s --port %s -h 0.0.0.0"; @@ -90,17 +90,17 @@ public class MlflowConstants { public static final String DOCKER_RREMOVE_CONTAINER = "docker rm -f %s"; - public static final String DOCKER_RUN = "docker run -d --name=%s -p=%s:8080 " + - "--health-cmd \"curl --fail http://127.0.0.1:8080/ping || exit 1\" --health-interval 5s --health-retries 20" + - " %s"; + public static final String DOCKER_RUN = "docker run -d --name=%s -p=%s:8080 " + + "--health-cmd \"curl --fail http://127.0.0.1:8080/ping || exit 1\" --health-interval 5s --health-retries 20" + + " %s"; public static final String DOCKER_COMPOSE_RUN = "docker-compose up -d"; - public static final String SET_DOCKER_COMPOSE_ENV = "export DS_TASK_MLFLOW_IMAGE_NAME=%s\n" + - "export DS_TASK_MLFLOW_CONTAINER_NAME=%s\n" + - "export DS_TASK_MLFLOW_DEPLOY_PORT=%s\n" + - "export DS_TASK_MLFLOW_CPU_LIMIT=%s\n" + - "export DS_TASK_MLFLOW_MEMORY_LIMIT=%s"; + public static final String SET_DOCKER_COMPOSE_ENV = "export DS_TASK_MLFLOW_IMAGE_NAME=%s\n" + + "export DS_TASK_MLFLOW_CONTAINER_NAME=%s\n" + + "export DS_TASK_MLFLOW_DEPLOY_PORT=%s\n" + + "export DS_TASK_MLFLOW_CPU_LIMIT=%s\n" + + "export DS_TASK_MLFLOW_MEMORY_LIMIT=%s"; public static final String DOCKER_HEALTH_CHECK = "docker inspect --format \"{{json .State.Health.Status }}\" %s"; @@ -108,4 +108,6 @@ public class MlflowConstants { public static final int DOCKER_HEALTH_CHECK_TIMEOUT = 20; public static final int DOCKER_HEALTH_CHECK_INTERVAL = 5000; + + public static final String GIT_CLONE_REPO = "git clone %s %s"; } diff --git a/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowTask.java b/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowTask.java index 6c87162354..0f3df4f357 100644 --- a/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowTask.java +++ b/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/main/java/org/apache/dolphinscheduler/plugin/task/mlflow/MlflowTask.java @@ -19,6 +19,7 @@ package org.apache.dolphinscheduler.plugin.task.mlflow; import static org.apache.dolphinscheduler.plugin.task.api.TaskConstants.EXIT_CODE_FAILURE; +import org.apache.dolphinscheduler.common.thread.ThreadUtils; import org.apache.dolphinscheduler.plugin.task.api.AbstractTaskExecutor; import org.apache.dolphinscheduler.plugin.task.api.ShellCommandExecutor; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; @@ -26,14 +27,11 @@ import org.apache.dolphinscheduler.plugin.task.api.model.Property; import org.apache.dolphinscheduler.plugin.task.api.model.TaskResponse; import org.apache.dolphinscheduler.plugin.task.api.parameters.AbstractParameters; import org.apache.dolphinscheduler.plugin.task.api.parser.ParamUtils; -import org.apache.dolphinscheduler.plugin.task.api.utils.MapUtils; -import org.apache.dolphinscheduler.plugin.task.api.utils.OSUtils; import org.apache.dolphinscheduler.plugin.task.api.parser.ParameterUtils; +import org.apache.dolphinscheduler.plugin.task.api.utils.OSUtils; import org.apache.dolphinscheduler.spi.utils.JSONUtils; -import org.apache.dolphinscheduler.common.thread.ThreadUtils; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -50,12 +48,12 @@ public class MlflowTask extends AbstractTaskExecutor { /** * shell command executor */ - private ShellCommandExecutor shellCommandExecutor; + private final ShellCommandExecutor shellCommandExecutor; /** * taskExecutionContext */ - private TaskExecutionContext taskExecutionContext; + private final TaskExecutionContext taskExecutionContext; /** * constructor @@ -87,9 +85,9 @@ public class MlflowTask extends AbstractTaskExecutor { String command = buildCommand(); TaskResponse commandExecuteResult = shellCommandExecutor.run(command); int exitCode; - if (mlflowParameters.getIsDeployDocker()){ + if (mlflowParameters.getIsDeployDocker()) { exitCode = checkDockerHealth(); - }else { + } else { exitCode = commandExecuteResult.getExitStatusCode(); } setExitStatusCode(exitCode); @@ -136,19 +134,20 @@ public class MlflowTask extends AbstractTaskExecutor { if (mlflowParameters.getMlflowJobType().equals(MlflowConstants.JOB_TYPE_BASIC_ALGORITHM)) { args.add(String.format(MlflowConstants.SET_DATA_PATH, mlflowParameters.getDataPath())); args.add(String.format(MlflowConstants.SET_REPOSITORY, MlflowConstants.PRESET_BASIC_ALGORITHM_PROJECT)); - + args.add(String.format(MlflowConstants.GIT_CLONE_REPO, MlflowConstants.PRESET_REPOSITORY, MlflowConstants.PRESET_PATH)); runCommand = MlflowConstants.MLFLOW_RUN_BASIC_ALGORITHM; - runCommand = String.format(runCommand, mlflowParameters.getAlgorithm(), mlflowParameters.getParams(), mlflowParameters.getSearchParams(), mlflowParameters.getModelName(), mlflowParameters.getExperimentName()); + runCommand = String.format(runCommand, mlflowParameters.getAlgorithm(), mlflowParameters.getParams(), mlflowParameters.getSearchParams(), mlflowParameters.getModelName(), + mlflowParameters.getExperimentName()); } else if (mlflowParameters.getMlflowJobType().equals(MlflowConstants.JOB_TYPE_AUTOML)) { args.add(String.format(MlflowConstants.SET_DATA_PATH, mlflowParameters.getDataPath())); args.add(String.format(MlflowConstants.SET_REPOSITORY, MlflowConstants.PRESET_AUTOML_PROJECT)); + args.add(String.format(MlflowConstants.GIT_CLONE_REPO, MlflowConstants.PRESET_REPOSITORY, MlflowConstants.PRESET_PATH)); runCommand = MlflowConstants.MLFLOW_RUN_AUTOML_PROJECT; runCommand = String.format(runCommand, mlflowParameters.getAutomlTool(), mlflowParameters.getParams(), mlflowParameters.getModelName(), mlflowParameters.getExperimentName()); - } else if (mlflowParameters.getMlflowJobType().equals(MlflowConstants.JOB_TYPE_CUSTOM_PROJECT)) { args.add(String.format(MlflowConstants.SET_REPOSITORY, mlflowParameters.getMlflowProjectRepository())); @@ -166,10 +165,9 @@ public class MlflowTask extends AbstractTaskExecutor { protected String buildCommandForMlflowModels() { /** - * papermill [OPTIONS] NOTEBOOK_PATH [OUTPUT_PATH] + * build mlflow models command */ - Map paramsMap = getParamsMap(); List args = new ArrayList<>(); args.add(String.format(MlflowConstants.EXPORT_MLFLOW_TRACKING_URI_ENV, mlflowParameters.getMlflowTrackingUri())); @@ -211,7 +209,7 @@ public class MlflowTask extends AbstractTaskExecutor { logger.info("checking container healthy ... "); int exitCode = -1; String[] command = {"sh", "-c", String.format(MlflowConstants.DOCKER_HEALTH_CHECK, mlflowParameters.getContainerName())}; - for(int x = 0; x < MlflowConstants.DOCKER_HEALTH_CHECK_TIMEOUT; x = x+1) { + for (int x = 0; x < MlflowConstants.DOCKER_HEALTH_CHECK_TIMEOUT; x = x + 1) { String status; try { status = OSUtils.exeShell(command).replace("\n", "").replace("\"", ""); @@ -224,7 +222,7 @@ public class MlflowTask extends AbstractTaskExecutor { exitCode = 0; logger.info("container is healthy"); return exitCode; - }else { + } else { logger.info("The health check has been running for {} seconds", x * MlflowConstants.DOCKER_HEALTH_CHECK_INTERVAL / 1000); ThreadUtils.sleep(MlflowConstants.DOCKER_HEALTH_CHECK_INTERVAL); } @@ -234,7 +232,6 @@ public class MlflowTask extends AbstractTaskExecutor { return exitCode; } - @Override public AbstractParameters getParameters() { return mlflowParameters; diff --git a/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/test/java/org/apache/dolphinler/plugin/task/mlflow/MlflowTaskTest.java b/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/test/java/org/apache/dolphinler/plugin/task/mlflow/MlflowTaskTest.java index f985666006..ea29d3b4bd 100644 --- a/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/test/java/org/apache/dolphinler/plugin/task/mlflow/MlflowTaskTest.java +++ b/dolphinscheduler-task-plugin/dolphinscheduler-task-mlflow/src/test/java/org/apache/dolphinler/plugin/task/mlflow/MlflowTaskTest.java @@ -17,35 +17,34 @@ package org.apache.dolphinler.plugin.task.mlflow; -import java.util.Date; -import java.util.UUID; - import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext; import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager; -import org.apache.dolphinscheduler.plugin.task.api.utils.OSUtils; import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowConstants; import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowParameters; import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowTask; +import org.apache.dolphinscheduler.spi.utils.JSONUtils; +import org.apache.dolphinscheduler.spi.utils.PropertyUtils; + +import java.util.Date; +import java.util.UUID; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mockito; import org.powermock.api.mockito.PowerMockito; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.apache.dolphinscheduler.spi.utils.PropertyUtils; - -import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.junit.runner.RunWith; -import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; -import org.apache.dolphinscheduler.spi.utils.JSONUtils; +import org.powermock.modules.junit4.PowerMockRunner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @RunWith(PowerMockRunner.class) @PrepareForTest({ - JSONUtils.class, - PropertyUtils.class, + JSONUtils.class, + PropertyUtils.class, }) @PowerMockIgnore({"javax.*"}) @SuppressStaticInitializationFor("org.apache.dolphinscheduler.spi.utils.PropertyUtils") @@ -81,84 +80,85 @@ public class MlflowTaskTest { public void testInitBasicAlgorithmTask() { MlflowTask mlflowTask = initTask(createBasicAlgorithmParameters()); Assert.assertEquals(mlflowTask.buildCommand(), - "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + - "data_path=/data/iris.csv\n" + - "repo=https://github.com/apache/dolphinscheduler-mlflow#Project-BasicAlgorithm\n" + - "mlflow run $repo " + - "-P algorithm=xgboost " + - "-P data_path=$data_path " + - "-P params=\"n_estimators=100\" " + - "-P search_params=\"\" " + - "-P model_name=\"BasicAlgorithm\" " + - "--experiment-name=\"BasicAlgorithm\" " + - "--version=main "); + "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + + "data_path=/data/iris.csv\n" + + "repo=dolphinscheduler-mlflow#Project-BasicAlgorithm\n" + + "git clone https://github.com/apache/dolphinscheduler-mlflow dolphinscheduler-mlflow\n" + + "mlflow run $repo " + + "-P algorithm=xgboost " + + "-P data_path=$data_path " + + "-P params=\"n_estimators=100\" " + + "-P search_params=\"\" " + + "-P model_name=\"BasicAlgorithm\" " + + "--experiment-name=\"BasicAlgorithm\""); } @Test public void testInitAutoMLTask() { MlflowTask mlflowTask = initTask(createAutoMLParameters()); Assert.assertEquals(mlflowTask.buildCommand(), - "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + - "data_path=/data/iris.csv\n" + - "repo=https://github.com/apache/dolphinscheduler-mlflow#Project-AutoML\n" + - "mlflow run $repo " + - "-P tool=autosklearn " + - "-P data_path=$data_path " + - "-P params=\"time_left_for_this_task=30\" " + - "-P model_name=\"AutoML\" " + - "--experiment-name=\"AutoML\" " + - "--version=main "); + "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + + "data_path=/data/iris.csv\n" + + "repo=dolphinscheduler-mlflow#Project-AutoML\n" + + "git clone https://github.com/apache/dolphinscheduler-mlflow dolphinscheduler-mlflow\n" + + "mlflow run $repo " + + "-P tool=autosklearn " + + "-P data_path=$data_path " + + "-P params=\"time_left_for_this_task=30\" " + + "-P model_name=\"AutoML\" " + + "--experiment-name=\"AutoML\""); } @Test public void testInitCustomProjectTask() { MlflowTask mlflowTask = initTask(createCustomProjectParameters()); Assert.assertEquals(mlflowTask.buildCommand(), - "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + - "repo=https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native\n" + - "mlflow run $repo " + - "-P learning_rate=0.2 " + - "-P colsample_bytree=0.8 " + - "-P subsample=0.9 " + - "--experiment-name=\"custom_project\" " + - "--version=\"master\" "); + "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + + "repo=https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native\n" + + "mlflow run $repo " + + "-P learning_rate=0.2 " + + "-P colsample_bytree=0.8 " + + "-P subsample=0.9 " + + "--experiment-name=\"custom_project\" " + + "--version=\"master\" "); } @Test public void testModelsDeployMlflow() { MlflowTask mlflowTask = initTask(createModelDeplyMlflowParameters()); Assert.assertEquals(mlflowTask.buildCommand(), - "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + - "mlflow models serve -m models:/model/1 --port 7000 -h 0.0.0.0"); + "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + + "mlflow models serve -m models:/model/1 --port 7000 -h 0.0.0.0"); } @Test public void testModelsDeployDocker() { MlflowTask mlflowTask = initTask(createModelDeplyDockerParameters()); Assert.assertEquals(mlflowTask.buildCommand(), - "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + - "mlflow models build-docker -m models:/model/1 -n mlflow/model:1 --enable-mlserver\n" + - "docker rm -f ds-mlflow-model-1\n" + - "docker run -d --name=ds-mlflow-model-1 -p=7000:8080 " + - "--health-cmd \"curl --fail http://127.0.0.1:8080/ping || exit 1\" --health-interval 5s --health-retries 20 " + - "mlflow/model:1"); + "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + + "mlflow models build-docker -m models:/model/1 -n mlflow/model:1 --enable-mlserver\n" + + "docker rm -f ds-mlflow-model-1\n" + + "docker run -d --name=ds-mlflow-model-1 -p=7000:8080 " + + "--health-cmd \"curl --fail http://127.0.0.1:8080/ping || exit 1\" --health-interval 5s --health-retries 20 " + + "mlflow/model:1"); } @Test - public void testModelsDeployDockerCompose() throws Exception{ + public void testModelsDeployDockerCompose() throws Exception { MlflowTask mlflowTask = initTask(createModelDeplyDockerComposeParameters()); Assert.assertEquals(mlflowTask.buildCommand(), - "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + - "cp " + mlflowTask.getTemplatePath(MlflowConstants.TEMPLATE_DOCKER_COMPOSE) + - " /tmp/dolphinscheduler_test\n" + - "mlflow models build-docker -m models:/model/1 -n mlflow/model:1 --enable-mlserver\n" + - "docker rm -f ds-mlflow-model-1\n" + - "export DS_TASK_MLFLOW_IMAGE_NAME=mlflow/model:1\n" + - "export DS_TASK_MLFLOW_CONTAINER_NAME=ds-mlflow-model-1\n" + - "export DS_TASK_MLFLOW_DEPLOY_PORT=7000\n" + - "export DS_TASK_MLFLOW_CPU_LIMIT=0.5\n" + - "export DS_TASK_MLFLOW_MEMORY_LIMIT=200m\n" + - "docker-compose up -d"); + "export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n" + + "cp " + + mlflowTask.getTemplatePath(MlflowConstants.TEMPLATE_DOCKER_COMPOSE) + + " /tmp/dolphinscheduler_test\n" + + "mlflow models build-docker -m models:/model/1 -n mlflow/model:1 --enable-mlserver\n" + + "docker rm -f ds-mlflow-model-1\n" + + "export DS_TASK_MLFLOW_IMAGE_NAME=mlflow/model:1\n" + + "export DS_TASK_MLFLOW_CONTAINER_NAME=ds-mlflow-model-1\n" + + "export DS_TASK_MLFLOW_DEPLOY_PORT=7000\n" + + "export DS_TASK_MLFLOW_CPU_LIMIT=0.5\n" + + "export DS_TASK_MLFLOW_MEMORY_LIMIT=200m\n" + + "docker-compose up -d"); } private MlflowTask initTask(MlflowParameters mlflowParameters) {