import os import copy import torch import random import argparse import numpy as np from PIL import Image from pathlib import Path from tqdm.auto import tqdm from lib.style_transfer_utils import ( tensor2pil, load_style_transfer_model, run_style_transfer, style_content_image_loader, ) def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--content-imgs", type=str, help="Path to the content images.", required=True ) parser.add_argument( "--style-imgs", type=str, help="Path to the style images.", required=True ) parser.add_argument( "--save-folder", type=str, help="Path to the save the generated images.", required=True, ) parser.add_argument( "--vgg", type=str, help="Path to the pretrained VGG model.", required=True ) parser.add_argument("--cuda", action="store_true", help="use cuda.") parser.add_argument( "--ext", type=str, default="stl", help="extension for generated image." ) parser.add_argument( "--min-step", type=int, default=100, help="minimum iteration steps" ) parser.add_argument( "--max-step", type=int, default=200, help="maximum iteration steps" ) parser.add_argument( "--style-weight", type=float, default=100000, help="weight for style loss" ) parser.add_argument( "--content-weight", type=float, default=2, help="weight for content loss" ) return parser.parse_args() def transfer_style( cnn_path, cimg, simg, min_step=100, max_step=200, style_weight=100000, content_weight=2, device="cpu", ): cnn = load_style_transfer_model(pretrained=cnn_path) content_img, style_img = style_content_image_loader(cimg, simg) input_img = copy.deepcopy(content_img).to(device, torch.float) output = run_style_transfer( cnn, content_img, style_img, input_img, num_steps=random.randint(min_step, max_step), style_weight=style_weight, content_weight=content_weight, device=device, ) return tensor2pil(output[0].detach().cpu()) def main(): args = parse_arguments() if args.cuda and torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") content_images = sorted(Path(args.content_imgs).glob("*")) # with open(Path(args.content_imgs), "r") as f: # lines = f.read() # content_images = lines.split("\n") # content_images = [Path("./content_images") / f for f in content_images] style_images = sorted(Path(args.style_imgs).glob("*")) save_folder = Path(args.save_folder) if not os.path.exists(args.save_folder): print(f"Creating {args.save_folder}") os.makedirs(str(save_folder)) for i, cimg in enumerate(content_images): name, extension = cimg.name.split(".") simg = style_images[i % len(style_images)] output_img = transfer_style( cnn_path=args.vgg, cimg=cimg, simg=simg, min_step=args.min_step, max_step=args.max_step, style_weight=args.style_weight, content_weight=args.content_weight, device=device, ) output_img.save(save_folder / f"{name}.{extension}") if __name__ == "__main__": main()