import streamlit as st import time # model part import json import torch from torch import nn from transformers import AutoTokenizer, AutoModelForSequenceClassification with open('categories_with_names.json', 'r') as f: cat_with_names = json.load(f) with open('categories_from_model.json', 'r') as f: categories_from_model = json.load(f) @st.cache_resource def load_models_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv") model_titles = AutoModelForSequenceClassification.from_pretrained( "powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" ) model_titles.eval() model_abstracts = AutoModelForSequenceClassification.from_pretrained( "powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" ) model_abstracts.eval() return model_titles, model_abstracts, tokenizer model_titles, model_abstracts, tokenizer = load_models_and_tokenizer() def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None): if title is None and abstract is None: raise ValueError('title is None and abstract is None') models_to_run = 2 if (title is not None and abstract is not None) else 1 proba_title = None if title is not None: progresses = (10, 30) if models_to_run == 2 else (20, 60) my_bar.progress(progresses[0], text='computing titles') input_tok = tokenizer(title, return_tensors='pt') with torch.no_grad(): logits = model_titles(**input_tok)['logits'] proba_title = torch.sigmoid(logits)[0] my_bar.progress(progresses[1], text='computed titles') proba_abstract = None if abstract is not None: progresses = (40, 70) if models_to_run == 2 else (20, 60) my_bar.progress(progresses[0], text='computing abstracts') input_tok = tokenizer(abstract, return_tensors='pt') with torch.no_grad(): logits = model_abstracts(**input_tok)['logits'] proba_abstract = torch.sigmoid(logits)[0] my_bar.progress(progresses[0], text='computed abstracts') if title is None: proba = proba_abstract elif abstract is None: proba = proba_title else: proba = proba_title * 0.1 + proba_abstract * 0.9 progresses = (80, 90) if models_to_run == 2 else (70, 90) my_bar.progress(progresses[0], text='computed proba') sorted_proba, indices = torch.sort(proba, descending=True) my_bar.progress(progresses[1], text='sorted proba') to_take = 1 while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model): to_take += 1 output = [(cat_with_names[categories_from_model[index]], proba[index].item()) for index in indices[:to_take]] my_bar.progress(100, text='generated output') return output # front part st.markdown("

Classify your paper!

", unsafe_allow_html=True) if "title" not in st.session_state: st.session_state.title = "" if "abstract" not in st.session_state: st.session_state.abstract = "" if "title_input_key" not in st.session_state: st.session_state.title_input_key = "" if "abstract_input_key" not in st.session_state: st.session_state.abstract_input_key = "" if "model_type" not in st.session_state: st.session_state.model_type = [] def input_error(): if not st.session_state.model_type: return 'you have to select title or abstract' if 'Title' in model_type and not st.session_state.title: return 'Title is empty' if 'Abstract' in model_type and not st.session_state.abstract: return 'Abstract is empty' return '' def clear_input(): st.session_state.title = title.title() st.session_state.abstract = abstract.title() if not input_error(): if "Title" in st.session_state.model_type: st.session_state.title_input_key = "" if "Abstract" in st.session_state.model_type: st.session_state.abstract_input_key = "" title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key") abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key") model_type = st.multiselect( r"$\textsf{\large Classify by:}$", ['Title', 'Abstract'], ) st.session_state.model_type = model_type if(st.button('Submit', on_click=clear_input)): if input_error(): st.error(input_error()) else: send_time = time.localtime(time.time()) #st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}") model_input = dict() if 'Title' in st.session_state.model_type: model_input['title'] = st.session_state.title if 'Abstract' in st.session_state.model_type: model_input['abstract'] = st.session_state.abstract #st.success(f'{model_input=}') my_bar = st.progress(0, text='starting model') model_result = categorize_text(**model_input, progress_bar=my_bar) st.markdown("

Classification completed!

", unsafe_allow_html=True) small_categories = [] cat, proba = model_result[0] st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$") for cat, proba in model_result[1:]: if proba < 0.1: small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%") else: st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$") if small_categories: st.write(', '.join(small_categories))