Spaces:
Running
Running
import streamlit as st | |
import gradio as gr | |
import shap | |
import numpy as np | |
import scipy as sp | |
import torch | |
import transformers | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification | |
import matplotlib.pyplot as plt | |
import sys | |
import csv | |
csv.field_size_limit(sys.maxsize) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024") | |
model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024") | |
# Build a pipeline object for predictions | |
pred = transformers.pipeline("text-classification", model=model, | |
tokenizer=tokenizer, return_all_scores=True) | |
explainer = shap.Explainer(pred) | |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") | |
def adr_predict(x): | |
encoded_input = tokenizer(x, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach() | |
scores = torch.nn.functional.softmax(scores) | |
shap_values = explainer([str(x).lower()]) | |
local_plot = shap.plots.text(shap_values[0], display=False) | |
res = ner_pipe(x) | |
entity_colors = { | |
'Severity': 'red', | |
'Sign_symptom': 'green', | |
'Medication': 'lightblue', | |
'Age': 'yellow', | |
'Sex': 'yellow', | |
'Diagnostic_procedure': 'gray', | |
'Biological_structure': 'silver' | |
} | |
htext = "" | |
prev_end = 0 | |
for entity in res: | |
start = entity['start'] | |
end = entity['end'] | |
word = entity['word'].replace("##", "") | |
color = entity_colors[entity['entity_group']] | |
htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>" | |
prev_end = end | |
htext += x[prev_end:] | |
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext | |
def main(prob1): | |
text = str(prob1).lower() | |
obj = adr_predict(text) | |
return obj[0], obj[1], obj[2] | |
title = "Welcome to **ADR Tracker**" | |
description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis." | |
css = """ | |
body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; } | |
h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; } /* Ensuring that all text elements are consistently light blue */ | |
.textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; } | |
.button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; } | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description1) | |
gr.Markdown("---") | |
prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...") | |
submit_btn = gr.Button("Analyze") | |
with gr.Row(): | |
with gr.Column(visible=True): | |
label = gr.Label(label="Predicted Label") | |
with gr.Column(visible=True): | |
local_plot = gr.HTML(label='Shap:') | |
htext = gr.HTML(label="NER") | |
submit_btn.click( | |
main, | |
[prob1], | |
[label, local_plot, htext], | |
api_name="adr" | |
) | |
with gr.Row(): | |
gr.Markdown("### Click on any of the examples below to see how it works:") | |
gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], | |
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]], | |
[prob1], [label, local_plot, htext], main, cache_examples=True) | |
demo.launch() | |