Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
import torch | |
from huggingface_hub import login | |
import os | |
import logging | |
login(token = os.getenv('HF_TOKEN')) | |
class Model(torch.nn.Module): | |
number_of_models = 0 | |
__model_list__ = [ | |
"Qwen/Qwen2-1.5B-Instruct", | |
"lmsys/vicuna-7b-v1.5", | |
"google-t5/t5-large", | |
"mistralai/Mistral-7B-Instruct-v0.1", | |
"meta-llama/Meta-Llama-3.1-8B-Instruct" | |
] | |
def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None: | |
super(Model, self).__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.name = model_name | |
logging.info(f'start loading model {self.name}') | |
if model_name == "google-t5/t5-large": | |
# For T5 or any other Seq2Seq model | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
else: | |
# For GPT-like models or other causal language models | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
logging.info(f'Loaded model {self.name}') | |
self.model.eval() | |
self.update() | |
def update(cls): | |
cls.number_of_models += 1 | |
def return_mode_name(self): | |
return self.name | |
def return_tokenizer(self): | |
return self.tokenizer | |
def return_model(self): | |
return self.model | |
def streaming(self, content_list, temp=0.001, max_length=500): | |
# Convert list of texts to input IDs | |
input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device) | |
# Set up the initial generation parameters | |
gen_kwargs = { | |
"input_ids": input_ids, | |
"do_sample": True, | |
"temperature": temp, | |
"eos_token_id": self.tokenizer.eos_token_id, | |
"max_new_tokens": 1, # Generate one token at a time | |
"return_dict_in_generate": True, | |
"output_scores": True | |
} | |
# Generate and yield tokens one by one | |
generated_tokens = 0 | |
batch_size = input_ids.shape[0] | |
active_sequences = torch.arange(batch_size) | |
while generated_tokens < max_length and len(active_sequences) > 0: | |
with torch.no_grad(): | |
output = self.model.generate(**gen_kwargs) | |
next_tokens = output.sequences[:, -1].unsqueeze(-1) | |
# Yield the newly generated tokens for each sequence in the batch | |
for i, token in zip(active_sequences, next_tokens): | |
yield i, self.tokenizer.decode(token[0], skip_special_tokens=True) | |
# Update input_ids for the next iteration | |
gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1) | |
generated_tokens += 1 | |
# Check for completed sequences | |
completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1) | |
active_sequences = torch.tensor([i for i in active_sequences if i not in completed]) | |
if len(active_sequences) > 0: | |
gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences] | |
def gen(self, content_list, temp=0.001, max_length=500): | |
# Convert list of texts to input IDs | |
input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device) | |
# Non-streaming generation (unchanged) | |
outputs = self.model.generate( | |
input_ids, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=temp, | |
eos_token_id=self.tokenizer.eos_token_id, | |
) | |
return self.tokenizer.batch_decode(outputs[:, input_ids.shape[1]:], skip_special_tokens=True) |