Spaces:
Sleeping
Sleeping
added app file and back file
Browse files
app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from back import load_setup, predict
|
4 |
+
|
5 |
+
|
6 |
+
st.markdown("""# Hi, let's find some arxiv tags of papers! <img width=50px src='https://tahayassine.me/media/icon_huecf151563ab8d381649e162e44411916_11613_512x512_fill_lanczos_center_3.png'>""",
|
7 |
+
unsafe_allow_html=True)
|
8 |
+
|
9 |
+
title = st.text_area("$\LARGE \\text{Title of the paper:}$", placeholder='')
|
10 |
+
abstract = st.text_area("$\LARGE \\text{Abstract of the paper(if any):}$", placeholder='')
|
11 |
+
|
12 |
+
if st.button('Generate tags!'):
|
13 |
+
if title == '':
|
14 |
+
st.write('# Please, enter title!')
|
15 |
+
else:
|
16 |
+
text = title + '[SEP]' + abstract
|
17 |
+
model, tokenizer = load_setup(path_to_model=, path_to_vocab=)
|
18 |
+
prediction = predict(model, tokenizer, text)
|
19 |
+
st.markdown(f'## Predicted tag(s): {prediction}')
|
20 |
+
else:
|
21 |
+
pass
|
back.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from typing import TypeVar
|
6 |
+
from transformers import DistilBertTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
ModelType = TypeVar('ModelType')
|
10 |
+
TokenizerType = TypeVar('TokenizerType')
|
11 |
+
|
12 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
13 |
+
label_mapping = {
|
14 |
+
0: 'adap-org', 1: 'astro-ph', 2: 'astro-ph.CO', 3: 'astro-ph.EP',
|
15 |
+
4: 'astro-ph.GA', 5: 'astro-ph.IM', 6: 'astro-ph.SR', 7: 'cmp-lg',
|
16 |
+
8: 'cond-mat', 9: 'cond-mat.dis-nn', 10: 'cond-mat.mtrl-sci', 11: 'cond-mat.other',
|
17 |
+
12: 'cond-mat.soft', 13: 'cond-mat.stat-mech', 14: 'cond-mat.supr-con', 15: 'cs.AI',
|
18 |
+
16: 'cs.AR', 17: 'cs.CC', 18: 'cs.CE', 19: 'cs.CG', 20: 'cs.CL', 21: 'cs.CR',
|
19 |
+
22: 'cs.CV', 23: 'cs.CY', 24: 'cs.DB', 25: 'cs.DC', 26: 'cs.DL', 27: 'cs.DM',
|
20 |
+
28: 'cs.DS', 29: 'cs.ET', 30: 'cs.FL', 31: 'cs.GL', 32: 'cs.GR', 33: 'cs.GT',
|
21 |
+
34: 'cs.HC', 35: 'cs.IR', 36: 'cs.IT', 37: 'cs.LG', 38: 'cs.LO', 39: 'cs.MA',
|
22 |
+
40: 'cs.MM', 41: 'cs.MS', 42: 'cs.NA', 43: 'cs.NE', 44: 'cs.NI', 45: 'cs.OH',
|
23 |
+
46: 'cs.OS', 47: 'cs.PF', 48: 'cs.PL', 49: 'cs.RO', 50: 'cs.SC', 51: 'cs.SD',
|
24 |
+
52: 'cs.SE', 53: 'cs.SI', 54: 'cs.SY', 55: 'econ.EM', 56: 'eess.AS',
|
25 |
+
57: 'eess.IV', 58: 'eess.SP', 59: 'gr-qc', 60: 'hep-ex', 61: 'hep-lat',
|
26 |
+
62: 'hep-ph', 63: 'hep-th', 64: 'math.AG', 65: 'math.AP', 66: 'math.AT',
|
27 |
+
67: 'math.CA', 68: 'math.CO', 69: 'math.CT', 70: 'math.DG', 71: 'math.DS',
|
28 |
+
72: 'math.FA', 73: 'math.GM', 74: 'math.GN', 75: 'math.GR', 76: 'math.GT',
|
29 |
+
77: 'math.HO', 78: 'math.LO', 79: 'math.MG', 80: 'math.NA', 81: 'math.NT',
|
30 |
+
82: 'math.OC', 83: 'math.PR', 84: 'math.RA', 85: 'math.RT', 86: 'math.ST',
|
31 |
+
87: 'nlin.AO', 88: 'nlin.CD', 89: 'nlin.CG', 90: 'nlin.PS', 91: 'nucl-th',
|
32 |
+
92: 'physics.ao-ph', 93: 'physics.bio-ph', 94: 'physics.chem-ph',
|
33 |
+
95: 'physics.class-ph', 96: 'physics.comp-ph', 97: 'physics.data-an',
|
34 |
+
98: 'physics.gen-ph', 99: 'physics.geo-ph', 100: 'physics.hist-ph',
|
35 |
+
101: 'physics.ins-det', 102: 'physics.med-ph', 103: 'physics.optics',
|
36 |
+
104: 'physics.soc-ph', 105: 'q-bio.BM', 106: 'q-bio.CB', 107: 'q-bio.GN',
|
37 |
+
108: 'q-bio.MN', 109: 'q-bio.NC', 110: 'q-bio.PE', 111: 'q-bio.QM',
|
38 |
+
112: 'q-bio.TO', 113: 'q-fin.CP', 114: 'q-fin.EC', 115: 'q-fin.GN',
|
39 |
+
116: 'q-fin.PM', 117: 'q-fin.RM', 118: 'q-fin.ST', 119: 'q-fin.TR',
|
40 |
+
120: 'quant-ph', 121: 'stat.AP', 122: 'stat.CO', 123: 'stat.ME',
|
41 |
+
124: 'stat.ML', 125: 'stat.OT'
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
@st.cache_resource
|
46 |
+
def load_setup(path_to_model: str, path_to_vocab: str) -> tuple[ModelType, TokenizerType]:
|
47 |
+
loaded_model = torch.load(path_to_model, map_location=device)
|
48 |
+
loaded_tokenizer = DistilBertTokenizer(path_to_vocab)
|
49 |
+
return loaded_model, loaded_tokenizer
|
50 |
+
|
51 |
+
|
52 |
+
@st.cache_data
|
53 |
+
def predict(model: ModelType, tokenizer: TokenizerType, input_text: str, max_length: int = 512) -> str:
|
54 |
+
inputs = tokenizer.encode_plus(
|
55 |
+
input_text,
|
56 |
+
add_special_tokens=True,
|
57 |
+
max_length=max_length,
|
58 |
+
padding='max_length',
|
59 |
+
return_token_type_ids=True,
|
60 |
+
truncation=True
|
61 |
+
)
|
62 |
+
|
63 |
+
ids = torch.tensor(inputs['input_ids']).to(device, dtype=torch.long)
|
64 |
+
mask = torch.tensor(inputs['attention_mask']).to(device, dtype=torch.long)
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
output_for_sentence = model(ids, mask).squeeze()
|
68 |
+
preds = torch.nn.functional.softmax(output_for_sentence).cpu()
|
69 |
+
ind = np.argpartition(preds, -5)[-5:]
|
70 |
+
top5_ind = ind[np.argsort(preds[ind])]
|
71 |
+
top5_tags = ''
|
72 |
+
for pred_label in top5_ind.flip(0):
|
73 |
+
top5_tags += label_mapping[pred_label.item()] + ', '
|
74 |
+
return top5_tags[:-2]
|