letrunglinh commited on
Commit
3eda55a
·
1 Parent(s): c419dfc

Upload 2 files

Browse files
Files changed (2) hide show
  1. last_0603_92.pth +3 -0
  2. model_fasterrcnn.py +77 -0
last_0603_92.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dc05f2a61959c1a0e91e1381e461638b4c5ff872f07d0fc383c5425b0e2871b
3
+ size 107986969
model_fasterrcnn.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ # import torch.nn.utils.prune as prune
4
+
5
+
6
+ import torchvision.models as models
7
+ import torchvision
8
+ # from torchsummary import summary
9
+
10
+ class MobileNetV2FeatureExtractor(nn.Module):
11
+ def __init__(self):
12
+ super(MobileNetV2FeatureExtractor, self).__init__()
13
+ self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=False)
14
+
15
+ for param in self.model.parameters():
16
+ param.requires_grad = True
17
+ self.model = self.model.backbone
18
+
19
+ def forward(self, x):
20
+ return self.model(x)
21
+ class GlobalAvgPool2D(nn.Module):
22
+ def __init__(self):
23
+ super(GlobalAvgPool2D, self).__init__()
24
+
25
+ def forward(self, x):
26
+ tensor = x['0']
27
+ return torch.mean(tensor.view(tensor.size(0), tensor.size(1), -1), dim=2)
28
+
29
+
30
+ class LDRNet_fasterrcnn(nn.Module):
31
+ def __init__(self, points_size=100, classification_list=[1]):
32
+ super(LDRNet_fasterrcnn, self).__init__()
33
+
34
+ self.points_size = points_size
35
+ self.classification_list = classification_list
36
+
37
+ self.backbone = MobileNetV2FeatureExtractor()
38
+ if len(classification_list) > 0:
39
+ class_size = sum(self.classification_list)
40
+ else:
41
+ class_size = 0
42
+ self.global_pool = GlobalAvgPool2D()
43
+ # self.dropout = nn.Dropout(p=0.3)
44
+ self.corner = nn.Linear(256, 8)
45
+ self.border = nn.Linear(256, (points_size - 4) * 2)
46
+ self.cls = nn.Linear(256, class_size + len(self.classification_list))
47
+
48
+ def forward(self, x):
49
+ x = self.backbone(x)
50
+ x = self.global_pool(x)
51
+ # x = self.dropout(x)
52
+ corner_output = self.corner(x)
53
+ border_output = self.border(x)
54
+ cls_output = self.cls(x)
55
+ return corner_output, border_output, cls_output
56
+ if __name__ == "__main__":
57
+ import torch
58
+ # from torchsummary import summary
59
+ xx = torch.zeros((1, 3, 224, 224))
60
+ model = LDRNet_fasterrcnn()
61
+ print(model)
62
+ y = model(xx)
63
+ for name, module in model.named_modules():
64
+ if isinstance(module, torch.nn.Conv2d):
65
+ prune.l1_unstructured(module, name='weight', amount=0.2)
66
+ elif isinstance(module, torch.nn.Linear):
67
+ prune.l1_unstructured(module, name='weight', amount=0.4)
68
+ # print(y[0].detach().numpy()[0])
69
+ # summary(model,input_size=(3, 224, 224))
70
+ total_params = sum(p.numel() for p in model.parameters())
71
+ total_trainable_params = sum(
72
+ p.numel() for p in model.parameters() if p.requires_grad
73
+ )
74
+ print(f"[INFO]: {total_params:,} total parameters.")
75
+ print(f"[INFO]: {total_trainable_params:,} trainable parameters.")
76
+
77
+