StoneSeller commited on
Commit
d16040b
·
verified ·
1 Parent(s): 7545949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -25
app.py CHANGED
@@ -71,10 +71,11 @@ except Exception as e:
71
  print(f"Error loading model: {str(e)}")
72
  traceback.print_exc()
73
 
 
74
  transform = transforms.Compose([
75
  transforms.Resize((128, 128)),
76
- transforms.Lambda(lambda x: x.convert('RGB')),
77
- transforms.ToTensor(),
78
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
79
  ])
80
 
@@ -83,17 +84,14 @@ def process_image(image):
83
  return None
84
 
85
  try:
86
- # numpy array PIL Image로 변환
87
  if isinstance(image, np.ndarray):
88
  image = Image.fromarray(image.astype('uint8'))
89
 
90
- # 이미지가 RGB 아니면 변환
91
  if image.mode != 'RGB':
92
  image = image.convert('RGB')
93
 
94
- # 이미지 크기 조정
95
- image = image.resize((128, 128), Image.Resampling.LANCZOS)
96
-
97
  print(f"Processed image size: {image.size}")
98
  print(f"Processed image mode: {image.mode}")
99
 
@@ -108,32 +106,21 @@ def predict(image):
108
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
109
 
110
  try:
111
- # 이미지 처리
112
  processed_image = process_image(image)
113
  if processed_image is None:
114
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
115
 
116
- # PIL Image를 텐서로 변환
117
  try:
118
- # PIL Image를 numpy array로 변환
119
- img_array = np.array(processed_image)
120
- # numpy array를 torch tensor로 변환
121
- tensor_image = torch.from_numpy(img_array.transpose((2, 0, 1))).float() / 255.0
122
- # 정규화
123
- tensor_image = transforms.Normalize(
124
- mean=[0.485, 0.456, 0.406],
125
- std=[0.229, 0.224, 0.225]
126
- )(tensor_image)
127
- # 배치 차원 추가
128
- tensor_image = tensor_image.unsqueeze(0)
129
-
130
  print(f"Input tensor shape: {tensor_image.shape}")
131
  except Exception as e:
132
  print(f"Error in tensor conversion: {str(e)}")
133
  traceback.print_exc()
134
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
135
 
136
- # 예측 수행
137
  with torch.no_grad():
138
  outputs = model(tensor_image)
139
  print(f"Raw outputs: {outputs}")
@@ -141,7 +128,7 @@ def predict(image):
141
  probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
142
  print(f"Probabilities: {probabilities}")
143
 
144
- # 결과 반환
145
  classes = ["Rope", "Hammer", "Other"]
146
  results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
147
  print(f"Final results: {results}")
@@ -152,7 +139,7 @@ def predict(image):
152
  traceback.print_exc()
153
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
154
 
155
- # Gradio 인터페이스
156
  interface = gr.Interface(
157
  fn=predict,
158
  inputs=gr.Image(),
@@ -161,6 +148,6 @@ interface = gr.Interface(
161
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
162
  )
163
 
164
- # 인터페이스 실행
165
  if __name__ == "__main__":
166
  interface.launch()
 
71
  print(f"Error loading model: {str(e)}")
72
  traceback.print_exc()
73
 
74
+ # Define image transformation pipeline
75
  transform = transforms.Compose([
76
  transforms.Resize((128, 128)),
77
+ transforms.PILToTensor(), # Changed from ToTensor()
78
+ transforms.ConvertImageDtype(torch.float32),
79
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80
  ])
81
 
 
84
  return None
85
 
86
  try:
87
+ # Convert numpy array to PIL Image
88
  if isinstance(image, np.ndarray):
89
  image = Image.fromarray(image.astype('uint8'))
90
 
91
+ # Convert to RGB if necessary
92
  if image.mode != 'RGB':
93
  image = image.convert('RGB')
94
 
 
 
 
95
  print(f"Processed image size: {image.size}")
96
  print(f"Processed image mode: {image.mode}")
97
 
 
106
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
107
 
108
  try:
109
+ # Process the image
110
  processed_image = process_image(image)
111
  if processed_image is None:
112
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
113
 
114
+ # Transform image to tensor using torchvision transforms
115
  try:
116
+ tensor_image = transform(processed_image).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
117
  print(f"Input tensor shape: {tensor_image.shape}")
118
  except Exception as e:
119
  print(f"Error in tensor conversion: {str(e)}")
120
  traceback.print_exc()
121
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
122
 
123
+ # Make prediction
124
  with torch.no_grad():
125
  outputs = model(tensor_image)
126
  print(f"Raw outputs: {outputs}")
 
128
  probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
129
  print(f"Probabilities: {probabilities}")
130
 
131
+ # Return results
132
  classes = ["Rope", "Hammer", "Other"]
133
  results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
134
  print(f"Final results: {results}")
 
139
  traceback.print_exc()
140
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
141
 
142
+ # Gradio interface
143
  interface = gr.Interface(
144
  fn=predict,
145
  inputs=gr.Image(),
 
148
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
149
  )
150
 
151
+ # Launch the interface
152
  if __name__ == "__main__":
153
  interface.launch()