Mod4_Team1 / app.py
jschwaller's picture
Update app.py
6598a12 verified
raw
history blame
4.88 kB
import streamlit as st
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")
# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model,
tokenizer=tokenizer, return_all_scores=True)
explainer = shap.Explainer(pred)
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
#
def adr_predict(x):
encoded_input = tokenizer(x, return_tensors='pt')
output = model(**encoded_input)
scores = output[0][0].detach()
scores = torch.nn.functional.softmax(scores)
shap_values = explainer([str(x).lower()])
local_plot = shap.plots.text(shap_values[0], display=False)
res = ner_pipe(x)
entity_colors = {
'Severity': '#E63946', # a vivid red
'Sign_symptom': '#2A9D8F', # a deep teal
'Medication': '#457B9D', # a dusky blue
'Age': '#F4A261', # a sandy orange
'Sex': '#F4A261', # same sandy orange for consistency with 'Age'
'Diagnostic_procedure': '#9C6644', # a brown
'Biological_structure': '#BDB2FF', # a light pastel purple
}
htext = ""
prev_end = 0
for entity in res:
start = entity['start']
end = entity['end']
word = entity['word'].replace("##", "")
color = entity_colors[entity['entity_group']]
htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>"
prev_end = end
htext += x[prev_end:]
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext
def main(prob1):
text = str(prob1).lower()
obj = adr_predict(text)
return obj[0], obj[1], obj[2]
title = "Welcome to **ADR Tracker**"
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis."
css = """
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; } /* Ensuring that all text elements are consistently light blue */
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("---")
prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...")
submit_btn = gr.Button("Analyze")
with gr.Row():
with gr.Column(visible=True):
label = gr.Label(label="Predicted Label")
with gr.Column(visible=True):
local_plot = gr.HTML(label='Shap:')
htext = gr.HTML(label="NER")
submit_btn.click(
main,
[prob1],
[label, local_plot, htext],
api_name="adr"
)
gr.Markdown("### Legend")
legend = gr.HTML(value="<div style='margin-top: 20px;'><strong>Legend:</strong><br>" +
"<mark style='background-color:#E63946;'>Severity</mark> " +
"<mark style='background-color:#2A9D8F;'>Sign/Symptom</mark> " +
"<mark style='background-color:#457B9D;'>Medication</mark> " +
"<mark style='background-color:#F4A261;'>Age/Sex</mark> " +
"<mark style='background-color:#9C6644;'>Diagnostic Procedure</mark> " +
"<mark style='background-color:#BDB2FF;'>Biological Structure</mark></div>")
gr.Row([legend])
with gr.Row():
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]],
[prob1], [label, local_plot, htext], main, cache_examples=True)