File size: 2,680 Bytes
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daac507
8d015d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch


def make_color_wheel():
	"""
	Generate color wheel according Middlebury color code
	:return: Color wheel
	"""
	RY = 15
	YG = 6
	GC = 4
	CB = 11
	BM = 13
	MR = 6

	ncols = RY + YG + GC + CB + BM + MR

	colorwheel = torch.zeros([3, ncols])

	col = 0

	# RY
	colorwheel[0, 0:RY] = 255
	colorwheel[1, 0:RY] = torch.floor(255 * torch.arange(0, RY) / RY)
	col += RY

	# YG
	colorwheel[0, col:col + YG] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
	colorwheel[1, col:col + YG] = 255
	col += YG

	# GC
	colorwheel[1, col:col + GC] = 255
	colorwheel[2, col:col + GC] = torch.floor(255 * torch.arange(0, GC) / GC)
	col += GC

	# CB
	colorwheel[1, col:col + CB] = 255 - torch.floor(255 * torch.arange(0, CB) / CB)
	colorwheel[2, col:col + CB] = 255
	col += CB

	# BM
	colorwheel[2, col:col + BM] = 255
	colorwheel[0, col:col + BM] = torch.floor(255 * torch.arange(0, BM) / BM)
	col += + BM

	# MR
	colorwheel[2, col:col + MR] = 255 - torch.floor(255 * torch.arange(0, MR) / MR)
	colorwheel[0, col:col + MR] = 255

	return colorwheel


colorwheel = make_color_wheel()#.cuda()


def flow2img(flow_data: torch.Tensor):
	"""
	convert optical flow into color image
	:param flow_data:
	:return: color image
	"""
	# print(flow_data.shape)
	# print(type(flow_data))
	u = flow_data[:, 0:1, :, :]
	v = flow_data[:, 1:2, :, :]

	UNKNOW_FLOW_THRESHOLD = 1e7
	pr1 = torch.abs(u) > UNKNOW_FLOW_THRESHOLD
	pr2 = torch.abs(v) > UNKNOW_FLOW_THRESHOLD
	idx_unknown = (pr1 | pr2)
	u[idx_unknown] = 0
	v[idx_unknown] = 0
	idx_unknown = idx_unknown.repeat(1, 3, 1, 1)

	rad = torch.sqrt(u ** 2 + v ** 2)
	maxrad = max(-1, torch.max(rad).item())
	u = u / maxrad + torch.finfo(float).eps
	v = v / maxrad + torch.finfo(float).eps

	img = compute_color(u, v)

	img[idx_unknown] = 0

	return img / 255.


def compute_color(u, v):
	"""
	compute optical flow color map
	:param u: horizontal optical flow
	:param v: vertical optical flow
	:return:
	"""

	B, _, H, W = u.shape
	img = torch.zeros((B, 3, H, W), device=torch.device('cuda'))

	NAN_idx = torch.isnan(u) | torch.isnan(v)
	u[NAN_idx] = v[NAN_idx] = 0
	ncols = colorwheel.shape[1]

	rad = torch.sqrt(u ** 2 + v ** 2)

	a = torch.arctan2(-v, -u) / torch.pi

	fk = (a + 1) / 2 * (ncols - 1) + 1

	k0 = torch.floor(fk).to(int)

	k1 = k0 + 1
	k1[k1 == ncols + 1] = 1
	f = fk - k0

	for i in range(0, colorwheel.shape[0]):
		tmp = colorwheel[i, :]
		col0 = tmp[k0 - 1] / 255
		col1 = tmp[k1 - 1] / 255
		col = (1 - f) * col0 + f * col1

		idx = rad <= 1
		col[idx] = 1 - rad[idx] * (1 - col[idx])
		notidx = torch.logical_not(idx)

		col[notidx] *= 0.75
		img[:, i:i+1, :, :] = torch.floor(255 * col * (~NAN_idx)).to(torch.uint8)

	return img