nielsr HF staff commited on
Commit
6f130b0
1 Parent(s): edac551

Output scores

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -20,14 +20,17 @@ def predict(image1, image2, text):
20
  outputs = model(input_ids=encoding_1.input_ids, pixel_values=encoding_1.pixel_values, pixel_values_2=encoding_2.pixel_values)
21
 
22
  logits = outputs.logits
23
- idx = logits.argmax(-1).item()
24
- predicted_answer = model.config.id2label[idx]
 
 
 
25
 
26
- return predicted_answer
27
 
28
  images = [gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")]
29
  text = gr.inputs.Textbox(lines=2, label="Sentence")
30
- answer = gr.outputs.Textbox(label="Predicted answer")
31
 
32
  example_sentence_1 = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
33
  example_sentence_2 = "One image shows exactly two brown acorns in back-to-back caps on green foliage."
@@ -39,7 +42,7 @@ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2102.033
39
 
40
  interface = gr.Interface(fn=predict,
41
  inputs=images + [text],
42
- outputs=answer,
43
  examples=examples,
44
  title=title,
45
  description=description,
 
20
  outputs = model(input_ids=encoding_1.input_ids, pixel_values=encoding_1.pixel_values, pixel_values_2=encoding_2.pixel_values)
21
 
22
  logits = outputs.logits
23
+ probs = torch.nn.functional.softmax(logits, dim=1)
24
+
25
+ output = dict()
26
+ for label, id in model.config.label2id.items():
27
+ output[label] = probs[:,id].item()
28
 
29
+ return output
30
 
31
  images = [gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")]
32
  text = gr.inputs.Textbox(lines=2, label="Sentence")
33
+ label = gr.outputs.Label(num_top_classes=2)
34
 
35
  example_sentence_1 = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
36
  example_sentence_2 = "One image shows exactly two brown acorns in back-to-back caps on green foliage."
 
42
 
43
  interface = gr.Interface(fn=predict,
44
  inputs=images + [text],
45
+ outputs=label,
46
  examples=examples,
47
  title=title,
48
  description=description,