KoonJamesZ's picture
Update app.py
648540e verified
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import os
from sentence_transformers import SentenceTransformer
# Load the pre-trained embedding model
model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
# Set the max sequence length
model.max_seq_length = 512
# Load the JSON data into a DataFrame
df = pd.read_json('White-Stride-Red-68.json')
df['embeding_context'] = df['embeding_context'].astype(str).fillna('')
# Filter out any rows where 'embeding_context' might be empty or invalid
df = df[df['embeding_context'] != '']
index = faiss.read_index('vector_store.index')
# Function to perform search and return all columns
def search_query(query_text):
num_records = 50
# Encode the input query text
embeddings_query = model.encode([query_text], prompt_name="query")
embeddings_query_np = np.array(embeddings_query).astype('float32')
# Search in FAISS index for nearest neighbors
distances, indices = index.search(embeddings_query_np, num_records)
# Get the top results based on FAISS indices
result_df = df.iloc[indices[0]].drop(columns=['embeding_context']).drop_duplicates().reset_index(drop=True)
return result_df
# Gradio interface function
def gradio_interface(query_text):
search_results = search_query(query_text)
return search_results
with gr.Blocks() as app:
gr.Markdown("<h1>White Stride Red Search (GTE-Qwen2)</h1>")
# Input text box for the search query
search_input = gr.Textbox(label="Search Query", placeholder="Enter search text", interactive=True)
# Search button below the text box
search_button = gr.Button("Search")
# Output table for displaying results
search_output = gr.DataFrame(label="Search Results")
# Link button click to action
search_button.click(fn=gradio_interface, inputs=search_input, outputs=search_output)
# Launch the Gradio app
app.launch()