import streamlit as st import time # model part import json import torch from torch import nn from transformers import AutoTokenizer, AutoModelForSequenceClassification with open('categories_with_names.json', 'r') as f: cat_with_names = json.load(f) with open('categories_from_model.json', 'r') as f: categories_from_model = json.load(f) @st.cache_resource def load_models_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv") model_titles = AutoModelForSequenceClassification.from_pretrained( "powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" ) model_titles.eval() model_abstracts = AutoModelForSequenceClassification.from_pretrained( "powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" ) model_abstracts.eval() return model_titles, model_abstracts, tokenizer model_titles, model_abstracts, tokenizer = load_models_and_tokenizer() def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None): if title is None and abstract is None: raise ValueError('title is None and abstract is None') models_to_run = 2 if (title is not None and abstract is not None) else 1 proba_title = None if title is not None: progresses = (10, 30) if models_to_run == 2 else (20, 60) my_bar.progress(progresses[0], text='computing titles') input_tok = tokenizer(title, return_tensors='pt') with torch.no_grad(): logits = model_titles(**input_tok)['logits'] proba_title = torch.sigmoid(logits)[0] my_bar.progress(progresses[1], text='computed titles') proba_abstract = None if abstract is not None: progresses = (40, 70) if models_to_run == 2 else (20, 60) my_bar.progress(progresses[0], text='computing abstracts') input_tok = tokenizer(abstract, return_tensors='pt') with torch.no_grad(): logits = model_abstracts(**input_tok)['logits'] proba_abstract = torch.sigmoid(logits)[0] my_bar.progress(progresses[0], text='computed abstracts') if title is None: proba = proba_abstract elif abstract is None: proba = proba_title else: proba = proba_title * 0.1 + proba_abstract * 0.9 progresses = (80, 90) if models_to_run == 2 else (70, 90) my_bar.progress(progresses[0], text='computed proba') sorted_proba, indices = torch.sort(proba, descending=True) my_bar.progress(progresses[1], text='sorted proba') to_take = 1 while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model): to_take += 1 output = [(cat_with_names[categories_from_model[index]], proba[index].item()) for index in indices[:to_take]] my_bar.progress(100, text='generated output') return output # front part st.markdown("