lombardata commited on
Commit
2ce87ee
·
verified ·
1 Parent(s): 6cbf3f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -66,12 +66,13 @@ def predict(image, slider_threshold=0.5, fixed_thresholds=None):
66
  # Create a dictionary of label scores based on the slider threshold
67
  slider_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > slider_threshold}
68
 
69
- # If fixed thresholds are provided, create a dictionary of label scores based on the fixed thresholds
70
- fixed_threshold_results = None
71
  if fixed_thresholds is not None:
72
- fixed_threshold_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > fixed_thresholds[id2label[str(i)]]}
 
73
 
74
- return slider_results, fixed_threshold_results
75
 
76
  def predict_wrapper(image, slider_threshold=0.5):
77
  # Download thresholds from the model repository
@@ -94,7 +95,7 @@ iface = gr.Interface(
94
  inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")],
95
  outputs=[
96
  gr.components.Label(label="Slider Threshold Predictions"),
97
- gr.components.Label(label="Fixed Thresholds Predictions")
98
  ],
99
  title=title,
100
  description=description,
 
66
  # Create a dictionary of label scores based on the slider threshold
67
  slider_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > slider_threshold}
68
 
69
+ # If fixed thresholds are provided, format the labels into a string
70
+ fixed_threshold_labels_str = None
71
  if fixed_thresholds is not None:
72
+ fixed_threshold_labels = [id2label[str(i)] for i, prob in enumerate(probabilities) if prob > fixed_thresholds[id2label[str(i)]]]
73
+ fixed_threshold_labels_str = ", ".join(fixed_threshold_labels)
74
 
75
+ return slider_results, fixed_threshold_labels_str
76
 
77
  def predict_wrapper(image, slider_threshold=0.5):
78
  # Download thresholds from the model repository
 
95
  inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")],
96
  outputs=[
97
  gr.components.Label(label="Slider Threshold Predictions"),
98
+ gr.components.Textbox(label="Fixed Threshold Labels")
99
  ],
100
  title=title,
101
  description=description,