Spaces:
Runtime error
Runtime error
# Copyright (c) SenseTime Research. All rights reserved. | |
"""Here we demo style-mixing results using StyleGAN2 pretrained model. | |
Script reference: https://github.com/PDillis/stylegan2-fun """ | |
import moviepy.editor | |
import argparse | |
import legacy | |
import scipy | |
import numpy as np | |
import PIL.Image | |
import dnnlib | |
import dnnlib.tflib as tflib | |
from typing import List | |
import re | |
import sys | |
import os | |
import click | |
import torch | |
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" | |
""" | |
Generate style mixing video. | |
Examples: | |
\b | |
python stylemixing_video.py --network=pretrained_models/stylegan_human_v2_1024.pkl --row-seed=3859 \\ | |
--col-seeds=3098,31759,3791 --col-styles=8-12 --trunc=0.8 --outdir=outputs/stylemixing_video | |
""" | |
def style_mixing_video(network_pkl: str, | |
# Seed of the source image style (row) | |
src_seed: List[int], | |
# Seeds of the destination image styles (columns) | |
dst_seeds: List[int], | |
# Styles to transfer from first row to first column | |
col_styles: List[int], | |
truncation_psi=float, | |
# True if user wishes to show only thre style transferred result | |
only_stylemix=bool, | |
duration_sec=float, | |
smoothing_sec=1.0, | |
mp4_fps=int, | |
mp4_codec="libx264", | |
mp4_bitrate="16M", | |
minibatch_size=8, | |
noise_mode='const', | |
indent_range=int, | |
outdir=str): | |
# Calculate the number of frames: | |
print('col_seeds: ', dst_seeds) | |
num_frames = int(np.rint(duration_sec * mp4_fps)) | |
print('Loading networks from "%s"...' % network_pkl) | |
device = torch.device('cuda') | |
with dnnlib.util.open_url(network_pkl) as f: | |
Gs = legacy.load_network_pkl(f)['G_ema'].to(device) | |
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution) | |
max_style = int(2 * np.log2(Gs.img_resolution)) - 3 | |
assert max( | |
col_styles) <= max_style, f"Maximum col-style allowed: {max_style}" | |
# Left col latents | |
print('Generating Source W vectors...') | |
src_shape = [num_frames] + [Gs.z_dim] | |
src_z = np.random.RandomState( | |
*src_seed).randn(*src_shape).astype(np.float32) # [frames, src, component] | |
src_z = scipy.ndimage.gaussian_filter( | |
src_z, [smoothing_sec * mp4_fps] + [0] * (2 - 1), mode="wrap") | |
src_z /= np.sqrt(np.mean(np.square(src_z))) | |
# Map into the detangled latent space W and do truncation trick | |
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None) | |
w_avg = Gs.mapping.w_avg | |
src_w = w_avg + (src_w - w_avg) * truncation_psi | |
# Top row latents (fixed reference) | |
print('Generating Destination W vectors...') | |
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) | |
for seed in dst_seeds]) | |
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None) | |
dst_w = w_avg + (dst_w - w_avg) * truncation_psi | |
# Get the width and height of each image: | |
H = Gs.img_resolution # 1024 | |
W = Gs.img_resolution//2 # 512 | |
# Generate ALL the source images: | |
src_images = Gs.synthesis(src_w, noise_mode=noise_mode) | |
src_images = (src_images.permute(0, 2, 3, 1) * 127.5 + | |
128).clamp(0, 255).to(torch.uint8) | |
# Generate the column images: | |
dst_images = Gs.synthesis(dst_w, noise_mode=noise_mode) | |
dst_images = (dst_images.permute(0, 2, 3, 1) * 127.5 + | |
128).clamp(0, 255).to(torch.uint8) | |
print('Generating full video (including source and destination images)') | |
# Generate our canvas where we will paste all the generated images: | |
canvas = PIL.Image.new("RGB", (( | |
W-indent_range) * (len(dst_seeds) + 1), H * (len(src_seed) + 1)), "white") # W, H | |
# dst_image:[3,1024,512] | |
for col, dst_image in enumerate(list(dst_images)): | |
canvas.paste(PIL.Image.fromarray(dst_image.cpu().numpy(), | |
"RGB"), ((col + 1) * (W-indent_range), 0)) # H | |
# Aux functions: Frame generation func for moviepy. | |
def make_frame(t): | |
# Get the frame number according to time t: | |
frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) | |
# We wish the image belonging to the frame at time t: | |
src_image = src_images[frame_idx] # always in the same place | |
canvas.paste(PIL.Image.fromarray(src_image.cpu().numpy(), "RGB"), | |
(0-indent_range, H)) # Paste it to the lower left | |
# Now, for each of the column images: | |
for col, dst_image in enumerate(list(dst_images)): | |
# Select the pertinent latent w column: | |
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512] | |
w_col = torch.from_numpy(w_col).to(device) | |
# Replace the values defined by col_styles: | |
w_col[:, col_styles] = src_w[frame_idx, col_styles] # .cpu() | |
# Generate these synthesized images: | |
col_images = Gs.synthesis(w_col, noise_mode=noise_mode) | |
col_images = (col_images.permute(0, 2, 3, 1) * | |
127.5 + 128).clamp(0, 255).to(torch.uint8) | |
# Paste them in their respective spot: | |
for row, image in enumerate(list(col_images)): | |
canvas.paste( | |
PIL.Image.fromarray(image.cpu().numpy(), "RGB"), | |
((col + 1) * (W - indent_range), (row + 1) * H), | |
) | |
return np.array(canvas) | |
# Generate video using make_frame: | |
print('Generating style-mixed video...') | |
videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) | |
grid_size = [len(dst_seeds), len(src_seed)] | |
mp4 = "{}x{}-style-mixing_{}_{}.mp4".format( | |
*grid_size, min(col_styles), max(col_styles)) | |
if not os.path.exists(outdir): | |
os.makedirs(outdir) | |
videoclip.write_videofile(os.path.join(outdir, mp4), | |
fps=mp4_fps, | |
codec=mp4_codec, | |
bitrate=mp4_bitrate) | |
if __name__ == "__main__": | |
style_mixing_video() | |