Spaces:
Runtime error
Runtime error
# Copyright (c) SenseTime Research. All rights reserved. | |
import torch | |
import cv2 | |
from torchvision import transforms | |
import numpy as np | |
import math | |
def visual(output, out_path): | |
output = (output + 1)/2 | |
output = torch.clamp(output, 0, 1) | |
if output.shape[1] == 1: | |
output = torch.cat([output, output, output], 1) | |
output = output[0].detach().cpu().permute(1, 2, 0).numpy() | |
output = (output*255).astype(np.uint8) | |
output = output[:, :, ::-1] | |
cv2.imwrite(out_path, output) | |
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | |
lr_ramp = min(1, (1 - t) / rampdown) | |
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | |
lr_ramp = lr_ramp * min(1, t / rampup) | |
return initial_lr * lr_ramp | |
def latent_noise(latent, strength): | |
noise = torch.randn_like(latent) * strength | |
return latent + noise | |
def noise_regularize_(noises): | |
loss = 0 | |
for noise in noises: | |
size = noise.shape[2] | |
while True: | |
loss = ( | |
loss | |
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) | |
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) | |
) | |
if size <= 8: | |
break | |
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) | |
noise = noise.mean([3, 5]) | |
size //= 2 | |
return loss | |
def noise_normalize_(noises): | |
for noise in noises: | |
mean = noise.mean() | |
std = noise.std() | |
noise.data.add_(-mean).div_(std) | |
def tensor_to_numpy(x): | |
x = x[0].permute(1, 2, 0) | |
x = torch.clamp(x, -1, 1) | |
x = (x+1) * 127.5 | |
x = x.cpu().detach().numpy().astype(np.uint8) | |
return x | |
def numpy_to_tensor(x): | |
x = (x / 255 - 0.5) * 2 | |
x = torch.from_numpy(x).unsqueeze(0).permute(0, 3, 1, 2) | |
x = x.cuda().float() | |
return x | |
def tensor_to_pil(x): | |
x = torch.clamp(x, -1, 1) | |
x = (x+1) * 127.5 | |
return transforms.ToPILImage()(x.squeeze_(0)) | |