Spaces:
Sleeping
Sleeping
File size: 4,821 Bytes
acc1c62 6598a12 d18aa6e acc1c62 9513354 e8136b2 acc1c62 6598a12 a338a0a acc1c62 9c9bf1a 3ab73a4 6598a12 acc1c62 16a2843 acc1c62 b932aaf 9f1b256 acc1c62 9f1b256 acc1c62 d18aa6e acc1c62 16a2843 acc1c62 d18aa6e f0b4ad9 d136104 f0b4ad9 d136104 f0b4ad9 d136104 f0b4ad9 acc1c62 b0a2f69 d18aa6e acc1c62 d18aa6e 646850d b6c1514 d18aa6e acc1c62 d18aa6e acc1c62 d18aa6e acc1c62 6598a12 f0b4ad9 acc1c62 548049b 3034a2d d18aa6e a76acd6 f0b4ad9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import streamlit as st
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")
# Build a pipeline object for predictions
pred = transformers.pipeline("text-classification", model=model,
tokenizer=tokenizer, top_k=None)
explainer = shap.Explainer(pred)
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
#
entity_colors = {
'Severity': '#E63946', # a vivid red
'Sign_symptom': '#2A9D8F', # a deep teal
'Medication': '#457B9D', # a dusky blue
'Age': '#F4A261', # a sandy orange
'Sex': '#F4A261', # same sandy orange for consistency with 'Age'
'Diagnostic_procedure': '#9C6644', # a brown
'Biological_structure': '#BDB2FF', # a light pastel purple
}
def adr_predict(x):
encoded_input = tokenizer(x, return_tensors='pt')
output = model(**encoded_input)
scores = output[0][0].detach()
scores = torch.nn.functional.softmax(scores)
shap_values = explainer([str(x).lower()])
local_plot = shap.plots.text(shap_values[0], display=False)
res = ner_pipe(x)
htext = ""
prev_end = 0
for entity in res:
start = entity['start']
end = entity['end']
word = entity['word'].replace("##", "")
color = entity_colors[entity['entity_group']]
htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>"
prev_end = end
htext += x[prev_end:]
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext
def main(prob1):
text = str(prob1).lower()
obj = adr_predict(text)
return obj[0], obj[1], obj[2]
# Define HTML for the legend
legend_html = """
<div style='margin-top: 20px; color: white;'> <!-- Ensure the legend text is white for visibility -->
<h3>NER Legend</h3>
<ul style='list-style-type:none; padding-left: 0;'> <!-- Remove padding from the list -->
"""
for entity, color in entity_colors.items():
legend_html += f"<li><span style='color: white; background-color: {color}; padding: 5px 10px; margin-right: 5px; border-radius: 5px;'>{entity}</span></li>"
legend_html += "</ul></div>"
# Create a Gradio HTML component to display the legend
ner_legend = gr.HTML(value=legend_html)
title = "Welcome to **ADR Tracker**"
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis."
css = """
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #163E64; }
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #163E64; } /* Ensuring that all text elements are consistently light blue */
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("---")
prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...")
submit_btn = gr.Button("Analyze")
with gr.Row():
with gr.Column(visible=True):
label = gr.Label(label="Predicted Label")
with gr.Column(visible=True):
local_plot = gr.HTML(label='Shap:')
htext = gr.HTML(label="NER")
submit_btn.click(
main,
[prob1],
[label, local_plot, htext],
api_name="adr"
)
# Display the NER Legend below the buttons
ner_legend # Assuming you've defined this component above as shown
with gr.Row():
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([["A 35 year-old female had suicidal ideation after taking Prednisone."],
["A 23 year-old male had minor nausea after taking Acetaminophen."]],
[prob1], [label, local_plot, htext], main, cache_examples=True)
demo.launch()
|