llm-arch / src /models.py
alfraser's picture
Fixed a bug where the response from the LLM was being trimmed of content.
1215037
raw
history blame
2.52 kB
import json
import os
import requests
from typing import List
from src.common import config_dir, hf_api_token
class HFLlamaChatModel:
models = None
@classmethod
def load_configs(cls):
config_file = os.path.join(config_dir, "models.json")
with open(config_file, "r") as f:
configs = json.load(f)['models']
cls.models = []
for cfg in configs:
if cls.for_name(cfg['name']) is None:
cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
@classmethod
def for_name(cls, name: str):
if cls.models is None:
cls.load_configs()
for m in cls.models:
if m.name == name:
return m
@classmethod
def for_model(cls, model: str):
if cls.models is None:
cls.load_configs()
for m in cls.models:
if m.id == model:
return m
@classmethod
def available_models(cls) -> List[str]:
if cls.models is None:
cls.load_configs()
return [m.name for m in cls.models]
def __init__(self, name: str, id: str, description: str):
self.name = name
self.id = id
self.description = description
def __call__(self,
query: str,
auth_token: str = None,
system_prompt: str = None,
max_new_tokens: str = 256,
temperature: float = 1.0):
if auth_token is None:
auth_token = hf_api_token() # Attempt look up if not provided
headers = {"Authorization": f"Bearer {auth_token}"}
api_url = f"https://api-inference.huggingface.co/models/{self.id}"
if system_prompt is None:
system_prompt = "You are a helpful assistant."
query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {query} [/INST] "
query_payload = {
"inputs": query_input,
"parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}
}
response = requests.post(api_url, headers=headers, json=query_payload)
if response.status_code == 200:
resp_json = json.loads(response.text)
llm_text = resp_json[0]['generated_text'].strip()
return llm_text
else:
error_detail = f"Error from hugging face code: {response.status_code}: {response.reason} ({response.content})"
raise ValueError(error_detail)