File size: 1,698 Bytes
8fe5582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Utils"""
from __future__ import annotations

import json
from pathlib import Path
from typing import Literal

from loguru import logger


def download_model(
    model_name: str,
    model_stage: Literal["staging", "production"],
    model_dir: str | Path = "model",
) -> Path:
    """Download model from mlflow"""
    import mlflow.artifacts
    import mlflow.models
    from mlflow.client import MlflowClient
    
    logger.info(f"Looking for model {model_name}/{model_stage}")

    if isinstance(model_dir, str):
        model_dir = Path(model_dir)

    client = MlflowClient()
    model_versions = client.get_latest_versions(model_name, stages=[model_stage])
    if len(model_versions) != 1:
        raise ValueError(f"No model version for {model_name}/{model_stage}")

    artifact_uri = model_versions[0].source
    model_version = model_versions[0].version

    logger.info(f"Found version {model_version} for {model_name}/{model_stage}")

    model_path = model_dir / artifact_uri.split("/")[-1]  # type: ignore
    if model_path.exists():
        logger.info(f"Found model in {model_path}, skipping download")
        return model_path

    logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}")
    model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir))
    logger.info(f"Succesfully downloaded {model_name}")

    model_info = mlflow.models.get_model_info(model_path)
    metadata = model_info.metadata
    metadata_path = Path(model_path) / "metadata.json"
    logger.info(f"Saving metadata to {metadata_path}")
    with open(metadata_path, "w", encoding="utf-8") as file:
        json.dump(metadata, file)

    return Path(model_path)