Spaces:
Running
Running
update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,9 @@ import torchvision.transforms as transforms
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
import timm
|
11 |
|
|
|
12 |
class BaseModel(nn.Module):
|
|
|
13 |
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
14 |
with torch.no_grad():
|
15 |
logits = self(x)
|
@@ -20,6 +22,7 @@ class BaseModel(nn.Module):
|
|
20 |
|
21 |
|
22 |
class CNNModel(BaseModel):
|
|
|
23 |
def __init__(self, num_classes: int, input_size: int = 224):
|
24 |
super(CNNModel, self).__init__()
|
25 |
|
@@ -55,6 +58,21 @@ class CNNModel(BaseModel):
|
|
55 |
nn.Linear(256, num_classes)
|
56 |
)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
59 |
x = self.conv_layers(x)
|
60 |
return self.classifier(x)
|
@@ -64,6 +82,7 @@ class CNNModel(BaseModel):
|
|
64 |
|
65 |
|
66 |
class EfficientNetModel(BaseModel):
|
|
|
67 |
def __init__(
|
68 |
self,
|
69 |
num_classes: int,
|
@@ -98,6 +117,7 @@ class EfficientNetModel(BaseModel):
|
|
98 |
|
99 |
class AnimalClassifierApp:
|
100 |
def __init__(self):
|
|
|
101 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
102 |
self.labels = ["bird", "cat", "dog", "horse"]
|
103 |
|
@@ -115,12 +135,12 @@ class AnimalClassifierApp:
|
|
115 |
print("Warning: No models found in checkpoints directory!")
|
116 |
|
117 |
def load_models(self):
|
|
|
118 |
models = {}
|
119 |
|
120 |
-
# Load EfficientNet
|
121 |
try:
|
122 |
efficientnet = EfficientNetModel(num_classes=len(self.labels))
|
123 |
-
efficientnet_path = "efficientnet_best_model.pth"
|
124 |
if os.path.exists(efficientnet_path):
|
125 |
checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
|
126 |
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
@@ -131,10 +151,9 @@ class AnimalClassifierApp:
|
|
131 |
except Exception as e:
|
132 |
print(f"Error loading EfficientNet model: {str(e)}")
|
133 |
|
134 |
-
# Load CNN
|
135 |
try:
|
136 |
cnn = CNNModel(num_classes=len(self.labels))
|
137 |
-
cnn_path = "cnn_best_model.pth"
|
138 |
if os.path.exists(cnn_path):
|
139 |
checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
|
140 |
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
@@ -149,8 +168,9 @@ class AnimalClassifierApp:
|
|
149 |
|
150 |
def predict(self, image: Image.Image):
|
151 |
if not self.models:
|
152 |
-
return "No trained models found. Please train the models first."
|
153 |
|
|
|
154 |
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
155 |
|
156 |
results = {}
|
@@ -209,10 +229,12 @@ class AnimalClassifierApp:
|
|
209 |
description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
|
210 |
)
|
211 |
|
|
|
212 |
def main():
|
213 |
app = AnimalClassifierApp()
|
214 |
interface = app.create_interface()
|
215 |
interface.launch()
|
216 |
|
|
|
217 |
if __name__ == "__main__":
|
218 |
main()
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
import timm
|
11 |
|
12 |
+
|
13 |
class BaseModel(nn.Module):
|
14 |
+
|
15 |
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
16 |
with torch.no_grad():
|
17 |
logits = self(x)
|
|
|
22 |
|
23 |
|
24 |
class CNNModel(BaseModel):
|
25 |
+
|
26 |
def __init__(self, num_classes: int, input_size: int = 224):
|
27 |
super(CNNModel, self).__init__()
|
28 |
|
|
|
58 |
nn.Linear(256, num_classes)
|
59 |
)
|
60 |
|
61 |
+
self._initialize_weights()
|
62 |
+
|
63 |
+
def _initialize_weights(self):
|
64 |
+
for m in self.modules():
|
65 |
+
if isinstance(m, nn.Conv2d):
|
66 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
67 |
+
if m.bias is not None:
|
68 |
+
nn.init.constant_(m.bias, 0)
|
69 |
+
elif isinstance(m, nn.BatchNorm2d):
|
70 |
+
nn.init.constant_(m.weight, 1)
|
71 |
+
nn.init.constant_(m.bias, 0)
|
72 |
+
elif isinstance(m, nn.Linear):
|
73 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
74 |
+
nn.init.constant_(m.bias, 0)
|
75 |
+
|
76 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
x = self.conv_layers(x)
|
78 |
return self.classifier(x)
|
|
|
82 |
|
83 |
|
84 |
class EfficientNetModel(BaseModel):
|
85 |
+
|
86 |
def __init__(
|
87 |
self,
|
88 |
num_classes: int,
|
|
|
117 |
|
118 |
class AnimalClassifierApp:
|
119 |
def __init__(self):
|
120 |
+
"""Initialize the application."""
|
121 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
self.labels = ["bird", "cat", "dog", "horse"]
|
123 |
|
|
|
135 |
print("Warning: No models found in checkpoints directory!")
|
136 |
|
137 |
def load_models(self):
|
138 |
+
"""Load both trained models."""
|
139 |
models = {}
|
140 |
|
|
|
141 |
try:
|
142 |
efficientnet = EfficientNetModel(num_classes=len(self.labels))
|
143 |
+
efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
|
144 |
if os.path.exists(efficientnet_path):
|
145 |
checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
|
146 |
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
|
|
151 |
except Exception as e:
|
152 |
print(f"Error loading EfficientNet model: {str(e)}")
|
153 |
|
|
|
154 |
try:
|
155 |
cnn = CNNModel(num_classes=len(self.labels))
|
156 |
+
cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
|
157 |
if os.path.exists(cnn_path):
|
158 |
checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
|
159 |
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
|
|
168 |
|
169 |
def predict(self, image: Image.Image):
|
170 |
if not self.models:
|
171 |
+
return ["No trained models found. Please train the models first.", ""]
|
172 |
|
173 |
+
# Preprocess image
|
174 |
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
175 |
|
176 |
results = {}
|
|
|
229 |
description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
|
230 |
)
|
231 |
|
232 |
+
|
233 |
def main():
|
234 |
app = AnimalClassifierApp()
|
235 |
interface = app.create_interface()
|
236 |
interface.launch()
|
237 |
|
238 |
+
|
239 |
if __name__ == "__main__":
|
240 |
main()
|