surkovvv's picture
Update back.py
ddd86a3
raw
history blame
3.6 kB
import streamlit as st
import numpy as np
import torch
from typing import TypeVar, Tuple
from transformers import DistilBertTokenizer, DistilBertModel
from model import DistillBERTClass
ModelType = TypeVar('ModelType')
TokenizerType = TypeVar('TokenizerType')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
label_mapping = {
0: 'adap-org', 1: 'astro-ph', 2: 'astro-ph.CO', 3: 'astro-ph.EP',
4: 'astro-ph.GA', 5: 'astro-ph.IM', 6: 'astro-ph.SR', 7: 'cmp-lg',
8: 'cond-mat', 9: 'cond-mat.dis-nn', 10: 'cond-mat.mtrl-sci', 11: 'cond-mat.other',
12: 'cond-mat.soft', 13: 'cond-mat.stat-mech', 14: 'cond-mat.supr-con', 15: 'cs.AI',
16: 'cs.AR', 17: 'cs.CC', 18: 'cs.CE', 19: 'cs.CG', 20: 'cs.CL', 21: 'cs.CR',
22: 'cs.CV', 23: 'cs.CY', 24: 'cs.DB', 25: 'cs.DC', 26: 'cs.DL', 27: 'cs.DM',
28: 'cs.DS', 29: 'cs.ET', 30: 'cs.FL', 31: 'cs.GL', 32: 'cs.GR', 33: 'cs.GT',
34: 'cs.HC', 35: 'cs.IR', 36: 'cs.IT', 37: 'cs.LG', 38: 'cs.LO', 39: 'cs.MA',
40: 'cs.MM', 41: 'cs.MS', 42: 'cs.NA', 43: 'cs.NE', 44: 'cs.NI', 45: 'cs.OH',
46: 'cs.OS', 47: 'cs.PF', 48: 'cs.PL', 49: 'cs.RO', 50: 'cs.SC', 51: 'cs.SD',
52: 'cs.SE', 53: 'cs.SI', 54: 'cs.SY', 55: 'econ.EM', 56: 'eess.AS',
57: 'eess.IV', 58: 'eess.SP', 59: 'gr-qc', 60: 'hep-ex', 61: 'hep-lat',
62: 'hep-ph', 63: 'hep-th', 64: 'math.AG', 65: 'math.AP', 66: 'math.AT',
67: 'math.CA', 68: 'math.CO', 69: 'math.CT', 70: 'math.DG', 71: 'math.DS',
72: 'math.FA', 73: 'math.GM', 74: 'math.GN', 75: 'math.GR', 76: 'math.GT',
77: 'math.HO', 78: 'math.LO', 79: 'math.MG', 80: 'math.NA', 81: 'math.NT',
82: 'math.OC', 83: 'math.PR', 84: 'math.RA', 85: 'math.RT', 86: 'math.ST',
87: 'nlin.AO', 88: 'nlin.CD', 89: 'nlin.CG', 90: 'nlin.PS', 91: 'nucl-th',
92: 'physics.ao-ph', 93: 'physics.bio-ph', 94: 'physics.chem-ph',
95: 'physics.class-ph', 96: 'physics.comp-ph', 97: 'physics.data-an',
98: 'physics.gen-ph', 99: 'physics.geo-ph', 100: 'physics.hist-ph',
101: 'physics.ins-det', 102: 'physics.med-ph', 103: 'physics.optics',
104: 'physics.soc-ph', 105: 'q-bio.BM', 106: 'q-bio.CB', 107: 'q-bio.GN',
108: 'q-bio.MN', 109: 'q-bio.NC', 110: 'q-bio.PE', 111: 'q-bio.QM',
112: 'q-bio.TO', 113: 'q-fin.CP', 114: 'q-fin.EC', 115: 'q-fin.GN',
116: 'q-fin.PM', 117: 'q-fin.RM', 118: 'q-fin.ST', 119: 'q-fin.TR',
120: 'quant-ph', 121: 'stat.AP', 122: 'stat.CO', 123: 'stat.ME',
124: 'stat.ML', 125: 'stat.OT'
}
def load_setup(path_to_model: str, path_to_vocab: str) -> Tuple[ModelType, TokenizerType]:
loaded_model = torch.load(path_to_model, map_location=device)
loaded_tokenizer = DistilBertTokenizer(path_to_vocab)
return loaded_model, loaded_tokenizer
def predict(model: ModelType, tokenizer: TokenizerType, input_text: str, max_length: int = 512) -> str:
inputs = tokenizer.encode_plus(
input_text,
add_special_tokens=True,
max_length=max_length,
padding='max_length',
return_token_type_ids=True,
truncation=True
)
ids = torch.tensor(inputs['input_ids']).to(device, dtype=torch.long)
mask = torch.tensor(inputs['attention_mask']).to(device, dtype=torch.long)
with torch.no_grad():
output_for_sentence = model(ids, mask).squeeze()
preds = torch.nn.functional.softmax(output_for_sentence).cpu()
ind = np.argpartition(preds, -5)[-5:]
top5_ind = ind[np.argsort(preds[ind])]
top5_tags = ''
for pred_label in top5_ind.flip(0):
top5_tags += label_mapping[pred_label.item()] + ', '
return top5_tags[:-2]