File size: 6,674 Bytes
0320907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from lpips import LPIPS
from PIL import Image
from torchvision.transforms import Normalize


def show_images_horizontally(
    list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
) -> None:
    """
    Visualize the list of images horizontally and save the figure as PNG.

    Args:
        list_of_files: The list of images as numpy array with shape (N, H, W, C).
        output_file: The output file path to save the figure as PNG.
        interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
    """
    number_of_files = len(list_of_files)

    heights = [a[0].shape[0] for a in list_of_files]
    widths = [a.shape[1] for a in list_of_files[0]]

    fig_width = 8.0  # inches
    fig_height = fig_width * sum(heights) / sum(widths)

    # Create a figure with subplots
    _, axs = plt.subplots(
        1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
    )
    plt.tight_layout()
    for i in range(number_of_files):
        _image = list_of_files[i]
        axs[i].imshow(_image)
        axs[i].axis("off")

    # Save the figure as PNG
    if interact:
        plt.show()
    else:
        plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)


def image_grids(images, rows=None, cols=None):
    if not images:
        raise ValueError("The image list is empty.")

    n_images = len(images)
    if cols is None:
        cols = int(n_images**0.5)
    if rows is None:
        rows = (n_images + cols - 1) // cols

    width, height = images[0].size
    grid_width = cols * width
    grid_height = rows * height

    grid_image = Image.new("RGB", (grid_width, grid_height))

    for i, image in enumerate(images):
        row, col = divmod(i, cols)
        grid_image.paste(image, (col * width, row * height))

    return grid_image


def save_image(image: np.array, file_name: str) -> None:
    """
    Save the image as JPG.

    Args:
        image: The input image as numpy array with shape (H, W, C).
        file_name: The file name to save the image.
    """
    image = Image.fromarray(image)
    image.save(file_name)


def load_and_process_images(load_dir: str) -> np.array:
    """
    Load and process the images into numpy array from the directory.

    Args:
        load_dir: The directory to load the images.

    Returns:
        images: The images as numpy array with shape (N, H, W, C).
    """
    images = []
    print(load_dir)
    filenames = sorted(
        os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
    )  # Ensure the files are sorted numerically
    for filename in filenames:
        if filename.endswith(".jpg"):
            img = Image.open(os.path.join(load_dir, filename))
            img_array = (
                np.asarray(img) / 255.0
            )  # Convert to numpy array and scale pixel values to [0, 1]
            images.append(img_array)
    return images


def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
    """
    Compute the LPIPS of the input images.

    Args:
        images: The input images as numpy array with shape (N, H, W, C).
        lpips_model: The LPIPS model used to compute perceptual distances.

    Returns:
        distances: The LPIPS of the input images.
    """
    # Get device of lpips_model
    device = next(lpips_model.parameters()).device
    device = str(device)

    # Change the input images into tensor
    images = torch.tensor(images).to(device).float()
    images = torch.permute(images, (0, 3, 1, 2))
    normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    images = normalize(images)

    # Compute the LPIPS between each adjacent input images
    distances = []
    for i in range(images.shape[0]):
        if i == images.shape[0] - 1:
            break
        img1 = images[i].unsqueeze(0)
        img2 = images[i + 1].unsqueeze(0)
        loss = lpips_model(img1, img2)
        distances.append(loss.item())
    distances = np.array(distances)
    return distances


def compute_gini(distances: np.array) -> float:
    """
    Compute the Gini index of the input distances.

    Args:
        distances: The input distances as numpy array.

    Returns:
        gini: The Gini index of the input distances.
    """
    if len(distances) < 2:
        return 0.0  # Gini index is 0 for less than two elements

    # Sort the list of distances
    sorted_distances = sorted(distances)
    n = len(sorted_distances)
    mean_distance = sum(sorted_distances) / n

    # Compute the sum of absolute differences
    sum_of_differences = 0
    for di in sorted_distances:
        for dj in sorted_distances:
            sum_of_differences += abs(di - dj)

    # Normalize the sum of differences by the mean and the number of elements
    gini = sum_of_differences / (2 * n * n * mean_distance)
    return gini


def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
    """
    Compute the smoothness and efficiency of the input images.

    Args:
        images: The input images as numpy array with shape (N, H, W, C).
        lpips_model: The LPIPS model used to compute perceptual distances.

    Returns:
        smoothness: One minus gini index of LPIPS of consecutive images.
        consistency: The mean LPIPS of consecutive images.
        max_inception_distance: The maximum LPIPS of consecutive images.
    """
    distances = compute_lpips(images, lpips_model)
    smoothness = 1 - compute_gini(distances)
    consistency = np.mean(distances)
    max_inception_distance = np.max(distances)
    return smoothness, consistency, max_inception_distance


def separate_source_and_interpolated_images(images: np.array) -> tuple:
    """
    Separate the input images into source and interpolated images.
    The input source is the start and end of the images, while the interpolated images are the rest.

    Args:
        images: The input images as numpy array with shape (N, H, W, C).

    Returns:
        source: The source images as numpy array with shape (2, H, W, C).
        interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
    """
    # Check if the array has at least two elements
    if len(images) < 2:
        raise ValueError("The input array should have at least two elements.")

    # Separate the array into two parts
    # First part takes the first and last element
    source = np.array([images[0], images[-1]])
    # Second part takes the rest of the elements
    interpolation = images[1:-1]
    return source, interpolation