Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import json | |
from collections import defaultdict | |
# Create tokenizer for biomed model | |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification | |
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
# Matplotlib for entity graph | |
import matplotlib.pyplot as plt | |
plt.switch_backend("Agg") | |
# Load examples from JSON | |
EXAMPLES = {} | |
with open("examples.json", "r") as f: | |
example_json = json.load(f) | |
EXAMPLES = {x["text"]: x["label"] for x in example_json} | |
def group_by_entity(raw): | |
out = defaultdict(int) | |
for ent in raw: | |
out[ent["entity_group"]] += 1 | |
# out["total"] = sum(out.values()) | |
return out | |
def plot_to_figure(grouped): | |
fig = plt.figure() | |
plt.bar(x=list(grouped.keys()), height=list(grouped.values())) | |
plt.margins(0.2) | |
plt.subplots_adjust(bottom=0.4) | |
plt.xticks(rotation=90) | |
return fig | |
def ner(text): | |
raw = pipe(text) | |
ner_content = { | |
"text": text, | |
"entities": [ | |
{ | |
"entity": x["entity_group"], | |
"word": x["word"], | |
"score": x["score"], | |
"start": x["start"], | |
"end": x["end"], | |
} | |
for x in raw | |
], | |
} | |
grouped = group_by_entity(raw) | |
figure = plot_to_figure(grouped) | |
label = EXAMPLES.get(text, "Unknown") | |
meta = { | |
"entity_counts": grouped, | |
"entities": len(set(grouped.keys())), | |
"counts": sum(grouped.values()), | |
} | |
return (ner_content, meta, label, figure) | |
interface = gr.Interface( | |
ner, | |
inputs=gr.Textbox(label="Note text", value=""), | |
outputs=[ | |
gr.HighlightedText(label="NER", combine_adjacent=True), | |
gr.JSON(label="Entity Counts"), | |
gr.Label(label="Rating"), | |
gr.Plot(label="Bar"), | |
], | |
examples=list(EXAMPLES.keys()), | |
allow_flagging="never", | |
) | |
interface.launch() |