Zhdantim's picture
Update app.py
283aad9 verified
raw
history blame
1.86 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
from torch.nn import Softmax
from torch import sort
@st.cache_resource # кэширование
def load_model():
return AutoModelForSequenceClassification.from_pretrained('zhdantim/mydeberta-v3-small').eval()
@st.cache_resource # кэширование
def load_id2classes():
return pd.read_csv('classes.tsv', sep='\t', index_col=0).to_dict()['classes']
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
model = load_model()
id2classes = load_id2classes()
tokenizer = load_tokenizer()
def get_top_classes(text):
tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
logits = model(**tokenized_text).logits.detach()
probs = Softmax()(logits)
probs_sorted, indices = sort(probs, descending=True)
k = 1
while sum(probs_sorted[0, :k]) < 0.95:
k += 1
return [id2classes[idx.item()] for idx in indices[0, :k]], probs_sorted[0, :k]
st.title("Простой классификатор статей")
title = st.text_input(label="Введите название статьи (обязательно)", value="Type Here ...")
abstract = st.text_input(label="Введите abstract", value="Type Here ...")
if st.button('Submit'):
if title.title() != 'Type Here ...':
if abstract.title() != 'Type Here ...':
text = title.title() + '\n' + abstract.title()
else:
text = title.title()
top_classes, probs = get_top_classes(text)
for p, cls in zip(probs, top_classes):
st.success(f'Статья относится к {cls} с вероятностью {p}')
else:
st.error('Введите название статьи')