Hjgugugjhuhjggg's picture
Update app.py
578a326 verified
raw
history blame
6.53 kB
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
import uvicorn
from fastapi import FastAPI, HTTPException
import os
from dotenv import load_dotenv
from pydantic import BaseModel
import logging
import torch
from nltk.tokenize import sent_tokenize
from difflib import SequenceMatcher
import nltk
import spaces
from faker import Faker
import gradio as gr
from threading import Thread
# Download NLTK resources
nltk.download('punkt')
nltk.download('stopwords')
# Load environment variables from .env file
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
fake = Faker()
# Global data structure to hold models and configurations
global_data = {
'models': {},
'tokens': {
'eos': '<|end_of-text|>',
'pad': '<pad>',
'padding': '<pad>',
'unk': '<unk>',
'bos': '<|begin_of_text|>',
'sep': '<|sep|>',
'cls': '<|cls|>',
'mask': '<mask>',
'eot': '<|eot_id|>',
'eom': '<|eom_id|>',
'lf': '<|0x0A|>'
},
'model_metadata': {},
'tokenizers': {},
'model_params': {},
}
# Model configurations
model_configs = [
{"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-70B-Q2_K-GGUF", "filename": "meta-llama-3.1-70b-q2_k.gguf", "name": "meta-llama-3.1-70b", "seed": 42, "n_ctx": 1024}
]
# Function to load model
def load_model(model_config):
model_name = model_config['name']
if model_name not in global_data['models']:
try:
# Explicitly check if GPU (CUDA) is available
device = "cuda" if torch.cuda.is_available() else "cpu"
context_params = {
"seed": model_config.get('seed', 42),
"n_ctx": model_config.get('n_ctx', 1024)
}
# Initialize model
model = Llama.from_pretrained(
repo_id=model_config['repo_id'],
filename=model_config['filename'],
use_auth_token=HUGGINGFACE_TOKEN,
verbose=True,
device=device,
context_params=context_params
)
global_data['models'][model_name] = model
logging.info(f"Model '{model_name}' loaded successfully.")
return model
except Exception as e:
logging.critical(f"CRITICAL ERROR loading model '{model_name}': {e}", exc_info=True)
return None
# Load all models at the start
for config in model_configs:
load_model(config)
# Pydantic model to validate incoming requests
class ChatRequest(BaseModel):
message: str
# Function to normalize input text
def normalize_input(input_text):
return input_text.strip()
# Function to remove duplicate sentences
def remove_duplicates(text, similarity_threshold=0.85):
sentences = sent_tokenize(text)
unique_sentences = []
for i, sentence in enumerate(sentences):
is_duplicate = False
for j, prev_sentence in enumerate(unique_sentences):
similarity = SequenceMatcher(None, sentence, prev_sentence).ratio()
if similarity >= similarity_threshold:
is_duplicate = True
break
if not is_duplicate:
unique_sentences.append(sentence)
return " ".join(unique_sentences)
# Function to handle model response generation with GPU fallback
@spaces.GPU(duration=0)
def generate_model_response(model, inputs, model_config):
try:
if model is None:
return []
responses = []
model_metadata = global_data['model_metadata'].get(model_config['name'], {})
stop_tokens = [global_data['tokens'].get('eos', '<|end_of_text|>')]
try:
# Try running on GPU, if GPU errors occur, fall back to CPU
response = model(inputs, stop=stop_tokens)
except torch.cuda.OutOfMemoryError as e:
logging.warning("GPU memory exceeded, switching to CPU.")
response = model(inputs, stop=stop_tokens, device='cpu')
except gradio.exceptions.Error as e:
logging.warning(f"Gradio GPU task failed: {e}, switching to CPU.")
response = model(inputs, stop=stop_tokens, device='cpu')
except Exception as e:
logging.warning(f"GPU task failed: {e}, switching to CPU.")
response = model(inputs, stop=stop_tokens, device='cpu')
# Check if the response is valid
if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]:
logging.error(f"Invalid model response format: {response}")
return [f"Error: Invalid model response format for '{model_config['name']}'."]
return [remove_duplicates(response['choices'][0]['text'])]
except Exception as e:
logging.critical(f"Error in generate_model_response: {e}", exc_info=True)
return []
# FastAPI app
app = FastAPI()
# FastAPI POST endpoint to handle chat requests
@app.post("/chat")
async def chat(request: ChatRequest):
input_text = normalize_input(request.message)
model_name = "meta-llama-3.1-70b"
model_instance = global_data['models'].get(model_name, None)
if model_instance is None:
raise HTTPException(status_code=500, detail="Model not found")
response = generate_model_response(model_instance, input_text, model_configs[0])
return {"response": response[0] if response else "No response generated."}
# Gradio interface for model testing
def gradio_interface(input_text):
model_name = "meta-llama-3.1-70b"
model_instance = global_data['models'].get(model_name, None)
if model_instance is None:
return "Model not found"
response = generate_model_response(model_instance, input_text, model_configs[0])
return response[0] if response else "No response generated."
# Gradio interface setup
def start_gradio_interface():
gr.Interface(fn=gradio_interface, inputs="text", outputs="text").launch(share=True)
# Run Gradio in a separate thread
def start_gradio():
gradio_thread = Thread(target=start_gradio_interface)
gradio_thread.daemon = True # Ensures the thread will exit when the main program exits
gradio_thread.start()
# Start the Gradio interface
start_gradio()
# Run FastAPI app using uvicorn
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)