File size: 2,456 Bytes
bb0f5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
from PIL import Image
import wandb
from PTI.configs import global_config
import torch
import matplotlib.pyplot as plt


def log_image_from_w(w, G, name):
    img = get_image_from_w(w, G)
    pillow_image = Image.fromarray(img)
    wandb.log(
        {f"{name}": [
            wandb.Image(pillow_image, caption=f"current inversion {name}")]},
        step=global_config.training_step)


def log_images_from_w(ws, G, names):
    for name, w in zip(names, ws):
        w = w.to(global_config.device)
        log_image_from_w(w, G, name)


def plot_image_from_w(w, G):
    img = get_image_from_w(w, G)
    pillow_image = Image.fromarray(img)
    plt.imshow(pillow_image)
    plt.show()


def plot_image(img):
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
    pillow_image = Image.fromarray(img[0])
    plt.imshow(pillow_image)
    plt.show()


def save_image(name, method_type, results_dir, image, run_id):
    image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg')


def save_w(w, G, name, method_type, results_dir):
    im = get_image_from_w(w, G)
    im = Image.fromarray(im, mode='RGB')
    save_image(name, method_type, results_dir, im)


def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G,
                      old_G,
                      file_name,
                      extra_image=None):
    images_to_save = []
    if extra_image is not None:
        images_to_save.append(extra_image)
    for latent in image_latents:
        images_to_save.append(get_image_from_w(latent, old_G))
    images_to_save.append(get_image_from_w(new_inv_image_latent, new_G))
    result_image = create_alongside_images(images_to_save)
    result_image.save(f'{base_dir}/{file_name}.jpg')


def save_single_image(base_dir, image_latent, G, file_name):
    image_to_save = get_image_from_w(image_latent, G)
    image_to_save = Image.fromarray(image_to_save, mode='RGB')
    image_to_save.save(f'{base_dir}/{file_name}.jpg')


def create_alongside_images(images):
    res = np.concatenate([np.array(image) for image in images], axis=1)
    return Image.fromarray(res, mode='RGB')


def get_image_from_w(w, G):
    if len(w.size()) <= 2:
        w = w.unsqueeze(0)
    with torch.no_grad():
        img = G.synthesis(w, noise_mode='const')
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()
    return img[0]