import faiss import numpy as np from FlagEmbedding import FlagModel from flask import Flask, request, jsonify from datasets import load_dataset import gradio as gr import os import time from functools import lru_cache # Initialize components app = Flask(__name__) model = None index = None corpus = None def initialize_components(): global model, index, corpus # Load model with safety checks if model is None: model = FlagModel( "BAAI/bge-large-en-v1.5", query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", use_fp16=True ) # Load corpus from Hugging Face dataset if corpus is None: dataset = load_dataset("awinml/medrag_corpus_sampled", split='train') corpus = [f"{row['id']}\t{row['contents']}" for row in dataset] # Create FAISS index in memory if index is None: embeddings = model.encode([doc.split('\t', 1)[1] for doc in corpus]) dimension = embeddings.shape[1] index = faiss.IndexFlatIP(dimension) index.add(embeddings.astype('float32')) @app.route("/retrieve", methods=["POST"]) def retrieve(): start_time = time.time() # Validate request data = request.json if not data or "queries" not in data: return jsonify({"error": "Missing 'queries' parameter"}), 400 # Initialize components if needed initialize_components() # Process queries queries = data["queries"] topk = data.get("topk", 3) return_scores = data.get("return_scores", False) # Batch processing query_embeddings = model.encode_queries(queries) scores, indices = index.search(query_embeddings.astype('float32'), topk) # Format results results = [] for i, query in enumerate(queries): query_results = [] for j in range(topk): doc_idx = indices[i][j] doc = corpus[doc_idx] doc_id, content = doc.split('\t', 1) result = { "document": { "id": doc_id, "contents": content }, "score": float(scores[i][j]) } query_results.append(result) results.append(query_results) return jsonify({ "result": results, "time": f"{time.time() - start_time:.2f}s" }) # Gradio UI for testing def gradio_interface(query, topk): response = requests.post( "http://localhost:7860/retrieve", json={"queries": [query], "topk": topk} ) return response.json()["result"][0] # Start server if __name__ == "__main__": # First-time initialization initialize_components() # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(label="Medical Query", placeholder="Enter your medical question..."), gr.Slider(1, 10, value=3, label="Top Results") ], outputs=gr.JSON(label="Retrieval Results"), title="Medical Retrieval System", description="Search across medical literature using AI-powered semantic search" ) # Run both Flask and Gradio iface.launch(server_name="0.0.0.0", server_port=7860, share=True)