Spaces:
Runtime error
Runtime error
File size: 1,103 Bytes
914ad9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import gradio as gr
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors
# Load model and data
model = SentenceTransformer('models/ad_categorizer')
with open('data/listings.json') as f:
listings = json.load(f)
# Prepare embeddings
texts = [item['text'] for item in listings]
embeddings = model.encode(texts)
categories = [item['category'] for item in listings]
# Create search index
nn = NearestNeighbors(n_neighbors=1).fit(embeddings)
def categorize(text):
# Encode query
query_embedding = model.encode(text)
# Find nearest match
_, indices = nn.kneighbors([query_embedding])
best_match = listings[indices[0][0]]
return {
"category": best_match['category'],
"category_id": best_match['category_id'],
"similar_listing": best_match['text']
}
# Gradio interface
demo = gr.Interface(
fn=categorize,
inputs=gr.Textbox(label="Ad Listing"),
outputs=gr.JSON(label="Prediction"),
examples=json.load(open('data/test_cases.json'))
)
demo.launch() |