|
from typing import Dict, List, Any |
|
import transformers |
|
import torch |
|
from datetime import datetime |
|
|
|
|
|
class EndpointHandler(): |
|
|
|
def __init__(self, path=""): |
|
print(f"Hugging face handler path {path}") |
|
path = 'mosaicml/mpt-7b' |
|
self.model = transformers.AutoModelForCausalLM.from_pretrained(path, |
|
|
|
|
|
|
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
max_seq_len=2048 |
|
) |
|
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') |
|
print("tokenizer created ", datetime.now()) |
|
self.generate_text = transformers.pipeline( |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
task='text-generation', |
|
return_full_text=True, |
|
temperature=0.1, |
|
top_p=0.15, |
|
top_k=0, |
|
|
|
repetition_penalty=1.1 |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
print("iiiiiiiiii " data) |
|
inputs = data.pop("inputs",data) |
|
print(inputs) |
|
res = self.generate_text("Explain to me the difference between nuclear fission and fusion." , max_length= 60) |
|
return res |
|
|