youj2005 commited on
Commit
8cb81df
·
1 Parent(s): 36b3b29

Switched qa model

Browse files
Files changed (2) hide show
  1. app.py +19 -11
  2. gradio_cached_examples/18/log.csv +4 -4
app.py CHANGED
@@ -1,25 +1,30 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
4
  import torch
5
 
6
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
 
7
  te_tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli')
8
  te_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli').to(device)
9
- qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
10
- qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
 
11
 
12
  def predict(context, intent, multi_class):
 
13
  input_text = "What is the opposite of " + intent + "?"
14
  input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
15
  opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
16
- input_text = "What object is the following describing: " + context
17
- input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
18
- object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
19
  batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is neither ' + intent + ' nor ' + opposite_output]
20
  outputs = []
21
  normal = 0
22
 
 
 
23
  for i, hypothesis in enumerate(batch):
24
  input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
25
 
@@ -33,6 +38,7 @@ def predict(context, intent, multi_class):
33
  else:
34
  probs = torch.exp(logits)
35
  outputs.append(probs)
 
36
  # calculate the stochastic vector for it being neither the positive or negative class
37
  perfect_prob = outputs[2]
38
  # -> [entailment, contradiction] for perfect
@@ -40,15 +46,18 @@ def predict(context, intent, multi_class):
40
  # -> [entailment, neutral, contradiction] for positive
41
  outputs[1] = outputs[1].flip(dims=[0])
42
 
 
 
 
43
  # combine the negative and positive class by summing by the opposite of the negative class
44
- aggregated = (outputs[0] + outputs[1])/2
 
 
45
 
46
  # multiplying vectors
47
- aggregated[1] = aggregated[1] * perfect_prob[0]
48
  aggregated[0] = aggregated[0] * perfect_prob[1]
49
  aggregated[2] = aggregated[2] * perfect_prob[1]
50
-
51
- aggregated = torch.sqrt(aggregated)
52
 
53
  # multiple true classes
54
  if (multi_class):
@@ -58,9 +67,8 @@ def predict(context, intent, multi_class):
58
  else:
59
  aggregated = aggregated.softmax(dim=0)
60
  normal = normal.softmax(dim=0)
61
- aggregated = aggregated.tolist()
62
  return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": normal[0], "neutral": normal[1], "disagree": normal[2]}
63
- examples = [["These are my absolute favorite cargos in my closet. I’m 5’7 and they’re actually long enough for me. I’m 165lbs and ordered an M & it fits nice and loose just how I wanted it. The adjustable waist band is awesome!", "long"], ["I feel strongly about politics in the US", "long"], ["The pants are long", "long"], ["The pants are slightly long", "long"]]
64
 
