import os import src.utils.utils as utils from src.utils.video_utils import create_video_from_intermediate_results import torch from torch import nn from torch.optim import Adam, LBFGS from torch.autograd import Variable class ContentLoss(nn.Module): def __init__(self, target): super(ContentLoss, self).__init__() self.target = target.detach() def forward(self, current): return nn.MSELoss(reduction='mean')(self.target, current) class StyleLoss(nn.Module): def __init__(self): super(StyleLoss, self).__init__() self.loss = 0.0 def forward(self, x, y): for gram_gt, gram_hat in zip(x, y): self.loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) self.loss /= len(x) return self.loss class Build(nn.Module): def __init__( self, config, target_content_representation, target_style_representation, ): super(Build, self).__init__() self.current_set_of_feature_maps = None self.current_content_representation = None self.current_Style_representation = None self.config = config self.target_content_representation = target_content_representation self.target_style_representation = target_style_representation def forward(self, model, x): self.current_set_of_feature_maps = model(x) self.current_content_representation = self.current_set_of_feature_maps[ self.config.content_feature_maps_index].squeeze(axis=0) self.current_style_representation = [ utils.gram_matrix(x) for cnt, x in enumerate(self.current_set_of_feature_maps) if cnt in self.config.style_feature_maps_indices ] content_loss = ContentLoss(self.target_content_representation)( self.current_content_representation) style_loss = StyleLoss()( self.target_style_representation, self.current_style_representation) tv_loss = TotalVariationLoss(x)() return Loss()(content_loss, style_loss, tv_loss) class TotalVariationLoss(nn.Module): def __init__(self, y): super(TotalVariationLoss, self).__init__() self.first = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) self.second = torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) def forward(self): return self.first + self.second class Loss(nn.Module): def __init__(self): super(Loss, self).__init__() def forward(self, x, y, z): return utils.yamlGet("contentWeight") * x + utils.yamlGet("styleWeight") * y + utils.yamlGet("totalVariationWeight") * z def neural_style_transfer(): dump_path = os.path.join(os.path.dirname(__file__), "data/transfer") config = utils.Config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") content_img, style_img, init_img = utils.Images().getImages(device) optimizing_img = Variable(init_img, requires_grad=True) output = list(utils.prepare_model(device)) neural_net = output[0] content_feature_maps_index_name = output[1] style_feature_maps_indices_names = output[2] config.content_feature_maps_index = content_feature_maps_index_name[0] config.style_feature_maps_indices = style_feature_maps_indices_names[0] content_img_set_of_feature_maps = neural_net(content_img) style_img_set_of_feature_maps = neural_net(style_img) target_content_representation = content_img_set_of_feature_maps[ config.content_feature_maps_index].squeeze(axis=0) target_style_representation = [ utils.gram_matrix(x) for cnt, x in enumerate(style_img_set_of_feature_maps) if cnt in config.style_feature_maps_indices ] if utils.yamlGet('optimizer') == 'Adam': optimizer = Adam((optimizing_img, ), lr=utils.yamlGet('learning_rate')) for cnt in range(utils.yamlGet("iterations")): total_loss = Build(config, target_content_representation, target_style_representation)(neural_net, optimizing_img) total_loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): utils.save_optimizing_image(optimizing_img, dump_path, cnt) elif utils.yamlGet('optimizer') == 'LBFGS': optimizer = LBFGS((optimizing_img, ), max_iter=utils.yamlGet('iterations'), line_search_fn='strong_wolfe') def closure(): total_loss, _, _, _ = build_loss( neural_net, optimizing_img, target_content_representation, target_style_representation, config) total_loss.backward() optimizer.zero_grad() with torch.no_grad(): utils.save_optimizing_image(optimizing_img, dump_path, cnt) return total_loss for cnt in range(utils.yamlGet("iterations")): optimizer.step(closure) create_video_from_intermediate_results(dump_path) # some values of weights that worked for figures.jpg, vg_starry_night.jpg # (starting point for finding good images) # once you understand what each one does it gets really easy -> also see # README.md # lbfgs, content init -> (cw, sw, tv) = (1e5, 3e4, 1e0) # lbfgs, style init -> (cw, sw, tv) = (1e5, 1e1, 1e-1) # lbfgs, random init -> (cw, sw, tv) = (1e5, 1e3, 1e0) # adam, content init -> (cw, sw, tv, lr) = (1e5, 1e5, 1e-1, 1e1) # adam, style init -> (cw, sw, tv, lr) = (1e5, 1e2, 1e-1, 1e1) # adam, random init -> (cw, sw, tv, lr) = (1e5, 1e2, 1e-1, 1e1) # original NST Neural Style Transfer) algorithm (Gatys et al.) # results_path = neural_style_transfer() # create_video_from_intermediate_results(results_path)