File size: 3,667 Bytes
44189a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from PIL import Image
import matplotlib
import numpy as np

from PIL import Image

import torch
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize

def concatenate_images(*image_lists):
    # Ensure at least one image list is provided
    if not image_lists or not image_lists[0]:
        raise ValueError("At least one non-empty image list must be provided")
    
    # Determine the maximum width of any single row and the total height
    max_width = 0
    total_height = 0
    row_widths = []
    row_heights = []

    # Compute dimensions for each row
    for image_list in image_lists:
        if image_list:  # Ensure the list is not empty
            width = sum(img.width for img in image_list)
            height = image_list[0].height  # Assuming all images in the list have the same height
            max_width = max(max_width, width)
            total_height += height
            row_widths.append(width)
            row_heights.append(height)
    
    # Create a new image to concatenate everything into
    new_image = Image.new('RGB', (max_width, total_height))
    
    # Concatenate each row of images
    y_offset = 0
    for i, image_list in enumerate(image_lists):
        x_offset = 0
        for img in image_list:
            new_image.paste(img, (x_offset, y_offset))
            x_offset += img.width
        y_offset += row_heights[i]  # Move the offset down to the next row
    
    return new_image


def colorize_depth_map(depth, mask=None):
    cm = matplotlib.colormaps["Spectral"]
    # normalize
    depth = ((depth - depth.min()) / (depth.max() - depth.min()))
    # colorize
    img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
    depth_colored = (img_colored_np * 255).astype(np.uint8) 
    if mask is not None:
        masked_image = np.zeros_like(depth_colored)
        masked_image[mask.numpy()] = depth_colored[mask.numpy()]
        depth_colored_img = Image.fromarray(masked_image)
    else:
        depth_colored_img = Image.fromarray(depth_colored)
    return depth_colored_img


def resize_max_res(
    img: torch.Tensor,
    max_edge_resolution: int,
    resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
    """
    Resize image to limit maximum edge length while keeping aspect ratio.

    Args:
        img (`torch.Tensor`):
            Image tensor to be resized. Expected shape: [B, C, H, W]
        max_edge_resolution (`int`):
            Maximum edge length (pixel).
        resample_method (`PIL.Image.Resampling`):
            Resampling method used to resize images.

    Returns:
        `torch.Tensor`: Resized image.
    """
    assert 4 == img.dim(), f"Invalid input shape {img.shape}"

    original_height, original_width = img.shape[-2:]
    downscale_factor = min(
        max_edge_resolution / original_width, max_edge_resolution / original_height
    )

    new_width = int(original_width * downscale_factor)
    new_height = int(original_height * downscale_factor)

    resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
    return resized_img


def get_tv_resample_method(method_str: str) -> InterpolationMode:
    resample_method_dict = {
        "bilinear": InterpolationMode.BILINEAR,
        "bicubic": InterpolationMode.BICUBIC,
        "nearest": InterpolationMode.NEAREST_EXACT,
        "nearest-exact": InterpolationMode.NEAREST_EXACT,
    }
    resample_method = resample_method_dict.get(method_str, None)
    if resample_method is None:
        raise ValueError(f"Unknown resampling method: {resample_method}")
    else:
        return resample_method