File size: 6,534 Bytes
ff4f294
543d52b
 
c875a14
543d52b
 
 
 
 
 
 
 
 
 
7455667
543d52b
ff4f294
c875a14
543d52b
 
7455667
c875a14
543d52b
 
 
c875a14
543d52b
 
 
 
c875a14
543d52b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c875a14
543d52b
578a326
543d52b
 
c875a14
543d52b
 
 
 
7864d87
543d52b
 
 
 
 
e22bf4e
543d52b
c875a14
543d52b
 
 
 
 
 
c875a14
543d52b
 
 
 
 
 
 
 
 
c875a14
543d52b
 
 
c875a14
543d52b
 
 
 
 
 
 
c875a14
543d52b
 
 
 
 
 
 
 
 
 
 
 
 
 
c875a14
543d52b
 
 
 
 
c875a14
543d52b
 
 
 
 
7864d87
543d52b
 
 
 
 
 
 
 
 
 
 
7864d87
543d52b
 
 
 
 
e22bf4e
543d52b
 
 
 
 
 
c875a14
543d52b
 
 
 
 
 
 
 
 
 
 
c875a14
543d52b
 
 
 
 
 
 
 
c875a14
543d52b
 
 
c875a14
 
 
 
 
 
 
 
543d52b
c875a14
543d52b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
import uvicorn
from fastapi import FastAPI, HTTPException
import os
from dotenv import load_dotenv
from pydantic import BaseModel
import logging
import torch
from nltk.tokenize import sent_tokenize
from difflib import SequenceMatcher
import nltk
import spaces
from faker import Faker
import gradio as gr
from threading import Thread

# Download NLTK resources
nltk.download('punkt')
nltk.download('stopwords')

# Load environment variables from .env file
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')

fake = Faker()

# Global data structure to hold models and configurations
global_data = {
    'models': {},
    'tokens': {
        'eos': '<|end_of-text|>',
        'pad': '<pad>',
        'padding': '<pad>',
        'unk': '<unk>',
        'bos': '<|begin_of_text|>',
        'sep': '<|sep|>',
        'cls': '<|cls|>',
        'mask': '<mask>',
        'eot': '<|eot_id|>',
        'eom': '<|eom_id|>',
        'lf': '<|0x0A|>'
    },
    'model_metadata': {},
    'tokenizers': {},
    'model_params': {},
}

# Model configurations
model_configs = [
    {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-70B-Q2_K-GGUF", "filename": "meta-llama-3.1-70b-q2_k.gguf", "name": "meta-llama-3.1-70b", "seed": 42, "n_ctx": 1024}
]

# Function to load model
def load_model(model_config):
    model_name = model_config['name']
    if model_name not in global_data['models']:
        try:
            # Explicitly check if GPU (CUDA) is available
            device = "cuda" if torch.cuda.is_available() else "cpu"
            
            context_params = {
                "seed": model_config.get('seed', 42),
                "n_ctx": model_config.get('n_ctx', 1024)
            }

            # Initialize model
            model = Llama.from_pretrained(
                repo_id=model_config['repo_id'],
                filename=model_config['filename'],
                use_auth_token=HUGGINGFACE_TOKEN,
                verbose=True,
                device=device,
                context_params=context_params
            )
            
            global_data['models'][model_name] = model
            logging.info(f"Model '{model_name}' loaded successfully.")
            return model
        except Exception as e:
            logging.critical(f"CRITICAL ERROR loading model '{model_name}': {e}", exc_info=True)
            return None

# Load all models at the start
for config in model_configs:
    load_model(config)

# Pydantic model to validate incoming requests
class ChatRequest(BaseModel):
    message: str

# Function to normalize input text
def normalize_input(input_text):
    return input_text.strip()

# Function to remove duplicate sentences
def remove_duplicates(text, similarity_threshold=0.85):
    sentences = sent_tokenize(text)
    unique_sentences = []
    for i, sentence in enumerate(sentences):
        is_duplicate = False
        for j, prev_sentence in enumerate(unique_sentences):
            similarity = SequenceMatcher(None, sentence, prev_sentence).ratio()
            if similarity >= similarity_threshold:
                is_duplicate = True
                break
        if not is_duplicate:
            unique_sentences.append(sentence)
    return " ".join(unique_sentences)

# Function to handle model response generation with GPU fallback
@spaces.GPU(duration=0)
def generate_model_response(model, inputs, model_config):
    try:
        if model is None:
            return []

        responses = []
        model_metadata = global_data['model_metadata'].get(model_config['name'], {})
        stop_tokens = [global_data['tokens'].get('eos', '<|end_of_text|>')]

        try:
            # Try running on GPU, if GPU errors occur, fall back to CPU
            response = model(inputs, stop=stop_tokens)
        except torch.cuda.OutOfMemoryError as e:
            logging.warning("GPU memory exceeded, switching to CPU.")
            response = model(inputs, stop=stop_tokens, device='cpu')
        except gradio.exceptions.Error as e:
            logging.warning(f"Gradio GPU task failed: {e}, switching to CPU.")
            response = model(inputs, stop=stop_tokens, device='cpu')
        except Exception as e:
            logging.warning(f"GPU task failed: {e}, switching to CPU.")
            response = model(inputs, stop=stop_tokens, device='cpu')

        # Check if the response is valid
        if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]:
            logging.error(f"Invalid model response format: {response}")
            return [f"Error: Invalid model response format for '{model_config['name']}'."]
        
        return [remove_duplicates(response['choices'][0]['text'])]
    except Exception as e:
        logging.critical(f"Error in generate_model_response: {e}", exc_info=True)
        return []

# FastAPI app
app = FastAPI()

# FastAPI POST endpoint to handle chat requests
@app.post("/chat")
async def chat(request: ChatRequest):
    input_text = normalize_input(request.message)
    model_name = "meta-llama-3.1-70b"
    model_instance = global_data['models'].get(model_name, None)
    if model_instance is None:
        raise HTTPException(status_code=500, detail="Model not found")
    
    response = generate_model_response(model_instance, input_text, model_configs[0])
    return {"response": response[0] if response else "No response generated."}

# Gradio interface for model testing
def gradio_interface(input_text):
    model_name = "meta-llama-3.1-70b"
    model_instance = global_data['models'].get(model_name, None)
    if model_instance is None:
        return "Model not found"
    response = generate_model_response(model_instance, input_text, model_configs[0])
    return response[0] if response else "No response generated."

# Gradio interface setup
def start_gradio_interface():
    gr.Interface(fn=gradio_interface, inputs="text", outputs="text").launch(share=True)

# Run Gradio in a separate thread
def start_gradio():
    gradio_thread = Thread(target=start_gradio_interface)
    gradio_thread.daemon = True  # Ensures the thread will exit when the main program exits
    gradio_thread.start()

# Start the Gradio interface
start_gradio()

# Run FastAPI app using uvicorn
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)