josedolot commited on
Commit
8239fc8
·
1 Parent(s): 128de2f

Upload hybridnets_test.py

Browse files
Files changed (1) hide show
  1. hybridnets_test.py +193 -0
hybridnets_test.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from torch.backends import cudnn
4
+ from backbone import HybridNetsBackbone
5
+ import cv2
6
+ import numpy as np
7
+ from glob import glob
8
+ from utils.utils import letterbox, scale_coords, postprocess, BBoxTransform, ClipBoxes, restricted_float, boolean_string
9
+ from utils.plot import STANDARD_COLORS, standard_to_bgr, get_index_label, plot_one_box
10
+ import os
11
+ from torchvision import transforms
12
+ import argparse
13
+
14
+ parser = argparse.ArgumentParser('HybridNets: End-to-End Perception Network - DatVu')
15
+ parser.add_argument('-c', '--compound_coef', type=int, default=3, help='Coefficient of efficientnet backbone')
16
+ parser.add_argument('--source', type=str, default='demo/image', help='The demo image folder')
17
+ parser.add_argument('--output', type=str, default='demo_result', help='Output folder')
18
+ parser.add_argument('-w', '--load_weights', type=str, default='weights/hybridnets.pth')
19
+ parser.add_argument('--nms_thresh', type=restricted_float, default='0.25')
20
+ parser.add_argument('--iou_thresh', type=restricted_float, default='0.3')
21
+ parser.add_argument('--imshow', type=boolean_string, default=False, help="Show result onscreen (unusable on colab, jupyter...)")
22
+ parser.add_argument('--imwrite', type=boolean_string, default=True, help="Write result to output folder")
23
+ parser.add_argument('--show_det', type=boolean_string, default=False, help="Output detection result exclusively")
24
+ parser.add_argument('--show_seg', type=boolean_string, default=False, help="Output segmentation result exclusively")
25
+ parser.add_argument('--cuda', type=boolean_string, default=True)
26
+ parser.add_argument('--float16', type=boolean_string, default=True, help="Use float16 for faster inference")
27
+ args = parser.parse_args()
28
+
29
+ compound_coef = args.compound_coef
30
+ source = args.source
31
+ if source.endswith("/"):
32
+ source = source[:-1]
33
+ output = args.output
34
+ if output.endswith("/"):
35
+ output = output[:-1]
36
+ weight = args.load_weights
37
+ img_path = glob(f'{source}/*.jpg') + glob(f'{source}/*.png')
38
+ # img_path = [img_path[0]] # demo with 1 image
39
+ input_imgs = []
40
+ shapes = []
41
+ det_only_imgs = []
42
+
43
+ # replace this part with your project's anchor config
44
+ anchor_ratios = [(0.62, 1.58), (1.0, 1.0), (1.58, 0.62)]
45
+ anchor_scales = [2 ** 0, 2 ** 0.70, 2 ** 1.32]
46
+
47
+ threshold = args.nms_thresh
48
+ iou_threshold = args.iou_thresh
49
+ imshow = args.imshow
50
+ imwrite = args.imwrite
51
+ show_det = args.show_det
52
+ show_seg = args.show_seg
53
+ os.makedirs(output, exist_ok=True)
54
+
55
+ use_cuda = args.cuda
56
+ use_float16 = args.float16
57
+ cudnn.fastest = True
58
+ cudnn.benchmark = True
59
+
60
+ obj_list = ['car']
61
+
62
+ color_list = standard_to_bgr(STANDARD_COLORS)
63
+ ori_imgs = [cv2.imread(i, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) for i in img_path]
64
+ ori_imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in ori_imgs]
65
+ # cv2.imwrite('ori.jpg', ori_imgs[0])
66
+ # cv2.imwrite('normalized.jpg', normalized_imgs[0]*255)
67
+ resized_shape = 640
68
+ normalize = transforms.Normalize(
69
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
70
+ )
71
+ transform = transforms.Compose([
72
+ transforms.ToTensor(),
73
+ normalize,
74
+ ])
75
+ for ori_img in ori_imgs:
76
+ h0, w0 = ori_img.shape[:2] # orig hw
77
+ r = resized_shape / max(h0, w0) # resize image to img_size
78
+ input_img = cv2.resize(ori_img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_AREA)
79
+ h, w = input_img.shape[:2]
80
+
81
+ (input_img, _, _), ratio, pad = letterbox((input_img, input_img.copy(), input_img.copy()), resized_shape, auto=True,
82
+ scaleup=False)
83
+
84
+ input_imgs.append(input_img)
85
+ # cv2.imwrite('input.jpg', input_img * 255)
86
+ shapes.append(((h0, w0), ((h / h0, w / w0), pad))) # for COCO mAP rescaling
87
+
88
+ if use_cuda:
89
+ x = torch.stack([transform(fi).cuda() for fi in input_imgs], 0)
90
+ else:
91
+ x = torch.stack([transform(fi) for fi in input_imgs], 0)
92
+
93
+ x = x.to(torch.float32 if not use_float16 else torch.float16)
94
+ # print(x.shape)
95
+ model = HybridNetsBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
96
+ ratios=anchor_ratios, scales=anchor_scales, seg_classes=2)
97
+ try:
98
+ model.load_state_dict(torch.load(weight, map_location='cuda' if use_cuda else 'cpu'))
99
+ except:
100
+ model.load_state_dict(torch.load(weight, map_location='cuda' if use_cuda else 'cpu')['model'])
101
+ model.requires_grad_(False)
102
+ model.eval()
103
+
104
+ if use_cuda:
105
+ model = model.cuda()
106
+ if use_float16:
107
+ model = model.half()
108
+
109
+ with torch.no_grad():
110
+ features, regression, classification, anchors, seg = model(x)
111
+
112
+ seg = seg[:, :, 12:372, :]
113
+ da_seg_mask = torch.nn.functional.interpolate(seg, size=[720, 1280], mode='nearest')
114
+ _, da_seg_mask = torch.max(da_seg_mask, 1)
115
+ for i in range(da_seg_mask.size(0)):
116
+ # print(i)
117
+ da_seg_mask_ = da_seg_mask[i].squeeze().cpu().numpy().round()
118
+ color_area = np.zeros((da_seg_mask_.shape[0], da_seg_mask_.shape[1], 3), dtype=np.uint8)
119
+ color_area[da_seg_mask_ == 1] = [0, 255, 0]
120
+ color_area[da_seg_mask_ == 2] = [0, 0, 255]
121
+ color_seg = color_area[..., ::-1]
122
+ # cv2.imwrite('seg_only_{}.jpg'.format(i), color_seg)
123
+
124
+ color_mask = np.mean(color_seg, 2)
125
+ # prepare to show det on 2 different imgs
126
+ # (with and without seg) -> (full and det_only)
127
+ det_only_imgs.append(ori_imgs[i].copy())
128
+ seg_img = ori_imgs[i]
129
+ seg_img[color_mask != 0] = seg_img[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
130
+ seg_img = seg_img.astype(np.uint8)
131
+ if show_seg:
132
+ cv2.imwrite(f'{output}/{i}_seg.jpg', cv2.cvtColor(seg_img, cv2.COLOR_RGB2BGR))
133
+
134
+ regressBoxes = BBoxTransform()
135
+ clipBoxes = ClipBoxes()
136
+ out = postprocess(x,
137
+ anchors, regression, classification,
138
+ regressBoxes, clipBoxes,
139
+ threshold, iou_threshold)
140
+
141
+ for i in range(len(ori_imgs)):
142
+ out[i]['rois'] = scale_coords(ori_imgs[i][:2], out[i]['rois'], shapes[i][0], shapes[i][1])
143
+ for j in range(len(out[i]['rois'])):
144
+ x1, y1, x2, y2 = out[i]['rois'][j].astype(int)
145
+ obj = obj_list[out[i]['class_ids'][j]]
146
+ score = float(out[i]['scores'][j])
147
+ plot_one_box(ori_imgs[i], [x1, y1, x2, y2], label=obj, score=score,
148
+ color=color_list[get_index_label(obj, obj_list)])
149
+ if show_det:
150
+ plot_one_box(det_only_imgs[i], [x1, y1, x2, y2], label=obj, score=score,
151
+ color=color_list[get_index_label(obj, obj_list)])
152
+
153
+ if show_det:
154
+ cv2.imwrite(f'{output}/{i}_det.jpg', cv2.cvtColor(det_only_imgs[i], cv2.COLOR_RGB2BGR))
155
+
156
+ if imshow:
157
+ cv2.imshow('img', ori_imgs[i])
158
+ cv2.waitKey(0)
159
+
160
+ if imwrite:
161
+ cv2.imwrite(f'{output}/{i}.jpg', cv2.cvtColor(ori_imgs[i], cv2.COLOR_RGB2BGR))
162
+
163
+ # exit()
164
+ print('running speed test...')
165
+ with torch.no_grad():
166
+ print('test1: model inferring and postprocessing')
167
+ print('inferring 1 image for 10 times...')
168
+ x = x[0, ...]
169
+ x.unsqueeze_(0)
170
+ t1 = time.time()
171
+ for _ in range(10):
172
+ _, regression, classification, anchors, segmentation = model(x)
173
+
174
+ out = postprocess(x,
175
+ anchors, regression, classification,
176
+ regressBoxes, clipBoxes,
177
+ threshold, iou_threshold)
178
+
179
+ t2 = time.time()
180
+ tact_time = (t2 - t1) / 10
181
+ print(f'{tact_time} seconds, {1 / tact_time} FPS, @batch_size 1')
182
+
183
+ # uncomment this if you want a extreme fps test
184
+ print('test2: model inferring only')
185
+ print('inferring images for batch_size 32 for 10 times...')
186
+ t1 = time.time()
187
+ x = torch.cat([x] * 32, 0)
188
+ for _ in range(10):
189
+ _, regression, classification, anchors, segmentation = model(x)
190
+
191
+ t2 = time.time()
192
+ tact_time = (t2 - t1) / 10
193
+ print(f'{tact_time} seconds, {32 / tact_time} FPS, @batch_size 32')