flichote commited on
Commit
5dc8d86
1 Parent(s): bff58a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -91,18 +91,38 @@
91
  # out=grad.Textbox(lines=10, label="Summary")
92
  # grad.Interface(summarize, inputs=txt, outputs=out).launch()
93
 
94
- from transformers import pipeline
95
- import gradio as grad
96
- zero_shot_classifier = pipeline("zero-shot-classification")
97
 
98
 
99
- def classify(text,labels):
100
- classifer_labels = labels.split(",")
101
- #["software", "politics", "love", "movies", "emergency", "advertisment","sports"]
102
- response = zero_shot_classifier(text,classifer_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  return response
104
  txt=grad.Textbox(lines=1, label="English", placeholder="text to be classified")
105
- labels=grad.Textbox(lines=1, label="Labels", placeholder="comma separated labels")
106
- out=grad.Textbox(lines=1, label="Classification")
107
  grad.Interface(classify, inputs=[txt,labels], outputs=out).launch()
108
 
 
 
 
 
91
  # out=grad.Textbox(lines=10, label="Summary")
92
  # grad.Interface(summarize, inputs=txt, outputs=out).launch()
93
 
94
+ # from transformers import pipeline
95
+ # import gradio as grad
96
+ # zero_shot_classifier = pipeline("zero-shot-classification")
97
 
98
 
99
+ # def classify(text,labels):
100
+ # classifer_labels = labels.split(",")
101
+ # #["software", "politics", "love", "movies", "emergency", "advertisment","sports"]
102
+ # response = zero_shot_classifier(text,classifer_labels)
103
+ # return response
104
+ # txt=grad.Textbox(lines=1, label="English", placeholder="text to be classified")
105
+ # labels=grad.Textbox(lines=1, label="Labels", placeholder="comma separated labels")
106
+ # out=grad.Textbox(lines=1, label="Classification")
107
+ # grad.Interface(classify, inputs=[txt,labels], outputs=out).launch()
108
+
109
+ from transformers import BartForSequenceClassification, BartTokenizer
110
+ import gradio as grad
111
+ bart_tkn = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
112
+ mdl = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
113
+
114
+ def classify(text,label):
115
+ tkn_ids = bart_tkn.encode(text, label, return_tensors='pt')
116
+ tkn_lgts = mdl(tkn_ids)[0]
117
+ entail_contra_tkn_lgts = tkn_lgts[:,[0,2]]
118
+ probab = entail_contra_tkn_lgts.softmax(dim=1)
119
+ response = probab[:,1].item() * 100
120
  return response
121
  txt=grad.Textbox(lines=1, label="English", placeholder="text to be classified")
122
+ labels=grad.Textbox(lines=1, label="Label", placeholder="Input a Label")
123
+ out=grad.Textbox(lines=1, label="Probablity of label being true is")
124
  grad.Interface(classify, inputs=[txt,labels], outputs=out).launch()
125
 
126
+
127
+
128
+