File size: 2,680 Bytes
8d015d4 daac507 8d015d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
def make_color_wheel():
"""
Generate color wheel according Middlebury color code
:return: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = torch.zeros([3, ncols])
col = 0
# RY
colorwheel[0, 0:RY] = 255
colorwheel[1, 0:RY] = torch.floor(255 * torch.arange(0, RY) / RY)
col += RY
# YG
colorwheel[0, col:col + YG] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
colorwheel[1, col:col + YG] = 255
col += YG
# GC
colorwheel[1, col:col + GC] = 255
colorwheel[2, col:col + GC] = torch.floor(255 * torch.arange(0, GC) / GC)
col += GC
# CB
colorwheel[1, col:col + CB] = 255 - torch.floor(255 * torch.arange(0, CB) / CB)
colorwheel[2, col:col + CB] = 255
col += CB
# BM
colorwheel[2, col:col + BM] = 255
colorwheel[0, col:col + BM] = torch.floor(255 * torch.arange(0, BM) / BM)
col += + BM
# MR
colorwheel[2, col:col + MR] = 255 - torch.floor(255 * torch.arange(0, MR) / MR)
colorwheel[0, col:col + MR] = 255
return colorwheel
colorwheel = make_color_wheel()#.cuda()
def flow2img(flow_data: torch.Tensor):
"""
convert optical flow into color image
:param flow_data:
:return: color image
"""
# print(flow_data.shape)
# print(type(flow_data))
u = flow_data[:, 0:1, :, :]
v = flow_data[:, 1:2, :, :]
UNKNOW_FLOW_THRESHOLD = 1e7
pr1 = torch.abs(u) > UNKNOW_FLOW_THRESHOLD
pr2 = torch.abs(v) > UNKNOW_FLOW_THRESHOLD
idx_unknown = (pr1 | pr2)
u[idx_unknown] = 0
v[idx_unknown] = 0
idx_unknown = idx_unknown.repeat(1, 3, 1, 1)
rad = torch.sqrt(u ** 2 + v ** 2)
maxrad = max(-1, torch.max(rad).item())
u = u / maxrad + torch.finfo(float).eps
v = v / maxrad + torch.finfo(float).eps
img = compute_color(u, v)
img[idx_unknown] = 0
return img / 255.
def compute_color(u, v):
"""
compute optical flow color map
:param u: horizontal optical flow
:param v: vertical optical flow
:return:
"""
B, _, H, W = u.shape
img = torch.zeros((B, 3, H, W), device=torch.device('cuda'))
NAN_idx = torch.isnan(u) | torch.isnan(v)
u[NAN_idx] = v[NAN_idx] = 0
ncols = colorwheel.shape[1]
rad = torch.sqrt(u ** 2 + v ** 2)
a = torch.arctan2(-v, -u) / torch.pi
fk = (a + 1) / 2 * (ncols - 1) + 1
k0 = torch.floor(fk).to(int)
k1 = k0 + 1
k1[k1 == ncols + 1] = 1
f = fk - k0
for i in range(0, colorwheel.shape[0]):
tmp = colorwheel[i, :]
col0 = tmp[k0 - 1] / 255
col1 = tmp[k1 - 1] / 255
col = (1 - f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1 - rad[idx] * (1 - col[idx])
notidx = torch.logical_not(idx)
col[notidx] *= 0.75
img[:, i:i+1, :, :] = torch.floor(255 * col * (~NAN_idx)).to(torch.uint8)
return img
|