File size: 4,361 Bytes
44189a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import cv2
import numpy as np

import torch

from matplotlib import cm
import matplotlib.pyplot as plt

import logging
logger = logging.getLogger('root')



def tensor_to_numpy(tensor_in):
    """ torch tensor to numpy array
    """
    if tensor_in is not None:
        if tensor_in.ndim == 3:
            # (C, H, W) -> (H, W, C)
            tensor_in = tensor_in.detach().cpu().permute(1, 2, 0).numpy()
        elif tensor_in.ndim == 4:
            # (B, C, H, W) -> (B, H, W, C)
            tensor_in = tensor_in.detach().cpu().permute(0, 2, 3, 1).numpy()
        else:
            raise Exception('invalid tensor size')
    return tensor_in

# def unnormalize(img_in, img_stats={'mean': [0.485, 0.456, 0.406], 
#                                     'std': [0.229, 0.224, 0.225]}):
def unnormalize(img_in, img_stats={'mean': [0.5,0.5,0.5], 'std': [0.5,0.5,0.5]}):
    """ unnormalize input image
    """
    if torch.is_tensor(img_in):
        img_in = tensor_to_numpy(img_in)

    img_out = np.zeros_like(img_in)
    for ich in range(3):
        img_out[..., ich] = img_in[..., ich] * img_stats['std'][ich]
        img_out[..., ich] += img_stats['mean'][ich]
    img_out = (img_out * 255.0).astype(np.uint8)
    return img_out

def normal_to_rgb(normal, normal_mask=None):
    """ surface normal map to RGB
        (used for visualization)

        NOTE: x, y, z are mapped to R, G, B
        NOTE: [-1, 1] are mapped to [0, 255]
    """
    if torch.is_tensor(normal):
        normal = tensor_to_numpy(normal)
        normal_mask = tensor_to_numpy(normal_mask)

    normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
    normal_norm[normal_norm < 1e-12] = 1e-12
    normal = normal / normal_norm

    normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
    if normal_mask is not None:
        normal_rgb = normal_rgb * normal_mask     # (B, H, W, 3)
    return normal_rgb

def kappa_to_alpha(pred_kappa, to_numpy=True):
    """ Confidence kappa to uncertainty alpha
        Assuming AngMF distribution (introduced in https://arxiv.org/abs/2109.09881)
    """
    if torch.is_tensor(pred_kappa) and to_numpy:
        pred_kappa = tensor_to_numpy(pred_kappa)

    if torch.is_tensor(pred_kappa):
        alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
            + ((torch.exp(- pred_kappa * np.pi) * np.pi) / (1 + torch.exp(- pred_kappa * np.pi)))
        alpha = torch.rad2deg(alpha)
    else:
        alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
                + ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi)))
        alpha = np.degrees(alpha)

    return alpha


def visualize_normal(target_dir, prefixs, img, pred_norm, pred_kappa,
                        gt_norm, gt_norm_mask, pred_error, num_vis=-1):
    """ visualize normal
    """
    error_max = 60.0

    img = tensor_to_numpy(img)                      # (B, H, W, 3)
    pred_norm = tensor_to_numpy(pred_norm)          # (B, H, W, 3)
    pred_kappa = tensor_to_numpy(pred_kappa)        # (B, H, W, 1)
    gt_norm = tensor_to_numpy(gt_norm)              # (B, H, W, 3)
    gt_norm_mask = tensor_to_numpy(gt_norm_mask)    # (B, H, W, 1)
    pred_error = tensor_to_numpy(pred_error)        # (B, H, W, 1)

    num_vis = len(prefixs) if num_vis == -1 else num_vis
    for i in range(num_vis):
        # img
        img_ = unnormalize(img[i, ...])
        target_path = '%s/%s_img.png' % (target_dir, prefixs[i])
        plt.imsave(target_path, img_)

        # pred_norm 
        target_path = '%s/%s_norm.png' % (target_dir, prefixs[i])
        plt.imsave(target_path, normal_to_rgb(pred_norm[i, ...]))

        # pred_kappa
        if pred_kappa is not None:
            pred_alpha = kappa_to_alpha(pred_kappa[i, :, :, 0])
            target_path = '%s/%s_pred_alpha.png' % (target_dir, prefixs[i])
            plt.imsave(target_path, pred_alpha, vmin=0.0, vmax=error_max, cmap='jet')

        # gt_norm, pred_error
        if gt_norm is not None:
            target_path = '%s/%s_gt.png' % (target_dir, prefixs[i])
            plt.imsave(target_path, normal_to_rgb(gt_norm[i, ...], gt_norm_mask[i, ...]))

            E = pred_error[i, :, :, 0] * gt_norm_mask[i, :, :, 0]
            target_path = '%s/%s_pred_error.png' % (target_dir, prefixs[i])
            plt.imsave(target_path, E, vmin=0, vmax=error_max, cmap='jet')