StoneSeller commited on
Commit
d06517d
·
verified ·
1 Parent(s): 757e51c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import subprocess
2
  import sys
 
3
 
4
  # Function to install or reinstall specific packages
5
  def install(package):
@@ -75,25 +76,38 @@ transform = transforms.Compose([
75
 
76
  # Prediction function
77
  def predict(image):
78
- # Verify input image is a PIL image
 
 
 
79
  if not isinstance(image, Image.Image):
80
- raise ValueError("Input must be a PIL Image.")
 
 
 
81
 
82
  # Transform and predict
83
- image = transform(image).unsqueeze(0) # Add batch dimension
84
- with torch.no_grad():
85
- outputs = model(image)
86
- probabilities = torch.softmax(outputs, dim=1).numpy()[0]
87
- classes = ["Rope", "Hammer", "Other"]
88
- return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
 
 
 
89
 
90
  # Gradio interface
91
  interface = gr.Interface(
92
  fn=predict,
93
- inputs=gr.Image(type="pil"), # Ensure input is a PIL image
94
- outputs=gr.Label(num_top_classes=3), # Display top 3 class probabilities
95
  title="Mechanical Tools Classifier",
96
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
 
 
 
 
97
  )
98
 
99
  # Launch the interface
 
1
  import subprocess
2
  import sys
3
+ import os
4
 
5
  # Function to install or reinstall specific packages
6
  def install(package):
 
76
 
77
  # Prediction function
78
  def predict(image):
79
+ if image is None:
80
+ raise ValueError("Please provide an image")
81
+
82
+ # Convert to PIL Image if necessary
83
  if not isinstance(image, Image.Image):
84
+ try:
85
+ image = Image.fromarray(image)
86
+ except Exception as e:
87
+ raise ValueError(f"Failed to convert input to PIL Image: {str(e)}")
88
 
89
  # Transform and predict
90
+ try:
91
+ image = transform(image).unsqueeze(0) # Add batch dimension
92
+ with torch.no_grad():
93
+ outputs = model(image)
94
+ probabilities = torch.softmax(outputs, dim=1).numpy()[0]
95
+ classes = ["Rope", "Hammer", "Other"]
96
+ return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
97
+ except Exception as e:
98
+ raise ValueError(f"Error during prediction: {str(e)}")
99
 
100
  # Gradio interface
101
  interface = gr.Interface(
102
  fn=predict,
103
+ inputs=gr.Image(), # Remove type="pil" constraint
104
+ outputs=gr.Label(num_top_classes=3),
105
  title="Mechanical Tools Classifier",
106
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
107
+ examples=[
108
+ ["example_rope.jpg"],
109
+ ["example_hammer.jpg"],
110
+ ] if os.path.exists("example_rope.jpg") else None # Optional examples
111
  )
112
 
113
  # Launch the interface