feizhengcong's picture
Upload 198 files
074c857
raw
history blame
20.8 kB
import os
import json
from IPython import display
import random
from torchvision.utils import make_grid
from einops import rearrange
import pandas as pd
import cv2
import numpy as np
from PIL import Image
import pathlib
import torchvision.transforms as T
from .generate import generate, add_noise
from .prompt import sanitize
from .animation import DeformAnimKeys, sample_from_cv2, sample_to_cv2, anim_frame_warp, vid2frames
from .depth import DepthModel
from .colors import maintain_colors
from .load_images import prepare_overlay_mask
def next_seed(args):
if args.seed_behavior == 'iter':
args.seed += 1
elif args.seed_behavior == 'fixed':
pass # always keep seed the same
else:
args.seed = random.randint(0, 2**32 - 1)
return args.seed
def render_image_batch(args, prompts, root):
args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)}
# create output folder for the batch
os.makedirs(args.outdir, exist_ok=True)
if args.save_settings or args.save_samples:
print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*")
# save settings for the batch
if args.save_settings:
filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
with open(filename, "w+", encoding="utf-8") as f:
dictlist = dict(args.__dict__)
del dictlist['master_args']
del dictlist['root']
del dictlist['get_output_folder']
json.dump(dictlist, f, ensure_ascii=False, indent=4)
index = 0
# function for init image batching
init_array = []
if args.use_init:
if args.init_image == "":
raise FileNotFoundError("No path was given for init_image")
if args.init_image.startswith('http://') or args.init_image.startswith('https://'):
init_array.append(args.init_image)
elif not os.path.isfile(args.init_image):
if args.init_image[-1] != "/": # avoids path error by adding / to end if not there
args.init_image += "/"
for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array
if image.split(".")[-1] in ("png", "jpg", "jpeg"):
init_array.append(args.init_image + image)
else:
init_array.append(args.init_image)
else:
init_array = [""]
# when doing large batches don't flood browser with images
clear_between_batches = args.n_batch >= 32
for iprompt, prompt in enumerate(prompts):
args.prompt = prompt
args.clip_prompt = prompt
print(f"Prompt {iprompt+1} of {len(prompts)}")
print(f"{args.prompt}")
all_images = []
for batch_index in range(args.n_batch):
if clear_between_batches and batch_index % 32 == 0:
display.clear_output(wait=True)
print(f"Batch {batch_index+1} of {args.n_batch}")
for image in init_array: # iterates the init images
args.init_image = image
results = generate(args, root)
for image in results:
if args.make_grid:
all_images.append(T.functional.pil_to_tensor(image))
if args.save_samples:
if args.filename_format == "{timestring}_{index}_{prompt}.png":
filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png"
else:
filename = f"{args.timestring}_{index:05}_{args.seed}.png"
image.save(os.path.join(args.outdir, filename))
if args.display_samples:
display.display(image)
index += 1
args.seed = next_seed(args)
#print(len(all_images))
if args.make_grid:
grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))
grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()
filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png"
grid_image = Image.fromarray(grid.astype(np.uint8))
grid_image.save(os.path.join(args.outdir, filename))
display.clear_output(wait=True)
display.display(grid_image)
def render_animation(args, anim_args, animation_prompts, root):
# animations use key framed prompts
args.prompts = animation_prompts
# expand key frame strings to values
keys = DeformAnimKeys(anim_args)
# resume animation
start_frame = 0
if anim_args.resume_from_timestring:
for tmp in os.listdir(args.outdir):
if tmp.split("_")[0] == anim_args.resume_timestring:
start_frame += 1
start_frame = start_frame - 1
# create output folder for the batch
os.makedirs(args.outdir, exist_ok=True)
print(f"Saving animation frames to {args.outdir}")
# save settings for the batch
'''
settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
with open(settings_filename, "w+", encoding="utf-8") as f:
s = {**dict(args.__dict__), **dict(anim_args.__dict__)}
#DGSpitzer: run.py adds these three parameters
del s['master_args']
del s['opt']
del s['root']
del s['get_output_folder']
#print(s)
json.dump(s, f, ensure_ascii=False, indent=4)
'''
# resume from timestring
if anim_args.resume_from_timestring:
args.timestring = anim_args.resume_timestring
# expand prompts out to per-frame
prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])
for i, prompt in animation_prompts.items():
prompt_series[int(i)] = prompt
prompt_series = prompt_series.ffill().bfill()
# check for video inits
using_vid_init = anim_args.animation_mode == 'Video Input'
# load depth model for 3D
predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps
if predict_depths:
depth_model = DepthModel(root.device)
depth_model.load_midas(root.models_path)
if anim_args.midas_weight < 1.0:
depth_model.load_adabins(root.models_path)
else:
depth_model = None
anim_args.save_depth_maps = False
# state for interpolating between diffusion steps
turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence)
turbo_prev_image, turbo_prev_frame_idx = None, 0
turbo_next_image, turbo_next_frame_idx = None, 0
# resume animation
prev_sample = None
color_match_sample = None
if anim_args.resume_from_timestring:
last_frame = start_frame-1
if turbo_steps > 1:
last_frame -= last_frame%turbo_steps
path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png")
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
prev_sample = sample_from_cv2(img)
if anim_args.color_coherence != 'None':
color_match_sample = img
if turbo_steps > 1:
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
start_frame = last_frame+turbo_steps
args.n_samples = 1
frame_idx = start_frame
while frame_idx < anim_args.max_frames:
print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}")
noise = keys.noise_schedule_series[frame_idx]
strength = keys.strength_schedule_series[frame_idx]
contrast = keys.contrast_schedule_series[frame_idx]
depth = None
# emit in-between frames
if turbo_steps > 1:
tween_frame_start_idx = max(0, frame_idx-turbo_steps)
for tween_frame_idx in range(tween_frame_start_idx, frame_idx):
tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx)
print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}")
advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx
advance_next = tween_frame_idx > turbo_next_frame_idx
if depth_model is not None:
assert(turbo_next_image is not None)
depth = depth_model.predict(turbo_next_image, anim_args)
if advance_prev:
turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device)
if advance_next:
turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device)
# Transformed raw image before color coherence and noise. Used for mask overlay
if args.use_mask and args.overlay_mask:
# Apply transforms to the original image
init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
if root.half_precision:
args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device)
else:
args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device)
#Transform the mask image
if args.use_mask:
if args.mask_sample is None:
args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape)
# Transform the mask
mask_image, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
if root.half_precision:
args.mask_sample = sample_from_cv2(mask_image).half().to(root.device)
else:
args.mask_sample = sample_from_cv2(mask_image).to(root.device)
turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx
if turbo_prev_image is not None and tween < 1.0:
img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween
else:
img = turbo_next_image
filename = f"{args.timestring}_{tween_frame_idx:05}.png"
cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))
if anim_args.save_depth_maps:
depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth)
if turbo_next_image is not None:
prev_sample = sample_from_cv2(turbo_next_image)
# apply transforms to previous frame
if prev_sample is not None:
prev_img, depth = anim_frame_warp(prev_sample, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device)
# Transformed raw image before color coherence and noise. Used for mask overlay
if args.use_mask and args.overlay_mask:
# Apply transforms to the original image
init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
if root.half_precision:
args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device)
else:
args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device)
#Transform the mask image
if args.use_mask:
if args.mask_sample is None:
args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape)
# Transform the mask
mask_sample, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device)
if root.half_precision:
args.mask_sample = sample_from_cv2(mask_sample).half().to(root.device)
else:
args.mask_sample = sample_from_cv2(mask_sample).to(root.device)
# apply color matching
if anim_args.color_coherence != 'None':
if color_match_sample is None:
color_match_sample = prev_img.copy()
else:
prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)
# apply scaling
contrast_sample = prev_img * contrast
# apply frame noising
noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)
# use transformed previous frame as init for current
args.use_init = True
if root.half_precision:
args.init_sample = noised_sample.half().to(root.device)
else:
args.init_sample = noised_sample.to(root.device)
args.strength = max(0.0, min(1.0, strength))
# grab prompt for current frame
args.prompt = prompt_series[frame_idx]
args.clip_prompt = args.prompt
print(f"{args.prompt} {args.seed}")
if not using_vid_init:
print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}")
print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}")
print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}")
# grab init image for current frame
if using_vid_init:
init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg")
print(f"Using video init frame {init_frame}")
args.init_image = init_frame
if anim_args.use_mask_video:
mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg")
args.mask_file = mask_frame
# sample the diffusion model
sample, image = generate(args, root, frame_idx, return_latent=False, return_sample=True)
# First image sample used for masking
if not using_vid_init:
prev_sample = sample
if args.use_mask and args.overlay_mask:
if args.init_sample_raw is None:
args.init_sample_raw = sample
if turbo_steps > 1:
turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx
frame_idx += turbo_steps
else:
filename = f"{args.timestring}_{frame_idx:05}.png"
image.save(os.path.join(args.outdir, filename))
if anim_args.save_depth_maps:
depth = depth_model.predict(sample_to_cv2(sample), anim_args)
depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth)
frame_idx += 1
display.clear_output(wait=True)
display.display(image)
args.seed = next_seed(args)
def render_input_video(args, anim_args, animation_prompts, root):
# create a folder for the video input frames to live in
video_in_frame_path = os.path.join(args.outdir, 'inputframes')
os.makedirs(video_in_frame_path, exist_ok=True)
# save the video frames from input video
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...")
vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)
# determine max frames from length of input frames
anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])
args.use_init = True
print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}")
if anim_args.use_mask_video:
# create a folder for the mask video input frames to live in
mask_in_frame_path = os.path.join(args.outdir, 'maskframes')
os.makedirs(mask_in_frame_path, exist_ok=True)
# save the video frames from mask video
print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...")
vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)
args.use_mask = True
args.overlay_mask = True
render_animation(args, anim_args, animation_prompts, root)
def render_interpolation(args, anim_args, animation_prompts, root):
# animations use key framed prompts
args.prompts = animation_prompts
# create output folder for the batch
os.makedirs(args.outdir, exist_ok=True)
print(f"Saving animation frames to {args.outdir}")
# save settings for the batch
settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
with open(settings_filename, "w+", encoding="utf-8") as f:
s = {**dict(args.__dict__), **dict(anim_args.__dict__)}
del s['master_args']
del s['opt']
del s['root']
del s['get_output_folder']
json.dump(s, f, ensure_ascii=False, indent=4)
# Interpolation Settings
args.n_samples = 1
args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available
prompts_c_s = [] # cache all the text embeddings
print(f"Preparing for interpolation of the following...")
for i, prompt in animation_prompts.items():
args.prompt = prompt
args.clip_prompt = args.prompt
# sample the diffusion model
results = generate(args, root, return_c=True)
c, image = results[0], results[1]
prompts_c_s.append(c)
# display.clear_output(wait=True)
display.display(image)
args.seed = next_seed(args)
display.clear_output(wait=True)
print(f"Interpolation start...")
frame_idx = 0
if anim_args.interpolate_key_frames:
for i in range(len(prompts_c_s)-1):
dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]
if dist_frames <= 0:
print("key frames duplicated or reversed. interpolation skipped.")
return
else:
for j in range(dist_frames):
# interpolate the text embedding
prompt1_c = prompts_c_s[i]
prompt2_c = prompts_c_s[i+1]
args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))
# sample the diffusion model
results = generate(args, root)
image = results[0]
filename = f"{args.timestring}_{frame_idx:05}.png"
image.save(os.path.join(args.outdir, filename))
frame_idx += 1
display.clear_output(wait=True)
display.display(image)
args.seed = next_seed(args)
else:
for i in range(len(prompts_c_s)-1):
for j in range(anim_args.interpolate_x_frames+1):
# interpolate the text embedding
prompt1_c = prompts_c_s[i]
prompt2_c = prompts_c_s[i+1]
args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))
# sample the diffusion model
results = generate(args, root)
image = results[0]
filename = f"{args.timestring}_{frame_idx:05}.png"
image.save(os.path.join(args.outdir, filename))
frame_idx += 1
display.clear_output(wait=True)
display.display(image)
args.seed = next_seed(args)
# generate the last prompt
args.init_c = prompts_c_s[-1]
results = generate(args, root)
image = results[0]
filename = f"{args.timestring}_{frame_idx:05}.png"
image.save(os.path.join(args.outdir, filename))
display.clear_output(wait=True)
display.display(image)
args.seed = next_seed(args)
#clear init_c
args.init_c = None