DataRaptor commited on
Commit
6197f1f
1 Parent(s): 845eb37

Upload 4 files

Browse files
Files changed (3) hide show
  1. ModelClass.py +28 -19
  2. app.py +4 -3
  3. model_weights.pth +3 -0
ModelClass.py CHANGED
@@ -3,22 +3,25 @@ from torch import nn
3
  from torchvision import transforms, models
4
 
5
  class ActionClassifier(nn.Module):
6
- def __init__(self, ntargets):
7
  super().__init__()
8
- resnet = models.resnet50(pretrained=True, progress=True)
9
  modules = list(resnet.children())[:-1] # delete last layer
 
10
  self.resnet = nn.Sequential(*modules)
11
- for param in self.resnet.parameters():
12
  param.requires_grad = False
 
13
  self.fc = nn.Sequential(
14
  nn.Flatten(),
15
  nn.BatchNorm1d(resnet.fc.in_features),
16
- nn.Dropout(0.2),
17
- nn.Linear(resnet.fc.in_features, 256),
18
  nn.ReLU(),
19
- nn.BatchNorm1d(256),
20
- nn.Dropout(0.2),
21
- nn.Linear(256, ntargets)
 
22
  )
23
 
24
  def forward(self, x):
@@ -27,22 +30,28 @@ class ActionClassifier(nn.Module):
27
  return x
28
 
29
 
30
-
31
  def get_transform():
32
- transform = transforms.Compose([
33
- transforms.Resize([224, 244]),
34
- transforms.ToTensor(),
35
- # std multiply by 255 to convert img of [0, 255]
36
- # to img of [0, 1]
37
- transforms.Normalize((0.485, 0.456, 0.406),
38
- (0.229*255, 0.224*255, 0.225*255))]
39
- )
40
  return transform
41
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def get_model():
44
- model = ActionClassifier(15)
45
- model.load_state_dict(torch.load('./classifier_weights.pth', map_location=torch.device('cpu')))
46
  return model
47
 
48
 
 
3
  from torchvision import transforms, models
4
 
5
  class ActionClassifier(nn.Module):
6
+ def __init__(self, train_last_nlayer, hidden_size, dropout, ntargets):
7
  super().__init__()
8
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT, progress=True)
9
  modules = list(resnet.children())[:-1] # delete last layer
10
+
11
  self.resnet = nn.Sequential(*modules)
12
+ for param in self.resnet[:-train_last_nlayer].parameters():
13
  param.requires_grad = False
14
+
15
  self.fc = nn.Sequential(
16
  nn.Flatten(),
17
  nn.BatchNorm1d(resnet.fc.in_features),
18
+ nn.Dropout(dropout),
19
+ nn.Linear(resnet.fc.in_features, hidden_size),
20
  nn.ReLU(),
21
+ nn.BatchNorm1d(hidden_size),
22
+ nn.Dropout(dropout),
23
+ nn.Linear(hidden_size, ntargets),
24
+ nn.Sigmoid()
25
  )
26
 
27
  def forward(self, x):
 
30
  return x
31
 
32
 
 
33
  def get_transform():
34
+ transform = transforms.Compose([
35
+ transforms.Resize([224, 244]),
36
+ models.ResNet50_Weights.DEFAULT.transforms()
37
+ ])
 
 
 
 
38
  return transform
39
 
40
+ # def get_transform():
41
+ # transform = transforms.Compose([
42
+ # transforms.Resize([224, 244]),
43
+ # transforms.ToTensor(),
44
+ # # std multiply by 255 to convert img of [0, 255]
45
+ # # to img of [0, 1]
46
+ # transforms.Normalize((0.485, 0.456, 0.406),
47
+ # (0.229*255, 0.224*255, 0.225*255))]
48
+ # )
49
+ # return transform
50
+
51
 
52
  def get_model():
53
+ model = ActionClassifier(0, 512, 0.2, 15)
54
+ model.load_state_dict(torch.load('./model_weights.pth', map_location=torch.device('cpu')))
55
  return model
56
 
57
 
app.py CHANGED
@@ -35,7 +35,7 @@ def infer(img):
35
 
36
 
37
  st.set_page_config(
38
- page_title="Whale Identification",
39
  page_icon="🧊",
40
  layout="centered",
41
  initial_sidebar_state="expanded",
@@ -86,7 +86,7 @@ hide_st_style = """
86
  header {visibility: hidden;}
87
  </style>
88
  """
89
- #st.markdown(hide_st_style, unsafe_allow_html=True)
90
 
91
 
92
 
@@ -129,10 +129,11 @@ def app():
129
 
130
  res = infer(image)
131
  prob = res.numpy()
132
- idx = np.argpartition(prob, -4)[-4:]
133
  right_column.markdown('#### Results')
134
 
135
  idx = list(idx)
 
136
  for i in idx:
137
 
138
  class_name = ModelClass.get_class(i).replace('_', ' ').capitalize()
 
35
 
36
 
37
  st.set_page_config(
38
+ page_title="ActionNet",
39
  page_icon="🧊",
40
  layout="centered",
41
  initial_sidebar_state="expanded",
 
86
  header {visibility: hidden;}
87
  </style>
88
  """
89
+ st.markdown(hide_st_style, unsafe_allow_html=True)
90
 
91
 
92
 
 
129
 
130
  res = infer(image)
131
  prob = res.numpy()
132
+ idx = np.argpartition(prob, -6)[-6:]
133
  right_column.markdown('#### Results')
134
 
135
  idx = list(idx)
136
+ idx.sort(key=lambda x: prob[x].astype(float), reverse=True)
137
  for i in idx:
138
 
139
  class_name = ModelClass.get_class(i).replace('_', ' ').capitalize()
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d300834d5794b294533827f8f7200c7f5fa29fb984fa17075ae0b87b8e4c7e6
3
+ size 98624253