import cv2 as cv import numpy as np import torch from torchvision import transforms import os import yaml import PIL.Image as Image from src.models.definitions.vgg_nets import Vgg16, Vgg19, Vgg16Experimental IMAGENET_MEAN_255 = [123.675, 116.28, 103.53] IMAGENET_STD_NEUTRAL = [1, 1, 1] def load_image(img_path, target_shape=None): if not os.path.exists(img_path): raise Exception(f'Path does not exist: {img_path}') img = cv.imread(img_path)[:, :, ::-1] if target_shape is not None: # resize section current_height, current_width = img.shape[:2] new_height = target_shape new_width = int(current_width * (new_height / current_height)) img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC) # this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range img = img.astype(np.float32) # convert from uint8 to float32 img /= 255.0 # get to [0, 1] range return img def getInitImage(content_img, style_img, device): if yamlGet("initImage") == 'White Noise Image': white_noise_img = np.random.uniform( -90., 90., content_img.shape).astype(np.float32) init_img = torch.from_numpy(white_noise_img).float().to(device) elif yamlGet("initImage") == 'Gaussian Noise Image': gaussian_noise_img = np.random.normal(loc=0, scale=90., size=content_img.shape).astype( np.float32) init_img = torch.from_numpy(gaussian_noise_img).float().to(device) elif yamlGet("initImage") == 'Content': init_img = content_img else: # init image has same dimension as content image - this is a hard constraint # feature maps need to be of same size for content image and init image style_img_resized = prepare_img(style_img, np.asarray(content_img.shape[2:]), device) init_img = style_img_resized return init_img def prepare_img(img_path, target_shape, device): img = load_image(img_path, target_shape=target_shape) # normalize using ImageNet's mean # [0, 255] range worked much better for me than [0, 1] range (even though PyTorch models were trained on latter) transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)), transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL) ]) img = transform(img).to(device).unsqueeze(0) return img def save_image(img, img_path): if len(img.shape) == 2: img = np.stack((img, ) * 3, axis=-1) cv.imwrite(img_path, img[:, :, ::-1] ) # [:, :, ::-1] converts rgb into bgr (opencv contraint...) def save_optimizing_image(optimizing_img, dump_path, img_id): img_format = (4, '.jpg') saving_freq = yamlGet('reprSavFreq') out_img = optimizing_img.squeeze(axis=0).to('cpu').detach().numpy() out_img = np.moveaxis( out_img, 0, 2) # swap channel from 1st to 3rd position: ch, _, _ -> _, _, chr if img_id == yamlGet('iterations') - 1 or \ (saving_freq > 0 and img_id % saving_freq == 0): out_img_name = str(img_id).zfill(img_format[0]) + img_format[1] \ if saving_freq != -1 else None dump_img = np.copy(out_img) dump_img += np.array(IMAGENET_MEAN_255).reshape((1, 1, 3)) dump_img = np.clip(dump_img, 0, 255).astype('uint8') cv.imwrite(os.path.join(dump_path, out_img_name), dump_img[:, :, ::-1]) print(f"{out_img_name} written to {dump_path}") # if should_display: # plt.imshow(np.uint8(get_uint8_range(out_img))) # plt.show() def get_uint8_range(x): if isinstance(x, np.ndarray): x -= np.min(x) x /= np.max(x) x *= 255 return x else: raise ValueError(f'Expected numpy array got {type(x)}') def prepare_model(device): model = yamlGet('model') if model == 'VGG16': model = Vgg16(requires_grad=False, show_progress=True) elif model == 'VGG16-Experimental': model = Vgg16Experimental(requires_grad=False, show_progress=True) elif model == 'VGG19': model = Vgg19(requires_grad=False, show_progress=True) else: raise ValueError(f'{model} not supported.') content_feature_maps_index = model.content_feature_maps_index style_feature_maps_indices = model.style_feature_maps_indices layer_names = list(model.layer_names.keys()) content_fms_index_name = (content_feature_maps_index, layer_names[content_feature_maps_index]) style_fms_indices_names = (style_feature_maps_indices, layer_names) return model.to( device).eval(), content_fms_index_name, style_fms_indices_names def yamlSet(key, value): with open('src/config.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) config[key] = value with open('src/config.yaml', 'w') as f: yaml.dump(config, f, default_flow_style=False) def yamlGet(key): with open('src/config.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return config[key] def save_numpy_array_as_jpg(array, name): image = Image.fromarray(array) image.save("src/data/" + str(name) + '.jpg') return "src/data/" + str(name) + '.jpg' def gram_matrix(x, should_normalize=True): (b, ch, h, w) = x.size() features = x.view(b, ch, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) if should_normalize: gram /= ch * h * w return gram def total_variation(y): return def getImageAndPath(device): if yamlGet('reconstruct') == 'Content': img_path = yamlGet('contentPath') elif yamlGet('reconstruct') == 'Style': img_path = yamlGet('stylePath') img = prepare_img(img_path, yamlGet('height'), device) return img, img_path def getContentCurrentData(config): current_representation = config.current_set_of_feature_maps[ config.content_feature_maps_index].squeeze(axis=0) loss = torch.nn.MSELoss(reduction='mean')(config.target_representation, current_representation) return loss, current_representation def getStyleCurrentData(config): current_representation = [ gram_matrix(x) for cnt, x in enumerate(config.current_set_of_feature_maps) if cnt in config.style_feature_maps_indices ] loss = 0.0 for gram_gt, gram_hat in zip(config.target_style_representation, current_representation): loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) loss /= len(config.target_style_representation) return loss, current_representation def getCurrentData(config): if yamlGet('reconstruct') == 'Content': return getContentCurrentData(config) elif yamlGet('reconstruct') == 'Style': return getStyleCurrentData(config) def getLBFGSReconstructLoss(config, optimizing_img): loss = 0.0 if yamlGet('reconstruct') == 'Content': loss = torch.nn.MSELoss(reduction='mean')( config.target_content_representation, config.neural_net(optimizing_img)[ config.content_feature_maps_index].squeeze(axis=0)) else: config.current_set_of_feature_maps = config.neural_net(optimizing_img) current_style_representation = [ gram_matrix(fmaps) for i, fmaps in enumerate(config.current_set_of_feature_maps) if i in config.style_feature_maps_indices ] for gram_gt, gram_hat in zip(config.target_style_representation, current_style_representation): loss += (1 / len(config.target_style_representation)) * \ torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) return loss class Config: def __init__(self): self.target_representation = 0 self.target_content_representation = 0 self.target_style_representation = 0 self.content_feature_maps_index = 0 self.style_feature_maps_indices = 0 self.current_set_of_feature_maps = 0 self.current_representation = 0 self.neural_net = 0 class Images: def getImages(self, device): return [ self.__getContentImage(device), self.__getStyleImage(device), self.__getInitImage(device), ] def __getContentImage(self, device): return prepare_img(yamlGet('contentPath'), yamlGet('height'), device) def __getStyleImage(self, device): return prepare_img(yamlGet('stylePath'), yamlGet('height'), device) def __getInitImage(self, device): return getInitImage(self.__getContentImage(device), self.__getStyleImage(device), device) def clearDir(): path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") reconstructPath = os.path.join(path, "reconstruct") transferPath = os.path.join(path, "transfer") for transfer_file in os.scandir(transferPath): os.remove(transfer_file) for reconstruct_file in os.scandir(reconstructPath): os.remove(reconstruct_file)