File size: 4,517 Bytes
acc1c62
 
 
 
 
 
 
d18aa6e
acc1c62
 
 
 
 
 
 
 
 
9513354
e8136b2
acc1c62
9f1b256
acc1c62
 
 
9c9bf1a
 
3ab73a4
9f1b256
acc1c62
 
 
 
b932aaf
 
9f1b256
acc1c62
9f1b256
acc1c62
 
 
9f1b256
 
 
 
 
 
 
9c9bf1a
acc1c62
 
 
 
 
 
 
 
 
 
 
 
 
d18aa6e
acc1c62
 
 
 
d18aa6e
acc1c62
b0a2f69
d18aa6e
acc1c62
d18aa6e
b6c1514
9f1b256
b6c1514
 
d18aa6e
 
 
acc1c62
 
d18aa6e
 
acc1c62
 
 
d18aa6e
 
 
 
acc1c62
9f1b256
 
 
 
 
 
 
 
acc1c62
 
 
d18aa6e
 
acc1c62
 
9f1b256
 
acc1c62
 
 
 
d18aa6e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

import matplotlib.pyplot as plt
import sys
import csv

csv.field_size_limit(sys.maxsize)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")  
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")

pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)

explainer = shap.Explainer(pred)

ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")

ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")

def adr_predict(x):
    encoded_input = tokenizer(x, return_tensors='pt')
    output = model(**encoded_input)
    scores = output[0][0].detach()
    scores = torch.nn.functional.softmax(scores)
   
    shap_values = explainer([str(x).lower()])
    local_plot = shap.plots.text(shap_values[0], display=False)

    res = ner_pipe(x)
    entity_colors = {
        'Severity': '#E63946',
        'Sign_symptom': '#2A9D8F',
        'Medication': '#457B9D',
        'Age': '#F4A261',
        'Sex': '#F4A261',
        'Diagnostic_procedure': '#9C6644',
        'Biological_structure': '#BDB2FF',
        }

    htext = ""
    prev_end = 0
    for entity in res:
        start = entity['start']
        end = entity['end']
        word = entity['word'].replace("##", "")
        color = entity_colors[entity['entity_group']]
        
        htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>"
        prev_end = end
    htext += x[prev_end:]
   
    return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext

def main(prob1):
    text = str(prob1).lower()
    obj = adr_predict(text)
    return obj[0], obj[1], obj[2]

title = "Welcome to **ADR Tracker**"
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis."

css = """
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; }
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("---")
    prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...")
    submit_btn = gr.Button("Analyze")

    with gr.Row():
        with gr.Column(visible=True):
            label = gr.Label(label="Predicted Label")
        with gr.Column(visible=True):
            local_plot = gr.HTML(label='Shap:')
            htext = gr.HTML(label="NER")
    
    legend = gr.HTML(value="<div style='margin-top: 20px;'><strong>Legend:</strong><br>" +
                            "<mark style='background-color:#E63946;'>Severity</mark> " +
                            "<mark style='background-color:#2A9D8F;'>Sign/Symptom</mark> " +
                            "<mark style='background-color:#457B9D;'>Medication</mark> " +
                            "<mark style='background-color:#F4A261;'>Age/Sex</mark> " +
                            "<mark style='background-color:#9C6644;'>Diagnostic Procedure</mark> " +
                            "<mark style='background-color:#BDB2FF;'>Biological Structure</mark></div>")
    submit_btn.click(
        main,
        [prob1],
        [label, local_plot, htext],
        api_name="adr"
    )
    
    gr.Row([legend])

    with gr.Row():
        gr.Markdown("### Click on any of the examples below to see how it works:")
        gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
                     ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]], 
                    [prob1], [label, local_plot, htext], main, cache_examples=True)

demo.launch()