yucelgumus61 commited on
Commit
acb1b3c
·
verified ·
1 Parent(s): af6f2d4

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -0
model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+
4
+
5
+ class ExampleModel(nn.Module):
6
+ def __init__(self, num_classes):
7
+ super(ExampleModel, self).__init__()
8
+ self.base_model = timm.create_model("efficientnet_b0", pretrained=True)
9
+ self.features = nn.Sequential(*list(self.base_model.children())[:-1])
10
+
11
+ network_out_size = 1280
12
+
13
+ self.classifier = nn.Sequential(
14
+ nn.Flatten(), nn.Linear(network_out_size, num_classes)
15
+ )
16
+
17
+ def forward(self, x):
18
+ x = self.features(x)
19
+ output = self.classifier(x)
20
+ return output