|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
def classify_text(line_item, classes): |
|
|
|
class_list = classes.split(',') |
|
|
|
results = classifier(line_item, class_list, multi_class=True) |
|
|
|
|
|
output = {label: round(score, 4) for label, score in zip(results['labels'], results['scores'])} |
|
|
|
return output |
|
|
|
|
|
interface = gr.Interface( |
|
classify_text, |
|
[ |
|
gr.Textbox(lines=2, placeholder="Enter Line Item Here...", label="Line Item"), |
|
gr.Textbox(placeholder="Enter Classes Here, Separated by Commas", label="Classes") |
|
], |
|
gr.Label(num_top_classes=None, label="Class Probability Scores"), |
|
title="Bad Stuff, But So Good.", |
|
description="A zero-shot classification app using facebook/bart-large-mnli model to classify text into given categories.", |
|
examples=[ |
|
["wijn glas x3 $18", "tobacco,alcohol"], |
|
["Stoofvlees $25", "tobacco,alcohol"], |
|
["Marlboro 10$", "tobacco,alcohol"], |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|