Skip to content

Instantly share code, notes, and snippets.

@nkapila6
Last active June 22, 2025 12:02
Show Gist options
  • Select an option

  • Save nkapila6/23eb57356e3557350b34490332cca256 to your computer and use it in GitHub Desktop.

Select an option

Save nkapila6/23eb57356e3557350b34490332cca256 to your computer and use it in GitHub Desktop.
MLFlow Wrapper that wraps onto the PyTorch Solver.py for Assignment 2 in CS7643 Deep Learning OMSCS Class
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on 2024-09-28 19:39:23 Saturday
@author: Nikhil Kapila
"""
import mlflow
import os
from mlflow.models import infer_signature
import matplotlib.pyplot as plt
import yaml
import argparse
import torch
from solver import Solver
import pickle
import mlflow
def log_run(solver,
uri, experiment_name, **kwargs):
mlflow.set_tracking_uri(uri)
mlflow.set_experiment(experiment_name)
print(uri)
with mlflow.start_run() as run:
# run_id = run.info.run_id
# Logging
# logs the full dict
mlflow.log_params(kwargs)
# log Solver object
# path_save = kwargs['prefix_path']+f'/a2-mymodel/{run_id}.pkl'
# with open(path_save, 'wb') as f:
# pickle.dump(solver, f)
# mlflow.log_artifact(path_save)
# log per class acc
per_class_acc = solver.best_cm.diag().detach().numpy().tolist()
for i, acc in enumerate(per_class_acc):
mlflow.log_metric(f"accuracy_class_{i}", acc)
mlflow.log_param("per_class_acc", solver.best_cm.diag().detach().numpy().tolist())
# log val loss over epochs
for epoch in range(len(solver.loss_val)):
mlflow.log_metric("loss_train", solver.loss_train[epoch], step=epoch)
mlflow.log_metric("loss_val", solver.loss_val[epoch], step=epoch)
# log epoch at end
mlflow.log_metric("Val Loss", solver.loss_val[-1])
mlflow.log_metric("Train Loss", solver.loss_train[-1])
# log weights
# mlflow.log_param("weights", solver.best_model.state_dict())
# saving mymodel.pth
torch.save(solver.best_model.state_dict(), solver.path_to_save)
mlflow.log_artifact(solver.path_to_save)
# log best
mlflow.log_metric("Accuracy", solver.best.item())
# Plot
plt.plot(solver.loss_train, label='Training loss')
plt.plot(solver.loss_val, label='Validation loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.xticks(range(0,len(solver.loss_train)))
plt.legend()
plt.savefig('loss_plot.png', dpi=300)
mlflow.log_artifact('loss_plot.png')
plt.close()
# Remove saved elements
os.remove('loss_plot.png')
def main(path, mlflow_tracking_uri, mlflow_exp, databricks=False, **kwargs):
GOOGLE_DRIVE_PATH = path
# config_file = GOOGLE_DRIVE_PATH + "/configs/" + config_file + ".yaml"
# print("Training a model using configuration file " + config_file)
# with open(config_file, "r") as read_file:
# config = yaml.safe_load(read_file)
local_uri = f"{path}/mlruns/{mlflow_exp}"
# local_uri = f"file:///{path}/mlruns/a2_mymodel"
# print(kwargs)
solver = Solver(**kwargs)
solver.train()
# print(local_name)
log_run(solver, uri=local_uri, experiment_name=mlflow_exp, **kwargs)
if databricks:
mlflow.login()
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(mlflow_exp)
log_run(solver, uri=mlflow_tracking_uri, experiment_name=mlflow_exp, **kwargs)
return solver
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment