Commit
·
2182d48
1
Parent(s):
eb03e62
fix
Browse files
app.py
CHANGED
@@ -127,22 +127,19 @@ def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]:
|
|
127 |
return "Error: Failed to preprocess image", {}
|
128 |
|
129 |
with torch.no_grad():
|
130 |
-
# Move input to same device as model
|
131 |
input_tensor = input_tensor.to(DEVICE)
|
132 |
output = model(input_tensor)
|
133 |
-
# Apply softmax to get probabilities
|
134 |
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
135 |
|
136 |
# Get predictions and confidences
|
137 |
top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
|
138 |
|
139 |
-
#
|
140 |
confidences = {
|
141 |
-
CLASS_NAMES[idx.item()]: float(
|
142 |
for prob, idx in zip(top_5_probs, top_5_indices)
|
143 |
}
|
144 |
|
145 |
-
# Get top prediction
|
146 |
predicted_class = CLASS_NAMES[top_5_indices[0].item()]
|
147 |
|
148 |
return predicted_class, confidences
|
@@ -166,23 +163,40 @@ def get_example_list() -> list:
|
|
166 |
|
167 |
# Create Gradio interface with error handling
|
168 |
try:
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
gr.
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
except Exception as e:
|
187 |
logger.error(f"Error creating Gradio interface: {str(e)}")
|
188 |
raise
|
|
|
127 |
return "Error: Failed to preprocess image", {}
|
128 |
|
129 |
with torch.no_grad():
|
|
|
130 |
input_tensor = input_tensor.to(DEVICE)
|
131 |
output = model(input_tensor)
|
|
|
132 |
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
133 |
|
134 |
# Get predictions and confidences
|
135 |
top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
|
136 |
|
137 |
+
# Format confidences with exactly 2 decimal places
|
138 |
confidences = {
|
139 |
+
CLASS_NAMES[idx.item()]: "{:.2f}".format(float(prob.item() * 100))
|
140 |
for prob, idx in zip(top_5_probs, top_5_indices)
|
141 |
}
|
142 |
|
|
|
143 |
predicted_class = CLASS_NAMES[top_5_indices[0].item()]
|
144 |
|
145 |
return predicted_class, confidences
|
|
|
163 |
|
164 |
# Create Gradio interface with error handling
|
165 |
try:
|
166 |
+
with gr.Blocks(theme=gr.themes.Base()) as iface:
|
167 |
+
gr.Markdown("# Image Classification with ResNet50")
|
168 |
+
gr.Markdown("Upload an image to classify. The model will predict the class and show top 5 confidence scores.")
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column(scale=1):
|
172 |
+
input_image = gr.Image(type="numpy", label="Upload Image")
|
173 |
+
predict_btn = gr.Button("Predict")
|
174 |
+
|
175 |
+
with gr.Column(scale=1):
|
176 |
+
output_label = gr.Label(label="Predicted Class", num_top_classes=1)
|
177 |
+
confidence_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)
|
178 |
+
|
179 |
+
# Add examples
|
180 |
+
gr.Examples(
|
181 |
+
examples=get_example_list(),
|
182 |
+
inputs=input_image,
|
183 |
+
outputs=[output_label, confidence_label],
|
184 |
+
fn=predict,
|
185 |
+
cache_examples=True
|
186 |
+
)
|
187 |
+
|
188 |
+
# Set up prediction event
|
189 |
+
predict_btn.click(
|
190 |
+
fn=predict,
|
191 |
+
inputs=input_image,
|
192 |
+
outputs=[output_label, confidence_label]
|
193 |
+
)
|
194 |
+
input_image.change(
|
195 |
+
fn=predict,
|
196 |
+
inputs=input_image,
|
197 |
+
outputs=[output_label, confidence_label]
|
198 |
+
)
|
199 |
+
|
200 |
except Exception as e:
|
201 |
logger.error(f"Error creating Gradio interface: {str(e)}")
|
202 |
raise
|