File size: 4,879 Bytes
acc1c62
 
 
 
 
 
 
6598a12
d18aa6e
acc1c62
 
 
 
 
 
 
 
 
9513354
e8136b2
acc1c62
6598a12
 
 
acc1c62
 
 
9c9bf1a
 
3ab73a4
6598a12
 
acc1c62
 
 
 
b932aaf
 
9f1b256
acc1c62
9f1b256
acc1c62
 
 
6598a12
 
 
 
 
 
 
9c9bf1a
acc1c62
 
 
 
 
 
 
 
 
 
 
 
 
d18aa6e
acc1c62
 
 
 
d18aa6e
acc1c62
b0a2f69
d18aa6e
acc1c62
d18aa6e
b6c1514
6598a12
b6c1514
 
d18aa6e
 
 
acc1c62
 
d18aa6e
 
acc1c62
 
 
d18aa6e
 
 
 
acc1c62
6598a12
 
 
 
 
 
 
9f1b256
6598a12
9f1b256
 
 
 
 
 
 
 
 
acc1c62
 
 
 
d18aa6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification

import matplotlib.pyplot as plt
import sys
import csv

csv.field_size_limit(sys.maxsize)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")  
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")

# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model, 
                             tokenizer=tokenizer, return_all_scores=True)

explainer = shap.Explainer(pred)

ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")

ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
#

def adr_predict(x):
    encoded_input = tokenizer(x, return_tensors='pt')
    output = model(**encoded_input)
    scores = output[0][0].detach()
    scores = torch.nn.functional.softmax(scores)
   
    shap_values = explainer([str(x).lower()])
    local_plot = shap.plots.text(shap_values[0], display=False)

    res = ner_pipe(x)
    entity_colors = {
        'Severity': '#E63946',  # a vivid red
        'Sign_symptom': '#2A9D8F',  # a deep teal
        'Medication': '#457B9D',  # a dusky blue
        'Age': '#F4A261',  # a sandy orange
        'Sex': '#F4A261',  # same sandy orange for consistency with 'Age'
        'Diagnostic_procedure': '#9C6644',  # a brown
        'Biological_structure': '#BDB2FF',  # a light pastel purple
        }

    htext = ""
    prev_end = 0
    for entity in res:
        start = entity['start']
        end = entity['end']
        word = entity['word'].replace("##", "")
        color = entity_colors[entity['entity_group']]
        
        htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>"
        prev_end = end
    htext += x[prev_end:]
   
    return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext

def main(prob1):
    text = str(prob1).lower()
    obj = adr_predict(text)
    return obj[0], obj[1], obj[2]

title = "Welcome to **ADR Tracker**"
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis."

css = """
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; } /* Ensuring that all text elements are consistently light blue */
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("---")
    prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...")
    submit_btn = gr.Button("Analyze")

    with gr.Row():
        with gr.Column(visible=True):
            label = gr.Label(label="Predicted Label")
        with gr.Column(visible=True):
            local_plot = gr.HTML(label='Shap:')
            htext = gr.HTML(label="NER")

    submit_btn.click(
        main,
        [prob1],
        [label, local_plot, htext],
        api_name="adr"
    )
    
    gr.Markdown("### Legend")
    legend = gr.HTML(value="<div style='margin-top: 20px;'><strong>Legend:</strong><br>" +
                            "<mark style='background-color:#E63946;'>Severity</mark> " +
                            "<mark style='background-color:#2A9D8F;'>Sign/Symptom</mark> " +
                            "<mark style='background-color:#457B9D;'>Medication</mark> " +
                            "<mark style='background-color:#F4A261;'>Age/Sex</mark> " +
                            "<mark style='background-color:#9C6644;'>Diagnostic Procedure</mark> " +
                            "<mark style='background-color:#BDB2FF;'>Biological Structure</mark></div>")
    gr.Row([legend])

    with gr.Row():
        gr.Markdown("### Click on any of the examples below to see how it works:")
        gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
                     ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]], 
                    [prob1], [label, local_plot, htext], main, cache_examples=True)