@ -17,6 +17,8 @@
package org.apache.dolphinler.plugin.task.mlflow ;
import static org.powermock.api.mockito.PowerMockito.when ;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContext ;
import org.apache.dolphinscheduler.plugin.task.api.TaskExecutionContextCacheManager ;
import org.apache.dolphinscheduler.plugin.task.mlflow.MlflowConstants ;
@ -76,21 +78,46 @@ public class MlflowTaskTest {
return taskExecutionContext ;
}
@Test
public void testGetPresetRepositoryData ( ) {
Assert . assertEquals ( "https://github.com/apache/dolphinscheduler-mlflow" , MlflowTask . getPresetRepository ( ) ) ;
Assert . assertEquals ( "main" , MlflowTask . getPresetRepositoryVersion ( ) ) ;
String definedRepository = "https://github.com/<MY-ID>/dolphinscheduler-mlflow" ;
when ( PropertyUtils . getString ( MlflowConstants . PRESET_REPOSITORY_KEY ) ) . thenAnswer ( invocation - > definedRepository ) ;
Assert . assertEquals ( definedRepository , MlflowTask . getPresetRepository ( ) ) ;
String definedRepositoryVersion = "dev" ;
when ( PropertyUtils . getString ( MlflowConstants . PRESET_REPOSITORY_VERSION_KEY ) ) . thenAnswer ( invocation - > definedRepositoryVersion ) ;
Assert . assertEquals ( definedRepositoryVersion , MlflowTask . getPresetRepositoryVersion ( ) ) ;
}
@Test
public void testGetVersionString ( ) {
Assert . assertEquals ( "--version=main" , MlflowTask . getVersionString ( "main" , "https://github.com/apache/dolphinscheduler-mlflow" ) ) ;
Assert . assertEquals ( "--version=master" , MlflowTask . getVersionString ( "master" , "https://github.com/apache/dolphinscheduler-mlflow" ) ) ;
Assert . assertEquals ( "--version=main" , MlflowTask . getVersionString ( "main" , "git@github.com:apache/dolphinscheduler-mlflow.git" ) ) ;
Assert . assertEquals ( "--version=master" , MlflowTask . getVersionString ( "master" , "git@github.com:apache/dolphinscheduler-mlflow.git" ) ) ;
Assert . assertEquals ( "" , MlflowTask . getVersionString ( "main" , "/tmp/dolphinscheduler-mlflow" ) ) ;
Assert . assertEquals ( "" , MlflowTask . getVersionString ( "master" , "/tmp/dolphinscheduler-mlflow" ) ) ;
}
@Test
public void testInitBasicAlgorithmTask ( ) {
MlflowTask mlflowTask = initTask ( createBasicAlgorithmParameters ( ) ) ;
Assert . assertEquals ( mlflowTask . buildCommand ( ) ,
"export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n"
+ "data_path=/data/iris.csv\n"
+ "repo=dolphinscheduler-mlflow#Project-BasicAlgorithm\n"
+ "git clone https://github.com/apache/dolphinscheduler-mlflow dolphinscheduler-mlflow\n"
+ "repo=https://github.com/apache/dolphinscheduler-mlflow#Project-BasicAlgorithm\n"
+ "mlflow run $repo "
+ "-P algorithm=xgboost "
+ "-P data_path=$data_path "
+ "-P params=\"n_estimators=100\" "
+ "-P search_params=\"\" "
+ "-P model_name=\"BasicAlgorithm\" "
+ "--experiment-name=\"BasicAlgorithm\"" ) ;
+ "--experiment-name=\"BasicAlgorithm\" "
+ "--version=main" ) ;
}
@Test
@ -99,19 +126,32 @@ public class MlflowTaskTest {
Assert . assertEquals ( mlflowTask . buildCommand ( ) ,
"export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n"
+ "data_path=/data/iris.csv\n"
+ "repo=dolphinscheduler-mlflow#Project-AutoML\n"
+ "git clone https://github.com/apache/dolphinscheduler-mlflow dolphinscheduler-mlflow\n"
+ "repo=https://github.com/apache/dolphinscheduler-mlflow#Project-AutoML\n"
+ "mlflow run $repo "
+ "-P tool=autosklearn "
+ "-P data_path=$data_path "
+ "-P params=\"time_left_for_this_task=30\" "
+ "-P model_name=\"AutoML\" "
+ "--experiment-name=\"AutoML\"" ) ;
+ "--experiment-name=\"AutoML\" "
+ "--version=main" ) ;
}
@Test
public void testInitCustomProjectTask ( ) {
MlflowTask mlflowTask = initTask ( createCustomProjectParameters ( ) ) ;
// Version will be set if parameter.mlflowProjectVersion is empty
Assert . assertEquals ( mlflowTask . buildCommand ( ) ,
"export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n"
+ "repo=https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native\n"
+ "mlflow run $repo "
+ "-P learning_rate=0.2 "
+ "-P colsample_bytree=0.8 "
+ "-P subsample=0.9 "
+ "--experiment-name=\"custom_project\"" ) ;
// Version will be set if repository is remote path
mlflowTask . getParameters ( ) . setMlflowProjectVersion ( "dev" ) ;
Assert . assertEquals ( mlflowTask . buildCommand ( ) ,
"export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n"
+ "repo=https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native\n"
@ -120,7 +160,19 @@ public class MlflowTaskTest {
+ "-P colsample_bytree=0.8 "
+ "-P subsample=0.9 "
+ "--experiment-name=\"custom_project\" "
+ "--version=\"master\" " ) ;
+ "--version=dev" ) ;
// Version will not be set if repository is local path
mlflowTask . getParameters ( ) . setMlflowProjectRepository ( "/tmp/dolphinscheduler-mlflow" ) ;
Assert . assertEquals ( mlflowTask . buildCommand ( ) ,
"export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n"
+ "repo=/tmp/dolphinscheduler-mlflow\n"
+ "mlflow run $repo "
+ "-P learning_rate=0.2 "
+ "-P colsample_bytree=0.8 "
+ "-P subsample=0.9 "
+ "--experiment-name=\"custom_project\"" ) ;
}
@Test
@ -143,24 +195,6 @@ public class MlflowTaskTest {
+ "mlflow/model:1" ) ;
}
@Test
public void testModelsDeployDockerCompose ( ) throws Exception {
MlflowTask mlflowTask = initTask ( createModelDeplyDockerComposeParameters ( ) ) ;
Assert . assertEquals ( mlflowTask . buildCommand ( ) ,
"export MLFLOW_TRACKING_URI=http://127.0.0.1:5000\n"
+ "cp "
+ mlflowTask . getTemplatePath ( MlflowConstants . TEMPLATE_DOCKER_COMPOSE )
+ " /tmp/dolphinscheduler_test\n"
+ "mlflow models build-docker -m models:/model/1 -n mlflow/model:1 --enable-mlserver\n"
+ "docker rm -f ds-mlflow-model-1\n"
+ "export DS_TASK_MLFLOW_IMAGE_NAME=mlflow/model:1\n"
+ "export DS_TASK_MLFLOW_CONTAINER_NAME=ds-mlflow-model-1\n"
+ "export DS_TASK_MLFLOW_DEPLOY_PORT=7000\n"
+ "export DS_TASK_MLFLOW_CPU_LIMIT=0.5\n"
+ "export DS_TASK_MLFLOW_MEMORY_LIMIT=200m\n"
+ "docker-compose up -d" ) ;
}
private MlflowTask initTask ( MlflowParameters mlflowParameters ) {
TaskExecutionContext taskExecutionContext = createContext ( mlflowParameters ) ;
MlflowTask mlflowTask = new MlflowTask ( taskExecutionContext ) ;
@ -174,11 +208,11 @@ public class MlflowTaskTest {
mlflowParameters . setMlflowTaskType ( MlflowConstants . MLFLOW_TASK_TYPE_PROJECTS ) ;
mlflowParameters . setMlflowJobType ( MlflowConstants . JOB_TYPE_BASIC_ALGORITHM ) ;
mlflowParameters . setAlgorithm ( "xgboost" ) ;
mlflowParameters . setDataPaths ( "/data/iris.csv" ) ;
mlflowParameters . setDataPath ( "/data/iris.csv" ) ;
mlflowParameters . setParams ( "n_estimators=100" ) ;
mlflowParameters . setExperimentNames ( "BasicAlgorithm" ) ;
mlflowParameters . setModelNames ( "BasicAlgorithm" ) ;
mlflowParameters . setMlflowTrackingUris ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setExperimentName ( "BasicAlgorithm" ) ;
mlflowParameters . setModelName ( "BasicAlgorithm" ) ;
mlflowParameters . setMlflowTrackingUri ( "http://127.0.0.1:5000" ) ;
return mlflowParameters ;
}
@ -188,10 +222,10 @@ public class MlflowTaskTest {
mlflowParameters . setMlflowJobType ( MlflowConstants . JOB_TYPE_AUTOML ) ;
mlflowParameters . setAutomlTool ( "autosklearn" ) ;
mlflowParameters . setParams ( "time_left_for_this_task=30" ) ;
mlflowParameters . setDataPaths ( "/data/iris.csv" ) ;
mlflowParameters . setExperimentNames ( "AutoML" ) ;
mlflowParameters . setModelNames ( "AutoML" ) ;
mlflowParameters . setMlflowTrackingUris ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setDataPath ( "/data/iris.csv" ) ;
mlflowParameters . setExperimentName ( "AutoML" ) ;
mlflowParameters . setModelName ( "AutoML" ) ;
mlflowParameters . setMlflowTrackingUri ( "http://127.0.0.1:5000" ) ;
return mlflowParameters ;
}
@ -199,8 +233,8 @@ public class MlflowTaskTest {
MlflowParameters mlflowParameters = new MlflowParameters ( ) ;
mlflowParameters . setMlflowTaskType ( MlflowConstants . MLFLOW_TASK_TYPE_PROJECTS ) ;
mlflowParameters . setMlflowJobType ( MlflowConstants . JOB_TYPE_CUSTOM_PROJECT ) ;
mlflowParameters . setMlflowTrackingUris ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setExperimentNames ( "custom_project" ) ;
mlflowParameters . setMlflowTrackingUri ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setExperimentName ( "custom_project" ) ;
mlflowParameters . setParams ( "-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9" ) ;
mlflowParameters . setMlflowProjectRepository ( "https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native" ) ;
@ -211,7 +245,7 @@ public class MlflowTaskTest {
MlflowParameters mlflowParameters = new MlflowParameters ( ) ;
mlflowParameters . setMlflowTaskType ( MlflowConstants . MLFLOW_TASK_TYPE_MODELS ) ;
mlflowParameters . setDeployType ( MlflowConstants . MLFLOW_MODELS_DEPLOY_TYPE_MLFLOW ) ;
mlflowParameters . setMlflowTrackingUris ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setMlflowTrackingUri ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setDeployModelKey ( "models:/model/1" ) ;
mlflowParameters . setDeployPort ( "7000" ) ;
return mlflowParameters ;
@ -221,21 +255,9 @@ public class MlflowTaskTest {
MlflowParameters mlflowParameters = new MlflowParameters ( ) ;
mlflowParameters . setMlflowTaskType ( MlflowConstants . MLFLOW_TASK_TYPE_MODELS ) ;
mlflowParameters . setDeployType ( MlflowConstants . MLFLOW_MODELS_DEPLOY_TYPE_DOCKER ) ;
mlflowParameters . setMlflowTrackingUris ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setDeployModelKey ( "models:/model/1" ) ;
mlflowParameters . setDeployPort ( "7000" ) ;
return mlflowParameters ;
}
private MlflowParameters createModelDeplyDockerComposeParameters ( ) {
MlflowParameters mlflowParameters = new MlflowParameters ( ) ;
mlflowParameters . setMlflowTaskType ( MlflowConstants . MLFLOW_TASK_TYPE_MODELS ) ;
mlflowParameters . setDeployType ( MlflowConstants . MLFLOW_MODELS_DEPLOY_TYPE_DOCKER_COMPOSE ) ;
mlflowParameters . setMlflowTrackingUris ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setMlflowTrackingUri ( "http://127.0.0.1:5000" ) ;
mlflowParameters . setDeployModelKey ( "models:/model/1" ) ;
mlflowParameters . setDeployPort ( "7000" ) ;
mlflowParameters . setCpuLimit ( "0.5" ) ;
mlflowParameters . setMemoryLimit ( "200m" ) ;
return mlflowParameters ;
}
}