import os import sqlite3 import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from transformers import pipeline import gradio as gr class EmbeddingGenerator: def __init__(self, model_name="all-MiniLM-L6-v2", gen_model="distilgpt2", db_path="./embeddings.db"): self.model = SentenceTransformer(model_name) self.generator = pipeline("text-generation", model=gen_model) self.db_path = db_path self._initialize_db() print(f"Loaded embedding model: {model_name}") print(f"Loaded generative model: {gen_model}") def _initialize_db(self): self.conn = sqlite3.connect(self.db_path,check_same_thread=False) self.cursor = self.conn.cursor() self.cursor.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( filename TEXT PRIMARY KEY, content TEXT, embedding BLOB ) """) self.conn.commit() def generate_embedding(self, text): embedding = self.model.encode(text, convert_to_numpy=True) return embedding def ingest_files(self, directory): for filename in os.listdir(directory): if filename.endswith(".txt"): file_path = os.path.join(directory, filename) with open(file_path, 'r') as f: content = f.read() embedding = self.generate_embedding(content) self._store_embedding(filename, content, embedding) def _store_embedding(self, filename, content, embedding): self.cursor.execute("INSERT OR REPLACE INTO embeddings (filename, content, embedding) VALUES (?, ?, ?)", (filename, content, embedding.tobytes())) self.conn.commit() def load_embeddings(self): with sqlite3.connect(self.db_path, check_same_thread=False) as conn: cursor = conn.cursor() cursor.execute("SELECT filename, content, embedding FROM embeddings") rows = cursor.fetchall() documents = [{"filename": row[0], "content": row[1], "embedding": np.frombuffer(row[2], dtype=np.float32)} for row in rows] return documents def compute_similarity(self, query_embedding, document_embeddings): similarities = cosine_similarity([query_embedding], document_embeddings)[0] return similarities.tolist() def find_most_similar(self, query, top_k=5): query_embedding = self.generate_embedding(query) documents = self.load_embeddings() if not documents: return "No documents found in the database." document_embeddings = [doc["embedding"] for doc in documents] similarities = self.compute_similarity(query_embedding, document_embeddings) ranked_results = sorted( [{"filename": doc["filename"], "content": doc["content"][:100], "similarity": sim} for doc, sim in zip(documents, similarities)], key=lambda x: x["similarity"], reverse=True ) return ranked_results[:top_k] def generate_response(self, query, top_k_docs): context = " ".join(top_k_docs) input_text = f"Query: {query}\nContext: {context}\nAnswer:" response = self.generator(input_text, max_length=1000, num_return_sequences=1) return response[0]["generated_text"] def find_most_similar_and_generate(self, query, top_k=5): top_k_results = self.find_most_similar(query, top_k) if isinstance(top_k_results, str): # If no documents were found return top_k_results top_k_docs = [result["content"] for result in top_k_results] response = self.generate_response(query, top_k_docs) return response # Gradio Interface def search_and_generate(query): results = embedding_generator.find_most_similar(query, top_k=5) if isinstance(results, str): # Handle case where no docs are found return {"message": results}, "" # Return results as a dictionary top_k_docs = [result["content"] for result in results] response = embedding_generator.generate_response(query, top_k_docs) # Ensure results is a dictionary or list of dicts for JSON serialization results_dict = [{"filename": result["filename"], "similarity": result["similarity"]} for result in results] return results_dict, response if __name__ == "__main__": # Initialize the embedding generator embedding_generator = EmbeddingGenerator() # Ingest files (if not already ingested) embedding_generator.ingest_files("./data-sets/aclImdb/train/") # Gradio app with gr.Blocks() as demo: gr.Markdown("# Search and Generate with Embedding-Based Retrieval") with gr.Row(): query_input = gr.Textbox(label="Enter your query", placeholder="Ask something...") results_output = gr.JSON(label="Top Matches (with similarity scores)") response_output = gr.Textbox(label="Generated Response") search_btn = gr.Button("Search and Generate") search_btn.click( search_and_generate, inputs=[query_input], outputs=[results_output, response_output] ) demo.launch()