Spaces:
Runtime error
Runtime error
import json | |
import os | |
import requests | |
from typing import List | |
from src.common import config_dir, hf_api_token | |
class HFLlamaChatModel: | |
models = None | |
def load_configs(cls): | |
config_file = os.path.join(config_dir, "models.json") | |
with open(config_file, "r") as f: | |
configs = json.load(f)['models'] | |
cls.models = [] | |
for cfg in configs: | |
if cls.for_name(cfg['name']) is None: | |
cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description'])) | |
def for_name(cls, name: str): | |
if cls.models is None: | |
cls.load_configs() | |
for m in cls.models: | |
if m.name == name: | |
return m | |
def for_model(cls, model: str): | |
if cls.models is None: | |
cls.load_configs() | |
for m in cls.models: | |
if m.id == model: | |
return m | |
def available_models(cls) -> List[str]: | |
if cls.models is None: | |
cls.load_configs() | |
return [m.name for m in cls.models] | |
def __init__(self, name: str, id: str, description: str): | |
self.name = name | |
self.id = id | |
self.description = description | |
def __call__(self, | |
query: str, | |
auth_token: str = None, | |
system_prompt: str = None, | |
max_new_tokens: str = 256, | |
temperature: float = 1.0): | |
if auth_token is None: | |
auth_token = hf_api_token() # Attempt look up if not provided | |
headers = {"Authorization": f"Bearer {auth_token}"} | |
api_url = f"https://api-inference.huggingface.co/models/{self.id}" | |
if system_prompt is None: | |
system_prompt = "You are a helpful assistant." | |
query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {query} [/INST] " | |
query_payload = { | |
"inputs": query_input, | |
"parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature} | |
} | |
response = requests.post(api_url, headers=headers, json=query_payload) | |
if response.status_code == 200: | |
resp_json = json.loads(response.text) | |
llm_text = resp_json[0]['generated_text'].strip() | |
return llm_text | |
else: | |
error_detail = f"Error from hugging face code: {response.status_code}: {response.reason} ({response.content})" | |
raise ValueError(error_detail) | |