import json import os from pathlib import Path from typing import Any, Dict, List from inference import model_fn, predict_fn from internals.util.config import set_hf_cache_dir from internals.util.model_downloader import BaseModelDownloader class EndpointHandler: def __init__(self, path=""): set_hf_cache_dir(Path.home() / ".cache" / "hf_cache") self.model_dir = path if os.path.exists(path + "/inference.json"): with open(path + "/inference.json", "r") as f: config = json.loads(f.read()) if config.get("model_type") == "huggingface": self.model_dir = config["model_path"] if config.get("model_type") == "s3": s3_config = config["model_path"]["s3"] base_url = s3_config["base_url"] urls = [base_url + item for item in s3_config["paths"]] out_dir = Path.home() / ".cache" / "base_model" if out_dir.exists(): print("Model already exist") else: print("Downloading model") BaseModelDownloader( urls, s3_config["paths"], out_dir ).download() self.model_dir = str(out_dir) return model_fn(self.model_dir) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: return predict_fn(data, None)