mpt-7b / handler.py
Itamarl's picture
Update handler.py
a22dfe7
raw
history blame
1.4 kB
from typing import Dict, List, Any
import transformers
import torch
from datetime import datetime
class EndpointHandler():
def __init__(self, path=""):
print(f"Hugging face handler path {path}")
path = 'mosaicml/mpt-7b'
self.model = transformers.AutoModelForCausalLM.from_pretrained(path,
#"/Users/itamarlevi/Downloads/my_repo_hf/hf/mpt-7b/venv/Itamarl/test",
# 'mosaicml/mpt-7b-instruct',
# 'mosaicml/mpt-7b',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
max_seq_len=2048
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
print("tokenizer created ", datetime.now())
self.generate_text = transformers.pipeline(
model=self.model,
tokenizer=self.tokenizer,
task='text-generation',
return_full_text=True,
temperature=0.1,
top_p=0.15,
top_k=0,
# max_new_tokens=64,
repetition_penalty=1.1
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
print("iiiiiiiiii " data)
inputs = data.pop("inputs",data)
print(inputs)
res = self.generate_text("Explain to me the difference between nuclear fission and fusion." , max_length= 60)
return res