File size: 4,821 Bytes
acc1c62
 
 
 
 
 
 
6598a12
d18aa6e
acc1c62
 
 
 
 
 
 
 
 
9513354
e8136b2
acc1c62
6598a12
 
a338a0a
acc1c62
 
 
9c9bf1a
 
3ab73a4
6598a12
 
acc1c62
16a2843
 
 
 
 
 
 
 
 
 
acc1c62
 
 
b932aaf
 
9f1b256
acc1c62
9f1b256
acc1c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18aa6e
acc1c62
16a2843
acc1c62
 
 
d18aa6e
f0b4ad9
 
 
d136104
f0b4ad9
d136104
f0b4ad9
 
 
d136104
f0b4ad9
 
 
 
 
acc1c62
b0a2f69
d18aa6e
acc1c62
d18aa6e
646850d
 
b6c1514
 
d18aa6e
 
 
acc1c62
 
d18aa6e
 
acc1c62
 
 
d18aa6e
 
 
 
acc1c62
6598a12
 
 
 
 
 
 
f0b4ad9
 
 
 
acc1c62
 
548049b
3034a2d
d18aa6e
a76acd6
 
f0b4ad9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

import matplotlib.pyplot as plt
import sys
import csv

csv.field_size_limit(sys.maxsize)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")  
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")

# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, 
                             tokenizer=tokenizer, top_k=None)

explainer = shap.Explainer(pred)

ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")

ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
#

entity_colors = {
    'Severity': '#E63946',  # a vivid red
    'Sign_symptom': '#2A9D8F',  # a deep teal
    'Medication': '#457B9D',  # a dusky blue
    'Age': '#F4A261',  # a sandy orange
    'Sex': '#F4A261',  # same sandy orange for consistency with 'Age'
    'Diagnostic_procedure': '#9C6644',  # a brown
    'Biological_structure': '#BDB2FF',  # a light pastel purple
}

def adr_predict(x):
    encoded_input = tokenizer(x, return_tensors='pt')
    output = model(**encoded_input)
    scores = output[0][0].detach()
    scores = torch.nn.functional.softmax(scores)
   
    shap_values = explainer([str(x).lower()])
    local_plot = shap.plots.text(shap_values[0], display=False)

    res = ner_pipe(x)

    htext = ""
    prev_end = 0
    for entity in res:
        start = entity['start']
        end = entity['end']
        word = entity['word'].replace("##", "")
        color = entity_colors[entity['entity_group']]
        
        htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>"
        prev_end = end
    htext += x[prev_end:]
   
    return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext


def main(prob1):
    text = str(prob1).lower()
    obj = adr_predict(text)
    return obj[0], obj[1], obj[2]
    
# Define HTML for the legend
legend_html = """
<div style='margin-top: 20px; color: white;'>  <!-- Ensure the legend text is white for visibility -->
    <h3>NER Legend</h3>
    <ul style='list-style-type:none; padding-left: 0;'>  <!-- Remove padding from the list -->
"""

for entity, color in entity_colors.items():
    legend_html += f"<li><span style='color: white; background-color: {color}; padding: 5px 10px; margin-right: 5px; border-radius: 5px;'>{entity}</span></li>"

legend_html += "</ul></div>"

# Create a Gradio HTML component to display the legend
ner_legend = gr.HTML(value=legend_html)

title = "Welcome to **ADR Tracker**"
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis."

css = """
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #163E64; }
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #163E64; } /* Ensuring that all text elements are consistently light blue */
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("---")
    prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...")
    submit_btn = gr.Button("Analyze")

    with gr.Row():
        with gr.Column(visible=True):
            label = gr.Label(label="Predicted Label")
        with gr.Column(visible=True):
            local_plot = gr.HTML(label='Shap:')
            htext = gr.HTML(label="NER")

    submit_btn.click(
        main,
        [prob1],
        [label, local_plot, htext],
        api_name="adr"
    )

    # Display the NER Legend below the buttons
    ner_legend  # Assuming you've defined this component above as shown

    with gr.Row():
        gr.Markdown("### Click on any of the examples below to see how it works:")
        gr.Examples([["A 35 year-old female had suicidal ideation after taking Prednisone."],
                     ["A 23 year-old male had minor nausea after taking Acetaminophen."]], 
                    [prob1], [label, local_plot, htext], main, cache_examples=True)

demo.launch()