akapoor commited on
Commit
0623078
1 Parent(s): d36a844

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +106 -0
  2. data.pt +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import torch
5
+ import os
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import BertTokenizerFast as BertTokenizer, AutoModelForSequenceClassification, AutoTokenizer,AutoModel,BertModel, AdamW, get_linear_schedule_with_warmup
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
11
+ from pytorch_lightning.loggers import TensorBoardLogger
12
+ import streamlit as st
13
+ import torchmetrics
14
+ pwd = os.path.dirname(__file__)
15
+ MODEL_PATH = os.path.join(pwd,"data.pt")
16
+ print(MODEL_PATH)
17
+
18
+ BERT_MODEL_NAME = 'albert-base-v1'
19
+ tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
20
+
21
+ class MeshNetwork(pl.LightningModule):
22
+ def __init__(self):
23
+ super().__init__()
24
+ self.bert = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_NAME, num_labels=13,return_dict=True)
25
+ self.criterion = F.cross_entropy
26
+
27
+ def forward(self, input_ids, attention_mask):
28
+ output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
29
+ return output.logits
30
+ def training_step(self, batch, batch_idx):
31
+ input_ids = batch["input_ids"]
32
+ attention_mask = batch["attention_mask"]
33
+ y = batch['labels']
34
+ y_hat = self.forward(input_ids, attention_mask)
35
+ loss = self.criterion(y_hat, y)
36
+ # Calculate acc
37
+ predictions = F.softmax(y_hat, dim=1).argmax(dim=1)
38
+ acc = torchmetrics.functional.accuracy(predictions, y)
39
+ self.log("train_acc", acc, on_step=False,prog_bar=True, on_epoch=True, logger=True)
40
+ self.log("train_loss", loss, prog_bar=True, on_epoch=True, logger=True)
41
+ return {"loss": loss, "predictions": y_hat, "labels": y}
42
+
43
+ def validation_step(self, batch, batch_idx):
44
+ input_ids = batch["input_ids"]
45
+ attention_mask = batch["attention_mask"]
46
+ y = batch["labels"]
47
+ y_hat = self.forward(input_ids, attention_mask)
48
+ loss = self.criterion(y_hat, y)
49
+ predictions = F.softmax(y_hat, dim=1).argmax(dim=1)
50
+ acc = torchmetrics.functional.accuracy(predictions, y)
51
+ self.log("val_acc", acc, prog_bar=True, on_step = False,on_epoch=True, logger=True)
52
+ self.log("val_loss", loss, prog_bar=True, on_epoch = True, logger=True)
53
+
54
+ def test_step(self, batch, batch_idx):
55
+ input_ids = batch["input_ids"]
56
+ attention_mask = batch["attention_mask"]
57
+ y = batch["labels"]
58
+ y_hat = self.forward(input_ids, attention_mask)
59
+ loss = self.criterion(y_hat, y)
60
+ predictions = F.softmax(y_hat, dim=1).argmax(dim=1)
61
+ acc = torchmetrics.functional.accuracy(predictions, y)
62
+ self.log("test_acc", acc, prog_bar=True, on_step=False,on_epoch=True, logger=True)
63
+ self.log("test_loss", loss, prog_bar=True, on_epoch = True, logger=True)
64
+
65
+ def configure_optimizers(self):
66
+ optimizer = torch.optim.Adam(params = self.parameters())
67
+ return optimizer
68
+
69
+
70
+
71
+ st.title("MeSH Classify")
72
+ model = MeshNetwork()
73
+ with st.spinner("Loading model..."):
74
+ model.load_state_dict(torch.load(MODEL_PATH))
75
+ model.eval()
76
+ print(model)
77
+
78
+ st.success("Model loaded.")
79
+ user_input = st.text_input("Enter text to be classified.")
80
+ st.write("Check MeSH categories: [link](https://www.ncbi.nlm.nih.gov/mesh/1000048)")
81
+ st.markdown("***")
82
+
83
+
84
+ if st.button("Classify Text"):
85
+ if user_input:
86
+ encoding = tokenizer.encode_plus(
87
+ user_input,
88
+ add_special_tokens=True,
89
+ return_token_type_ids=False,
90
+ padding="max_length",
91
+ truncation=True,
92
+ return_attention_mask=True,
93
+ return_tensors='pt',
94
+ )
95
+ input_ids=encoding["input_ids"].flatten()
96
+ attention_mask=encoding["attention_mask"].flatten()
97
+
98
+
99
+ y_hat = model(input_ids=input_ids.reshape(-1, 512),attention_mask = attention_mask.reshape(-1, 512))
100
+ prob = F.softmax(y_hat, dim=1)
101
+ probs = prob.detach().numpy()
102
+ st.table(probs)
103
+ predictions = prob.argmax(dim=1)
104
+ st.write(predictions.detach().numpy())
105
+
106
+
data.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb19f344593f34c9cb45609eccd24895e37df8eddcdc477e59b92d44b41b43fe
3
+ size 46789201
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==1.22.4
2
+ pandas==1.4.1
3
+ pytorch_lightning==1.6.3
4
+ streamlit==1.10.0
5
+ torch==1.11.0
6
+ torchmetrics==0.8.2
7
+ transformers==4.20.1