diff --git a/src/models/model_utils.py b/src/models/model_utils.py index 8b13789..9ee4d84 100644 --- a/src/models/model_utils.py +++ b/src/models/model_utils.py @@ -1 +1,115 @@ +import joblib +import os +import json +from datetime import datetime +MODEL_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../models_saved")) +METADATA_FILE = os.path.join(MODEL_DIR, "metadata.json") + + +def _load_metadata() -> dict: + """Load existing metadata registry or return empty dict.""" + if os.path.exists(METADATA_FILE): + with open(METADATA_FILE, "r") as f: + return json.load(f) + return {} + + +def _save_metadata(metadata: dict): + """Persist metadata registry to disk.""" + os.makedirs(MODEL_DIR, exist_ok=True) + with open(METADATA_FILE, "w") as f: + json.dump(metadata, f, indent=4) + + +def save_model(model, version_tag: str, hyperparams: dict = None, metrics: dict = None): + """ + Save a trained model with a version tag and metadata. + + Args: + model: Trained sklearn model + version_tag: e.g. 'v1_base', 'v2_tuned' + hyperparams: Dict of hyperparameters used (optional) + metrics: Dict of evaluation metrics (optional) + + Example: + save_model(rf, "v1_base", + hyperparams={"n_estimators": 100}, + metrics={"accuracy": 0.91}) + """ + os.makedirs(MODEL_DIR, exist_ok=True) + + filename = f"model_{version_tag}.pkl" + path = os.path.join(MODEL_DIR, filename) + joblib.dump(model, path) + + metadata = _load_metadata() + metadata[version_tag] = { + "filename": filename, + "training_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "hyperparameters": hyperparams or {}, + "metrics": metrics or {}, + } + _save_metadata(metadata) + + print(f"[model_utils] Model saved as '{filename}'") + print(f"[model_utils] Metadata updated for version '{version_tag}'") + + +def load_model(version_tag: str): + """ + Load a saved model by version tag. + + Args: + version_tag: e.g. 'v1_base', 'v2_tuned' + + Returns: + Loaded sklearn model + """ + metadata = _load_metadata() + + if version_tag not in metadata: + raise KeyError(f"[model_utils] Version '{version_tag}' not found in metadata registry.") + + filename = metadata[version_tag]["filename"] + path = os.path.join(MODEL_DIR, filename) + + if not os.path.exists(path): + raise FileNotFoundError(f"[model_utils] Model file not found at {path}") + + model = joblib.load(path) + print(f"[model_utils] Loaded model version '{version_tag}' from {path}") + return model + + +def list_versions(): + """ + Print all saved model versions with their metadata. + """ + metadata = _load_metadata() + + if not metadata: + print("[model_utils] No saved model versions found.") + return + + print(f"\n{'='*50}") + print(f"{'VERSION':<15} {'DATE':<22} {'ACCURACY':<10}") + print(f"{'='*50}") + for version, info in metadata.items(): + accuracy = info["metrics"].get("accuracy", "N/A") + print(f"{version:<15} {info['training_date']:<22} {accuracy:<10}") + print(f"{'='*50}\n") + + +def rollback(version_tag: str): + """ + Rollback to a specific model version by loading it. + + Args: + version_tag: Version to roll back to + + Returns: + Loaded sklearn model + """ + print(f"[model_utils] Rolling back to version '{version_tag}'...") + return load_model(version_tag)