File size: 555 Bytes
b71808f
 
 
19b3da3
 
 
fd5252e
 
19b3da3
 
 
 
1bc457e
b71808f
 
 
19b3da3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import json
import os
from pathlib import Path
from typing import Any, Dict, List

from inference import model_fn, predict_fn
from internals.util.config import set_hf_cache_dir, set_model_config
from internals.util.model_loader import load_model_from_config


class EndpointHandler:
    def __init__(self, path=""):
        set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
        self.model_dir = path

        return model_fn(self.model_dir)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        return predict_fn(data, None)