""" input: json file with video, audio, motion paths output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps preprocess: 1. assume you have a video for one speaker in folder, listed in -- video_a.mp4 -- video_b.mp4 run process_video.py to extract frames and audio """ import os import json import smplx import torch import igraph import numpy as np import subprocess import utils.rotation_conversions as rc from moviepy.editor import VideoClip, AudioFileClip from tqdm import tqdm import imageio import tempfile import argparse import time SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__)) def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device="cuda"): bs, n, _ = motion_tensor.shape motion_tensor = motion_tensor.float().to(device) motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165) output = smplx_model( betas=torch.zeros(bs * n, 300, device=device), transl=torch.zeros(bs * n, 3, device=device), expression=torch.zeros(bs * n, 100, device=device), jaw_pose=torch.zeros(bs * n, 3, device=device), global_orient=torch.zeros(bs * n, 3, device=device), body_pose=motion_tensor_reshaped[:, 3 : 21 * 3 + 3], left_hand_pose=motion_tensor_reshaped[:, 25 * 3 : 40 * 3], right_hand_pose=motion_tensor_reshaped[:, 40 * 3 : 55 * 3], return_joints=True, leye_pose=torch.zeros(bs * n, 3, device=device), reye_pose=torch.zeros(bs * n, 3, device=device), ) joints = output["joints"].reshape(bs, n, 127, 3)[:, :, :55, :] dt = 1 / pose_fps init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt) final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt vel = torch.cat([init_vel, middle_vel, final_vel], dim=1) position = joints rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3)) rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6) init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt) final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3) rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15) return { "position": position, "velocity": vel, "rotation": rot6d, "axis_angle": motion_tensor, "angular_velocity": angular_velocity, "rep15d": rep15d, } def get_motion_reps(motion, smplx_model, pose_fps=30): gt_motion_tensor = motion["poses"] n = gt_motion_tensor.shape[0] bs = 1 gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0) gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165) output = smplx_model( betas=torch.zeros(bs * n, 300).to(device), transl=torch.zeros(bs * n, 3).to(device), expression=torch.zeros(bs * n, 100).to(device), jaw_pose=torch.zeros(bs * n, 3).to(device), global_orient=torch.zeros(bs * n, 3).to(device), body_pose=gt_motion_tensor_reshaped[:, 3 : 21 * 3 + 3], left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3 : 40 * 3], right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3 : 55 * 3], return_joints=True, leye_pose=torch.zeros(bs * n, 3).to(device), reye_pose=torch.zeros(bs * n, 3).to(device), ) joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :] dt = 1 / pose_fps init_vel = (joints[1:2] - joints[0:1]) / dt middle_vel = (joints[2:] - joints[:-2]) / (2 * dt) final_vel = (joints[-1:] - joints[-2:-1]) / dt vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0) position = joints rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0] rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy() init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt) final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3) rep15d = np.concatenate([position, vel, rot6d, angular_velocity], axis=2).reshape(n, 55 * 15) return { "position": position, "velocity": vel, "rotation": rot6d, "axis_angle": motion["poses"], "angular_velocity": angular_velocity, "rep15d": rep15d, "trans": motion["trans"], } def create_graph(json_path, smplx_model): fps = 30 data_meta = json.load(open(json_path, "r")) graph = igraph.Graph(directed=True) global_i = 0 for data_item in data_meta: video_path = os.path.join(data_item["video_path"], data_item["video_id"] + ".mp4") # audio_path = os.path.join(data_item['audio_path'], data_item['video_id'] + ".wav") motion_path = os.path.join(data_item["motion_path"], data_item["video_id"] + ".npz") video_id = data_item.get("video_id", "") motion = np.load(motion_path, allow_pickle=True) motion_reps = get_motion_reps(motion, smplx_model) position = motion_reps["position"] velocity = motion_reps["velocity"] trans = motion_reps["trans"] axis_angle = motion_reps["axis_angle"] # audio, sr = librosa.load(audio_path, sr=None) # audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) all_frames = [] reader = imageio.get_reader(video_path) all_frames = [] for frame in reader: all_frames.append(frame) video_frames = np.array(all_frames) min_frames = min(len(video_frames), position.shape[0]) position = position[:min_frames] velocity = velocity[:min_frames] video_frames = video_frames[:min_frames] # print(min_frames) for i in tqdm(range(min_frames)): if i == 0: previous = -1 next_node = global_i + 1 elif i == min_frames - 1: previous = global_i - 1 next_node = -1 else: previous = global_i - 1 next_node = global_i + 1 graph.add_vertex( idx=global_i, name=video_id, motion=motion_reps, position=position[i], velocity=velocity[i], axis_angle=axis_angle[i], trans=trans[i], # audio=audio[], video=video_frames[i], previous=previous, next=next_node, frame=i, fps=fps, ) global_i += 1 return graph def create_edges(graph): adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4] # print() for i, node in enumerate(graph.vs): current_position = node["position"] current_velocity = node["velocity"] current_trans = node["trans"] # print(current_position.shape, current_velocity.shape) avg_position = np.zeros(current_position.shape[0]) avg_velocity = np.zeros(current_position.shape[0]) avg_trans = 0 count = 0 for node_offset in adaptive_length: idx = i + node_offset if idx < 0 or idx >= len(graph.vs): continue if node_offset < 0: if graph.vs[idx]["next"] == -1: continue else: if graph.vs[idx]["previous"] == -1: continue # add check other_node = graph.vs[idx] other_position = other_node["position"] other_velocity = other_node["velocity"] other_trans = other_node["trans"] # print(other_position.shape, other_velocity.shape) avg_position += np.linalg.norm(current_position - other_position, axis=1) avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1) avg_trans += np.linalg.norm(current_trans - other_trans, axis=0) count += 1 if count == 0: continue threshold_position = avg_position / count threshold_velocity = avg_velocity / count threshold_trans = avg_trans / count # print(threshold_position, threshold_velocity, threshold_trans) for j, other_node in enumerate(graph.vs): if i == j: continue if j == node["previous"] or j == node["next"]: graph.add_edge(i, j, is_continue=1) continue other_position = other_node["position"] other_velocity = other_node["velocity"] other_trans = other_node["trans"] position_similarity = np.linalg.norm(current_position - other_position, axis=1) velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1) trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0) if trans_similarity < threshold_trans: if np.sum(position_similarity < threshold_position) >= 45 and np.sum(velocity_similarity < threshold_velocity) >= 45: graph.add_edge(i, j, is_continue=0) print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}") in_degrees = graph.indegree() out_degrees = graph.outdegree() avg_in_degree = sum(in_degrees) / len(in_degrees) avg_out_degree = sum(out_degrees) / len(out_degrees) print(f"Average In-degree: {avg_in_degree}") print(f"Average Out-degree: {avg_out_degree}") print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") # igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10) return graph def random_walk(graph, walk_length, start_node=None): if start_node is None: start_node = np.random.choice(graph.vs) walk = [start_node] is_continue = [1] for _ in range(walk_length): current_node = walk[-1] neighbor_indices = graph.neighbors(current_node.index, mode="OUT") if not neighbor_indices: break next_idx = np.random.choice(neighbor_indices) edge_id = graph.get_eid(current_node.index, next_idx) is_cont = graph.es[edge_id]["is_continue"] walk.append(graph.vs[next_idx]) is_continue.append(is_cont) return walk, is_continue def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): all_frames = [node["video"] for node in path] average_dis_continue = 1 - sum(is_continue) / len(is_continue) if verbose_continue: print("average_dis_continue:", average_dis_continue) fps = graph.vs[0]["fps"] duration = len(all_frames) / fps def make_frame(t): idx = min(int(t * fps), len(all_frames) - 1) return all_frames[idx] video_only_path = f"/tmp/video_only_{time.time()}.mp4" # Temporary file video_clip = VideoClip(make_frame, duration=duration) video_clip.write_videofile(video_only_path, codec="libx264", fps=fps, audio=False) # Optionally, ensure audio and video durations match if audio_path is not None: audio_clip = AudioFileClip(audio_path) video_duration = video_clip.duration audio_duration = audio_clip.duration if audio_duration > video_duration: # Trim the audio trimmed_audio_path = "trimmed_audio.aac" audio_clip = audio_clip.subclip(0, video_duration) audio_clip.write_audiofile(trimmed_audio_path) audio_input = trimmed_audio_path else: audio_input = audio_path # Use FFmpeg to combine video and audio ffmpeg_command = [ "ffmpeg", "-y", "-i", video_only_path, "-i", audio_input, "-c:v", "copy", "-c:a", "aac", "-strict", "experimental", save_path, ] subprocess.check_call(ffmpeg_command) # Clean up temporary files if necessary os.remove(video_only_path) if audio_input != audio_path: os.remove(audio_input) if return_motion: all_motion = [node["axis_angle"] for node in path] all_motion = np.stack(all_motion, 0) return all_motion def generate_transition_video(frame_start_path, frame_end_path, output_video_path): import subprocess import os # Define the path to your model and inference script model_path = os.path.join(SCRIPT_PATH, "frame-interpolation-pytorch/film_net_fp32.pt") inference_script = os.path.join(SCRIPT_PATH, "frame-interpolation-pytorch/inference.py") # Build the command to run the inference script command = [ "python", inference_script, model_path, frame_start_path, frame_end_path, "--save_path", output_video_path, "--gpu", "--frames", "3", "--fps", "30", ] # Run the command try: subprocess.run(command, check=True) print(f"Generated transition video saved at {output_video_path}") except subprocess.CalledProcessError as e: print(f"Error occurred while generating transition video: {e}") def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): """ this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method """ all_frames = [node["video"] for node in path] average_dis_continue = 1 - sum(is_continue) / len(is_continue) if verbose_continue: print("average_dis_continue:", average_dis_continue) duration = len(all_frames) / graph.vs[0]["fps"] # First loop: Confirm where blending is needed discontinuity_indices = [] for i, cont in enumerate(is_continue): if cont == 0: discontinuity_indices.append(i) # Identify blending positions without overlapping blend_positions = [] processed_frames = set() for i in discontinuity_indices: # Define the frames for blending: i-2 to i+2 start_idx = i - 2 end_idx = i + 2 # Check index boundaries if start_idx < 0 or end_idx >= len(all_frames): continue # Skip if indices are out of bounds # Check for overlapping frames overlap = any(idx in processed_frames for idx in range(i - 1, i + 2)) if overlap: continue # Skip if frames have been processed # Mark frames as processed processed_frames.update(range(i - 1, i + 2)) blend_positions.append(i) # Second loop: Perform blending temp_dir = tempfile.mkdtemp(prefix="blending_frames_") for i in tqdm(blend_positions): start_frame_idx = i - 2 end_frame_idx = i + 2 frame_start = all_frames[start_frame_idx] frame_end = all_frames[end_frame_idx] frame_start_path = os.path.join(temp_dir, f"frame_{start_frame_idx}.png") frame_end_path = os.path.join(temp_dir, f"frame_{end_frame_idx}.png") # Save the start and end frames as images imageio.imwrite(frame_start_path, frame_start) imageio.imwrite(frame_end_path, frame_end) # Call FiLM API to generate video generated_video_path = os.path.join(temp_dir, f"generated_{start_frame_idx}_{end_frame_idx}.mp4") generate_transition_video(frame_start_path, frame_end_path, generated_video_path) # Read the generated video frames reader = imageio.get_reader(generated_video_path) generated_frames = [frame for frame in reader] reader.close() # Replace the middle three frames (i-1, i, i+1) in all_frames total_generated_frames = len(generated_frames) if total_generated_frames < 5: print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.") continue middle_start = 1 # Start index for middle 3 frames middle_frames = generated_frames[middle_start : middle_start + 3] for idx, frame_idx in enumerate(range(i - 1, i + 2)): all_frames[frame_idx] = middle_frames[idx] # Create the video clip def make_frame(t): idx = min(int(t * graph.vs[0]["fps"]), len(all_frames) - 1) return all_frames[idx] video_clip = VideoClip(make_frame, duration=duration) if audio_path is not None: audio_clip = AudioFileClip(audio_path) video_clip = video_clip.set_audio(audio_clip) video_clip.write_videofile(save_path, codec="libx264", fps=graph.vs[0]["fps"], audio_codec="aac") if return_motion: all_motion = [node["axis_angle"] for node in path] all_motion = np.stack(all_motion, 0) return all_motion def graph_pruning(graph): ascc = graph.clusters(mode="STRONG") lascc = ascc.giant() print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}") print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}") in_degrees = lascc.indegree() out_degrees = lascc.outdegree() avg_in_degree = sum(in_degrees) / len(in_degrees) avg_out_degree = sum(out_degrees) / len(out_degrees) print(f"Average In-degree: {avg_in_degree}") print(f"Average Out-degree: {avg_out_degree}") print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") return lascc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--json_save_path", type=str, default="") parser.add_argument("--graph_save_path", type=str, default="") args = parser.parse_args() json_path = args.json_save_path print("json_path", json_path) graph_path = args.graph_save_path device = torch.device("cuda" if torch.cuda.is_available() else "cpu") smplx_model = ( smplx.create( os.path.join(SCRIPT_PATH, "emage/smplx_models/"), model_type="smplx", gender="NEUTRAL_2020", use_face_contour=False, num_betas=300, num_expression_coeffs=100, ext="npz", use_pca=False, ) .to(device) .eval() ) # single_test # graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json') graph = create_graph(json_path, smplx_model) graph = create_edges(graph) # pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl" # graph = igraph.Graph.Read_Pickle(fname=pool_path) # graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl") walk, is_continue = random_walk(graph, 100) motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True) # print(motion.shape) save_graph = graph.write_pickle(fname=graph_path) graph = graph_pruning(graph) # show-oliver # json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/" # pre_node_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/cached_graph/show_oliver_test/" # for json_file in tqdm(os.listdir(json_path)): # graph = create_graph(os.path.join(json_path, json_file)) # graph = create_edges(graph) # if not len(graph.vs) >= 1500: # print(f"skip: {len(graph.vs)}", json_file) # graph.write_pickle(fname=os.path.join(pre_node_path, json_file.split(".")[0] + ".pkl")) # print(f"Graph saved at {json_file.split('.')[0]}.pkl")