AravindKumarRajendran commited on
Commit
2182d48
·
1 Parent(s): eb03e62
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -127,22 +127,19 @@ def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]:
127
  return "Error: Failed to preprocess image", {}
128
 
129
  with torch.no_grad():
130
- # Move input to same device as model
131
  input_tensor = input_tensor.to(DEVICE)
132
  output = model(input_tensor)
133
- # Apply softmax to get probabilities
134
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
135
 
136
  # Get predictions and confidences
137
  top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
138
 
139
- # Convert to percentages and round to 2 decimal places
140
  confidences = {
141
- CLASS_NAMES[idx.item()]: float(round(prob.item() * 100, 2))
142
  for prob, idx in zip(top_5_probs, top_5_indices)
143
  }
144
 
145
- # Get top prediction
146
  predicted_class = CLASS_NAMES[top_5_indices[0].item()]
147
 
148
  return predicted_class, confidences
@@ -166,23 +163,40 @@ def get_example_list() -> list:
166
 
167
  # Create Gradio interface with error handling
168
  try:
169
- iface = gr.Interface(
170
- fn=predict,
171
- inputs=gr.Image(type="numpy", label="Upload Image"),
172
- outputs=[
173
- gr.Label(label="Predicted Class", num_top_classes=1),
174
- gr.Label(label="Top 5 Predictions", num_top_classes=5)
175
- ],
176
- title="Image Classification with ResNet50",
177
- description=(
178
- "Upload an image to classify:\n"
179
- "The model will predict the class and show top 5 confidence scores."
180
- ),
181
- examples=get_example_list(),
182
- cache_examples=True,
183
- theme=gr.themes.Base()
184
- )
185
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  except Exception as e:
187
  logger.error(f"Error creating Gradio interface: {str(e)}")
188
  raise
 
127
  return "Error: Failed to preprocess image", {}
128
 
129
  with torch.no_grad():
 
130
  input_tensor = input_tensor.to(DEVICE)
131
  output = model(input_tensor)
 
132
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
133
 
134
  # Get predictions and confidences
135
  top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
136
 
137
+ # Format confidences with exactly 2 decimal places
138
  confidences = {
139
+ CLASS_NAMES[idx.item()]: "{:.2f}".format(float(prob.item() * 100))
140
  for prob, idx in zip(top_5_probs, top_5_indices)
141
  }
142
 
 
143
  predicted_class = CLASS_NAMES[top_5_indices[0].item()]
144
 
145
  return predicted_class, confidences
 
163
 
164
  # Create Gradio interface with error handling
165
  try:
166
+ with gr.Blocks(theme=gr.themes.Base()) as iface:
167
+ gr.Markdown("# Image Classification with ResNet50")
168
+ gr.Markdown("Upload an image to classify. The model will predict the class and show top 5 confidence scores.")
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ input_image = gr.Image(type="numpy", label="Upload Image")
173
+ predict_btn = gr.Button("Predict")
174
+
175
+ with gr.Column(scale=1):
176
+ output_label = gr.Label(label="Predicted Class", num_top_classes=1)
177
+ confidence_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)
178
+
179
+ # Add examples
180
+ gr.Examples(
181
+ examples=get_example_list(),
182
+ inputs=input_image,
183
+ outputs=[output_label, confidence_label],
184
+ fn=predict,
185
+ cache_examples=True
186
+ )
187
+
188
+ # Set up prediction event
189
+ predict_btn.click(
190
+ fn=predict,
191
+ inputs=input_image,
192
+ outputs=[output_label, confidence_label]
193
+ )
194
+ input_image.change(
195
+ fn=predict,
196
+ inputs=input_image,
197
+ outputs=[output_label, confidence_label]
198
+ )
199
+
200
  except Exception as e:
201
  logger.error(f"Error creating Gradio interface: {str(e)}")
202
  raise