File size: 2,269 Bytes
03ff95e
1cedc15
03ff95e
 
1cedc15
03ff95e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495bb87
03ff95e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent

from transformers import pipeline

from typing import List, Optional, Any

import argilla as rg

nlp = pipeline("ner", model="deprem-ml/deprem-ner")

examples = [
  ["Lütfen yardım Akevler mahallesi Rüzgar sokak Tuncay apartmanı zemin kat Antakya akrabalarım gâçük altında #hatay #Afad"]
]

def create_record(input_text):
    # Making the prediction
    predictions = nlp(input_text, aggregation_strategy="first")
    
    # Creating the predicted entities as a list of tuples (entity, start_char, end_char, score)
    prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions]
    
    # Create word tokens
    batch_encoding = nlp.tokenizer(input_text)
    word_ids = sorted(set(batch_encoding.word_ids()) - {None})
    words = []
    for word_id in word_ids:
        char_span = batch_encoding.word_to_chars(word_id)
        words.append(input_text[char_span.start:char_span.end])
    
    # Building a TokenClassificationRecord
    record = rg.TokenClassificationRecord(
        text=input_text,
        tokens=words,
        prediction=prediction,
        prediction_agent="deprem-ml/deprem-ner",
    )
    print(record)
    return record

class ArgillaLogger(FlaggingCallback):
    def __init__(self, api_url, api_key, dataset_name):
        rg.init(api_url=api_url, api_key=api_key)
        self.dataset_name = dataset_name
    def setup(self, components: List[IOComponent], flagging_dir: str):
        pass
    def flag(
        self,
        flag_data: List[Any],
        flag_option: Optional[str] = None,
        flag_index: Optional[int] = None,
        username: Optional[str] = None,
    ) -> int:
        text = flag_data[0]
        inference = flag_data[1]
        rg.log(name=self.dataset_name, records=create_record(text))


        
gr.Interface.load(
    "models/deprem-ml/deprem-ner",
    examples=examples,
    allow_flagging="manual",
    flagging_callback=ArgillaLogger(
        api_url="https://dvilasuero-argilla-template-1-3.hf.space", 
        api_key="team.apikey", 
        dataset_name="ner-flags"
    ),
    flagging_options=["Correct", "Incorrect", "Ambiguous"]
).launch()