Last active
June 22, 2025 12:02
-
-
Save nkapila6/23eb57356e3557350b34490332cca256 to your computer and use it in GitHub Desktop.
MLFlow Wrapper that wraps onto the PyTorch Solver.py for Assignment 2 in CS7643 Deep Learning OMSCS Class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on 2024-09-28 19:39:23 Saturday | |
| @author: Nikhil Kapila | |
| """ | |
| import mlflow | |
| import os | |
| from mlflow.models import infer_signature | |
| import matplotlib.pyplot as plt | |
| import yaml | |
| import argparse | |
| import torch | |
| from solver import Solver | |
| import pickle | |
| import mlflow | |
| def log_run(solver, | |
| uri, experiment_name, **kwargs): | |
| mlflow.set_tracking_uri(uri) | |
| mlflow.set_experiment(experiment_name) | |
| print(uri) | |
| with mlflow.start_run() as run: | |
| # run_id = run.info.run_id | |
| # Logging | |
| # logs the full dict | |
| mlflow.log_params(kwargs) | |
| # log Solver object | |
| # path_save = kwargs['prefix_path']+f'/a2-mymodel/{run_id}.pkl' | |
| # with open(path_save, 'wb') as f: | |
| # pickle.dump(solver, f) | |
| # mlflow.log_artifact(path_save) | |
| # log per class acc | |
| per_class_acc = solver.best_cm.diag().detach().numpy().tolist() | |
| for i, acc in enumerate(per_class_acc): | |
| mlflow.log_metric(f"accuracy_class_{i}", acc) | |
| mlflow.log_param("per_class_acc", solver.best_cm.diag().detach().numpy().tolist()) | |
| # log val loss over epochs | |
| for epoch in range(len(solver.loss_val)): | |
| mlflow.log_metric("loss_train", solver.loss_train[epoch], step=epoch) | |
| mlflow.log_metric("loss_val", solver.loss_val[epoch], step=epoch) | |
| # log epoch at end | |
| mlflow.log_metric("Val Loss", solver.loss_val[-1]) | |
| mlflow.log_metric("Train Loss", solver.loss_train[-1]) | |
| # log weights | |
| # mlflow.log_param("weights", solver.best_model.state_dict()) | |
| # saving mymodel.pth | |
| torch.save(solver.best_model.state_dict(), solver.path_to_save) | |
| mlflow.log_artifact(solver.path_to_save) | |
| # log best | |
| mlflow.log_metric("Accuracy", solver.best.item()) | |
| # Plot | |
| plt.plot(solver.loss_train, label='Training loss') | |
| plt.plot(solver.loss_val, label='Validation loss') | |
| plt.xlabel('epochs') | |
| plt.ylabel('loss') | |
| plt.xticks(range(0,len(solver.loss_train))) | |
| plt.legend() | |
| plt.savefig('loss_plot.png', dpi=300) | |
| mlflow.log_artifact('loss_plot.png') | |
| plt.close() | |
| # Remove saved elements | |
| os.remove('loss_plot.png') | |
| def main(path, mlflow_tracking_uri, mlflow_exp, databricks=False, **kwargs): | |
| GOOGLE_DRIVE_PATH = path | |
| # config_file = GOOGLE_DRIVE_PATH + "/configs/" + config_file + ".yaml" | |
| # print("Training a model using configuration file " + config_file) | |
| # with open(config_file, "r") as read_file: | |
| # config = yaml.safe_load(read_file) | |
| local_uri = f"{path}/mlruns/{mlflow_exp}" | |
| # local_uri = f"file:///{path}/mlruns/a2_mymodel" | |
| # print(kwargs) | |
| solver = Solver(**kwargs) | |
| solver.train() | |
| # print(local_name) | |
| log_run(solver, uri=local_uri, experiment_name=mlflow_exp, **kwargs) | |
| if databricks: | |
| mlflow.login() | |
| mlflow.set_tracking_uri(mlflow_tracking_uri) | |
| mlflow.set_experiment(mlflow_exp) | |
| log_run(solver, uri=mlflow_tracking_uri, experiment_name=mlflow_exp, **kwargs) | |
| return solver | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment