import datetime import os import sys import uuid import warnings import cv2 import gradio as gr import numpy as np import spaces import torch import torchvision from huggingface_hub import snapshot_download from PIL import Image from scipy.interpolate import PchipInterpolator sys.path.insert(0, os.getcwd()) from gradio_demo.utils_drag import * from models_diffusers.controlnet_svd import ControlNetSVDModel from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from pipelines.pipeline_stable_video_diffusion_interp_control import StableVideoDiffusionInterpControlPipeline print("gr file", gr.__file__) os.makedirs("checkpoints", exist_ok=True) snapshot_download( "wwen1997/framer_512x320", local_dir="checkpoints/framer_512x320", ) snapshot_download( "stabilityai/stable-video-diffusion-img2vid-xt", local_dir="checkpoints/stable-video-diffusion-img2vid-xt", ) model_id = "checkpoints/framer_512x320" device = "cuda" dtype = torch.float16 OUTPUT_DIR = "gradio_demo/outputs" HEIGHT = 320 WIDTH = 512 MODEL_LENGTH = 14 USE_SIFT = False unet = UNetSpatioTemporalConditionModel.from_pretrained( os.path.join(model_id, "unet"), torch_dtype=torch.float16, low_cpu_mem_usage=True, custom_resume=True, ) unet = unet.to(device, dtype) controlnet = ControlNetSVDModel.from_pretrained( os.path.join(model_id, "controlnet"), ) controlnet = controlnet.to(device, dtype) pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained( "checkpoints/stable-video-diffusion-img2vid-xt", unet=unet, controlnet=controlnet, low_cpu_mem_usage=False, torch_dtype=torch.float16, variant="fp16", local_files_only=True, ) pipe.to(device) def interpolate_trajectory(points, n_points): x = [point[0] for point in points] y = [point[1] for point in points] t = np.linspace(0, 1, len(points)) # fx = interp1d(t, x, kind='cubic') # fy = interp1d(t, y, kind='cubic') fx = PchipInterpolator(t, x) fy = PchipInterpolator(t, y) new_t = np.linspace(0, 1, n_points) new_x = fx(new_t) new_y = fy(new_t) new_points = list(zip(new_x, new_y)) return new_points def gen_gaussian_heatmap(imgSize=200): circle_img = np.zeros((imgSize, imgSize), np.float32) circle_mask = cv2.circle(circle_img, (imgSize // 2, imgSize // 2), imgSize // 2, 1, -1) isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32) for i in range(imgSize): for j in range(imgSize): isotropicGrayscaleImage[i, j] = ( 1 / 2 / np.pi / (40**2) * np.exp(-1 / 2 * ((i - imgSize / 2) ** 2 / (40**2) + (j - imgSize / 2) ** 2 / (40**2))) ) isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32) isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage) * 255).astype(np.uint8) return isotropicGrayscaleImage def get_vis_image( target_size=(512, 512), points=None, side=20, num_frames=14, # original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None, ): # images = [] vis_images = [] heatmap = gen_gaussian_heatmap() trajectory_list = [] radius_list = [] for index, point in enumerate(points): trajectories = [[int(i[0]), int(i[1])] for i in point] trajectory_list.append(trajectories) radius = 20 radius_list.append(radius) if len(trajectory_list) == 0: vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)] return vis_images for idxx, point in enumerate(trajectory_list[0]): new_img = np.zeros(target_size, np.uint8) vis_img = new_img.copy() # ids_embedding = torch.zeros((target_size[0], target_size[1], 320)) if idxx >= num_frames: break # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)): for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)): center_coordinate = trajectory[idxx] trajectory_ = trajectory[:idxx] side = min(radius, 50) y1 = max(center_coordinate[1] - side, 0) y2 = min(center_coordinate[1] + side, target_size[0] - 1) x1 = max(center_coordinate[0] - side, 0) x2 = min(center_coordinate[0] + side, target_size[1] - 1) if x2 - x1 > 3 and y2 - y1 > 3: need_map = cv2.resize(heatmap, (x2 - x1, y2 - y1)) new_img[y1:y2, x1:x2] = need_map.copy() if cc >= 0: vis_img[y1:y2, x1:x2] = need_map.copy() if len(trajectory_) == 1: vis_img[trajectory_[0][1], trajectory_[0][0]] = 255 else: for itt in range(len(trajectory_) - 1): cv2.line( vis_img, (trajectory_[itt][0], trajectory_[itt][1]), (trajectory_[itt + 1][0], trajectory_[itt + 1][1]), (255, 255, 255), 3, ) img = new_img # Ensure all images are in RGB format if len(img.shape) == 2: # Grayscale image img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB) elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB) # Convert the numpy array to a PIL image # pil_img = Image.fromarray(img) # images.append(pil_img) vis_images.append(Image.fromarray(vis_img)) return vis_images def frames_to_video(frames_folder, output_video_path, fps=7): frame_files = os.listdir(frames_folder) # sort the frame files by their names frame_files = sorted(frame_files, key=lambda x: int(x.split(".")[0])) video = [] for frame_file in frame_files: frame_path = os.path.join(frames_folder, frame_file) frame = torchvision.io.read_image(frame_path) video.append(frame) video = torch.stack(video) video = rearrange(video, "T C H W -> T H W C") torchvision.io.write_video(output_video_path, video, fps=fps) def save_gifs_side_by_side( batch_output, validation_control_images, output_folder, target_size=(512, 512), duration=200, point_tracks=None, ): flattened_batch_output = batch_output def create_gif(image_list, gif_path, duration=100): pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list] pil_images = [img for img in pil_images if img is not None] if pil_images: pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration) # also save all the pil_images tmp_folder = gif_path.replace(".gif", "") print(tmp_folder) ensure_dirname(tmp_folder) tmp_frame_list = [] for idx, pil_image in enumerate(pil_images): tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png") pil_image.save(tmp_frame_path) tmp_frame_list.append(tmp_frame_path) # also save as mp4 output_video_path = gif_path.replace(".gif", ".mp4") frames_to_video(tmp_folder, output_video_path, fps=7) # Creating GIFs for each image list timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") gif_paths = [] for idx, image_list in enumerate([validation_control_images, flattened_batch_output]): gif_path = os.path.join(output_folder.replace("vis_gif.gif", ""), f"temp_{idx}_{timestamp}.gif") create_gif(image_list, gif_path) gif_paths.append(gif_path) # also save the point_tracks assert point_tracks is not None point_tracks_path = gif_path.replace(".gif", ".npy") np.save(point_tracks_path, point_tracks.cpu().numpy()) # Function to combine GIFs side by side def combine_gifs_side_by_side(gif_paths, output_path): print(gif_paths) gifs = [Image.open(gif) for gif in gif_paths] # Assuming all gifs have the same frame count and duration frames = [] for frame_idx in range(gifs[-1].n_frames): combined_frame = None for gif in gifs: if frame_idx >= gif.n_frames: gif.seek(gif.n_frames - 1) else: gif.seek(frame_idx) if combined_frame is None: combined_frame = gif.copy() else: combined_frame = get_concat_h(combined_frame, gif.copy(), gap=10) frames.append(combined_frame) if output_path.endswith(".mp4"): video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames] video = torch.stack(video) video = rearrange(video, "T C H W -> T H W C") torchvision.io.write_video(output_path, video, fps=7) print(f"Saved video to {output_path}") else: frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration) # Helper function to concatenate images horizontally def get_concat_h(im1, im2, gap=10): # # img first, heatmap second # im1, im2 = im2, im1 dst = Image.new("RGB", (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width + gap, 0)) return dst # Helper function to concatenate images vertically def get_concat_v(im1, im2): dst = Image.new("RGB", (max(im1.width, im2.width), im1.height + im2.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (0, im1.height)) return dst # Combine the GIFs into a single file combined_gif_path = output_folder combine_gifs_side_by_side(gif_paths, combined_gif_path) combined_gif_path_v = gif_path.replace(".gif", "_v.mp4") ensure_dirname(combined_gif_path_v.replace(".mp4", "")) combine_gifs_side_by_side(gif_paths, combined_gif_path_v) # # Clean up temporary GIFs # for gif_path in gif_paths: # os.remove(gif_path) return combined_gif_path # Define functions def validate_and_convert_image(image, target_size=(512, 512)): if image is None: print("Encountered a None image") return None if isinstance(image, torch.Tensor): # Convert PyTorch tensor to PIL Image if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format if image.shape[0] == 1: # Convert single-channel grayscale to RGB image = image.repeat(3, 1, 1) image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() image = Image.fromarray(image) else: print(f"Invalid image tensor shape: {image.shape}") return None elif isinstance(image, Image.Image): # Resize PIL Image image = image.resize(target_size) else: print("Image is not a PIL Image or a PyTorch tensor") return None return image def reset_states(first_frame_path, last_frame_path, tracking_points): first_frame_path = None last_frame_path = None tracking_points = [] return first_frame_path, last_frame_path, tracking_points def preprocess_image(image): image_pil = image2pil(image.name) raw_w, raw_h = image_pil.size # resize_ratio = max(512 / raw_w, 320 / raw_h) # image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR) # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB')) image_pil = image_pil.resize((512, 320), Image.BILINEAR) first_frame_path = os.path.join(OUTPUT_DIR, f"first_frame_{str(uuid.uuid4())[:4]}.png") image_pil.save(first_frame_path) return first_frame_path, first_frame_path, [] def preprocess_image_end(image_end): image_end_pil = image2pil(image_end.name) raw_w, raw_h = image_end_pil.size # resize_ratio = max(512 / raw_w, 320 / raw_h) # image_end_pil = image_end_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR) # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB')) image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR) last_frame_path = os.path.join(OUTPUT_DIR, f"last_frame_{str(uuid.uuid4())[:4]}.png") image_end_pil.save(last_frame_path) return last_frame_path, last_frame_path, [] def add_drag(tracking_points): if not tracking_points or tracking_points[-1]: tracking_points.append([]) return tracking_points def delete_last_drag(tracking_points, first_frame_path, last_frame_path): if tracking_points: tracking_points.pop() transparent_background = Image.open(first_frame_path).convert("RGBA") transparent_background_end = Image.open(last_frame_path).convert("RGBA") w, h = transparent_background.size transparent_layer = np.zeros((h, w, 4)) for track in tracking_points: if len(track) > 1: for i in range(len(track) - 1): start_point = track[i] end_point = track[i + 1] vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(track) - 2: cv2.arrowedLine( transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length, ) else: cv2.line( transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, ) else: cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer) return tracking_points, trajectory_map, trajectory_map_end def delete_last_step(tracking_points, first_frame_path, last_frame_path): if tracking_points and tracking_points[-1]: tracking_points[-1].pop() transparent_background = Image.open(first_frame_path).convert("RGBA") transparent_background_end = Image.open(last_frame_path).convert("RGBA") w, h = transparent_background.size transparent_layer = np.zeros((h, w, 4)) for track in tracking_points: if not track: continue if len(track) > 1: for i in range(len(track) - 1): start_point = track[i] end_point = track[i + 1] vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(track) - 2: cv2.arrowedLine( transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length, ) else: cv2.line( transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, ) else: cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer) return tracking_points, trajectory_map, trajectory_map_end def add_tracking_points( tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData ): # SelectData is a subclass of EventData print(f"You selected {evt.value} at {evt.index} from {evt.target}") if not tracking_points: tracking_points = [[]] tracking_points[-1].append(evt.index) transparent_background = Image.open(first_frame_path).convert("RGBA") transparent_background_end = Image.open(last_frame_path).convert("RGBA") w, h = transparent_background.size transparent_layer = 0 for idx, track in enumerate(tracking_points): # mask = cv2.imread( # os.path.join(OUTPUT_DIR, f"mask_{idx+1}.jpg") # ) mask = np.zeros((320, 512, 3)) color = color_list[idx + 1] transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer if len(track) > 1: for i in range(len(track) - 1): start_point = track[i] end_point = track[i + 1] vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(track) - 2: cv2.arrowedLine( transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length, ) else: cv2.line( transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, ) else: cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) alpha_coef = 0.99 im2_data = transparent_layer.getdata() new_im2_data = [(r, g, b, int(a * alpha_coef)) for r, g, b, a in im2_data] transparent_layer.putdata(new_im2_data) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer) return tracking_points, trajectory_map, trajectory_map_end @spaces.GPU def run(first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id): original_width, original_height = 512, 320 # TODO # load_image image = Image.open(first_frame_path).convert("RGB") width, height = image.size image = image.resize((WIDTH, HEIGHT)) image_end = Image.open(last_frame_path).convert("RGB") image_end = image_end.resize((WIDTH, HEIGHT)) input_all_points = tracking_points sift_track_update = False anchor_points_flag = None if (len(input_all_points) == 0) and USE_SIFT: sift_track_update = True controlnet_cond_scale = 0.5 from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory from models_diffusers.sift_match import sift_match output_file_sift = os.path.join(OUTPUT_DIR, "sift.png") # (f, topk, 2), f=2 (before interpolation) pred_tracks = sift_match( image, image_end, thr=0.5, topk=5, method="random", output_path=output_file_sift, ) if pred_tracks is not None: # interpolate the tracks, following draganything gradio demo pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=MODEL_LENGTH) anchor_points_flag = torch.zeros((MODEL_LENGTH, pred_tracks.shape[1])).to(pred_tracks.device) anchor_points_flag[0] = 1 anchor_points_flag[-1] = 1 pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2) else: resized_all_points = [ tuple([tuple([int(e1[0] * WIDTH / original_width), int(e1[1] * HEIGHT / original_height)]) for e1 in e]) for e in input_all_points ] # a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y) # in image w & h scale for idx, splited_track in enumerate(resized_all_points): if len(splited_track) == 0: warnings.warn("running without point trajectory control") continue if len(splited_track) == 1: # stationary point displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1]) splited_track = tuple([splited_track[0], displacement_point]) # interpolate the track splited_track = interpolate_trajectory(splited_track, MODEL_LENGTH) splited_track = splited_track[:MODEL_LENGTH] resized_all_points[idx] = splited_track pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2) vis_images = get_vis_image( target_size=(HEIGHT, WIDTH), points=pred_tracks, num_frames=MODEL_LENGTH, ) if len(pred_tracks.shape) != 3: print("pred_tracks.shape", pred_tracks.shape) with_control = False controlnet_cond_scale = 0.0 else: with_control = True pred_tracks = pred_tracks.permute(1, 0, 2).to(device, dtype) # (num_frames, num_points, 2) point_embedding = None video_frames = pipe( image, image_end, # trajectory control with_control=with_control, point_tracks=pred_tracks, point_embedding=point_embedding, with_id_feature=False, controlnet_cond_scale=controlnet_cond_scale, # others num_frames=14, width=width, height=height, # decode_chunk_size=8, # generator=generator, motion_bucket_id=motion_bucket_id, fps=7, num_inference_steps=30, # track sift_track_update=sift_track_update, anchor_points_flag=anchor_points_flag, ).frames[0] vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images] vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images] vis_images = [Image.fromarray(img) for img in vis_images] # video_frames = [img for sublist in video_frames for img in sublist] val_save_dir = os.path.join(OUTPUT_DIR, "vis_gif.gif") save_gifs_side_by_side( video_frames, vis_images[:MODEL_LENGTH], val_save_dir, target_size=(WIDTH, HEIGHT), duration=110, point_tracks=pred_tracks, ) return val_save_dir if __name__ == "__main__": ensure_dirname(OUTPUT_DIR) color_list = [] for i in range(20): color = np.concatenate([np.random.random(4) * 255], axis=0) color_list.append(color) with gr.Blocks() as demo: gr.Markdown("""