Spaces:
Runtime error
Runtime error
import os | |
import src.utils.utils as utils | |
from src.utils.video_utils import create_video_from_intermediate_results | |
import torch | |
from torch import nn | |
from torch.optim import Adam, LBFGS | |
from torch.autograd import Variable | |
class ContentLoss(nn.Module): | |
def __init__(self, target): | |
super(ContentLoss, self).__init__() | |
self.target = target.detach() | |
def forward(self, current): | |
return nn.MSELoss(reduction='mean')(self.target, current) | |
class StyleLoss(nn.Module): | |
def __init__(self): | |
super(StyleLoss, self).__init__() | |
self.loss = 0.0 | |
def forward(self, x, y): | |
for gram_gt, gram_hat in zip(x, y): | |
self.loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) | |
self.loss /= len(x) | |
return self.loss | |
class Build(nn.Module): | |
def __init__( | |
self, | |
config, | |
target_content_representation, | |
target_style_representation, | |
): | |
super(Build, self).__init__() | |
self.current_set_of_feature_maps = None | |
self.current_content_representation = None | |
self.current_Style_representation = None | |
self.config = config | |
self.target_content_representation = target_content_representation | |
self.target_style_representation = target_style_representation | |
def forward(self, model, x): | |
self.current_set_of_feature_maps = model(x) | |
self.current_content_representation = self.current_set_of_feature_maps[ | |
self.config.content_feature_maps_index].squeeze(axis=0) | |
self.current_style_representation = [ | |
utils.gram_matrix(x) | |
for cnt, x in enumerate(self.current_set_of_feature_maps) | |
if cnt in self.config.style_feature_maps_indices | |
] | |
content_loss = ContentLoss(self.target_content_representation)( | |
self.current_content_representation) | |
style_loss = StyleLoss()( | |
self.target_style_representation, | |
self.current_style_representation) | |
tv_loss = TotalVariationLoss(x)() | |
return Loss()(content_loss, style_loss, tv_loss) | |
class TotalVariationLoss(nn.Module): | |
def __init__(self, y): | |
super(TotalVariationLoss, self).__init__() | |
self.first = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) | |
self.second = torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) | |
def forward(self): | |
return self.first + self.second | |
class Loss(nn.Module): | |
def __init__(self): | |
super(Loss, self).__init__() | |
def forward(self, x, y, z): | |
return utils.yamlGet("contentWeight") * x + utils.yamlGet("styleWeight") * y + utils.yamlGet("totalVariationWeight") * z | |
def neural_style_transfer(): | |
dump_path = os.path.join(os.path.dirname(__file__), "data/transfer") | |
config = utils.Config() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
content_img, style_img, init_img = utils.Images().getImages(device) | |
optimizing_img = Variable(init_img, requires_grad=True) | |
output = list(utils.prepare_model(device)) | |
neural_net = output[0] | |
content_feature_maps_index_name = output[1] | |
style_feature_maps_indices_names = output[2] | |
config.content_feature_maps_index = content_feature_maps_index_name[0] | |
config.style_feature_maps_indices = style_feature_maps_indices_names[0] | |
content_img_set_of_feature_maps = neural_net(content_img) | |
style_img_set_of_feature_maps = neural_net(style_img) | |
target_content_representation = content_img_set_of_feature_maps[ | |
config.content_feature_maps_index].squeeze(axis=0) | |
target_style_representation = [ | |
utils.gram_matrix(x) | |
for cnt, x in enumerate(style_img_set_of_feature_maps) | |
if cnt in config.style_feature_maps_indices | |
] | |
if utils.yamlGet('optimizer') == 'Adam': | |
optimizer = Adam((optimizing_img, ), lr=utils.yamlGet('learning_rate')) | |
for cnt in range(utils.yamlGet("iterations")): | |
total_loss = Build(config, target_content_representation, | |
target_style_representation)(neural_net, | |
optimizing_img) | |
total_loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
with torch.no_grad(): | |
utils.save_optimizing_image(optimizing_img, dump_path, cnt) | |
elif utils.yamlGet('optimizer') == 'LBFGS': | |
optimizer = LBFGS((optimizing_img, ), | |
max_iter=utils.yamlGet('iterations'), | |
line_search_fn='strong_wolfe') | |
def closure(): | |
total_loss, _, _, _ = build_loss( | |
neural_net, optimizing_img, target_content_representation, | |
target_style_representation, config) | |
total_loss.backward() | |
optimizer.zero_grad() | |
with torch.no_grad(): | |
utils.save_optimizing_image(optimizing_img, dump_path, cnt) | |
return total_loss | |
for cnt in range(utils.yamlGet("iterations")): | |
optimizer.step(closure) | |
create_video_from_intermediate_results(dump_path) | |
# some values of weights that worked for figures.jpg, vg_starry_night.jpg | |
# (starting point for finding good images) | |
# once you understand what each one does it gets really easy -> also see | |
# README.md | |
# lbfgs, content init -> (cw, sw, tv) = (1e5, 3e4, 1e0) | |
# lbfgs, style init -> (cw, sw, tv) = (1e5, 1e1, 1e-1) | |
# lbfgs, random init -> (cw, sw, tv) = (1e5, 1e3, 1e0) | |
# adam, content init -> (cw, sw, tv, lr) = (1e5, 1e5, 1e-1, 1e1) | |
# adam, style init -> (cw, sw, tv, lr) = (1e5, 1e2, 1e-1, 1e1) | |
# adam, random init -> (cw, sw, tv, lr) = (1e5, 1e2, 1e-1, 1e1) | |
# original NST Neural Style Transfer) algorithm (Gatys et al.) | |
# results_path = neural_style_transfer() | |
# create_video_from_intermediate_results(results_path) | |