enesmanan commited on
Commit
d13bccc
·
verified ·
1 Parent(s): fe1195b

update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -9,7 +9,9 @@ import torchvision.transforms as transforms
9
  import matplotlib.pyplot as plt
10
  import timm
11
 
 
12
  class BaseModel(nn.Module):
 
13
  def predict(self, x: torch.Tensor) -> torch.Tensor:
14
  with torch.no_grad():
15
  logits = self(x)
@@ -20,6 +22,7 @@ class BaseModel(nn.Module):
20
 
21
 
22
  class CNNModel(BaseModel):
 
23
  def __init__(self, num_classes: int, input_size: int = 224):
24
  super(CNNModel, self).__init__()
25
 
@@ -55,6 +58,21 @@ class CNNModel(BaseModel):
55
  nn.Linear(256, num_classes)
56
  )
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def forward(self, x: torch.Tensor) -> torch.Tensor:
59
  x = self.conv_layers(x)
60
  return self.classifier(x)
@@ -64,6 +82,7 @@ class CNNModel(BaseModel):
64
 
65
 
66
  class EfficientNetModel(BaseModel):
 
67
  def __init__(
68
  self,
69
  num_classes: int,
@@ -98,6 +117,7 @@ class EfficientNetModel(BaseModel):
98
 
99
  class AnimalClassifierApp:
100
  def __init__(self):
 
101
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
  self.labels = ["bird", "cat", "dog", "horse"]
103
 
@@ -115,12 +135,12 @@ class AnimalClassifierApp:
115
  print("Warning: No models found in checkpoints directory!")
116
 
117
  def load_models(self):
 
118
  models = {}
119
 
120
- # Load EfficientNet
121
  try:
122
  efficientnet = EfficientNetModel(num_classes=len(self.labels))
123
- efficientnet_path = "efficientnet_best_model.pth"
124
  if os.path.exists(efficientnet_path):
125
  checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
126
  state_dict = checkpoint.get('model_state_dict', checkpoint)
@@ -131,10 +151,9 @@ class AnimalClassifierApp:
131
  except Exception as e:
132
  print(f"Error loading EfficientNet model: {str(e)}")
133
 
134
- # Load CNN
135
  try:
136
  cnn = CNNModel(num_classes=len(self.labels))
137
- cnn_path = "cnn_best_model.pth"
138
  if os.path.exists(cnn_path):
139
  checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
140
  state_dict = checkpoint.get('model_state_dict', checkpoint)
@@ -149,8 +168,9 @@ class AnimalClassifierApp:
149
 
150
  def predict(self, image: Image.Image):
151
  if not self.models:
152
- return "No trained models found. Please train the models first."
153
 
 
154
  img_tensor = self.transform(image).unsqueeze(0).to(self.device)
155
 
156
  results = {}
@@ -209,10 +229,12 @@ class AnimalClassifierApp:
209
  description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
210
  )
211
 
 
212
  def main():
213
  app = AnimalClassifierApp()
214
  interface = app.create_interface()
215
  interface.launch()
216
 
 
217
  if __name__ == "__main__":
218
  main()
 
9
  import matplotlib.pyplot as plt
10
  import timm
11
 
12
+
13
  class BaseModel(nn.Module):
14
+
15
  def predict(self, x: torch.Tensor) -> torch.Tensor:
16
  with torch.no_grad():
17
  logits = self(x)
 
22
 
23
 
24
  class CNNModel(BaseModel):
25
+
26
  def __init__(self, num_classes: int, input_size: int = 224):
27
  super(CNNModel, self).__init__()
28
 
 
58
  nn.Linear(256, num_classes)
59
  )
60
 
61
+ self._initialize_weights()
62
+
63
+ def _initialize_weights(self):
64
+ for m in self.modules():
65
+ if isinstance(m, nn.Conv2d):
66
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
67
+ if m.bias is not None:
68
+ nn.init.constant_(m.bias, 0)
69
+ elif isinstance(m, nn.BatchNorm2d):
70
+ nn.init.constant_(m.weight, 1)
71
+ nn.init.constant_(m.bias, 0)
72
+ elif isinstance(m, nn.Linear):
73
+ nn.init.normal_(m.weight, 0, 0.01)
74
+ nn.init.constant_(m.bias, 0)
75
+
76
  def forward(self, x: torch.Tensor) -> torch.Tensor:
77
  x = self.conv_layers(x)
78
  return self.classifier(x)
 
82
 
83
 
84
  class EfficientNetModel(BaseModel):
85
+
86
  def __init__(
87
  self,
88
  num_classes: int,
 
117
 
118
  class AnimalClassifierApp:
119
  def __init__(self):
120
+ """Initialize the application."""
121
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
  self.labels = ["bird", "cat", "dog", "horse"]
123
 
 
135
  print("Warning: No models found in checkpoints directory!")
136
 
137
  def load_models(self):
138
+ """Load both trained models."""
139
  models = {}
140
 
 
141
  try:
142
  efficientnet = EfficientNetModel(num_classes=len(self.labels))
143
+ efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
144
  if os.path.exists(efficientnet_path):
145
  checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
146
  state_dict = checkpoint.get('model_state_dict', checkpoint)
 
151
  except Exception as e:
152
  print(f"Error loading EfficientNet model: {str(e)}")
153
 
 
154
  try:
155
  cnn = CNNModel(num_classes=len(self.labels))
156
+ cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
157
  if os.path.exists(cnn_path):
158
  checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
159
  state_dict = checkpoint.get('model_state_dict', checkpoint)
 
168
 
169
  def predict(self, image: Image.Image):
170
  if not self.models:
171
+ return ["No trained models found. Please train the models first.", ""]
172
 
173
+ # Preprocess image
174
  img_tensor = self.transform(image).unsqueeze(0).to(self.device)
175
 
176
  results = {}
 
229
  description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
230
  )
231
 
232
+
233
  def main():
234
  app = AnimalClassifierApp()
235
  interface = app.create_interface()
236
  interface.launch()
237
 
238
+
239
  if __name__ == "__main__":
240
  main()