|
import os |
|
import json |
|
from IPython import display |
|
import random |
|
from torchvision.utils import make_grid |
|
from einops import rearrange |
|
import pandas as pd |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import pathlib |
|
import torchvision.transforms as T |
|
|
|
from .generate import generate, add_noise |
|
from .prompt import sanitize |
|
from .animation import DeformAnimKeys, sample_from_cv2, sample_to_cv2, anim_frame_warp, vid2frames |
|
from .depth import DepthModel |
|
from .colors import maintain_colors |
|
from .load_images import prepare_overlay_mask |
|
|
|
def next_seed(args): |
|
if args.seed_behavior == 'iter': |
|
args.seed += 1 |
|
elif args.seed_behavior == 'fixed': |
|
pass |
|
else: |
|
args.seed = random.randint(0, 2**32 - 1) |
|
return args.seed |
|
|
|
def render_image_batch(args, prompts, root): |
|
args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} |
|
|
|
|
|
os.makedirs(args.outdir, exist_ok=True) |
|
if args.save_settings or args.save_samples: |
|
print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") |
|
|
|
|
|
if args.save_settings: |
|
filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") |
|
with open(filename, "w+", encoding="utf-8") as f: |
|
dictlist = dict(args.__dict__) |
|
del dictlist['master_args'] |
|
del dictlist['root'] |
|
del dictlist['get_output_folder'] |
|
json.dump(dictlist, f, ensure_ascii=False, indent=4) |
|
|
|
index = 0 |
|
|
|
|
|
init_array = [] |
|
if args.use_init: |
|
if args.init_image == "": |
|
raise FileNotFoundError("No path was given for init_image") |
|
if args.init_image.startswith('http://') or args.init_image.startswith('https://'): |
|
init_array.append(args.init_image) |
|
elif not os.path.isfile(args.init_image): |
|
if args.init_image[-1] != "/": |
|
args.init_image += "/" |
|
for image in sorted(os.listdir(args.init_image)): |
|
if image.split(".")[-1] in ("png", "jpg", "jpeg"): |
|
init_array.append(args.init_image + image) |
|
else: |
|
init_array.append(args.init_image) |
|
else: |
|
init_array = [""] |
|
|
|
|
|
clear_between_batches = args.n_batch >= 32 |
|
|
|
for iprompt, prompt in enumerate(prompts): |
|
args.prompt = prompt |
|
args.clip_prompt = prompt |
|
print(f"Prompt {iprompt+1} of {len(prompts)}") |
|
print(f"{args.prompt}") |
|
|
|
all_images = [] |
|
|
|
for batch_index in range(args.n_batch): |
|
if clear_between_batches and batch_index % 32 == 0: |
|
display.clear_output(wait=True) |
|
print(f"Batch {batch_index+1} of {args.n_batch}") |
|
|
|
for image in init_array: |
|
args.init_image = image |
|
results = generate(args, root) |
|
for image in results: |
|
if args.make_grid: |
|
all_images.append(T.functional.pil_to_tensor(image)) |
|
if args.save_samples: |
|
if args.filename_format == "{timestring}_{index}_{prompt}.png": |
|
filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" |
|
else: |
|
filename = f"{args.timestring}_{index:05}_{args.seed}.png" |
|
image.save(os.path.join(args.outdir, filename)) |
|
if args.display_samples: |
|
display.display(image) |
|
index += 1 |
|
args.seed = next_seed(args) |
|
|
|
|
|
if args.make_grid: |
|
grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) |
|
grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() |
|
filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" |
|
grid_image = Image.fromarray(grid.astype(np.uint8)) |
|
grid_image.save(os.path.join(args.outdir, filename)) |
|
display.clear_output(wait=True) |
|
display.display(grid_image) |
|
|
|
|
|
def render_animation(args, anim_args, animation_prompts, root): |
|
|
|
args.prompts = animation_prompts |
|
|
|
|
|
keys = DeformAnimKeys(anim_args) |
|
|
|
|
|
start_frame = 0 |
|
if anim_args.resume_from_timestring: |
|
for tmp in os.listdir(args.outdir): |
|
if tmp.split("_")[0] == anim_args.resume_timestring: |
|
start_frame += 1 |
|
start_frame = start_frame - 1 |
|
|
|
|
|
os.makedirs(args.outdir, exist_ok=True) |
|
print(f"Saving animation frames to {args.outdir}") |
|
|
|
|
|
''' |
|
settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") |
|
with open(settings_filename, "w+", encoding="utf-8") as f: |
|
s = {**dict(args.__dict__), **dict(anim_args.__dict__)} |
|
#DGSpitzer: run.py adds these three parameters |
|
del s['master_args'] |
|
del s['opt'] |
|
del s['root'] |
|
del s['get_output_folder'] |
|
#print(s) |
|
json.dump(s, f, ensure_ascii=False, indent=4) |
|
''' |
|
|
|
if anim_args.resume_from_timestring: |
|
args.timestring = anim_args.resume_timestring |
|
|
|
|
|
prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) |
|
for i, prompt in animation_prompts.items(): |
|
prompt_series[int(i)] = prompt |
|
prompt_series = prompt_series.ffill().bfill() |
|
|
|
|
|
using_vid_init = anim_args.animation_mode == 'Video Input' |
|
|
|
|
|
predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps |
|
if predict_depths: |
|
depth_model = DepthModel(root.device) |
|
depth_model.load_midas(root.models_path) |
|
if anim_args.midas_weight < 1.0: |
|
depth_model.load_adabins(root.models_path) |
|
else: |
|
depth_model = None |
|
anim_args.save_depth_maps = False |
|
|
|
|
|
turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) |
|
turbo_prev_image, turbo_prev_frame_idx = None, 0 |
|
turbo_next_image, turbo_next_frame_idx = None, 0 |
|
|
|
|
|
prev_sample = None |
|
color_match_sample = None |
|
if anim_args.resume_from_timestring: |
|
last_frame = start_frame-1 |
|
if turbo_steps > 1: |
|
last_frame -= last_frame%turbo_steps |
|
path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") |
|
img = cv2.imread(path) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
prev_sample = sample_from_cv2(img) |
|
if anim_args.color_coherence != 'None': |
|
color_match_sample = img |
|
if turbo_steps > 1: |
|
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame |
|
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx |
|
start_frame = last_frame+turbo_steps |
|
|
|
args.n_samples = 1 |
|
frame_idx = start_frame |
|
while frame_idx < anim_args.max_frames: |
|
print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") |
|
noise = keys.noise_schedule_series[frame_idx] |
|
strength = keys.strength_schedule_series[frame_idx] |
|
contrast = keys.contrast_schedule_series[frame_idx] |
|
depth = None |
|
|
|
|
|
if turbo_steps > 1: |
|
tween_frame_start_idx = max(0, frame_idx-turbo_steps) |
|
for tween_frame_idx in range(tween_frame_start_idx, frame_idx): |
|
tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) |
|
print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") |
|
|
|
advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx |
|
advance_next = tween_frame_idx > turbo_next_frame_idx |
|
|
|
if depth_model is not None: |
|
assert(turbo_next_image is not None) |
|
depth = depth_model.predict(turbo_next_image, anim_args) |
|
|
|
if advance_prev: |
|
turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) |
|
if advance_next: |
|
turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) |
|
|
|
if args.use_mask and args.overlay_mask: |
|
|
|
init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
|
if root.half_precision: |
|
args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) |
|
else: |
|
args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) |
|
|
|
|
|
if args.use_mask: |
|
if args.mask_sample is None: |
|
args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) |
|
|
|
mask_image, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
|
if root.half_precision: |
|
args.mask_sample = sample_from_cv2(mask_image).half().to(root.device) |
|
else: |
|
args.mask_sample = sample_from_cv2(mask_image).to(root.device) |
|
|
|
turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx |
|
|
|
if turbo_prev_image is not None and tween < 1.0: |
|
img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween |
|
else: |
|
img = turbo_next_image |
|
|
|
filename = f"{args.timestring}_{tween_frame_idx:05}.png" |
|
cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) |
|
if anim_args.save_depth_maps: |
|
depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) |
|
if turbo_next_image is not None: |
|
prev_sample = sample_from_cv2(turbo_next_image) |
|
|
|
|
|
if prev_sample is not None: |
|
prev_img, depth = anim_frame_warp(prev_sample, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device) |
|
|
|
|
|
if args.use_mask and args.overlay_mask: |
|
|
|
init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
|
|
|
if root.half_precision: |
|
args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) |
|
else: |
|
args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) |
|
|
|
|
|
if args.use_mask: |
|
if args.mask_sample is None: |
|
args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) |
|
|
|
mask_sample, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
|
|
|
if root.half_precision: |
|
args.mask_sample = sample_from_cv2(mask_sample).half().to(root.device) |
|
else: |
|
args.mask_sample = sample_from_cv2(mask_sample).to(root.device) |
|
|
|
|
|
if anim_args.color_coherence != 'None': |
|
if color_match_sample is None: |
|
color_match_sample = prev_img.copy() |
|
else: |
|
prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) |
|
|
|
|
|
contrast_sample = prev_img * contrast |
|
|
|
noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) |
|
|
|
|
|
args.use_init = True |
|
if root.half_precision: |
|
args.init_sample = noised_sample.half().to(root.device) |
|
else: |
|
args.init_sample = noised_sample.to(root.device) |
|
args.strength = max(0.0, min(1.0, strength)) |
|
|
|
|
|
args.prompt = prompt_series[frame_idx] |
|
args.clip_prompt = args.prompt |
|
print(f"{args.prompt} {args.seed}") |
|
if not using_vid_init: |
|
print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}") |
|
print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}") |
|
print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}") |
|
|
|
|
|
if using_vid_init: |
|
init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg") |
|
print(f"Using video init frame {init_frame}") |
|
args.init_image = init_frame |
|
if anim_args.use_mask_video: |
|
mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg") |
|
args.mask_file = mask_frame |
|
|
|
|
|
sample, image = generate(args, root, frame_idx, return_latent=False, return_sample=True) |
|
|
|
if not using_vid_init: |
|
prev_sample = sample |
|
if args.use_mask and args.overlay_mask: |
|
if args.init_sample_raw is None: |
|
args.init_sample_raw = sample |
|
|
|
if turbo_steps > 1: |
|
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx |
|
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx |
|
frame_idx += turbo_steps |
|
else: |
|
filename = f"{args.timestring}_{frame_idx:05}.png" |
|
image.save(os.path.join(args.outdir, filename)) |
|
if anim_args.save_depth_maps: |
|
depth = depth_model.predict(sample_to_cv2(sample), anim_args) |
|
depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) |
|
frame_idx += 1 |
|
|
|
display.clear_output(wait=True) |
|
display.display(image) |
|
|
|
args.seed = next_seed(args) |
|
|
|
def render_input_video(args, anim_args, animation_prompts, root): |
|
|
|
video_in_frame_path = os.path.join(args.outdir, 'inputframes') |
|
os.makedirs(video_in_frame_path, exist_ok=True) |
|
|
|
|
|
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") |
|
vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) |
|
|
|
|
|
anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) |
|
args.use_init = True |
|
print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") |
|
|
|
if anim_args.use_mask_video: |
|
|
|
mask_in_frame_path = os.path.join(args.outdir, 'maskframes') |
|
os.makedirs(mask_in_frame_path, exist_ok=True) |
|
|
|
|
|
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") |
|
vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) |
|
args.use_mask = True |
|
args.overlay_mask = True |
|
|
|
render_animation(args, anim_args, animation_prompts, root) |
|
|
|
def render_interpolation(args, anim_args, animation_prompts, root): |
|
|
|
args.prompts = animation_prompts |
|
|
|
|
|
os.makedirs(args.outdir, exist_ok=True) |
|
print(f"Saving animation frames to {args.outdir}") |
|
|
|
|
|
settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") |
|
with open(settings_filename, "w+", encoding="utf-8") as f: |
|
s = {**dict(args.__dict__), **dict(anim_args.__dict__)} |
|
del s['master_args'] |
|
del s['opt'] |
|
del s['root'] |
|
del s['get_output_folder'] |
|
json.dump(s, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
args.n_samples = 1 |
|
args.seed_behavior = 'fixed' |
|
prompts_c_s = [] |
|
|
|
print(f"Preparing for interpolation of the following...") |
|
|
|
for i, prompt in animation_prompts.items(): |
|
args.prompt = prompt |
|
args.clip_prompt = args.prompt |
|
|
|
|
|
results = generate(args, root, return_c=True) |
|
c, image = results[0], results[1] |
|
prompts_c_s.append(c) |
|
|
|
|
|
display.display(image) |
|
|
|
args.seed = next_seed(args) |
|
|
|
display.clear_output(wait=True) |
|
print(f"Interpolation start...") |
|
|
|
frame_idx = 0 |
|
|
|
if anim_args.interpolate_key_frames: |
|
for i in range(len(prompts_c_s)-1): |
|
dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] |
|
if dist_frames <= 0: |
|
print("key frames duplicated or reversed. interpolation skipped.") |
|
return |
|
else: |
|
for j in range(dist_frames): |
|
|
|
prompt1_c = prompts_c_s[i] |
|
prompt2_c = prompts_c_s[i+1] |
|
args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) |
|
|
|
|
|
results = generate(args, root) |
|
image = results[0] |
|
|
|
filename = f"{args.timestring}_{frame_idx:05}.png" |
|
image.save(os.path.join(args.outdir, filename)) |
|
frame_idx += 1 |
|
|
|
display.clear_output(wait=True) |
|
display.display(image) |
|
|
|
args.seed = next_seed(args) |
|
|
|
else: |
|
for i in range(len(prompts_c_s)-1): |
|
for j in range(anim_args.interpolate_x_frames+1): |
|
|
|
prompt1_c = prompts_c_s[i] |
|
prompt2_c = prompts_c_s[i+1] |
|
args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) |
|
|
|
|
|
results = generate(args, root) |
|
image = results[0] |
|
|
|
filename = f"{args.timestring}_{frame_idx:05}.png" |
|
image.save(os.path.join(args.outdir, filename)) |
|
frame_idx += 1 |
|
|
|
display.clear_output(wait=True) |
|
display.display(image) |
|
|
|
args.seed = next_seed(args) |
|
|
|
|
|
args.init_c = prompts_c_s[-1] |
|
results = generate(args, root) |
|
image = results[0] |
|
filename = f"{args.timestring}_{frame_idx:05}.png" |
|
image.save(os.path.join(args.outdir, filename)) |
|
|
|
display.clear_output(wait=True) |
|
display.display(image) |
|
args.seed = next_seed(args) |
|
|
|
|
|
args.init_c = None |
|
|