Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import pandas as pd | |
from torch.nn import Softmax | |
from torch import sort | |
# кэширование | |
def load_model(): | |
return AutoModelForSequenceClassification.from_pretrained('zhdantim/mydeberta-v3-small').eval() | |
# кэширование | |
def load_id2classes(): | |
return pd.read_csv('classes.tsv', sep='\t', index_col=0).to_dict()['classes'] | |
def load_tokenizer(): | |
return AutoTokenizer.from_pretrained('microsoft/deberta-v3-small') | |
model = load_model() | |
id2classes = load_id2classes() | |
tokenizer = load_tokenizer() | |
def get_top_classes(text): | |
tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
logits = model(**tokenized_text).logits.detach() | |
probs = Softmax()(logits) | |
probs_sorted, indices = sort(probs, descending=True) | |
k = 1 | |
while sum(probs_sorted[0, :k]) < 0.95: | |
k += 1 | |
return [id2classes[idx.item()] for idx in indices[0, :k]], probs_sorted[0, :k] | |
st.title("Простой классификатор статей") | |
title = st.text_input(label="Введите название статьи (обязательно)", value="Type Here ...") | |
abstract = st.text_input(label="Введите abstract", value="Type Here ...") | |
if st.button('Submit'): | |
if title.title() != 'Type Here ...': | |
if abstract.title() != 'Type Here ...': | |
text = title.title() + '\n' + abstract.title() | |
else: | |
text = title.title() | |
top_classes, probs = get_top_classes(text) | |
for p, cls in zip(probs, top_classes): | |
st.success(f'Статья относится к {cls} с вероятностью {p}') | |
else: | |
st.error('Введите название статьи') | |