Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.neighbors import NearestNeighbors | |
# Load model and data | |
model = SentenceTransformer('models/ad_categorizer') | |
with open('data/listings.json') as f: | |
listings = json.load(f) | |
# Prepare embeddings | |
texts = [item['text'] for item in listings] | |
embeddings = model.encode(texts) | |
categories = [item['category'] for item in listings] | |
# Create search index | |
nn = NearestNeighbors(n_neighbors=1).fit(embeddings) | |
def categorize(text): | |
# Encode query | |
query_embedding = model.encode(text) | |
# Find nearest match | |
_, indices = nn.kneighbors([query_embedding]) | |
best_match = listings[indices[0][0]] | |
return { | |
"category": best_match['category'], | |
"category_id": best_match['category_id'], | |
"similar_listing": best_match['text'] | |
} | |
# Gradio interface | |
demo = gr.Interface( | |
fn=categorize, | |
inputs=gr.Textbox(label="Ad Listing"), | |
outputs=gr.JSON(label="Prediction"), | |
examples=json.load(open('data/test_cases.json')) | |
) | |
demo.launch() |