llama3-8b-meetingQA / handler.py
tykiww's picture
Create handler.py
0fb279d verified
from utilities.setup import *
import json
import os
from typing import Dict, List, Any
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
class EndpointHandler():
def __init__(self, path=""):
"""Initialize class. Load model of interest upon init."""
print("Reading config")
self.path = path
self.HF_TOKEN = os.getenv("HF_TOKEN")
self.wd = os.getcwd()
self.model_name = os.path.basename(self.wd)
print("loading model")
self.model, self.tokenizer = self.load_model()
def load_model(self):
"""Load unsloth model and tokenizer"""
model = AutoPeftModelForCausalLM.from_pretrained(
self.path,
load_in_4bit = True,
)
tokenizer = AutoTokenizer.from_pretrained(self.path)
return model, tokenizer
def prompt_formatter(self, prompt):
"""Prompts must be formatted in alpaca style prior to API."""
inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda")
return inputs, prompt
def infer(self, prompt, max_new_tokens=1000): # add streaming capability
"""Bringing it all together"""
# load model
inputs, prompt_text = self.prompt_formatter(prompt)
outputs = self.model.generate(**inputs,
max_new_tokens = max_new_tokens,
use_cache=True)
completion = self.tokenizer.batch_decode(outputs)
return completion
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
if data["inputs"] is not None:
request = data['inputs']
prediction = self.infer(request)
return {"prediction": prediction}
else:
return [{"Error" : "no input received."}]