Skip to content

Commit 3206603

Browse files
committed
test_regressiom_model succesfully run
1 parent 2990991 commit 3206603

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

main.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,8 @@ def go(config: DictConfig):
117117
f"{config['main']['components_repository']}/test_regression_model",
118118
"main",
119119
parameters={
120-
"trainval_artifact": "trainval_data.csv:latest",
121-
"val_size": config['modeling']['val_size'],
122-
"random_seed": config['modeling']['random_seed'],
123-
"stratify_by": config['modeling']['stratify_by'],
124-
'rf_config': rf_config,
125-
'max_tfidf_features': config['modeling']['max_tfidf_features'],
126-
'output_artifact': 'random_forest_export'
120+
"mlflow_model": "random_forest_export:latest",
121+
"test_dataset": "test_data.csv:latest"
127122
},
128123
)
129124

src/train_random_forest/run.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,12 @@ def go(args):
9494
# Save model package in the MLFlow sklearn format
9595
if os.path.exists("random_forest_dir"):
9696
shutil.rmtree("random_forest_dir")
97-
9897
######################################
9998
# Save the sk_pipe pipeline as a mlflow.sklearn model in the directory "random_forest_dir"
10099
# HINT: use mlflow.sklearn.save_model
101100
# YOUR CODE HERE
102101
######################################
103-
mlflow.sklearn.save_model(sk_pipe, 'random_forest_dir')
102+
mlflow.sklearn.save_model(sk_pipe, "random_forest_dir")
104103

105104
logger.info("Saving model") ######################################
106105
# Upload the model we just exported to W&B
@@ -113,10 +112,9 @@ def go(args):
113112
sklearn_artifact = wandb.Artifact(
114113
args.output_artifact,
115114
type="model_export",
116-
description="Sklearn Trained model",
117-
metadata=args.rf_config
115+
description="Sklearn Trained model"
118116
)
119-
sklearn_artifact.add_dir('random_forest_dir')
117+
sklearn_artifact.add_dir("random_forest_dir")
120118
run.log_artifact(sklearn_artifact)
121119
logger.info("Model Uploaded")
122120
# Plot feature importance

0 commit comments

Comments
 (0)