Chris-lab / utils /model.py
kz209
update
1921336
raw
history blame
3.1 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextStreamer
import transformers
import torch
from huggingface_hub import login
import os
import logging
login(token = os.getenv('HF_TOKEN'))
class Model(torch.nn.Module):
number_of_models = 0
__model_list__ = [
"Qwen/Qwen2-1.5B-Instruct",
"lmsys/vicuna-7b-v1.5",
"google-t5/t5-large",
"mistralai/Mistral-7B-Instruct-v0.1",
"meta-llama/Meta-Llama-3.1-8B-Instruct"
]
def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
super(Model, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.name = model_name
logging.info(f'start loading model {self.name}')
if model_name == "google-t5/t5-large":
# For T5 or any other Seq2Seq model
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
else:
# For GPT-like models or other causal language models
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
)
logging.info(f'Loaded model {self.name}')
self.update()
@classmethod
def update(cls):
cls.number_of_models += 1
def return_mode_name(self):
return self.name
def return_tokenizer(self):
return self.tokenizer
def return_model(self):
return self.pipeline
def gen(self, content_list, temp=0.1, max_length=500, streaming=False):
# Convert list of texts to input IDs
input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
if streaming:
# Prepare streamers for each input
streamers = [TextStreamer(self.tokenizer, skip_prompt=True) for _ in content_list]
# Stream the output token by token for each input text
for i, streamer in enumerate(streamers):
for output in self.model.generate(
input_ids[i].unsqueeze(0), # Process each input separately
max_new_tokens=max_length,
do_sample=True,
temperature=temp,
eos_token_id=self.tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
streamer=streamer):
pass # TextStreamer automatically handles the streaming, no need to manually handle the output
else:
outputs = self.model.generate(
input_ids,
max_new_tokens=max_length,
do_sample=True,
temperature=temp,
eos_token_id=self.tokenizer.eos_token_id
)
return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]