|
from utilities.setup import * |
|
|
|
import json |
|
import os |
|
|
|
from typing import Dict, List, Any |
|
from peft import AutoPeftModelForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
"""Initialize class. Load model of interest upon init.""" |
|
|
|
|
|
print("Reading config") |
|
self.path = path |
|
self.HF_TOKEN = os.getenv("HF_TOKEN") |
|
self.wd = os.getcwd() |
|
self.model_name = os.path.basename(self.wd) |
|
|
|
print("loading model") |
|
self.model, self.tokenizer = self.load_model() |
|
|
|
|
|
def load_model(self): |
|
"""Load unsloth model and tokenizer""" |
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
self.path, |
|
load_in_4bit = True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(self.path) |
|
|
|
return model, tokenizer |
|
|
|
|
|
def prompt_formatter(self, prompt): |
|
"""Prompts must be formatted in alpaca style prior to API.""" |
|
inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda") |
|
|
|
return inputs, prompt |
|
|
|
|
|
def infer(self, prompt, max_new_tokens=1000): |
|
"""Bringing it all together""" |
|
|
|
inputs, prompt_text = self.prompt_formatter(prompt) |
|
outputs = self.model.generate(**inputs, |
|
max_new_tokens = max_new_tokens, |
|
use_cache=True) |
|
completion = self.tokenizer.batch_decode(outputs) |
|
|
|
return completion |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
if data["inputs"] is not None: |
|
request = data['inputs'] |
|
|
|
prediction = self.infer(request) |
|
return {"prediction": prediction} |
|
else: |
|
return [{"Error" : "no input received."}] |