EL GHAFRAOUI AYOUB
C'
6f14d8b
raw
history blame
3.03 kB
from typing import List, Optional
from pydantic import BaseModel
import torch
import logging
from transformers import pipeline
class F5ModelHandler:
def __init__(self):
logging.info("Initializing F5ModelHandler...")
try:
logging.info("Loading model 'google/flan-t5-small'...")
self.model_name = "google/flan-t5-small"
# Use pipeline for simpler model loading
self.generator = pipeline(
"text2text-generation",
model=self.model_name,
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
logging.info(f"Model loaded successfully on {self.generator.device}")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
async def generate_response(self, prompt: str, max_length: int = 2048) -> str:
try:
logging.info(f"Generating response for prompt: {prompt[:100]}...")
# Generate with more focused parameters
response = self.generator(
prompt,
max_length=max_length,
num_beams=5,
temperature=0.7,
top_p=0.95,
top_k=50,
repetition_penalty=1.2,
length_penalty=1.0,
do_sample=True,
num_return_sequences=1
)[0]['generated_text']
# Clean up the response
response = response.strip()
# Ensure minimum content length
if len(response) < 100:
logging.warning("Response too short, regenerating...")
return await self.generate_response(prompt, max_length)
logging.info(f"Generated response successfully: {response[:100]}...")
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
raise
async def stream_response(self, prompt: str, max_length: int = 1000):
try:
response = self.generator(
prompt,
max_length=max_length,
num_beams=4,
temperature=0.7,
top_p=0.9,
do_sample=True,
return_full_text=False
)[0]['generated_text']
# Simulate streaming by yielding chunks of the response
chunk_size = 20
for i in range(0, len(response), chunk_size):
chunk = response[i:i + chunk_size]
yield chunk
except Exception as e:
logging.error(f"Error in stream_response: {str(e)}")
raise
# Initialize the model handler
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
f5_model = F5ModelHandler()