f5_model_final / app /services /flan_t5_service.py
EL GHAFRAOUI AYOUB
C'
6f14d8b
raw
history blame contribute delete
744 Bytes
from transformers import T5ForConditionalGeneration, T5Tokenizer
class FlanT5Service:
def __init__(self):
self.model_name = "google/flan-t5-base"
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
async def generate_response(self, prompt: str, max_length: int = 512) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=4,
temperature=0.7,
top_p=0.9
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)