File size: 1,493 Bytes
b71808f 19b3da3 1bc457e b71808f 19b3da3 1bc457e b71808f 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import json
import os
from pathlib import Path
from typing import Any, Dict, List
from inference import model_fn, predict_fn
from internals.util.config import set_hf_cache_dir
from internals.util.model_downloader import BaseModelDownloader
class EndpointHandler:
def __init__(self, path=""):
set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
self.model_dir = path
if os.path.exists(path + "/inference.json"):
with open(path + "/inference.json", "r") as f:
config = json.loads(f.read())
if config.get("model_type") == "huggingface":
self.model_dir = config["model_path"]
if config.get("model_type") == "s3":
s3_config = config["model_path"]["s3"]
base_url = s3_config["base_url"]
urls = [base_url + item for item in s3_config["paths"]]
out_dir = Path.home() / ".cache" / "base_model"
if out_dir.exists():
print("Model already exist")
else:
print("Downloading model")
BaseModelDownloader(
urls, s3_config["paths"], out_dir
).download()
self.model_dir = str(out_dir)
return model_fn(self.model_dir)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
return predict_fn(data, None)
|