Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from gradio import FlaggingCallback | |
from gradio.components import IOComponent | |
from transformers import pipeline | |
from typing import List, Optional, Any | |
import argilla as rg | |
import os | |
nlp = pipeline("ner", model="deprem-ml/deprem-ner") | |
examples = [ | |
["Lütfen yardım Akevler mahallesi Rüzgar sokak Tuncay apartmanı zemin kat Antakya akrabalarım gâçük altında #hatay #Afad"] | |
] | |
def create_record(input_text): | |
# Making the prediction | |
predictions = nlp(input_text, aggregation_strategy="first") | |
# Creating the predicted entities as a list of tuples (entity, start_char, end_char, score) | |
prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions] | |
# Create word tokens | |
batch_encoding = nlp.tokenizer(input_text) | |
word_ids = sorted(set(batch_encoding.word_ids()) - {None}) | |
words = [] | |
for word_id in word_ids: | |
char_span = batch_encoding.word_to_chars(word_id) | |
words.append(input_text[char_span.start:char_span.end]) | |
# Building a TokenClassificationRecord | |
record = rg.TokenClassificationRecord( | |
text=input_text, | |
tokens=words, | |
prediction=prediction, | |
prediction_agent="deprem-ml/deprem-ner", | |
) | |
print(record) | |
return record | |
class ArgillaLogger(FlaggingCallback): | |
def __init__(self, api_url, api_key, dataset_name): | |
rg.init(api_url=api_url, api_key=api_key) | |
self.dataset_name = dataset_name | |
def setup(self, components: List[IOComponent], flagging_dir: str): | |
pass | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
text = flag_data[0] | |
inference = flag_data[1] | |
rg.log(name=self.dataset_name, records=create_record(text)) | |
gr.Interface.load( | |
"models/deprem-ml/deprem-ner", | |
examples=examples, | |
allow_flagging="manual", | |
flagging_callback=ArgillaLogger( | |
api_url="https://merve-argilla.hf.space", | |
api_key=os.getenv("TEAM_API_KEY"), | |
dataset_name="ner-flags" | |
), | |
flagging_options=["Correct", "Incorrect", "Ambiguous"] | |
).launch() |