|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import Any, Dict, List |
|
|
|
from inference import model_fn, predict_fn |
|
from internals.util.config import set_hf_cache_dir |
|
from internals.util.model_downloader import BaseModelDownloader |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
set_hf_cache_dir(Path.home() / ".cache" / "hf_cache") |
|
self.model_dir = path |
|
|
|
if os.path.exists(path + "/inference.json"): |
|
with open(path + "/inference.json", "r") as f: |
|
config = json.loads(f.read()) |
|
if config.get("model_type") == "huggingface": |
|
self.model_dir = config["model_path"] |
|
if config.get("model_type") == "s3": |
|
s3_config = config["model_path"]["s3"] |
|
base_url = s3_config["base_url"] |
|
|
|
urls = [base_url + item for item in s3_config["paths"]] |
|
out_dir = Path.home() / ".cache" / "base_model" |
|
if out_dir.exists(): |
|
print("Model already exist") |
|
else: |
|
print("Downloading model") |
|
BaseModelDownloader( |
|
urls, s3_config["paths"], out_dir |
|
).download() |
|
|
|
self.model_dir = str(out_dir) |
|
|
|
return model_fn(self.model_dir) |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
return predict_fn(data, None) |
|
|