File size: 1,493 Bytes
b71808f
 
 
19b3da3
 
 
1bc457e
b71808f
19b3da3
 
 
 
1bc457e
b71808f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19b3da3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import json
import os
from pathlib import Path
from typing import Any, Dict, List

from inference import model_fn, predict_fn
from internals.util.config import set_hf_cache_dir
from internals.util.model_downloader import BaseModelDownloader


class EndpointHandler:
    def __init__(self, path=""):
        set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
        self.model_dir = path

        if os.path.exists(path + "/inference.json"):
            with open(path + "/inference.json", "r") as f:
                config = json.loads(f.read())
                if config.get("model_type") == "huggingface":
                    self.model_dir = config["model_path"]
                if config.get("model_type") == "s3":
                    s3_config = config["model_path"]["s3"]
                    base_url = s3_config["base_url"]

                    urls = [base_url + item for item in s3_config["paths"]]
                    out_dir = Path.home() / ".cache" / "base_model"
                    if out_dir.exists():
                        print("Model already exist")
                    else:
                        print("Downloading model")
                        BaseModelDownloader(
                            urls, s3_config["paths"], out_dir
                        ).download()

                    self.model_dir = str(out_dir)

        return model_fn(self.model_dir)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        return predict_fn(data, None)