Spaces:
Runtime error
Runtime error
File size: 2,276 Bytes
03ff95e 1cedc15 03ff95e 1cedc15 03ff95e 6bd075e 03ff95e aa8112b e6f3be0 03ff95e 495bb87 03ff95e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent
from transformers import pipeline
from typing import List, Optional, Any
import argilla as rg
import os
nlp = pipeline("ner", model="deprem-ml/deprem-ner")
examples = [
["Lütfen yardım Akevler mahallesi Rüzgar sokak Tuncay apartmanı zemin kat Antakya akrabalarım gâçük altında #hatay #Afad"]
]
def create_record(input_text):
# Making the prediction
predictions = nlp(input_text, aggregation_strategy="first")
# Creating the predicted entities as a list of tuples (entity, start_char, end_char, score)
prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions]
# Create word tokens
batch_encoding = nlp.tokenizer(input_text)
word_ids = sorted(set(batch_encoding.word_ids()) - {None})
words = []
for word_id in word_ids:
char_span = batch_encoding.word_to_chars(word_id)
words.append(input_text[char_span.start:char_span.end])
# Building a TokenClassificationRecord
record = rg.TokenClassificationRecord(
text=input_text,
tokens=words,
prediction=prediction,
prediction_agent="deprem-ml/deprem-ner",
)
print(record)
return record
class ArgillaLogger(FlaggingCallback):
def __init__(self, api_url, api_key, dataset_name):
rg.init(api_url=api_url, api_key=api_key)
self.dataset_name = dataset_name
def setup(self, components: List[IOComponent], flagging_dir: str):
pass
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
text = flag_data[0]
inference = flag_data[1]
rg.log(name=self.dataset_name, records=create_record(text))
gr.Interface.load(
"models/deprem-ml/deprem-ner",
examples=examples,
allow_flagging="manual",
flagging_callback=ArgillaLogger(
api_url="https://merve-argilla.hf.space",
api_key=os.getenv("TEAM_API_KEY"),
dataset_name="ner-flags"
),
flagging_options=["Correct", "Incorrect", "Ambiguous"]
).launch() |