josedolot commited on
Commit
80f1cdc
·
1 Parent(s): 8239fc8

Upload hybridnets_test_videos.py

Browse files
Files changed (1) hide show
  1. hybridnets_test_videos.py +150 -0
hybridnets_test_videos.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from torch.backends import cudnn
4
+ from backbone import HybridNetsBackbone
5
+ import cv2
6
+ import numpy as np
7
+ from glob import glob
8
+ from utils.utils import letterbox, scale_coords, postprocess, BBoxTransform, ClipBoxes, restricted_float, boolean_string
9
+ from utils.plot import STANDARD_COLORS, standard_to_bgr, get_index_label, plot_one_box
10
+ import os
11
+ from torchvision import transforms
12
+ import argparse
13
+
14
+ parser = argparse.ArgumentParser('HybridNets: End-to-End Perception Network - DatVu')
15
+ parser.add_argument('-c', '--compound_coef', type=int, default=3, help='Coefficient of efficientnet backbone')
16
+ parser.add_argument('--source', type=str, default='demo/video', help='The demo video folder')
17
+ parser.add_argument('--output', type=str, default='demo_result', help='Output folder')
18
+ parser.add_argument('-w', '--load_weights', type=str, default='weights/hybridnets.pth')
19
+ parser.add_argument('--nms_thresh', type=restricted_float, default='0.25')
20
+ parser.add_argument('--iou_thresh', type=restricted_float, default='0.3')
21
+ parser.add_argument('--cuda', type=boolean_string, default=True)
22
+ parser.add_argument('--float16', type=boolean_string, default=True, help="Use float16 for faster inference")
23
+ args = parser.parse_args()
24
+
25
+ compound_coef = args.compound_coef
26
+ source = args.source
27
+ if source.endswith("/"):
28
+ source = source[:-1]
29
+ output = args.output
30
+ if output.endswith("/"):
31
+ output = output[:-1]
32
+ weight = args.load_weights
33
+ video_src = glob(f'{source}/*.mp4')[0]
34
+ os.makedirs(output, exist_ok=True)
35
+ video_out = f'{output}/output.mp4'
36
+ input_imgs = []
37
+ shapes = []
38
+
39
+ # replace this part with your project's anchor config
40
+ anchor_ratios = [(0.62, 1.58), (1.0, 1.0), (1.58, 0.62)]
41
+ anchor_scales = [2 ** 0, 2 ** 0.70, 2 ** 1.32]
42
+
43
+ threshold = args.nms_thresh
44
+ iou_threshold = args.iou_thresh
45
+
46
+ use_cuda = args.cuda
47
+ use_float16 = args.float16
48
+ cudnn.fastest = True
49
+ cudnn.benchmark = True
50
+
51
+ obj_list = ['car']
52
+
53
+ color_list = standard_to_bgr(STANDARD_COLORS)
54
+ resized_shape = 640
55
+ normalize = transforms.Normalize(
56
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
57
+ )
58
+ transform = transforms.Compose([
59
+ transforms.ToTensor(),
60
+ normalize,
61
+ ])
62
+ # print(x.shape)
63
+
64
+ model = HybridNetsBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
65
+ ratios=anchor_ratios, scales=anchor_scales, seg_classes=2)
66
+ try:
67
+ model.load_state_dict(torch.load(weight, map_location='cuda' if use_cuda else 'cpu'))
68
+ except:
69
+ model.load_state_dict(torch.load(weight, map_location='cuda' if use_cuda else 'cpu')['model'])
70
+ model.requires_grad_(False)
71
+ model.eval()
72
+
73
+ if use_cuda:
74
+ model = model.cuda()
75
+ if use_float16:
76
+ model = model.half()
77
+ cap = cv2.VideoCapture(video_src)
78
+ # Define the codec and create VideoWriter object
79
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
80
+ out_stream = cv2.VideoWriter(video_out, fourcc, 30.0,
81
+ (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
82
+ t1 = time.time()
83
+ frame_count = 0
84
+ while True:
85
+ ret, frame = cap.read()
86
+ if not ret:
87
+ break
88
+
89
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
90
+ h0, w0 = frame.shape[:2] # orig hw
91
+ r = resized_shape / max(h0, w0) # resize image to img_size
92
+ input_img = cv2.resize(frame, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_AREA)
93
+ h, w = input_img.shape[:2]
94
+
95
+ (input_img, _, _), ratio, pad = letterbox((input_img, input_img.copy(), input_img.copy()), resized_shape, auto=True,
96
+ scaleup=False)
97
+
98
+ shapes = ((h0, w0), ((h / h0, w / w0), pad))
99
+
100
+ if use_cuda:
101
+ x = transform(input_img).cuda()
102
+ else:
103
+ x = transform(input_img)
104
+
105
+ x = x.to(torch.float32 if not use_float16 else torch.float16)
106
+ x.unsqueeze_(0)
107
+ with torch.no_grad():
108
+ features, regression, classification, anchors, seg = model(x)
109
+
110
+ seg = seg[:, :, 12:372, :]
111
+ da_seg_mask = torch.nn.functional.interpolate(seg, size=[h0, w0], mode='nearest')
112
+ _, da_seg_mask = torch.max(da_seg_mask, 1)
113
+ da_seg_mask_ = da_seg_mask[0].squeeze().cpu().numpy().round()
114
+
115
+ color_area = np.zeros((da_seg_mask_.shape[0], da_seg_mask_.shape[1], 3), dtype=np.uint8)
116
+ color_area[da_seg_mask_ == 1] = [0, 255, 0]
117
+ color_area[da_seg_mask_ == 2] = [0, 0, 255]
118
+ color_seg = color_area[..., ::-1]
119
+
120
+ # cv2.imwrite('seg_only_{}.jpg'.format(i), color_seg)
121
+
122
+ color_mask = np.mean(color_seg, 2)
123
+ frame[color_mask != 0] = frame[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
124
+ frame = frame.astype(np.uint8)
125
+ # cv2.imwrite('seg_{}.jpg'.format(i), ori_img)
126
+
127
+ regressBoxes = BBoxTransform()
128
+ clipBoxes = ClipBoxes()
129
+ out = postprocess(x,
130
+ anchors, regression, classification,
131
+ regressBoxes, clipBoxes,
132
+ threshold, iou_threshold)
133
+ out = out[0]
134
+ out['rois'] = scale_coords(frame[:2], out['rois'], shapes[0], shapes[1])
135
+ for j in range(len(out['rois'])):
136
+ x1, y1, x2, y2 = out['rois'][j].astype(int)
137
+ obj = obj_list[out['class_ids'][j]]
138
+ score = float(out['scores'][j])
139
+ plot_one_box(frame, [x1, y1, x2, y2], label=obj, score=score,
140
+ color=color_list[get_index_label(obj, obj_list)])
141
+ out_stream.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
142
+ frame_count += 1
143
+
144
+ t2 = time.time()
145
+ print("frame: {}".format(frame_count))
146
+ print("second: {}".format(t2-t1))
147
+ print("fps: {}".format((t2-t1)/frame_count))
148
+
149
+ cap.release()
150
+ out_stream.release()