import os import logging import numpy as np from typing import List, Optional, Tuple import torch import gradio as gr import spaces from sentence_transformers import SentenceTransformer from langchain_community.vectorstores import FAISS from langchain.embeddings.base import Embeddings from gradio_client import Client import requests from tqdm import tqdm # Configuration DATABASE_DIR = "semantic_memory" QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat CHUNK_SIZE = 800 TOP_K_RESULTS = 150 SIMILARITY_THRESHOLD = 0.4 PASSWORD_HASH = "abc12345" # Replace with hashed password in production BASE_SYSTEM_PROMPT = """ Répondez en français selon ces règles : 1. Utilisez EXCLUSIVEMENT le contexte fourni 2. Structurez la réponse en : - Définition principale - Caractéristiques clés (3 points maximum) - Relations avec d'autres concepts 3. Si aucune information pertinente, indiquez-le clairement Contexte : {context} """ # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("mtc_chat.log"), logging.StreamHandler() ] ) class LocalEmbeddings(Embeddings): """Local sentence-transformers embeddings""" def __init__(self, model): self.model = model def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = [] for text in tqdm(texts, desc="Creating embeddings"): embeddings.append(self.model.encode(text).tolist()) return embeddings def embed_query(self, text: str) -> List[float]: return self.model.encode(text).tolist() def split_text_into_chunks(text: str) -> List[str]: """Split text with overlap and sentence preservation""" chunks = [] start = 0 text_length = len(text) while start < text_length: end = min(start + CHUNK_SIZE, text_length) chunk = text[start:end] # Find last complete punctuation last_punct = max( chunk.rfind('.'), chunk.rfind('!'), chunk.rfind('?'), chunk.rfind('\n\n') ) if last_punct != -1 and (end - start) > CHUNK_SIZE // 2: end = start + last_punct + 1 chunks.append(text[start:end].strip()) start = end if end > start else start + CHUNK_SIZE return chunks def initialize_vector_store(embeddings: Embeddings, db_name: str) -> FAISS: """Initialize or load a FAISS vector store""" db_path = os.path.join(DATABASE_DIR, db_name) if os.path.exists(db_path): try: logging.info(f"Loading existing database: {db_name}") return FAISS.load_local( db_path, embeddings, allow_dangerous_deserialization=True ) except Exception as e: logging.error(f"FAISS load error: {str(e)}") raise logging.info(f"Creating new vector database: {db_name}") os.makedirs(db_path, exist_ok=True) return None def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> str: """Create a new FAISS database from uploaded file""" if password != PASSWORD_HASH: return "Incorrect password. Database creation failed." if not file_content.strip(): return "Uploaded file is empty. Database creation failed." if not db_name.isalnum(): return "Database name must be alphanumeric. Database creation failed." try: db_path = os.path.join(DATABASE_DIR, db_name) if os.path.exists(db_path): return f"Database '{db_name}' already exists." # Initialize embeddings and split text chunks = split_text_into_chunks(file_content) if not chunks: return "No valid chunks generated. Database creation failed." logging.info(f"Creating {len(chunks)} chunks...") progress(0, desc="Starting embedding process...") # Create embeddings with progress tracking embeddings_list = [] for i, chunk in enumerate(chunks): progress(i / len(chunks), desc=f"Embedding chunk {i+1}/{len(chunks)}") embeddings_list.append(embeddings.embed_query(chunk)) # Create FAISS database vector_store = FAISS.from_embeddings( text_embeddings=list(zip(chunks, embeddings_list)), embedding=embeddings ) vector_store.save_local(db_path) logging.info(f"Vector store '{db_name}' initialized successfully") return f"Database '{db_name}' created successfully." except Exception as e: logging.error(f"Database creation failed: {str(e)}") return f"Error creating database: {str(e)}" def generate_response(user_input: str, db_name: str) -> Optional[str]: """Generate response using Qwen2.5 MAX""" try: db_path = os.path.join(DATABASE_DIR, db_name) if not os.path.exists(db_path): return f"Database '{db_name}' does not exist." vector_store = FAISS.load_local( db_path, embeddings, allow_dangerous_deserialization=True ) # Contextual search docs_scores = vector_store.similarity_search_with_score( user_input, k=TOP_K_RESULTS * 3 ) # Filter results filtered_docs = [ (doc, score) for doc, score in docs_scores if score < SIMILARITY_THRESHOLD ] filtered_docs.sort(key=lambda x: x[1]) if not filtered_docs: return "Aucune correspondance trouvée. Essayez des termes plus spécifiques." best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]] # Build context context = "\n".join( f"=== Source {i+1} ===\n{doc.page_content}\n" for i, doc in enumerate(best_docs) ) # Call Qwen API client = Client(QWEN_API_URL, verbose=False) response = client.predict( query=user_input, history=[], system=BASE_SYSTEM_PROMPT.format(context=context), api_name="/model_chat" ) # Extract response if isinstance(response, tuple) and len(response) >= 2: chat_history = response[1] if chat_history and len(chat_history[-1]) >= 2: return chat_history[-1][1] return "Réponse indisponible - Veuillez reformuler votre question." except Exception as e: logging.error(f"Generation error: {str(e)}", exc_info=True) return None # Initialize models and vector store device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True) embeddings = LocalEmbeddings(model) # Gradio interface with gr.Blocks() as app: gr.Markdown("# Local Tech Knowledge Assistant") with gr.Tab("Create Database"): gr.Markdown("## Create a New FAISS Database") file_input = gr.File(label="Upload .txt File") db_name_input = gr.Textbox(label="Enter Desired Database Name (Alphanumeric Only)") password_input = gr.Textbox(label="Enter Password", type="password") create_output = gr.Textbox(label="Status") create_button = gr.Button("Create Database") def handle_create(file, db_name, password, progress=gr.Progress()): if not file or not db_name or not password: return "Please provide all required inputs." # Check if the file is valid if isinstance(file, str): # Gradio provides the file path as a string try: with open(file, "r", encoding="utf-8") as f: file_content = f.read() except Exception as e: return f"Error reading file: {str(e)}" else: return "Invalid file format. Please upload a .txt file." return create_new_database(file_content, db_name, password, progress) create_button.click( handle_create, inputs=[file_input, db_name_input, password_input], outputs=create_output ) with gr.Tab("Chat with Database"): gr.Markdown("## Chat with Existing Databases") db_select = gr.Dropdown(choices=[], label="Select Database") chatbot = gr.Chatbot(height=500) msg = gr.Textbox(label="Votre question") clear = gr.ClearButton([msg, chatbot]) def update_db_list(): if not os.path.exists(DATABASE_DIR): return [] return [name for name in os.listdir(DATABASE_DIR) if os.path.isdir(os.path.join(DATABASE_DIR, name))] def chat_response(message: str, db_name: str, history: List[Tuple[str, str]]): response = generate_response(message, db_name) return "", history + [(message, response or "Erreur de génération - Veuillez réessayer.")] msg.submit( chat_response, inputs=[msg, db_select, chatbot], outputs=[msg, chatbot], queue=True ) # Update database list on page load db_select.choices = update_db_list() if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)