Spaces:
Running
Running
File size: 853 Bytes
2319518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
from models.base import HFModel
class LLM(HFModel):
def __init__(self, model_path):
super().__init__(model_path)
def generate(self, input_text, stop_words=[], max_new_tokens=512):
if isinstance(input_text, str):
input_text = [input_text]
input_ids = self.tokenizer(input_text)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
gen_kwargs = {'max_new_tokens': max_new_tokens, 'do_sample': False}
outputs = self.model.generate(input_ids, **gen_kwargs)
s = outputs[0][input_ids.shape[1]:]
output = self.tokenizer.decode(s, skip_special_tokens=True)
for stop_str in stop_words:
idx = output.find(stop_str)
if idx != -1:
output = output[:idx + len(stop_str)]
return output
|