ad_categorizer / predict.py
win2win's picture
Create predict.py
914ad9a verified
import gradio as gr
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors
# Load model and data
model = SentenceTransformer('models/ad_categorizer')
with open('data/listings.json') as f:
listings = json.load(f)
# Prepare embeddings
texts = [item['text'] for item in listings]
embeddings = model.encode(texts)
categories = [item['category'] for item in listings]
# Create search index
nn = NearestNeighbors(n_neighbors=1).fit(embeddings)
def categorize(text):
# Encode query
query_embedding = model.encode(text)
# Find nearest match
_, indices = nn.kneighbors([query_embedding])
best_match = listings[indices[0][0]]
return {
"category": best_match['category'],
"category_id": best_match['category_id'],
"similar_listing": best_match['text']
}
# Gradio interface
demo = gr.Interface(
fn=categorize,
inputs=gr.Textbox(label="Ad Listing"),
outputs=gr.JSON(label="Prediction"),
examples=json.load(open('data/test_cases.json'))
)
demo.launch()