enesmanan commited on
Commit
fe1195b
·
verified ·
1 Parent(s): 7c9319c

update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -151
app.py CHANGED
@@ -1,152 +1,218 @@
1
- import os
2
- import gradio as gr
3
- import torch
4
- import torch.nn.functional as F
5
- import numpy as np
6
- from PIL import Image
7
- import torchvision.transforms as transforms
8
- import matplotlib.pyplot as plt
9
-
10
- from models.model import EfficientNetModel, CNNModel
11
-
12
- class AnimalClassifierApp:
13
- def __init__(self):
14
- """Initialize the application."""
15
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- self.labels = ["bird", "cat", "dog", "horse"]
17
-
18
- # Image preprocessing
19
- self.transform = transforms.Compose([
20
- transforms.Resize((224, 224)),
21
- transforms.ToTensor(),
22
- transforms.Normalize(
23
- mean=[0.485, 0.456, 0.406],
24
- std=[0.229, 0.224, 0.225]
25
- )
26
- ])
27
-
28
- # Load models
29
- self.models = self.load_models()
30
- if not self.models:
31
- print("Warning: No models found in checkpoints directory!")
32
-
33
- def load_models(self):
34
- """Load both trained models."""
35
- models = {}
36
-
37
- # Load EfficientNet
38
- try:
39
- efficientnet = EfficientNetModel(num_classes=len(self.labels))
40
- efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
41
- if os.path.exists(efficientnet_path):
42
- checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
43
- state_dict = checkpoint.get('model_state_dict', checkpoint)
44
- efficientnet.load_state_dict(state_dict, strict=False)
45
- efficientnet.eval()
46
- models['EfficientNet'] = efficientnet
47
- print("Successfully loaded EfficientNet model")
48
- except Exception as e:
49
- print(f"Error loading EfficientNet model: {str(e)}")
50
-
51
- # Load CNN
52
- try:
53
- cnn = CNNModel(num_classes=len(self.labels))
54
- cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
55
- if os.path.exists(cnn_path):
56
- checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
57
- state_dict = checkpoint.get('model_state_dict', checkpoint)
58
- cnn.load_state_dict(state_dict, strict=False)
59
- cnn.eval()
60
- models['CNN'] = cnn
61
- print("Successfully loaded CNN model")
62
- except Exception as e:
63
- print(f"Error loading CNN model: {str(e)}")
64
-
65
- return models
66
-
67
- def predict(self, image: Image.Image):
68
- """Make predictions with both models and create comparison visualizations."""
69
- if not self.models:
70
- return "No trained models found. Please train the models first."
71
-
72
- # Preprocess image
73
- img_tensor = self.transform(image).unsqueeze(0).to(self.device)
74
-
75
- # Get predictions from both models
76
- results = {}
77
- probabilities = {}
78
- for model_name, model in self.models.items():
79
- with torch.no_grad():
80
- output = model(img_tensor)
81
- probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
82
- probabilities[model_name] = probs
83
-
84
- # Get top prediction
85
- pred_idx = np.argmax(probs)
86
- pred_label = self.labels[pred_idx]
87
- pred_prob = probs[pred_idx]
88
- results[model_name] = (pred_label, pred_prob)
89
-
90
- # Create comparison plot
91
- fig = plt.figure(figsize=(12, 5))
92
-
93
- # Plot for EfficientNet
94
- if 'EfficientNet' in probabilities:
95
- plt.subplot(1, 2, 1)
96
- plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue')
97
- plt.title('EfficientNet Predictions')
98
- plt.ylim(0, 1)
99
- plt.xticks(rotation=45)
100
- plt.ylabel('Probability')
101
-
102
- # Plot for CNN
103
- if 'CNN' in probabilities:
104
- plt.subplot(1, 2, 2)
105
- plt.bar(self.labels, probabilities['CNN'], color='lightcoral')
106
- plt.title('CNN Predictions')
107
- plt.ylim(0, 1)
108
- plt.xticks(rotation=45)
109
- plt.ylabel('Probability')
110
-
111
- plt.tight_layout()
112
-
113
- # Create results text
114
- text_results = "Model Predictions:\n\n"
115
- for model_name, (label, prob) in results.items():
116
- text_results += f"{model_name}:\n"
117
- text_results += f"Top prediction: {label} ({prob:.2%})\n"
118
- text_results += "All probabilities:\n"
119
- for label, prob in zip(self.labels, probabilities[model_name]):
120
- text_results += f" {label}: {prob:.2%}\n"
121
- text_results += "\n"
122
-
123
- return [
124
- fig, # Probability plots
125
- text_results # Detailed text results
126
- ]
127
-
128
- def create_interface(self):
129
- """Create Gradio interface."""
130
- return gr.Interface(
131
- fn=self.predict,
132
- inputs=gr.Image(type="pil"),
133
- outputs=[
134
- gr.Plot(label="Prediction Probabilities"),
135
- gr.Textbox(label="Detailed Results", lines=10)
136
- ],
137
- title="Animal Classifier - Model Comparison",
138
- description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
139
- )
140
-
141
- def main():
142
- """Run the web application."""
143
- app = AnimalClassifierApp()
144
- interface = app.create_interface()
145
- interface.launch(
146
- server_name="0.0.0.0",
147
- server_port=7860,
148
- share=True
149
- )
150
-
151
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  main()
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+ import matplotlib.pyplot as plt
10
+ import timm
11
+
12
+ class BaseModel(nn.Module):
13
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
14
+ with torch.no_grad():
15
+ logits = self(x)
16
+ return F.softmax(logits, dim=1)
17
+
18
+ def get_num_classes(self) -> int:
19
+ raise NotImplementedError
20
+
21
+
22
+ class CNNModel(BaseModel):
23
+ def __init__(self, num_classes: int, input_size: int = 224):
24
+ super(CNNModel, self).__init__()
25
+
26
+ self.conv_layers = nn.Sequential(
27
+ # First block: 32 filters
28
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
29
+ nn.BatchNorm2d(32),
30
+ nn.ReLU(),
31
+ nn.MaxPool2d(2),
32
+
33
+ # Second block: 64 filters
34
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
35
+ nn.BatchNorm2d(64),
36
+ nn.ReLU(),
37
+ nn.MaxPool2d(2),
38
+
39
+ # Third block: 128 filters
40
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(128),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d(2),
44
+
45
+ # Global Average Pooling
46
+ nn.AdaptiveAvgPool2d(1)
47
+ )
48
+
49
+ self.classifier = nn.Sequential(
50
+ nn.Flatten(),
51
+ nn.Dropout(0.5),
52
+ nn.Linear(128, 256),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.3),
55
+ nn.Linear(256, num_classes)
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.conv_layers(x)
60
+ return self.classifier(x)
61
+
62
+ def get_num_classes(self) -> int:
63
+ return self.classifier[-1].out_features
64
+
65
+
66
+ class EfficientNetModel(BaseModel):
67
+ def __init__(
68
+ self,
69
+ num_classes: int,
70
+ model_name: str = "efficientnet_b0",
71
+ pretrained: bool = True
72
+ ):
73
+ super(EfficientNetModel, self).__init__()
74
+
75
+ self.base_model = timm.create_model(
76
+ model_name,
77
+ pretrained=pretrained,
78
+ num_classes=0
79
+ )
80
+
81
+ with torch.no_grad():
82
+ dummy_input = torch.randn(1, 3, 224, 224)
83
+ features = self.base_model(dummy_input)
84
+ feature_dim = features.shape[1]
85
+
86
+ self.classifier = nn.Sequential(
87
+ nn.Dropout(0.2),
88
+ nn.Linear(feature_dim, num_classes)
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ features = self.base_model(x)
93
+ return self.classifier(features)
94
+
95
+ def get_num_classes(self) -> int:
96
+ return self.classifier[-1].out_features
97
+
98
+
99
+ class AnimalClassifierApp:
100
+ def __init__(self):
101
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ self.labels = ["bird", "cat", "dog", "horse"]
103
+
104
+ self.transform = transforms.Compose([
105
+ transforms.Resize((224, 224)),
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(
108
+ mean=[0.485, 0.456, 0.406],
109
+ std=[0.229, 0.224, 0.225]
110
+ )
111
+ ])
112
+
113
+ self.models = self.load_models()
114
+ if not self.models:
115
+ print("Warning: No models found in checkpoints directory!")
116
+
117
+ def load_models(self):
118
+ models = {}
119
+
120
+ # Load EfficientNet
121
+ try:
122
+ efficientnet = EfficientNetModel(num_classes=len(self.labels))
123
+ efficientnet_path = "efficientnet_best_model.pth"
124
+ if os.path.exists(efficientnet_path):
125
+ checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
126
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
127
+ efficientnet.load_state_dict(state_dict, strict=False)
128
+ efficientnet.eval()
129
+ models['EfficientNet'] = efficientnet
130
+ print("Successfully loaded EfficientNet model")
131
+ except Exception as e:
132
+ print(f"Error loading EfficientNet model: {str(e)}")
133
+
134
+ # Load CNN
135
+ try:
136
+ cnn = CNNModel(num_classes=len(self.labels))
137
+ cnn_path = "cnn_best_model.pth"
138
+ if os.path.exists(cnn_path):
139
+ checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
140
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
141
+ cnn.load_state_dict(state_dict, strict=False)
142
+ cnn.eval()
143
+ models['CNN'] = cnn
144
+ print("Successfully loaded CNN model")
145
+ except Exception as e:
146
+ print(f"Error loading CNN model: {str(e)}")
147
+
148
+ return models
149
+
150
+ def predict(self, image: Image.Image):
151
+ if not self.models:
152
+ return "No trained models found. Please train the models first."
153
+
154
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
155
+
156
+ results = {}
157
+ probabilities = {}
158
+ for model_name, model in self.models.items():
159
+ with torch.no_grad():
160
+ output = model(img_tensor)
161
+ probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
162
+ probabilities[model_name] = probs
163
+
164
+ pred_idx = np.argmax(probs)
165
+ pred_label = self.labels[pred_idx]
166
+ pred_prob = probs[pred_idx]
167
+ results[model_name] = (pred_label, pred_prob)
168
+
169
+ fig = plt.figure(figsize=(12, 5))
170
+
171
+ if 'EfficientNet' in probabilities:
172
+ plt.subplot(1, 2, 1)
173
+ plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue')
174
+ plt.title('EfficientNet Predictions')
175
+ plt.ylim(0, 1)
176
+ plt.xticks(rotation=45)
177
+ plt.ylabel('Probability')
178
+
179
+ if 'CNN' in probabilities:
180
+ plt.subplot(1, 2, 2)
181
+ plt.bar(self.labels, probabilities['CNN'], color='lightcoral')
182
+ plt.title('CNN Predictions')
183
+ plt.ylim(0, 1)
184
+ plt.xticks(rotation=45)
185
+ plt.ylabel('Probability')
186
+
187
+ plt.tight_layout()
188
+
189
+ text_results = "Model Predictions:\n\n"
190
+ for model_name, (label, prob) in results.items():
191
+ text_results += f"{model_name}:\n"
192
+ text_results += f"Top prediction: {label} ({prob:.2%})\n"
193
+ text_results += "All probabilities:\n"
194
+ for label, prob in zip(self.labels, probabilities[model_name]):
195
+ text_results += f" {label}: {prob:.2%}\n"
196
+ text_results += "\n"
197
+
198
+ return [fig, text_results]
199
+
200
+ def create_interface(self):
201
+ return gr.Interface(
202
+ fn=self.predict,
203
+ inputs=gr.Image(type="pil"),
204
+ outputs=[
205
+ gr.Plot(label="Prediction Probabilities"),
206
+ gr.Textbox(label="Detailed Results", lines=10)
207
+ ],
208
+ title="Animal Classifier - Model Comparison",
209
+ description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
210
+ )
211
+
212
+ def main():
213
+ app = AnimalClassifierApp()
214
+ interface = app.create_interface()
215
+ interface.launch()
216
+
217
+ if __name__ == "__main__":
218
  main()