import streamlit as st import gradio as gr import shap import numpy as np import scipy as sp import torch import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification import matplotlib.pyplot as plt import sys import csv csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024") model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024") pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) explainer = shap.Explainer(pred) ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") def adr_predict(x): encoded_input = tokenizer(x, return_tensors='pt') output = model(**encoded_input) scores = output[0][0].detach() scores = torch.nn.functional.softmax(scores) shap_values = explainer([str(x).lower()]) local_plot = shap.plots.text(shap_values[0], display=False) res = ner_pipe(x) entity_colors = { 'Severity': '#E63946', 'Sign_symptom': '#2A9D8F', 'Medication': '#457B9D', 'Age': '#F4A261', 'Sex': '#F4A261', 'Diagnostic_procedure': '#9C6644', 'Biological_structure': '#BDB2FF', } htext = "" prev_end = 0 for entity in res: start = entity['start'] end = entity['end'] word = entity['word'].replace("##", "") color = entity_colors[entity['entity_group']] htext += f"{x[prev_end:start]}{word}" prev_end = end htext += x[prev_end:] return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext def main(prob1): text = str(prob1).lower() obj = adr_predict(text) return obj[0], obj[1], obj[2] title = "Welcome to **ADR Tracker**" description1 = "This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medications. Please do NOT use for medical diagnosis." css = """ body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; } h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; } .textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; } .button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; } """ with gr.Blocks(css=css) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("---") prob1 = gr.Textbox(label="Enter Your Text Here:", lines=2, placeholder="Type it here...") submit_btn = gr.Button("Analyze") with gr.Row(): with gr.Column(visible=True): label = gr.Label(label="Predicted Label") with gr.Column(visible=True): local_plot = gr.HTML(label='Shap:') htext = gr.HTML(label="NER") legend = gr.HTML(value="