65
  gradio_app = gr.Interface(
66
  predict,
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ from transformers import pipeline
5
  import torch
6
 
7
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
8
+
9
  te_tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli')
10
  te_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli').to(device)
11
+ qa_pipeline = pipeline("question-answering", model='distilbert/distilbert-base-cased-distilled-squad')
12
+ qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
13
+ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto")
14
 
15
  def predict(context, intent, multi_class):
16
+ print(context, intent)
17
  input_text = "What is the opposite of " + intent + "?"
18
  input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
19
  opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
20
+ input_text = "What object/thing is being described in the entire sentence?"
21
+ object_output = qa_pipeline(question=input_text, context=context, max_answer_len=2)['answer']
 
22
  batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is neither ' + intent + ' nor ' + opposite_output]
23
  outputs = []
24
  normal = 0
25
 
26
+ print(batch)
27
+
28
  for i, hypothesis in enumerate(batch):
29
  input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
30
 
 
38
  else:
39
  probs = torch.exp(logits)
40
  outputs.append(probs)
41
+
42
  # calculate the stochastic vector for it being neither the positive or negative class
43
  perfect_prob = outputs[2]
44
  # -> [entailment, contradiction] for perfect
 
46
  # -> [entailment, neutral, contradiction] for positive
47
  outputs[1] = outputs[1].flip(dims=[0])
48
 
49
+ print(outputs)
50
+ print(perfect_prob)
51
+
52
  # combine the negative and positive class by summing by the opposite of the negative class
53
+ aggregated = (outputs[0]+outputs[1])/2
54
+
55
+ print(aggregated)
56
 
57
  # multiplying vectors
58
+ aggregated[1] = aggregated[1] + perfect_prob[0]
59
  aggregated[0] = aggregated[0] * perfect_prob[1]
60
  aggregated[2] = aggregated[2] * perfect_prob[1]
 
 
61
 
62
  # multiple true classes
63
  if (multi_class):
 
67
  else:
68
  aggregated = aggregated.softmax(dim=0)
69
  normal = normal.softmax(dim=0)
 
70
  return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": normal[0], "neutral": normal[1], "disagree": normal[2]}
71
+ examples = [["These are so warm and comfortable. I’m 5’7”, 140 lbs, size 6-8 and Medium is a great fit. They wash and dry nicely too. The jogger style is the only style I can wear in this brand - the others are way too long so I had to return.", "long"], ["I feel strongly about politics in the US", "long"], ["The pants are long", "long"], ["The pants are slightly long", "long"]]
72
 
73
  gradio_app = gr.Interface(
74
  predict,
gradio_cached_examples/18/log.csv CHANGED
@@ -1,5 +1,5 @@
1
  With Postprocessing,Without Postprocessing,flag,username,timestamp
2
- "{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.3989219665527344},{""label"":""agree"",""confidence"":0.3555052578449249},{""label"":""disagree"",""confidence"":0.24557280540466309}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.9896149635314941},{""label"":""neutral"",""confidence"":0.008607939817011356},{""label"":""disagree"",""confidence"":0.0017770910635590553}]}",,,2024-03-14 17:33:32.082554
3
- "{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.9888092875480652},{""label"":""agree"",""confidence"":0.0059028444811701775},{""label"":""disagree"",""confidence"":0.005287905689328909}]}","{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.9971976280212402},{""label"":""agree"",""confidence"":0.0019434979185461998},{""label"":""disagree"",""confidence"":0.0008588095079176128}]}",,,2024-03-14 17:33:43.894521
4
- "{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.999862790107727},{""label"":""disagree"",""confidence"":0.00007605748396599665},{""label"":""neutral"",""confidence"":0.00006114941061241552}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.9909017086029053},{""label"":""neutral"",""confidence"":0.008000608533620834},{""label"":""disagree"",""confidence"":0.0010977860074490309}]}",,,2024-03-14 17:33:56.295829
5
- "{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.5945301651954651},{""label"":""neutral"",""confidence"":0.26899591088294983},{""label"":""disagree"",""confidence"":0.1364738941192627}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.8074565529823303},{""label"":""neutral"",""confidence"":0.1722831279039383},{""label"":""disagree"",""confidence"":0.02026035077869892}]}",,,2024-03-14 17:34:11.315778
 
1
  With Postprocessing,Without Postprocessing,flag,username,timestamp
2
+ "{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.7777004837989807},{""label"":""disagree"",""confidence"":0.11701762676239014},{""label"":""agree"",""confidence"":0.10528186708688736}]}","{""label"":""disagree"",""confidences"":[{""label"":""disagree"",""confidence"":0.9934408068656921},{""label"":""neutral"",""confidence"":0.006133753340691328},{""label"":""agree"",""confidence"":0.0004254970990587026}]}",,,2024-03-18 01:16:39.407608
3
+ "{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":1.0},{""label"":""agree"",""confidence"":5.289930995508298e-31},{""label"":""disagree"",""confidence"":4.9565565658354345e-31}]}","{""label"":""neutral"",""confidences"":[{""label"":""neutral"",""confidence"":0.9979484677314758},{""label"":""agree"",""confidence"":0.0014055748470127583},{""label"":""disagree"",""confidence"":0.0006459908909164369}]}",,,2024-03-18 01:16:40.307454
4
+ "{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":1.0},{""label"":""neutral"",""confidence"":0.0},{""label"":""disagree"",""confidence"":0.0}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.9909017086029053},{""label"":""neutral"",""confidence"":0.008000608533620834},{""label"":""disagree"",""confidence"":0.0010977860074490309}]}",,,2024-03-18 01:16:41.177826
5
+ "{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.673342764377594},{""label"":""neutral"",""confidence"":0.2709066867828369},{""label"":""disagree"",""confidence"":0.0557505264878273}]}","{""label"":""agree"",""confidences"":[{""label"":""agree"",""confidence"":0.8074565529823303},{""label"":""neutral"",""confidence"":0.1722831279039383},{""label"":""disagree"",""confidence"":0.02026035077869892}]}",,,2024-03-18 01:16:41.822843