Spaces:
Runtime error
Runtime error
File size: 2,128 Bytes
a306ba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
import pandas as pd
import json
from collections import defaultdict
# Create tokenizer for biomed model
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
# Matplotlib for entity graph
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
# Load examples from JSON
EXAMPLES = {}
with open("examples.json", "r") as f:
example_json = json.load(f)
EXAMPLES = {x["text"]: x["label"] for x in example_json}
def group_by_entity(raw):
out = defaultdict(int)
for ent in raw:
out[ent["entity_group"]] += 1
# out["total"] = sum(out.values())
return out
def plot_to_figure(grouped):
fig = plt.figure()
plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
plt.margins(0.2)
plt.subplots_adjust(bottom=0.4)
plt.xticks(rotation=90)
return fig
def ner(text):
raw = pipe(text)
ner_content = {
"text": text,
"entities": [
{
"entity": x["entity_group"],
"word": x["word"],
"score": x["score"],
"start": x["start"],
"end": x["end"],
}
for x in raw
],
}
grouped = group_by_entity(raw)
figure = plot_to_figure(grouped)
label = EXAMPLES.get(text, "Unknown")
meta = {
"entity_counts": grouped,
"entities": len(set(grouped.keys())),
"counts": sum(grouped.values()),
}
return (ner_content, meta, label, figure)
interface = gr.Interface(
ner,
inputs=gr.Textbox(label="Note text", value=""),
outputs=[
gr.HighlightedText(label="NER", combine_adjacent=True),
gr.JSON(label="Entity Counts"),
gr.Label(label="Rating"),
gr.Plot(label="Bar"),
],
examples=list(EXAMPLES.keys()),
allow_flagging="never",
)
interface.launch() |