import os import json from IPython import display import random from torchvision.utils import make_grid from einops import rearrange import pandas as pd import cv2 import numpy as np from PIL import Image import pathlib import torchvision.transforms as T from .generate import generate, add_noise from .prompt import sanitize from .animation import DeformAnimKeys, sample_from_cv2, sample_to_cv2, anim_frame_warp, vid2frames from .depth import DepthModel from .colors import maintain_colors from .load_images import prepare_overlay_mask def next_seed(args): if args.seed_behavior == 'iter': args.seed += 1 elif args.seed_behavior == 'fixed': pass # always keep seed the same else: args.seed = random.randint(0, 2**32 - 1) return args.seed def render_image_batch(args, prompts, root): args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} # create output folder for the batch os.makedirs(args.outdir, exist_ok=True) if args.save_settings or args.save_samples: print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") # save settings for the batch if args.save_settings: filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") with open(filename, "w+", encoding="utf-8") as f: dictlist = dict(args.__dict__) del dictlist['master_args'] del dictlist['root'] del dictlist['get_output_folder'] json.dump(dictlist, f, ensure_ascii=False, indent=4) index = 0 # function for init image batching init_array = [] if args.use_init: if args.init_image == "": raise FileNotFoundError("No path was given for init_image") if args.init_image.startswith('http://') or args.init_image.startswith('https://'): init_array.append(args.init_image) elif not os.path.isfile(args.init_image): if args.init_image[-1] != "/": # avoids path error by adding / to end if not there args.init_image += "/" for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array if image.split(".")[-1] in ("png", "jpg", "jpeg"): init_array.append(args.init_image + image) else: init_array.append(args.init_image) else: init_array = [""] # when doing large batches don't flood browser with images clear_between_batches = args.n_batch >= 32 for iprompt, prompt in enumerate(prompts): args.prompt = prompt args.clip_prompt = prompt print(f"Prompt {iprompt+1} of {len(prompts)}") print(f"{args.prompt}") all_images = [] for batch_index in range(args.n_batch): if clear_between_batches and batch_index % 32 == 0: display.clear_output(wait=True) print(f"Batch {batch_index+1} of {args.n_batch}") for image in init_array: # iterates the init images args.init_image = image results = generate(args, root) for image in results: if args.make_grid: all_images.append(T.functional.pil_to_tensor(image)) if args.save_samples: if args.filename_format == "{timestring}_{index}_{prompt}.png": filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" else: filename = f"{args.timestring}_{index:05}_{args.seed}.png" image.save(os.path.join(args.outdir, filename)) if args.display_samples: display.display(image) index += 1 args.seed = next_seed(args) #print(len(all_images)) if args.make_grid: grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" grid_image = Image.fromarray(grid.astype(np.uint8)) grid_image.save(os.path.join(args.outdir, filename)) display.clear_output(wait=True) display.display(grid_image) def render_animation(args, anim_args, animation_prompts, root): # animations use key framed prompts args.prompts = animation_prompts # expand key frame strings to values keys = DeformAnimKeys(anim_args) # resume animation start_frame = 0 if anim_args.resume_from_timestring: for tmp in os.listdir(args.outdir): if tmp.split("_")[0] == anim_args.resume_timestring: start_frame += 1 start_frame = start_frame - 1 # create output folder for the batch os.makedirs(args.outdir, exist_ok=True) print(f"Saving animation frames to {args.outdir}") # save settings for the batch ''' settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") with open(settings_filename, "w+", encoding="utf-8") as f: s = {**dict(args.__dict__), **dict(anim_args.__dict__)} #DGSpitzer: run.py adds these three parameters del s['master_args'] del s['opt'] del s['root'] del s['get_output_folder'] #print(s) json.dump(s, f, ensure_ascii=False, indent=4) ''' # resume from timestring if anim_args.resume_from_timestring: args.timestring = anim_args.resume_timestring # expand prompts out to per-frame prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) for i, prompt in animation_prompts.items(): prompt_series[int(i)] = prompt prompt_series = prompt_series.ffill().bfill() # check for video inits using_vid_init = anim_args.animation_mode == 'Video Input' # load depth model for 3D predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps if predict_depths: depth_model = DepthModel(root.device) depth_model.load_midas(root.models_path) if anim_args.midas_weight < 1.0: depth_model.load_adabins(root.models_path) else: depth_model = None anim_args.save_depth_maps = False # state for interpolating between diffusion steps turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) turbo_prev_image, turbo_prev_frame_idx = None, 0 turbo_next_image, turbo_next_frame_idx = None, 0 # resume animation prev_sample = None color_match_sample = None if anim_args.resume_from_timestring: last_frame = start_frame-1 if turbo_steps > 1: last_frame -= last_frame%turbo_steps path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") img = cv2.imread(path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) prev_sample = sample_from_cv2(img) if anim_args.color_coherence != 'None': color_match_sample = img if turbo_steps > 1: turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx start_frame = last_frame+turbo_steps args.n_samples = 1 frame_idx = start_frame while frame_idx < anim_args.max_frames: print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") noise = keys.noise_schedule_series[frame_idx] strength = keys.strength_schedule_series[frame_idx] contrast = keys.contrast_schedule_series[frame_idx] depth = None # emit in-between frames if turbo_steps > 1: tween_frame_start_idx = max(0, frame_idx-turbo_steps) for tween_frame_idx in range(tween_frame_start_idx, frame_idx): tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx advance_next = tween_frame_idx > turbo_next_frame_idx if depth_model is not None: assert(turbo_next_image is not None) depth = depth_model.predict(turbo_next_image, anim_args) if advance_prev: turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) if advance_next: turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) # Transformed raw image before color coherence and noise. Used for mask overlay if args.use_mask and args.overlay_mask: # Apply transforms to the original image init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) if root.half_precision: args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) else: args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) #Transform the mask image if args.use_mask: if args.mask_sample is None: args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) # Transform the mask mask_image, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) if root.half_precision: args.mask_sample = sample_from_cv2(mask_image).half().to(root.device) else: args.mask_sample = sample_from_cv2(mask_image).to(root.device) turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx if turbo_prev_image is not None and tween < 1.0: img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween else: img = turbo_next_image filename = f"{args.timestring}_{tween_frame_idx:05}.png" cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) if anim_args.save_depth_maps: depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) if turbo_next_image is not None: prev_sample = sample_from_cv2(turbo_next_image) # apply transforms to previous frame if prev_sample is not None: prev_img, depth = anim_frame_warp(prev_sample, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device) # Transformed raw image before color coherence and noise. Used for mask overlay if args.use_mask and args.overlay_mask: # Apply transforms to the original image init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) if root.half_precision: args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) else: args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) #Transform the mask image if args.use_mask: if args.mask_sample is None: args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) # Transform the mask mask_sample, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) if root.half_precision: args.mask_sample = sample_from_cv2(mask_sample).half().to(root.device) else: args.mask_sample = sample_from_cv2(mask_sample).to(root.device) # apply color matching if anim_args.color_coherence != 'None': if color_match_sample is None: color_match_sample = prev_img.copy() else: prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) # apply scaling contrast_sample = prev_img * contrast # apply frame noising noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) # use transformed previous frame as init for current args.use_init = True if root.half_precision: args.init_sample = noised_sample.half().to(root.device) else: args.init_sample = noised_sample.to(root.device) args.strength = max(0.0, min(1.0, strength)) # grab prompt for current frame args.prompt = prompt_series[frame_idx] args.clip_prompt = args.prompt print(f"{args.prompt} {args.seed}") if not using_vid_init: print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}") print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}") print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}") # grab init image for current frame if using_vid_init: init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg") print(f"Using video init frame {init_frame}") args.init_image = init_frame if anim_args.use_mask_video: mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg") args.mask_file = mask_frame # sample the diffusion model sample, image = generate(args, root, frame_idx, return_latent=False, return_sample=True) # First image sample used for masking if not using_vid_init: prev_sample = sample if args.use_mask and args.overlay_mask: if args.init_sample_raw is None: args.init_sample_raw = sample if turbo_steps > 1: turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx frame_idx += turbo_steps else: filename = f"{args.timestring}_{frame_idx:05}.png" image.save(os.path.join(args.outdir, filename)) if anim_args.save_depth_maps: depth = depth_model.predict(sample_to_cv2(sample), anim_args) depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) frame_idx += 1 display.clear_output(wait=True) display.display(image) args.seed = next_seed(args) def render_input_video(args, anim_args, animation_prompts, root): # create a folder for the video input frames to live in video_in_frame_path = os.path.join(args.outdir, 'inputframes') os.makedirs(video_in_frame_path, exist_ok=True) # save the video frames from input video print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) # determine max frames from length of input frames anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) args.use_init = True print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") if anim_args.use_mask_video: # create a folder for the mask video input frames to live in mask_in_frame_path = os.path.join(args.outdir, 'maskframes') os.makedirs(mask_in_frame_path, exist_ok=True) # save the video frames from mask video print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) args.use_mask = True args.overlay_mask = True render_animation(args, anim_args, animation_prompts, root) def render_interpolation(args, anim_args, animation_prompts, root): # animations use key framed prompts args.prompts = animation_prompts # create output folder for the batch os.makedirs(args.outdir, exist_ok=True) print(f"Saving animation frames to {args.outdir}") # save settings for the batch settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") with open(settings_filename, "w+", encoding="utf-8") as f: s = {**dict(args.__dict__), **dict(anim_args.__dict__)} del s['master_args'] del s['opt'] del s['root'] del s['get_output_folder'] json.dump(s, f, ensure_ascii=False, indent=4) # Interpolation Settings args.n_samples = 1 args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available prompts_c_s = [] # cache all the text embeddings print(f"Preparing for interpolation of the following...") for i, prompt in animation_prompts.items(): args.prompt = prompt args.clip_prompt = args.prompt # sample the diffusion model results = generate(args, root, return_c=True) c, image = results[0], results[1] prompts_c_s.append(c) # display.clear_output(wait=True) display.display(image) args.seed = next_seed(args) display.clear_output(wait=True) print(f"Interpolation start...") frame_idx = 0 if anim_args.interpolate_key_frames: for i in range(len(prompts_c_s)-1): dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] if dist_frames <= 0: print("key frames duplicated or reversed. interpolation skipped.") return else: for j in range(dist_frames): # interpolate the text embedding prompt1_c = prompts_c_s[i] prompt2_c = prompts_c_s[i+1] args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) # sample the diffusion model results = generate(args, root) image = results[0] filename = f"{args.timestring}_{frame_idx:05}.png" image.save(os.path.join(args.outdir, filename)) frame_idx += 1 display.clear_output(wait=True) display.display(image) args.seed = next_seed(args) else: for i in range(len(prompts_c_s)-1): for j in range(anim_args.interpolate_x_frames+1): # interpolate the text embedding prompt1_c = prompts_c_s[i] prompt2_c = prompts_c_s[i+1] args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) # sample the diffusion model results = generate(args, root) image = results[0] filename = f"{args.timestring}_{frame_idx:05}.png" image.save(os.path.join(args.outdir, filename)) frame_idx += 1 display.clear_output(wait=True) display.display(image) args.seed = next_seed(args) # generate the last prompt args.init_c = prompts_c_s[-1] results = generate(args, root) image = results[0] filename = f"{args.timestring}_{frame_idx:05}.png" image.save(os.path.join(args.outdir, filename)) display.clear_output(wait=True) display.display(image) args.seed = next_seed(args) #clear init_c args.init_c = None