|
from typing import Dict, List, Any |
|
import transformers |
|
import torch |
|
from datetime import datetime |
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
from transformers.utils import logging |
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger("transformers") |
|
|
|
|
|
class EndpointHandler(): |
|
|
|
def __init__(self, path=""): |
|
logger.info("111111111111111111111111111") |
|
logger.info(f"Hugging face handler path {path}") |
|
path = 'mosaicml/mpt-7b-instruct' |
|
|
|
self.model = transformers.AutoModelForCausalLM.from_pretrained(path, |
|
|
|
|
|
|
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
max_seq_len=32000 |
|
) |
|
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') |
|
print("tokenizer created ", datetime.now()) |
|
|
|
|
|
stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): |
|
for stop_id in stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
stopping_criteria = StoppingCriteriaList([StopOnTokens()]) |
|
|
|
self.generate_text = transformers.pipeline( |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
stopping_criteria=stopping_criteria, |
|
task='text-generation', |
|
return_full_text=True, |
|
temperature=0.1, |
|
top_p=0.15, |
|
top_k=0, |
|
max_new_tokens=2048, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
logger.info(f"iiinnnnnnnnnn {data}") |
|
inputs = data.pop("inputs",data) |
|
logger.info(f"iiinnnnnnnnnnbbbbbb {inputs}") |
|
res = self.generate_text(inputs) |
|
return res |
|
|