|
import gradio as gr |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModel |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
MODELS = { |
|
"rubert-tiny2": "cointegrated/rubert-tiny2", |
|
"sbert": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
"LaBSE": "sentence-transformers/LaBSE", |
|
"ruRoberta": "sberbank-ai/ruRoberta-large" |
|
} |
|
|
|
PROMPT_TEMPLATES = { |
|
"basic": "Товар: {item}. Категория:", |
|
"examples": "Примеры:\n- Молоток → Инструменты\n- Морковь → Овощи\nТовар: {item} → ", |
|
"strict": "Выбери категорию из [{categories}]. Товар: {item}. Категория:" |
|
} |
|
|
|
def get_embeddings(model, tokenizer, text): |
|
inputs = tokenizer(text, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
max_length=512) |
|
outputs = model(**inputs) |
|
return outputs.last_hidden_state[:, 0].detach().numpy() |
|
|
|
def classify(model_name: str, prompt_type: str, item: str, categories: str) -> str: |
|
tokenizer = AutoTokenizer.from_pretrained(MODELS[model_name]) |
|
model = AutoModel.from_pretrained(MODELS[model_name]) |
|
|
|
|
|
prompt = PROMPT_TEMPLATES[prompt_type].format( |
|
item=item, |
|
categories=", ".join([c.strip() for c in categories.split(",")]) |
|
) |
|
|
|
|
|
item_embedding = get_embeddings(model, tokenizer, prompt) |
|
category_embeddings = [ |
|
get_embeddings(model, tokenizer, c.strip()) |
|
for c in categories.split(",") |
|
] |
|
|
|
|
|
similarities = cosine_similarity(item_embedding, np.vstack(category_embeddings))[0] |
|
best_idx = np.argmax(similarities) |
|
|
|
return f"{categories.split(',')[best_idx].strip()} ({similarities[best_idx]:.2f})" |
|
|
|
gr.Interface( |
|
fn=classify, |
|
inputs=[ |
|
gr.Dropdown(list(MODELS.keys()), label="Модель"), |
|
gr.Dropdown(list(PROMPT_TEMPLATES.keys()), label="Шаблон промпта"), |
|
gr.Textbox(label="Товар"), |
|
gr.Textbox(label="Категории", value="Инструменты, Овощи, Техника") |
|
], |
|
outputs=gr.Textbox() |
|
).launch() |