Spaces:
Runtime error
Runtime error
import cv2 as cv | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
import os | |
import yaml | |
import PIL.Image as Image | |
from src.models.definitions.vgg_nets import Vgg16, Vgg19, Vgg16Experimental | |
IMAGENET_MEAN_255 = [123.675, 116.28, 103.53] | |
IMAGENET_STD_NEUTRAL = [1, 1, 1] | |
def load_image(img_path, target_shape=None): | |
if not os.path.exists(img_path): | |
raise Exception(f'Path does not exist: {img_path}') | |
img = cv.imread(img_path)[:, :, ::-1] | |
if target_shape is not None: # resize section | |
current_height, current_width = img.shape[:2] | |
new_height = target_shape | |
new_width = int(current_width * (new_height / current_height)) | |
img = cv.resize(img, (new_width, new_height), | |
interpolation=cv.INTER_CUBIC) | |
# this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range | |
img = img.astype(np.float32) # convert from uint8 to float32 | |
img /= 255.0 # get to [0, 1] range | |
return img | |
def getInitImage(content_img, style_img, device): | |
if yamlGet("initImage") == 'White Noise Image': | |
white_noise_img = np.random.uniform( | |
-90., 90., content_img.shape).astype(np.float32) | |
init_img = torch.from_numpy(white_noise_img).float().to(device) | |
elif yamlGet("initImage") == 'Gaussian Noise Image': | |
gaussian_noise_img = np.random.normal(loc=0, | |
scale=90., | |
size=content_img.shape).astype( | |
np.float32) | |
init_img = torch.from_numpy(gaussian_noise_img).float().to(device) | |
elif yamlGet("initImage") == 'Content': | |
init_img = content_img | |
else: | |
# init image has same dimension as content image - this is a hard constraint | |
# feature maps need to be of same size for content image and init image | |
style_img_resized = prepare_img(style_img, | |
np.asarray(content_img.shape[2:]), | |
device) | |
init_img = style_img_resized | |
return init_img | |
def prepare_img(img_path, target_shape, device): | |
img = load_image(img_path, target_shape=target_shape) | |
# normalize using ImageNet's mean | |
# [0, 255] range worked much better for me than [0, 1] range (even though PyTorch models were trained on latter) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x.mul(255)), | |
transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL) | |
]) | |
img = transform(img).to(device).unsqueeze(0) | |
return img | |
def save_image(img, img_path): | |
if len(img.shape) == 2: | |
img = np.stack((img, ) * 3, axis=-1) | |
cv.imwrite(img_path, img[:, :, ::-1] | |
) # [:, :, ::-1] converts rgb into bgr (opencv contraint...) | |
def save_optimizing_image(optimizing_img, dump_path, img_id): | |
img_format = (4, '.jpg') | |
saving_freq = yamlGet('reprSavFreq') | |
out_img = optimizing_img.squeeze(axis=0).to('cpu').detach().numpy() | |
out_img = np.moveaxis( | |
out_img, 0, | |
2) # swap channel from 1st to 3rd position: ch, _, _ -> _, _, chr | |
if img_id == yamlGet('iterations') - 1 or \ | |
(saving_freq > 0 and img_id % saving_freq == 0): | |
out_img_name = str(img_id).zfill(img_format[0]) + img_format[1] \ | |
if saving_freq != -1 else None | |
dump_img = np.copy(out_img) | |
dump_img += np.array(IMAGENET_MEAN_255).reshape((1, 1, 3)) | |
dump_img = np.clip(dump_img, 0, 255).astype('uint8') | |
cv.imwrite(os.path.join(dump_path, out_img_name), dump_img[:, :, ::-1]) | |
print(f"{out_img_name} written to {dump_path}") | |
# if should_display: | |
# plt.imshow(np.uint8(get_uint8_range(out_img))) | |
# plt.show() | |
def get_uint8_range(x): | |
if isinstance(x, np.ndarray): | |
x -= np.min(x) | |
x /= np.max(x) | |
x *= 255 | |
return x | |
else: | |
raise ValueError(f'Expected numpy array got {type(x)}') | |
def prepare_model(device): | |
model = yamlGet('model') | |
if model == 'VGG16': | |
model = Vgg16(requires_grad=False, show_progress=True) | |
elif model == 'VGG16-Experimental': | |
model = Vgg16Experimental(requires_grad=False, show_progress=True) | |
elif model == 'VGG19': | |
model = Vgg19(requires_grad=False, show_progress=True) | |
else: | |
raise ValueError(f'{model} not supported.') | |
content_feature_maps_index = model.content_feature_maps_index | |
style_feature_maps_indices = model.style_feature_maps_indices | |
layer_names = list(model.layer_names.keys()) | |
content_fms_index_name = (content_feature_maps_index, | |
layer_names[content_feature_maps_index]) | |
style_fms_indices_names = (style_feature_maps_indices, layer_names) | |
return model.to( | |
device).eval(), content_fms_index_name, style_fms_indices_names | |
def yamlSet(key, value): | |
with open('src/config.yaml', 'r') as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
config[key] = value | |
with open('src/config.yaml', 'w') as f: | |
yaml.dump(config, f, default_flow_style=False) | |
def yamlGet(key): | |
with open('src/config.yaml', 'r') as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
return config[key] | |
def save_numpy_array_as_jpg(array, name): | |
image = Image.fromarray(array) | |
image.save("src/data/" + str(name) + '.jpg') | |
return "src/data/" + str(name) + '.jpg' | |
def gram_matrix(x, should_normalize=True): | |
(b, ch, h, w) = x.size() | |
features = x.view(b, ch, w * h) | |
features_t = features.transpose(1, 2) | |
gram = features.bmm(features_t) | |
if should_normalize: | |
gram /= ch * h * w | |
return gram | |
def total_variation(y): | |
return | |
def getImageAndPath(device): | |
if yamlGet('reconstruct') == 'Content': | |
img_path = yamlGet('contentPath') | |
elif yamlGet('reconstruct') == 'Style': | |
img_path = yamlGet('stylePath') | |
img = prepare_img(img_path, yamlGet('height'), device) | |
return img, img_path | |
def getContentCurrentData(config): | |
current_representation = config.current_set_of_feature_maps[ | |
config.content_feature_maps_index].squeeze(axis=0) | |
loss = torch.nn.MSELoss(reduction='mean')(config.target_representation, | |
current_representation) | |
return loss, current_representation | |
def getStyleCurrentData(config): | |
current_representation = [ | |
gram_matrix(x) | |
for cnt, x in enumerate(config.current_set_of_feature_maps) | |
if cnt in config.style_feature_maps_indices | |
] | |
loss = 0.0 | |
for gram_gt, gram_hat in zip(config.target_style_representation, | |
current_representation): | |
loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) | |
loss /= len(config.target_style_representation) | |
return loss, current_representation | |
def getCurrentData(config): | |
if yamlGet('reconstruct') == 'Content': | |
return getContentCurrentData(config) | |
elif yamlGet('reconstruct') == 'Style': | |
return getStyleCurrentData(config) | |
def getLBFGSReconstructLoss(config, optimizing_img): | |
loss = 0.0 | |
if yamlGet('reconstruct') == 'Content': | |
loss = torch.nn.MSELoss(reduction='mean')( | |
config.target_content_representation, | |
config.neural_net(optimizing_img)[ | |
config.content_feature_maps_index].squeeze(axis=0)) | |
else: | |
config.current_set_of_feature_maps = config.neural_net(optimizing_img) | |
current_style_representation = [ | |
gram_matrix(fmaps) | |
for i, fmaps in enumerate(config.current_set_of_feature_maps) | |
if i in config.style_feature_maps_indices | |
] | |
for gram_gt, gram_hat in zip(config.target_style_representation, | |
current_style_representation): | |
loss += (1 / len(config.target_style_representation)) * \ | |
torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) | |
return loss | |
class Config: | |
def __init__(self): | |
self.target_representation = 0 | |
self.target_content_representation = 0 | |
self.target_style_representation = 0 | |
self.content_feature_maps_index = 0 | |
self.style_feature_maps_indices = 0 | |
self.current_set_of_feature_maps = 0 | |
self.current_representation = 0 | |
self.neural_net = 0 | |
class Images: | |
def getImages(self, device): | |
return [ | |
self.__getContentImage(device), | |
self.__getStyleImage(device), | |
self.__getInitImage(device), | |
] | |
def __getContentImage(self, device): | |
return prepare_img(yamlGet('contentPath'), yamlGet('height'), device) | |
def __getStyleImage(self, device): | |
return prepare_img(yamlGet('stylePath'), yamlGet('height'), device) | |
def __getInitImage(self, device): | |
return getInitImage(self.__getContentImage(device), | |
self.__getStyleImage(device), device) | |
def clearDir(): | |
path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") | |
reconstructPath = os.path.join(path, "reconstruct") | |
transferPath = os.path.join(path, "transfer") | |
for transfer_file in os.scandir(transferPath): | |
os.remove(transfer_file) | |
for reconstruct_file in os.scandir(reconstructPath): | |
os.remove(reconstruct_file) | |