File size: 5,244 Bytes
885b434
 
eaeabb1
885b434
5a63293
cb34ab7
 
5a63293
 
885b434
4c9facd
0da98ac
 
 
 
 
 
 
 
 
 
4c9facd
 
885b434
eaeabb1
 
78ae48c
b9bec37
885b434
1dd5bbf
eaeabb1
 
1dd5bbf
eaeabb1
1af1aed
eaeabb1
1dd5bbf
eaeabb1
82c6a1e
f0e5035
82c6a1e
1dd5bbf
 
78ae48c
82c6a1e
78ae48c
1dd5bbf
78ae48c
9a2fed2
95c607e
fb48df9
 
 
 
 
 
 
 
 
 
 
95c607e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0c2053
 
 
 
 
 
 
fb48df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c607e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import csv

MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"

tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")

categories = {
    "arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
    "crime_law_and_justice": "Bosenyi, molao le bosiamisi",
    "disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
    "economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
    "education": "Thuto",
    "environment": "Tikologo",
    "health": "Boitekanelo",
    "politics": "Dipolotiki",
    "religion_and_belief": "Bodumedi le tumelo",
    "society": "Setšhaba"
}

def prediction(news):
    classifier = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
    preds = classifier(news)
    preds_dict = {categories.get(pred['label'], pred['label']): round(pred['score'], 4) for pred in preds[0]}
    return preds_dict

def file_prediction(file):
    news_list = []

    if file.name.endswith('.csv'):
        file.seek(0)
        reader = csv.reader(file.read().decode('utf-8').splitlines())
        news_list = [row[0] for row in reader if row]
    else:
        file.seek(0)
        file_content = file.read().decode('utf-8')
        news_list = file_content.splitlines()

    results = []
    for news in news_list:
        if news.strip(): 
            pred = prediction(news)
            results.append([news, pred])

    return results

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):  
            pass
        with gr.Column(scale=4, min_width=1000): 
            gr.Image("logo_transparent_small.png", elem_id="logo", show_label=False, width=500)
            gr.Markdown("""
            <h1 style='text-align: center;'>Setswana News Classification</h1>
            <p style='text-align: center;'>This space provides a classification service for news in Setswana.</p>
            """)
        with gr.Column(scale=1):  
            pass
    
    with gr.Tabs():
        with gr.Tab("Text Input"):
            gr.Markdown(f"""
            Enter Setswana news article to see the category of the news. <br>
            For this classification, the <a href='{MODEL_URL}' target='_blank'>PuoBERTa-News</a> model was used.
            """)
            inp_text = gr.Textbox(lines=10, label="Paste some Setswana news here")
            output_label = gr.Label(num_top_classes=5, label="News categories probabilities")
            translate_button = gr.Button("Classify")
            translate_button.click(prediction, inputs=inp_text, outputs=output_label)

        with gr.Tab("File Upload"):
            gr.Markdown("""
            Upload a text or CSV file with Setswana news articles. The first column in the CSV should contain the news text.
            """)
            file_input = gr.File(label="Upload text or CSV file")
            file_output = gr.Dataframe(headers=["News Text", "Category Predictions"], label="Predictions from file")
            file_button = gr.Button("Classify File")
            file_button.click(file_prediction, inputs=file_input, outputs=file_output)

    gr.Markdown("""
    <div style='text-align: center;'>
        <a href='https://github.com/dsfsi/PuoBERTa-News' target='_blank'>GitHub</a> |
        <a href='https://docs.google.com/forms/d/e/1FAIpQLSf7S36dyAUPx2egmXbFpnTBuzoRulhL5Elu-N1eoMhaO7v10w/viewform' target='_blank'>Feedback Form</a>
    </div>
    """)

    with gr.Accordion("More Information", open=False):
        
        gr.Markdown("""
        <h4 style="text-align: center;">Authors</h4>
        <div style='text-align: center;'>
            Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai
        </div>
        """)
        
        gr.Markdown("""
        <h4 style="text-align: center;">Citation</h4>
        <pre style="text-align: left; white-space: pre-wrap;">
        @inproceedings{marivate2023puoberta,
          title   = {PuoBERTa: Training and evaluation of a curated language model for Setswana},
          author  = {Vukosi Marivate and Moseli Mots'Oehli and Valencia Wagner and Richard Lastrucci and Isheanesu Dzingirai},
          year    = {2023},
          booktitle= {Artificial Intelligence Research. SACAIR 2023. Communications in Computer and Information Science},
          url= {https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17},
          keywords = {NLP},
          preprint_url = {https://arxiv.org/abs/2310.09141},
          dataset_url = {https://github.com/dsfsi/PuoBERTa},
          software_url = {https://huggingface.co/dsfsi/PuoBERTa}
        }
        </pre>
        """)
        
        gr.Markdown("""
        <h4 style="text-align: center;">DOI</h4>
        <div style='text-align: center;'>
            DOI: <a href="https://doi.org/10.1007/978-3-031-49002-6_17" target="_blank">10.1007/978-3-031-49002-6_17</a>
        </div>
        """)

demo.launch()