Ubuntu commited on
Commit
0128aee
·
1 Parent(s): c6827c1

[add] backbone and detector

Browse files
model/backbone/densenet.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, resnet_fpn_backbone, BackboneWithFPN
3
+
4
+ class DenseNet(nn.Module):
5
+ def __init__(self, num_classes): # @Thuan: Add argumentation in __init__
6
+ super(DenseNet, self).__init__()
7
+ # @Thuan: Initiation model here
8
+
9
+ # End of model initiation
10
+
11
+ def forward(self, x):
12
+ # @Thuan: Feedforward step here
13
+
14
+ return x # End of feedforward function
15
+
model/detector/fasterRCNN.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ from torchvision.models.detection import FasterRCNN
3
+ from torchvision.models.detection.rpn import AnchorGenerator
4
+
5
+ def fasterRCNN(backbone, class_num):
6
+ anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
7
+ aspect_ratios=((0.5, 1.0, 2.0),))
8
+
9
+ roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
10
+ output_size=7,
11
+ sampling_ratio=2)
12
+
13
+ # put the pieces together inside a FasterRCNN model
14
+ model = FasterRCNN(backbone,
15
+ num_classes=class_num,
16
+ rpn_anchor_generator=anchor_generator,
17
+ box_roi_pool=roi_pooler)
18
+ return model