File size: 3,033 Bytes
6f14d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import List, Optional
from pydantic import BaseModel
import torch
import logging
from transformers import pipeline

class F5ModelHandler:
    def __init__(self):
        logging.info("Initializing F5ModelHandler...")
        try:
            logging.info("Loading model 'google/flan-t5-small'...")
            self.model_name = "google/flan-t5-small"
            # Use pipeline for simpler model loading
            self.generator = pipeline(
                "text2text-generation",
                model=self.model_name,
                device="cuda" if torch.cuda.is_available() else "cpu",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            logging.info(f"Model loaded successfully on {self.generator.device}")
        except Exception as e:
            logging.error(f"Error loading model: {str(e)}")
            raise
        
    async def generate_response(self, prompt: str, max_length: int = 2048) -> str:
        try:
            logging.info(f"Generating response for prompt: {prompt[:100]}...")
            
            # Generate with more focused parameters
            response = self.generator(
                prompt,
                max_length=max_length,
                num_beams=5,
                temperature=0.7,
                top_p=0.95,
                top_k=50,
                repetition_penalty=1.2,
                length_penalty=1.0,
                do_sample=True,
                num_return_sequences=1
            )[0]['generated_text']
            
            # Clean up the response
            response = response.strip()
            
            # Ensure minimum content length
            if len(response) < 100:
                logging.warning("Response too short, regenerating...")
                return await self.generate_response(prompt, max_length)
            
            logging.info(f"Generated response successfully: {response[:100]}...")
            return response
            
        except Exception as e:
            logging.error(f"Error generating response: {str(e)}")
            raise

    async def stream_response(self, prompt: str, max_length: int = 1000):
        try:
            response = self.generator(
                prompt,
                max_length=max_length,
                num_beams=4,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                return_full_text=False
            )[0]['generated_text']
            
            # Simulate streaming by yielding chunks of the response
            chunk_size = 20
            for i in range(0, len(response), chunk_size):
                chunk = response[i:i + chunk_size]
                yield chunk

        except Exception as e:
            logging.error(f"Error in stream_response: {str(e)}")
            raise

# Initialize the model handler
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

f5_model = F5ModelHandler()