Update README.md
Browse files
README.md
CHANGED
@@ -8,3 +8,61 @@ widget:
|
|
8 |
- text: 0.05 ML aflibercept 40 MG/ML Injection is contraindicated with [MASK].
|
9 |
- text: The most import ingrediant for Excedrin is [MASK].
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
- text: 0.05 ML aflibercept 40 MG/ML Injection is contraindicated with [MASK].
|
9 |
- text: The most import ingrediant for Excedrin is [MASK].
|
10 |
---
|
11 |
+
|
12 |
+
## Overview
|
13 |
+
|
14 |
+
This repository contains the bert_base_uncased_rxnorm_babbage model, a continually pretrained [Bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) model with drugs, diseases, and their relationships from RxNorm using masked language modeling.
|
15 |
+
We hypothesize that the augmentation can boost the model's understanding of medical terminologies and contexts.
|
16 |
+
|
17 |
+
It uses a corpus comprising approximately 8.8M million tokens sythesized using drug and disease relations harvested from RxNorm. A few exampes show below.
|
18 |
+
```plaintext
|
19 |
+
ferrous fumarate 191 MG is contraindicated with Hemochromatosis.
|
20 |
+
24 HR metoprolol succinate 50 MG Extended Release Oral Capsule [Kapspargo] contains the ingredient Metoprolol.
|
21 |
+
Genvoya has the established pharmacologic class Cytochrome P450 3A Inhibitor.
|
22 |
+
cefprozil 250 MG Oral Tablet may be used to treat Haemophilus Infections.
|
23 |
+
mecobalamin 1 MG Sublingual Tablet contains the ingredient Vitamin B 12.
|
24 |
+
```
|
25 |
+
|
26 |
+
The dataset is hosted at [this commit](https://github.com/Su-informatics-lab/drug_disease_graph/blob/3a598cb9d55ffbb52d2f16e61eafff4dfefaf5b1/rxnorm.txt).
|
27 |
+
Note, this is the babbage version of the corpus using *all* drug and disease relations.
|
28 |
+
Don't confuse it with the ada version, where only a fraction of the relationships are used (see [the repo](https://github.com/Su-informatics-lab/drug_disease_graph/tree/main) for more information).
|
29 |
+
|
30 |
+
## Training
|
31 |
+
|
32 |
+
15% of the data was masked for prediction.
|
33 |
+
The model processes this data for *20* epochs.
|
34 |
+
Training happens on 4 A40(48G) using python3.8 (tried to match up dependencies specified at [requirements.txt](https://github.com/Su-informatics-lab/rxnorm_gatortron/blob/1f15ad349056e22089118519becf1392df084701/requirements.txt)).
|
35 |
+
It has a batch size of 16 and a learning rate of 5e-5.
|
36 |
+
See more configuration at [GitHub](https://github.com/Su-informatics-lab/rxnorm_gatortron/blob/main/runs_mlm_bert_base_uncased_rxnorm_babbage.sh) and training curves at [WandB](https://wandb.ai/hainingwang/continual_pretraining_gatortron/runs/1abzfvb9).
|
37 |
+
|
38 |
+
|
39 |
+
## Usage
|
40 |
+
You can use this model for masked language modeling tasks to predict missing words in a given text.
|
41 |
+
Below are the instructions and examples to get you started.
|
42 |
+
|
43 |
+
```python
|
44 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
45 |
+
import torch
|
46 |
+
|
47 |
+
# load the tokenizer and model
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained("Su-informatics-lab/bert_base_uncased_rxnorm_babbage")
|
49 |
+
model = AutoModelForMaskedLM.from_pretrained("Su-informatics-lab/bert_base_uncased_rxnorm_babbage")
|
50 |
+
|
51 |
+
# prepare the input
|
52 |
+
text = "0.05 ML aflibercept 40 MG/ML Injection is contraindicated with [MASK]."
|
53 |
+
inputs = tokenizer(text, return_tensors="pt")
|
54 |
+
|
55 |
+
# get model predictions
|
56 |
+
with torch.no_grad():
|
57 |
+
outputs = model(**inputs)
|
58 |
+
|
59 |
+
# decode the predictions
|
60 |
+
predictions = outputs.logits
|
61 |
+
masked_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
62 |
+
predicted_token_id = predictions[0, masked_index].argmax(axis=-1)
|
63 |
+
predicted_token = tokenizer.decode(predicted_token_id)
|
64 |
+
```
|
65 |
+
|
66 |
+
## License
|
67 |
+
Apache 2.0.
|
68 |
+
|