Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1 +1,115 @@
import joblib
import os
import json
from datetime import datetime

MODEL_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../models_saved"))
METADATA_FILE = os.path.join(MODEL_DIR, "metadata.json")


def _load_metadata() -> dict:
"""Load existing metadata registry or return empty dict."""
if os.path.exists(METADATA_FILE):
with open(METADATA_FILE, "r") as f:
return json.load(f)
return {}


def _save_metadata(metadata: dict):
"""Persist metadata registry to disk."""
os.makedirs(MODEL_DIR, exist_ok=True)
with open(METADATA_FILE, "w") as f:
json.dump(metadata, f, indent=4)


def save_model(model, version_tag: str, hyperparams: dict = None, metrics: dict = None):
"""
Save a trained model with a version tag and metadata.

Args:
model: Trained sklearn model
version_tag: e.g. 'v1_base', 'v2_tuned'
hyperparams: Dict of hyperparameters used (optional)
metrics: Dict of evaluation metrics (optional)

Example:
save_model(rf, "v1_base",
hyperparams={"n_estimators": 100},
metrics={"accuracy": 0.91})
"""
os.makedirs(MODEL_DIR, exist_ok=True)

filename = f"model_{version_tag}.pkl"
path = os.path.join(MODEL_DIR, filename)
joblib.dump(model, path)

metadata = _load_metadata()
metadata[version_tag] = {
"filename": filename,
"training_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"hyperparameters": hyperparams or {},
"metrics": metrics or {},
}
_save_metadata(metadata)

print(f"[model_utils] Model saved as '{filename}'")
print(f"[model_utils] Metadata updated for version '{version_tag}'")


def load_model(version_tag: str):
"""
Load a saved model by version tag.

Args:
version_tag: e.g. 'v1_base', 'v2_tuned'

Returns:
Loaded sklearn model
"""
metadata = _load_metadata()

if version_tag not in metadata:
raise KeyError(f"[model_utils] Version '{version_tag}' not found in metadata registry.")

filename = metadata[version_tag]["filename"]
path = os.path.join(MODEL_DIR, filename)

if not os.path.exists(path):
raise FileNotFoundError(f"[model_utils] Model file not found at {path}")

model = joblib.load(path)
print(f"[model_utils] Loaded model version '{version_tag}' from {path}")
return model


def list_versions():
"""
Print all saved model versions with their metadata.
"""
metadata = _load_metadata()

if not metadata:
print("[model_utils] No saved model versions found.")
return

print(f"\n{'='*50}")
print(f"{'VERSION':<15} {'DATE':<22} {'ACCURACY':<10}")
print(f"{'='*50}")
for version, info in metadata.items():
accuracy = info["metrics"].get("accuracy", "N/A")
print(f"{version:<15} {info['training_date']:<22} {accuracy:<10}")
print(f"{'='*50}\n")


def rollback(version_tag: str):
"""
Rollback to a specific model version by loading it.

Args:
version_tag: Version to roll back to

Returns:
Loaded sklearn model
"""
print(f"[model_utils] Rolling back to version '{version_tag}'...")
return load_model(version_tag)