File size: 2,724 Bytes
a7299bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# A reimplemented version in public environments by Xiao Fu and Mu Hu

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append("..")

from dataloader.mix_loader import MixDataset
from torch.utils.data import DataLoader
from dataloader import transforms
import os


# Get Dataset Here
def prepare_dataset(data_dir=None,
                    batch_size=1,
                    test_batch=1,
                    datathread=4,
                    logger=None):

    # set the config parameters
    dataset_config_dict = dict()
    
    train_dataset = MixDataset(data_dir=data_dir)

    img_height, img_width = train_dataset.get_img_size()

    datathread = datathread
    if os.environ.get('datathread') is not None:
        datathread = int(os.environ.get('datathread'))
    
    if logger is not None:
        logger.info("Use %d processes to load data..." % datathread)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, \
                            shuffle = True, num_workers = datathread, \
                            pin_memory = True)
    
    num_batches_per_epoch = len(train_loader)
    
    dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch
    dataset_config_dict['img_size'] = (img_height,img_width)
    
    return train_loader, dataset_config_dict

def depth_scale_shift_normalization(depth):

    bsz = depth.shape[0]

    depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy()
    min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None]
    max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None]

    normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2
    normalized_depth = torch.clip(normalized_depth, -1., 1.)

    return normalized_depth



def resize_max_res_tensor(input_tensor, mode, recom_resolution=768):
    assert input_tensor.shape[1]==3
    original_H, original_W = input_tensor.shape[2:]
    downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W)
    
    if mode == 'normal':
        resized_input_tensor = F.interpolate(input_tensor,
                                            scale_factor=downscale_factor,
                                            mode='nearest')
    else:
        resized_input_tensor = F.interpolate(input_tensor,
                                            scale_factor=downscale_factor,
                                            mode='bilinear',
                                            align_corners=False)

    if mode == 'depth':
        return resized_input_tensor / downscale_factor
    else:
        return resized_input_tensor