Spaces:
Runtime error
Runtime error
from llama_cpp import Llama | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
import os | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
import logging | |
import torch | |
from nltk.tokenize import sent_tokenize | |
from difflib import SequenceMatcher | |
import nltk | |
import spaces | |
from faker import Faker | |
import gradio as gr | |
from threading import Thread | |
# Download NLTK resources | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
# Load environment variables from .env file | |
load_dotenv() | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s') | |
fake = Faker() | |
# Global data structure to hold models and configurations | |
global_data = { | |
'models': {}, | |
'tokens': { | |
'eos': '<|end_of-text|>', | |
'pad': '<pad>', | |
'padding': '<pad>', | |
'unk': '<unk>', | |
'bos': '<|begin_of_text|>', | |
'sep': '<|sep|>', | |
'cls': '<|cls|>', | |
'mask': '<mask>', | |
'eot': '<|eot_id|>', | |
'eom': '<|eom_id|>', | |
'lf': '<|0x0A|>' | |
}, | |
'model_metadata': {}, | |
'tokenizers': {}, | |
'model_params': {}, | |
} | |
# Model configurations | |
model_configs = [ | |
{"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-70B-Q2_K-GGUF", "filename": "meta-llama-3.1-70b-q2_k.gguf", "name": "meta-llama-3.1-70b", "seed": 42, "n_ctx": 1024} | |
] | |
# Function to load model | |
def load_model(model_config): | |
model_name = model_config['name'] | |
if model_name not in global_data['models']: | |
try: | |
# Explicitly check if GPU (CUDA) is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
context_params = { | |
"seed": model_config.get('seed', 42), | |
"n_ctx": model_config.get('n_ctx', 1024) | |
} | |
# Initialize model | |
model = Llama.from_pretrained( | |
repo_id=model_config['repo_id'], | |
filename=model_config['filename'], | |
use_auth_token=HUGGINGFACE_TOKEN, | |
verbose=True, | |
device=device, | |
context_params=context_params | |
) | |
global_data['models'][model_name] = model | |
logging.info(f"Model '{model_name}' loaded successfully.") | |
return model | |
except Exception as e: | |
logging.critical(f"CRITICAL ERROR loading model '{model_name}': {e}", exc_info=True) | |
return None | |
# Load all models at the start | |
for config in model_configs: | |
load_model(config) | |
# Pydantic model to validate incoming requests | |
class ChatRequest(BaseModel): | |
message: str | |
# Function to normalize input text | |
def normalize_input(input_text): | |
return input_text.strip() | |
# Function to remove duplicate sentences | |
def remove_duplicates(text, similarity_threshold=0.85): | |
sentences = sent_tokenize(text) | |
unique_sentences = [] | |
for i, sentence in enumerate(sentences): | |
is_duplicate = False | |
for j, prev_sentence in enumerate(unique_sentences): | |
similarity = SequenceMatcher(None, sentence, prev_sentence).ratio() | |
if similarity >= similarity_threshold: | |
is_duplicate = True | |
break | |
if not is_duplicate: | |
unique_sentences.append(sentence) | |
return " ".join(unique_sentences) | |
# Function to handle model response generation with GPU fallback | |
def generate_model_response(model, inputs, model_config): | |
try: | |
if model is None: | |
return [] | |
responses = [] | |
model_metadata = global_data['model_metadata'].get(model_config['name'], {}) | |
stop_tokens = [global_data['tokens'].get('eos', '<|end_of_text|>')] | |
try: | |
# Try running on GPU, if GPU errors occur, fall back to CPU | |
response = model(inputs, stop=stop_tokens) | |
except torch.cuda.OutOfMemoryError as e: | |
logging.warning("GPU memory exceeded, switching to CPU.") | |
response = model(inputs, stop=stop_tokens, device='cpu') | |
except gradio.exceptions.Error as e: | |
logging.warning(f"Gradio GPU task failed: {e}, switching to CPU.") | |
response = model(inputs, stop=stop_tokens, device='cpu') | |
except Exception as e: | |
logging.warning(f"GPU task failed: {e}, switching to CPU.") | |
response = model(inputs, stop=stop_tokens, device='cpu') | |
# Check if the response is valid | |
if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]: | |
logging.error(f"Invalid model response format: {response}") | |
return [f"Error: Invalid model response format for '{model_config['name']}'."] | |
return [remove_duplicates(response['choices'][0]['text'])] | |
except Exception as e: | |
logging.critical(f"Error in generate_model_response: {e}", exc_info=True) | |
return [] | |
# FastAPI app | |
app = FastAPI() | |
# FastAPI POST endpoint to handle chat requests | |
async def chat(request: ChatRequest): | |
input_text = normalize_input(request.message) | |
model_name = "meta-llama-3.1-70b" | |
model_instance = global_data['models'].get(model_name, None) | |
if model_instance is None: | |
raise HTTPException(status_code=500, detail="Model not found") | |
response = generate_model_response(model_instance, input_text, model_configs[0]) | |
return {"response": response[0] if response else "No response generated."} | |
# Gradio interface for model testing | |
def gradio_interface(input_text): | |
model_name = "meta-llama-3.1-70b" | |
model_instance = global_data['models'].get(model_name, None) | |
if model_instance is None: | |
return "Model not found" | |
response = generate_model_response(model_instance, input_text, model_configs[0]) | |
return response[0] if response else "No response generated." | |
# Gradio interface setup | |
def start_gradio_interface(): | |
gr.Interface(fn=gradio_interface, inputs="text", outputs="text").launch(share=True) | |
# Run Gradio in a separate thread | |
def start_gradio(): | |
gradio_thread = Thread(target=start_gradio_interface) | |
gradio_thread.daemon = True # Ensures the thread will exit when the main program exits | |
gradio_thread.start() | |
# Start the Gradio interface | |
start_gradio() | |
# Run FastAPI app using uvicorn | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |