Zhdantim's picture
Update app.py
283aad9 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
from torch.nn import Softmax
from torch import sort
@st.cache_resource # кэширование
def load_model():
return AutoModelForSequenceClassification.from_pretrained('zhdantim/mydeberta-v3-small').eval()
@st.cache_resource # кэширование
def load_id2classes():
return pd.read_csv('classes.tsv', sep='\t', index_col=0).to_dict()['classes']
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
model = load_model()
id2classes = load_id2classes()
tokenizer = load_tokenizer()
def get_top_classes(text):
tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
logits = model(**tokenized_text).logits.detach()
probs = Softmax()(logits)
probs_sorted, indices = sort(probs, descending=True)
k = 1
while sum(probs_sorted[0, :k]) < 0.95:
k += 1
return [id2classes[idx.item()] for idx in indices[0, :k]], probs_sorted[0, :k]
st.title("Простой классификатор статей")
title = st.text_input(label="Введите название статьи (обязательно)", value="Type Here ...")
abstract = st.text_input(label="Введите abstract", value="Type Here ...")
if st.button('Submit'):
if title.title() != 'Type Here ...':
if abstract.title() != 'Type Here ...':
text = title.title() + '\n' + abstract.title()
else:
text = title.title()
top_classes, probs = get_top_classes(text)
for p, cls in zip(probs, top_classes):
st.success(f'Статья относится к {cls} с вероятностью {p}')
else:
st.error('Введите название статьи')