Spaces:
Running
Running
Switched qa model
Browse files- app.py +19 -11
- gradio_cached_examples/18/log.csv +4 -4
app.py
CHANGED
@@ -1,25 +1,30 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
4 |
import torch
|
5 |
|
6 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
|
7 |
te_tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli')
|
8 |
te_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli').to(device)
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
def predict(context, intent, multi_class):
|
|
|
13 |
input_text = "What is the opposite of " + intent + "?"
|
14 |
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
15 |
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
|
16 |
-
input_text = "What object is
|
17 |
-
|
18 |
-
object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
|
19 |
batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is neither ' + intent + ' nor ' + opposite_output]
|
20 |
outputs = []
|
21 |
normal = 0
|
22 |
|
|
|
|
|
23 |
for i, hypothesis in enumerate(batch):
|
24 |
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
|
25 |
|
@@ -33,6 +38,7 @@ def predict(context, intent, multi_class):
|
|
33 |
else:
|
34 |
probs = torch.exp(logits)
|
35 |
outputs.append(probs)
|
|
|
36 |
# calculate the stochastic vector for it being neither the positive or negative class
|
37 |
perfect_prob = outputs[2]
|
38 |
# -> [entailment, contradiction] for perfect
|
@@ -40,15 +46,18 @@ def predict(context, intent, multi_class):
|
|
40 |
# -> [entailment, neutral, contradiction] for positive
|
41 |
outputs[1] = outputs[1].flip(dims=[0])
|
42 |
|
|
|
|
|
|
|
43 |
# combine the negative and positive class by summing by the opposite of the negative class
|
44 |
-
aggregated = (outputs[0]
|
|
|
|
|
45 |
|
46 |
# multiplying vectors
|
47 |
-
aggregated[1] = aggregated[1]
|
48 |
aggregated[0] = aggregated[0] * perfect_prob[1]
|
49 |
aggregated[2] = aggregated[2] * perfect_prob[1]
|
50 |
-
|
51 |
-
aggregated = torch.sqrt(aggregated)
|
52 |
|
53 |
# multiple true classes
|
54 |
if (multi_class):
|
@@ -58,9 +67,8 @@ def predict(context, intent, multi_class):
|
|
58 |
else:
|
59 |
aggregated = aggregated.softmax(dim=0)
|
60 |
normal = normal.softmax(dim=0)
|
61 |
-
aggregated = aggregated.tolist()
|
62 |
return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": normal[0], "neutral": normal[1], "disagree": normal[2]}
|
63 |
-
examples = [["These are
|
64 |
|
65 |
gradio_app = gr.Interface(
|
66 |
predict,
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
4 |
+
from transformers import pipeline
|
5 |
import torch
|
6 |
|
7 |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
8 |
+
|
9 |
te_tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli')
|
10 |
te_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli').to(device)
|
11 |
+
qa_pipeline = pipeline("question-answering", model='distilbert/distilbert-base-cased-distilled-squad')
|
12 |
+
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
13 |
+
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto")
|
14 |
|
15 |
def predict(context, intent, multi_class):
|
16 |
+
print(context, intent)
|
17 |
input_text = "What is the opposite of " + intent + "?"
|
18 |
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
19 |
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
|
20 |
+
input_text = "What object/thing is being described in the entire sentence?"
|
21 |
+
object_output = qa_pipeline(question=input_text, context=context, max_answer_len=2)['answer']
|
|
|
22 |
batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is neither ' + intent + ' nor ' + opposite_output]
|
23 |
outputs = []
|
24 |
normal = 0
|
25 |
|
26 |
+
print(batch)
|
27 |
+
|
28 |
for i, hypothesis in enumerate(batch):
|
29 |
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
|
30 |
|
|
|
38 |
else:
|
39 |
probs = torch.exp(logits)
|
40 |
outputs.append(probs)
|
41 |
+
|
42 |
# calculate the stochastic vector for it being neither the positive or negative class
|
43 |
perfect_prob = outputs[2]
|
44 |
# -> [entailment, contradiction] for perfect
|
|
|
46 |
# -> [entailment, neutral, contradiction] for positive
|
47 |
outputs[1] = outputs[1].flip(dims=[0])
|
48 |
|
49 |
+
print(outputs)
|
50 |
+
print(perfect_prob)
|
51 |
+
|
52 |
# combine the negative and positive class by summing by the opposite of the negative class
|
53 |
+
aggregated = (outputs[0]+outputs[1])/2
|
54 |
+
|
55 |
+
print(aggregated)
|
56 |
|
57 |
# multiplying vectors
|
58 |
+
aggregated[1] = aggregated[1] + perfect_prob[0]
|
59 |
aggregated[0] = aggregated[0] * perfect_prob[1]
|
60 |
aggregated[2] = aggregated[2] * perfect_prob[1]
|
|
|
|
|
61 |
|
62 |
# multiple true classes
|
63 |
if (multi_class):
|
|
|
67 |
else:
|
68 |
aggregated = aggregated.softmax(dim=0)
|
69 |
normal = normal.softmax(dim=0)
|
|
|
70 |
return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": normal[0], "neutral": normal[1], "disagree": normal[2]}
|
71 |
+
examples = [["These are so warm and comfortable. I’m 5’7”, 140 lbs, size 6-8 and Medium is a great fit. They wash and dry nicely too. The jogger style is the only style I can wear in this brand - the others are way too long so I had to return.", "long"], ["I feel strongly about politics in the US", "long"], ["The pants are long", "long"], ["The pants are slightly long", "long"]]
|
72 |
|
73 |
gradio_app = gr.Interface(
|
74 |
predict,
|
gradio_cached_examples/18/log.csv
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
With Postprocessing,Without Postprocessing,flag,username,timestamp
|
2 |
-
"{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.
|
3 |
-
"{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0
|
4 |
-
"{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0
|
5 |
-
"{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.
|
|
|
1 |
With Postprocessing,Without Postprocessing,flag,username,timestamp
|
2 |
+
"{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.7777004837989807},{""label"":""disagree"",""confidence"":0.11701762676239014},{""label"":""agree"",""confidence"":0.10528186708688736}]}","{""label"":""disagree"",""confidences"":[{""label"":""disagree"",""confidence"":0.9934408068656921},{""label"":""neutral"",""confidence"":0.006133753340691328},{""label"":""agree"",""confidence"":0.0004254970990587026}]}",,,2024-03-18 01:16:39.407608
|
3 |
+
"{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":1.0},{""label"":""agree"",""confidence"":5.289930995508298e-31},{""label"":""disagree"",""confidence"":4.9565565658354345e-31}]}","{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.9979484677314758},{""label"":""agree"",""confidence"":0.0014055748470127583},{""label"":""disagree"",""confidence"":0.0006459908909164369}]}",,,2024-03-18 01:16:40.307454
|
4 |
+
"{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":1.0},{""label"":""neutral"",""confidence"":0.0},{""label"":""disagree"",""confidence"":0.0}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.9909017086029053},{""label"":""neutral"",""confidence"":0.008000608533620834},{""label"":""disagree"",""confidence"":0.0010977860074490309}]}",,,2024-03-18 01:16:41.177826
|
5 |
+
"{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.673342764377594},{""label"":""neutral"",""confidence"":0.2709066867828369},{""label"":""disagree"",""confidence"":0.0557505264878273}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.8074565529823303},{""label"":""neutral"",""confidence"":0.1722831279039383},{""label"":""disagree"",""confidence"":0.02026035077869892}]}",,,2024-03-18 01:16:41.822843
|