Spaces:
Sleeping
Sleeping
from typing import List, Optional | |
from pydantic import BaseModel | |
import torch | |
import logging | |
from transformers import pipeline | |
class F5ModelHandler: | |
def __init__(self): | |
logging.info("Initializing F5ModelHandler...") | |
try: | |
logging.info("Loading model 'google/flan-t5-small'...") | |
self.model_name = "google/flan-t5-small" | |
# Use pipeline for simpler model loading | |
self.generator = pipeline( | |
"text2text-generation", | |
model=self.model_name, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
logging.info(f"Model loaded successfully on {self.generator.device}") | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
raise | |
async def generate_response(self, prompt: str, max_length: int = 2048) -> str: | |
try: | |
logging.info(f"Generating response for prompt: {prompt[:100]}...") | |
# Generate with more focused parameters | |
response = self.generator( | |
prompt, | |
max_length=max_length, | |
num_beams=5, | |
temperature=0.7, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.2, | |
length_penalty=1.0, | |
do_sample=True, | |
num_return_sequences=1 | |
)[0]['generated_text'] | |
# Clean up the response | |
response = response.strip() | |
# Ensure minimum content length | |
if len(response) < 100: | |
logging.warning("Response too short, regenerating...") | |
return await self.generate_response(prompt, max_length) | |
logging.info(f"Generated response successfully: {response[:100]}...") | |
return response | |
except Exception as e: | |
logging.error(f"Error generating response: {str(e)}") | |
raise | |
async def stream_response(self, prompt: str, max_length: int = 1000): | |
try: | |
response = self.generator( | |
prompt, | |
max_length=max_length, | |
num_beams=4, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
return_full_text=False | |
)[0]['generated_text'] | |
# Simulate streaming by yielding chunks of the response | |
chunk_size = 20 | |
for i in range(0, len(response), chunk_size): | |
chunk = response[i:i + chunk_size] | |
yield chunk | |
except Exception as e: | |
logging.error(f"Error in stream_response: {str(e)}") | |
raise | |
# Initialize the model handler | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
f5_model = F5ModelHandler() |