|
"""Gradio app that showcases Scandinavian zero-shot text classification models.""" |
|
|
|
from typing import Dict, Tuple |
|
import gradio as gr |
|
from gradio.components import Dropdown, Textbox, Button, Label, Markdown |
|
from types import MethodType |
|
from gradio.layouts.column import Column |
|
from gradio.layouts.row import Row |
|
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer |
|
from luga import language as detect_language |
|
import torch |
|
import re |
|
import os |
|
|
|
|
|
def main(): |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
global classifier, model, tokenizer |
|
model_id = "alexandrainst/scandi-nli-large" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model.eval() |
|
classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer) |
|
classifier.get_inference_context = MethodType( |
|
lambda self: torch.no_grad, classifier |
|
) |
|
|
|
|
|
|
|
task_configs: Dict[str, Tuple[str, str, str, str, str, str]] = { |
|
"Sentiment classification": ( |
|
"Dette eksempel er {}.", |
|
"positivt, negativt, neutralt", |
|
"Detta exempel är {}.", |
|
"positivt, negativt, neutralt", |
|
"Dette eksemplet er {}.", |
|
"positivt, negativt, nøytralt", |
|
), |
|
"News topic classification": ( |
|
"Denne nyhedsartikel handler primært om {}.", |
|
"krig, politik, uddannelse, sundhed, økonomi, mode, sport", |
|
"Den här nyhetsartikeln handlar främst om {}.", |
|
"krig, politik, utbildning, hälsa, ekonomi, mode, sport", |
|
"Denne nyhetsartikkelen handler først og fremst om {}.", |
|
"krig, politikk, utdanning, helse, økonomi, mote, sport", |
|
), |
|
"Spam detection": ( |
|
"Denne e-mail ligner {}.", |
|
"en spam e-mail, ikke en spam e-mail", |
|
"Det här e-postmeddelandet ser {}.", |
|
"ut som ett skräppostmeddelande, inte ut som ett skräppostmeddelande", |
|
"Denne e-posten ser {}.", |
|
"ut som en spam-e-post, ikke ut som en spam-e-post", |
|
), |
|
"Product feedback detection": ( |
|
"Denne kommentar er {}.", |
|
"en anmeldelse af et produkt, ikke en anmeldelse af et produkt", |
|
"Den här kommentaren är {}.", |
|
"en recension av en produkt, inte en recension av en produkt", |
|
"Denne kommentaren er {}.", |
|
"en anmeldelse av et produkt, ikke en anmeldelse av et produkt", |
|
), |
|
"Define your own task!": ( |
|
"Dette eksempel er {}.", |
|
"", |
|
"Detta exempel är {}.", |
|
"", |
|
"Dette eksemplet er {}.", |
|
"", |
|
), |
|
} |
|
|
|
def set_task_setup(task: str) -> Tuple[str, str, str, str, str, str]: |
|
return task_configs[task] |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
Markdown("# Scandinavian Zero-shot Text Classification") |
|
Markdown(""" |
|
Classify text in Danish, Swedish or Norwegian into categories, without |
|
finetuning on any training data! |
|
|
|
Select one of the tasks from the dropdown menu on the left, and try |
|
entering some input text (in Danish, Swedish or Norwegian) in the input |
|
text box and press submit, to see the model in action! The labels are |
|
generated by putting in each candidate label into the hypothesis template, |
|
and then running the classifier on each label separately. Feel free to |
|
change the "hypothesis template" and "candidate labels" on the left as you |
|
please as well, and try to come up with your own tasks too 😊 |
|
|
|
_Also, be patient, as this demo is running on a CPU!_ |
|
""") |
|
|
|
with Row(): |
|
|
|
|
|
with Column(): |
|
|
|
|
|
dropdown = Dropdown( |
|
label="Task", |
|
choices=[ |
|
"Sentiment classification", |
|
"News topic classification", |
|
"Spam detection", |
|
"Product feedback detection", |
|
"Define your own task!", |
|
], |
|
value="Sentiment classification", |
|
) |
|
|
|
with Row(variant="compact"): |
|
da_hypothesis_template = Textbox( |
|
label="Danish hypothesis template", |
|
value="Dette eksempel er {}.", |
|
) |
|
da_candidate_labels = Textbox( |
|
label="Danish candidate labels (comma separated)", |
|
value="positivt, negativt, neutralt", |
|
) |
|
|
|
with Row(variant="compact"): |
|
sv_hypothesis_template = Textbox( |
|
label="Swedish hypothesis template", |
|
value="Detta exempel är {}.", |
|
) |
|
sv_candidate_labels = Textbox( |
|
label="Swedish candidate labels (comma separated)", |
|
value="positivt, negativt, neutralt", |
|
) |
|
|
|
with Row(variant="compact"): |
|
no_hypothesis_template = Textbox( |
|
label="Norwegian hypothesis template", |
|
value="Dette eksemplet er {}.", |
|
) |
|
no_candidate_labels = Textbox( |
|
label="Norwegian candidate labels (comma separated)", |
|
value="positivt, negativt, nøytralt", |
|
) |
|
|
|
|
|
dropdown.change( |
|
fn=set_task_setup, |
|
inputs=dropdown, |
|
outputs=[ |
|
da_hypothesis_template, |
|
da_candidate_labels, |
|
sv_hypothesis_template, |
|
sv_candidate_labels, |
|
no_hypothesis_template, |
|
no_candidate_labels, |
|
], |
|
) |
|
|
|
|
|
with Column(): |
|
|
|
|
|
input_textbox = Textbox( |
|
label="Input text", value="Jeg er helt vild med fodbolden 😊" |
|
) |
|
|
|
with Row(): |
|
clear_btn = Button(value="Clear") |
|
submit_btn = Button(value="Submit", variant="primary") |
|
|
|
|
|
clear_btn.click( |
|
fn=lambda _: "", inputs=input_textbox, outputs=input_textbox |
|
) |
|
|
|
|
|
with Column(): |
|
|
|
|
|
output_textbox = Label(label="Result") |
|
|
|
|
|
|
|
submit_btn.click( |
|
fn=classification, |
|
inputs=[ |
|
input_textbox, |
|
da_hypothesis_template, |
|
da_candidate_labels, |
|
sv_hypothesis_template, |
|
sv_candidate_labels, |
|
no_hypothesis_template, |
|
no_candidate_labels, |
|
], |
|
outputs=output_textbox, |
|
) |
|
|
|
|
|
demo.launch(width=.5, ssr_mode=False) |
|
|
|
|
|
def classification( |
|
doc: str, |
|
da_hypothesis_template: str, |
|
da_candidate_labels: str, |
|
sv_hypothesis_template: str, |
|
sv_candidate_labels: str, |
|
no_hypothesis_template: str, |
|
no_candidate_labels: str, |
|
) -> Dict[str, float]: |
|
"""Classify text into categories. |
|
|
|
Args: |
|
doc (str): |
|
Text to classify. |
|
da_hypothesis_template (str): |
|
Template for the hypothesis to be used for Danish classification. |
|
da_candidate_labels (str): |
|
Comma-separated list of candidate labels for Danish classification. |
|
sv_hypothesis_template (str): |
|
Template for the hypothesis to be used for Swedish classification. |
|
sv_candidate_labels (str): |
|
Comma-separated list of candidate labels for Swedish classification. |
|
no_hypothesis_template (str): |
|
Template for the hypothesis to be used for Norwegian classification. |
|
no_candidate_labels (str): |
|
Comma-separated list of candidate labels for Norwegian classification. |
|
|
|
Returns: |
|
dict of str to float: |
|
The predicted label and the confidence score. |
|
""" |
|
|
|
language = detect_language(doc.replace('\n', ' ')).name |
|
|
|
|
|
if language == "sv": |
|
hypothesis_template = sv_hypothesis_template |
|
candidate_labels = re.split(r', *', sv_candidate_labels) |
|
elif language == "no": |
|
hypothesis_template = no_hypothesis_template |
|
candidate_labels = re.split(r', *', no_candidate_labels) |
|
else: |
|
hypothesis_template = da_hypothesis_template |
|
candidate_labels = re.split(r', *', da_candidate_labels) |
|
|
|
|
|
result = classifier( |
|
doc, |
|
candidate_labels=candidate_labels, |
|
hypothesis_template=hypothesis_template, |
|
) |
|
|
|
print(result) |
|
|
|
|
|
return {lbl: score for lbl, score in zip(result["labels"], result["scores"])} |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|