StoneSeller commited on
Commit
9c80e2d
·
verified ·
1 Parent(s): 63d3e14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -15
app.py CHANGED
@@ -31,7 +31,8 @@ for package, version in packages.items():
31
  print(f"Installing {package}...")
32
  install(f"{package}=={version}")
33
 
34
- # Import all required libraries
 
35
  import numpy as np
36
  import torch
37
  import torch.nn as nn
@@ -40,7 +41,7 @@ import torchvision.transforms as transforms
40
  from PIL import Image
41
  import gradio as gr
42
 
43
- # Define the model
44
  class ModifiedLargeNet(nn.Module):
45
  def __init__(self):
46
  super(ModifiedLargeNet, self).__init__()
@@ -59,26 +60,47 @@ class ModifiedLargeNet(nn.Module):
59
  x = self.fc2(x)
60
  return x
61
 
62
- # Load the trained model
63
- model = ModifiedLargeNet()
64
- model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
65
- model.eval()
 
 
 
 
 
 
66
 
67
- # Define image transformation pipeline
68
  transform = transforms.Compose([
69
  transforms.Resize((128, 128)),
70
  transforms.ToTensor(),
71
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
 
72
  ])
73
 
74
  def process_image(image):
75
  if image is None:
76
  return None
77
 
78
- # Convert to RGB if necessary
79
- if image.mode != 'RGB':
80
- image = image.convert('RGB')
81
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def predict(image):
84
  if image is None:
@@ -86,30 +108,37 @@ def predict(image):
86
 
87
  try:
88
  # Process the image
89
- processed_image = process_image(Image.fromarray(image))
90
  if processed_image is None:
91
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
92
 
93
  # Transform for model
94
  tensor_image = transform(processed_image).unsqueeze(0)
 
95
 
96
  # Make prediction
97
  with torch.no_grad():
98
  outputs = model(tensor_image)
 
 
99
  probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
 
100
 
101
  # Return results
102
  classes = ["Rope", "Hammer", "Other"]
103
- return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
 
 
104
 
105
  except Exception as e:
106
  print(f"Prediction error: {str(e)}")
 
107
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
108
 
109
  # Gradio interface
110
  interface = gr.Interface(
111
  fn=predict,
112
- inputs=gr.Image(type="numpy"),
113
  outputs=gr.Label(num_top_classes=3),
114
  title="Mechanical Tools Classifier",
115
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
 
31
  print(f"Installing {package}...")
32
  install(f"{package}=={version}")
33
 
34
+
35
+
36
  import numpy as np
37
  import torch
38
  import torch.nn as nn
 
41
  from PIL import Image
42
  import gradio as gr
43
 
44
+ # Define the model exactly as in training
45
  class ModifiedLargeNet(nn.Module):
46
  def __init__(self):
47
  super(ModifiedLargeNet, self).__init__()
 
60
  x = self.fc2(x)
61
  return x
62
 
63
+ # Load the trained model with error handling
64
+ try:
65
+ model = ModifiedLargeNet()
66
+ state_dict = torch.load("modified_large_net.pt", map_location=torch.device("cpu"))
67
+ model.load_state_dict(state_dict)
68
+ print("Model loaded successfully")
69
+ model.eval()
70
+ except Exception as e:
71
+ print(f"Error loading model: {str(e)}")
72
+ traceback.print_exc()
73
 
74
+ # Define image transformation pipeline to match training
75
  transform = transforms.Compose([
76
  transforms.Resize((128, 128)),
77
  transforms.ToTensor(),
78
+ # Using standard normalization as in training
79
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80
  ])
81
 
82
  def process_image(image):
83
  if image is None:
84
  return None
85
 
86
+ try:
87
+ # Convert numpy array to PIL Image
88
+ if isinstance(image, np.ndarray):
89
+ image = Image.fromarray(image)
90
+
91
+ # Convert to RGB if necessary
92
+ if image.mode != 'RGB':
93
+ image = image.convert('RGB')
94
+
95
+ # Print debug information
96
+ print(f"Processed image size: {image.size}")
97
+ print(f"Processed image mode: {image.mode}")
98
+
99
+ return image
100
+ except Exception as e:
101
+ print(f"Error in process_image: {str(e)}")
102
+ traceback.print_exc()
103
+ return None
104
 
105
  def predict(image):
106
  if image is None:
 
108
 
109
  try:
110
  # Process the image
111
+ processed_image = process_image(image)
112
  if processed_image is None:
113
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
114
 
115
  # Transform for model
116
  tensor_image = transform(processed_image).unsqueeze(0)
117
+ print(f"Input tensor shape: {tensor_image.shape}")
118
 
119
  # Make prediction
120
  with torch.no_grad():
121
  outputs = model(tensor_image)
122
+ print(f"Raw outputs: {outputs}")
123
+
124
  probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
125
+ print(f"Probabilities: {probabilities}")
126
 
127
  # Return results
128
  classes = ["Rope", "Hammer", "Other"]
129
+ results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
130
+ print(f"Final results: {results}")
131
+ return results
132
 
133
  except Exception as e:
134
  print(f"Prediction error: {str(e)}")
135
+ traceback.print_exc()
136
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
137
 
138
  # Gradio interface
139
  interface = gr.Interface(
140
  fn=predict,
141
+ inputs=gr.Image(), # Accept any image format
142
  outputs=gr.Label(num_top_classes=3),
143
  title="Mechanical Tools Classifier",
144
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",