from typing import List, Optional from pydantic import BaseModel import torch import logging from transformers import pipeline class F5ModelHandler: def __init__(self): logging.info("Initializing F5ModelHandler...") try: logging.info("Loading model 'google/flan-t5-small'...") self.model_name = "google/flan-t5-small" # Use pipeline for simpler model loading self.generator = pipeline( "text2text-generation", model=self.model_name, device="cuda" if torch.cuda.is_available() else "cpu", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) logging.info(f"Model loaded successfully on {self.generator.device}") except Exception as e: logging.error(f"Error loading model: {str(e)}") raise async def generate_response(self, prompt: str, max_length: int = 2048) -> str: try: logging.info(f"Generating response for prompt: {prompt[:100]}...") # Generate with more focused parameters response = self.generator( prompt, max_length=max_length, num_beams=5, temperature=0.7, top_p=0.95, top_k=50, repetition_penalty=1.2, length_penalty=1.0, do_sample=True, num_return_sequences=1 )[0]['generated_text'] # Clean up the response response = response.strip() # Ensure minimum content length if len(response) < 100: logging.warning("Response too short, regenerating...") return await self.generate_response(prompt, max_length) logging.info(f"Generated response successfully: {response[:100]}...") return response except Exception as e: logging.error(f"Error generating response: {str(e)}") raise async def stream_response(self, prompt: str, max_length: int = 1000): try: response = self.generator( prompt, max_length=max_length, num_beams=4, temperature=0.7, top_p=0.9, do_sample=True, return_full_text=False )[0]['generated_text'] # Simulate streaming by yielding chunks of the response chunk_size = 20 for i in range(0, len(response), chunk_size): chunk = response[i:i + chunk_size] yield chunk except Exception as e: logging.error(f"Error in stream_response: {str(e)}") raise # Initialize the model handler logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) f5_model = F5ModelHandler()