versaggi's picture
Upload 4 files
a306ba7
import gradio as gr
import pandas as pd
import json
from collections import defaultdict
# Create tokenizer for biomed model
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
# Matplotlib for entity graph
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
# Load examples from JSON
EXAMPLES = {}
with open("examples.json", "r") as f:
example_json = json.load(f)
EXAMPLES = {x["text"]: x["label"] for x in example_json}
def group_by_entity(raw):
out = defaultdict(int)
for ent in raw:
out[ent["entity_group"]] += 1
# out["total"] = sum(out.values())
return out
def plot_to_figure(grouped):
fig = plt.figure()
plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
plt.margins(0.2)
plt.subplots_adjust(bottom=0.4)
plt.xticks(rotation=90)
return fig
def ner(text):
raw = pipe(text)
ner_content = {
"text": text,
"entities": [
{
"entity": x["entity_group"],
"word": x["word"],
"score": x["score"],
"start": x["start"],
"end": x["end"],
}
for x in raw
],
}
grouped = group_by_entity(raw)
figure = plot_to_figure(grouped)
label = EXAMPLES.get(text, "Unknown")
meta = {
"entity_counts": grouped,
"entities": len(set(grouped.keys())),
"counts": sum(grouped.values()),
}
return (ner_content, meta, label, figure)
interface = gr.Interface(
ner,
inputs=gr.Textbox(label="Note text", value=""),
outputs=[
gr.HighlightedText(label="NER", combine_adjacent=True),
gr.JSON(label="Entity Counts"),
gr.Label(label="Rating"),
gr.Plot(label="Bar"),
],
examples=list(EXAMPLES.keys()),
allow_flagging="never",
)
interface.launch()