import json import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, DistilBertForSequenceClassification CHECKPOINT_PATH = "checkpoints/epoch_8.pt" LABELS_PATH = "checkpoints/labels_info.json" with open(LABELS_PATH, 'r') as f: LABELS = json.load(f) print(len(LABELS)) BASE_MODEL = "distilbert-base-cased" @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # The same model model = DistilBertForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=len(LABELS)) state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu')) model.load_state_dict(state_dict) model.eval() return tokenizer, model tokenizer, model = load_model() st.title("Классификатор научных статей по заголовку и описанию") st.write("Введите название и аннотацию статьи для предсказания её тематики по таксономии arxiv.org") title = st.text_input("Название статьи:") abstract = st.text_area("Аннотация (abstract):") if st.button("Классифицировать"): if not title and not abstract: st.warning("Введите хотя бы название статьи.") else: text = title if not abstract else f"{title} {abstract}" inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=1).squeeze() label_probs = [(label, prob.item()) for label, prob in zip(list(LABELS.values()), probs)] # Sorting for getting 95% afterwards label_probs.sort(key=lambda x: x[1], reverse=True) cumulative = 0.0 top_labels = [] for label, prob in label_probs: cumulative += prob top_labels.append((label, prob)) if cumulative >= 0.95: break # Вывод st.subheader("Наиболее вероятные тематики (суммарно ≥95%):") for label, prob in top_labels: st.write(f"**{label}** — {prob * 100:.2f}%")