fffiloni commited on
Commit
6cd5133
·
verified ·
1 Parent(s): 2a5a725

Create hf_gradio_app.py

Browse files
Files changed (1) hide show
  1. hf_gradio_app.py +178 -0
hf_gradio_app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, random, time
2
+ from huggingface_hub import snapshot_download
3
+
4
+ # Download models
5
+ os.makedirs("checkpoints", exist_ok=True)
6
+
7
+ # List of subdirectories to create inside "checkpoints"
8
+ subfolders = [
9
+ "vae",
10
+ "wav2vec2",
11
+ "emotion2vec_plus_large"
12
+ ]
13
+
14
+ # Create each subdirectory
15
+ for subfolder in subfolders:
16
+ os.makedirs(os.path.join("checkpoints", subfolder), exist_ok=True)
17
+
18
+ snapshot_download(
19
+ repo_id = "memoavatar/memo",
20
+ local_dir = "./checkpoints"
21
+ )
22
+
23
+ snapshot_download(
24
+ repo_id = "stabilityai/sd-vae-ft-mse",
25
+ local_dir = "./checkpoints/vae"
26
+ )
27
+
28
+ snapshot_download(
29
+ repo_id = "facebook/wav2vec2-base-960h",
30
+ local_dir = "./checkpoints/wav2vec2"
31
+ )
32
+
33
+ snapshot_download(
34
+ repo_id = "emotion2vec/emotion2vec_plus_large",
35
+ local_dir = "./checkpoints/emotion2vec_plus_large"
36
+ )
37
+
38
+ import torch
39
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
40
+ from tqdm import tqdm
41
+
42
+ from memo.models.audio_proj import AudioProjModel
43
+ from memo.models.image_proj import ImageProjModel
44
+ from memo.models.unet_2d_condition import UNet2DConditionModel
45
+ from memo.models.unet_3d import UNet3DConditionModel
46
+ from memo.pipelines.video_pipeline import VideoPipeline
47
+ from memo.utils.audio_utils import extract_audio_emotion_labels, preprocess_audio, resample_audio
48
+ from memo.utils.vision_utils import preprocess_image, tensor_to_video
49
+
50
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
51
+ weight_dtype = torch.bfloat16
52
+
53
+ with torch.inference_mode():
54
+ vae = AutoencoderKL.from_pretrained("./checkpoints/vae").to(device=device, dtype=weight_dtype)
55
+ reference_net = UNet2DConditionModel.from_pretrained("./checkpoints", subfolder="reference_net", use_safetensors=True)
56
+ diffusion_net = UNet3DConditionModel.from_pretrained("./checkpoints", subfolder="diffusion_net", use_safetensors=True)
57
+ image_proj = ImageProjModel.from_pretrained("./checkpoints", subfolder="image_proj", use_safetensors=True)
58
+ audio_proj = AudioProjModel.from_pretrained("./checkpoints", subfolder="audio_proj", use_safetensors=True)
59
+
60
+ vae.requires_grad_(False).eval()
61
+ reference_net.requires_grad_(False).eval()
62
+ diffusion_net.requires_grad_(False).eval()
63
+ image_proj.requires_grad_(False).eval()
64
+ audio_proj.requires_grad_(False).eval()
65
+ reference_net.enable_xformers_memory_efficient_attention()
66
+ diffusion_net.enable_xformers_memory_efficient_attention()
67
+
68
+ noise_scheduler = FlowMatchEulerDiscreteScheduler()
69
+ pipeline = VideoPipeline(vae=vae, reference_net=reference_net, diffusion_net=diffusion_net, scheduler=noise_scheduler, image_proj=image_proj)
70
+ pipeline.to(device=device, dtype=weight_dtype)
71
+
72
+ @torch.inference_mode()
73
+ def generate(input_video, input_audio, seed):
74
+ resolution = 512
75
+ num_generated_frames_per_clip = 16
76
+ fps = 30
77
+ num_init_past_frames = 2
78
+ num_past_frames = 16
79
+ inference_steps = 20
80
+ cfg_scale = 3.5
81
+
82
+ if seed == 0:
83
+ random.seed(int(time.time()))
84
+ seed = random.randint(0, 18446744073709551615)
85
+
86
+ generator = torch.manual_seed(seed)
87
+ img_size = (resolution, resolution)
88
+ pixel_values, face_emb = preprocess_image(face_analysis_model="./checkpoints/misc/face_analysis", image_path=input_video, image_size=resolution)
89
+
90
+ output_dir = "./outputs"
91
+ os.makedirs(output_dir, exist_ok=True)
92
+ cache_dir = os.path.join(output_dir, "audio_preprocess")
93
+ os.makedirs(cache_dir, exist_ok=True)
94
+ input_audio = resample_audio(input_audio, os.path.join(cache_dir, f"{os.path.basename(input_audio).split('.')[0]}-16k.wav"))
95
+
96
+ audio_emb, audio_length = preprocess_audio(
97
+ wav_path=input_audio,
98
+ num_generated_frames_per_clip=num_generated_frames_per_clip,
99
+ fps=fps,
100
+ wav2vec_model="./checkpoints/wav2vec2",
101
+ vocal_separator_model="./checkpoints/misc/vocal_separator/Kim_Vocal_2.onnx",
102
+ cache_dir=cache_dir,
103
+ device=device,
104
+ )
105
+ audio_emotion, num_emotion_classes = extract_audio_emotion_labels(
106
+ model="./checkpoints",
107
+ wav_path=input_audio,
108
+ emotion2vec_model="./checkpoints/emotion2vec_plus_large",
109
+ audio_length=audio_length,
110
+ device=device,
111
+ )
112
+
113
+ video_frames = []
114
+ num_clips = audio_emb.shape[0] // num_generated_frames_per_clip
115
+ for t in tqdm(range(num_clips), desc="Generating video clips"):
116
+ if len(video_frames) == 0:
117
+ past_frames = pixel_values.repeat(num_init_past_frames, 1, 1, 1)
118
+ past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device)
119
+ pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0)
120
+ else:
121
+ past_frames = video_frames[-1][0]
122
+ past_frames = past_frames.permute(1, 0, 2, 3)
123
+ past_frames = past_frames[0 - num_past_frames :]
124
+ past_frames = past_frames * 2.0 - 1.0
125
+ past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device)
126
+ pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0)
127
+
128
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
129
+ audio_tensor = (audio_emb[t * num_generated_frames_per_clip : min((t + 1) * num_generated_frames_per_clip, audio_emb.shape[0])].unsqueeze(0).to(device=audio_proj.device, dtype=audio_proj.dtype))
130
+ audio_tensor = audio_proj(audio_tensor)
131
+ audio_emotion_tensor = audio_emotion[t * num_generated_frames_per_clip : min((t + 1) * num_generated_frames_per_clip, audio_emb.shape[0])]
132
+
133
+ pipeline_output = pipeline(
134
+ ref_image=pixel_values_ref_img,
135
+ audio_tensor=audio_tensor,
136
+ audio_emotion=audio_emotion_tensor,
137
+ emotion_class_num=num_emotion_classes,
138
+ face_emb=face_emb,
139
+ width=img_size[0],
140
+ height=img_size[1],
141
+ video_length=num_generated_frames_per_clip,
142
+ num_inference_steps=inference_steps,
143
+ guidance_scale=cfg_scale,
144
+ generator=generator,
145
+ )
146
+ video_frames.append(pipeline_output.videos)
147
+
148
+ video_frames = torch.cat(video_frames, dim=2)
149
+ video_frames = video_frames.squeeze(0)
150
+ video_frames = video_frames[:, :audio_length]
151
+
152
+ video_path = f"/content/memo-{seed}-tost.mp4"
153
+ tensor_to_video(video_frames, video_path, input_audio, fps=fps)
154
+
155
+ return video_path
156
+
157
+ import gradio as gr
158
+
159
+ with gr.Blocks(analytics_enabled=False) as demo:
160
+ with gr.Column():
161
+ gr.Markdown("# MEMO")
162
+
163
+ with gr.Row():
164
+ with gr.Column():
165
+ input_video = gr.Image(label="Upload Input Image", type="filepath")
166
+ input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
167
+ seed = gr.Number(label="Seed (0 for Random)", value=0, precision=0)
168
+ with gr.Column():
169
+ video_output = gr.Video(label="Generated Video")
170
+ generate_button = gr.Button("Generate")
171
+
172
+ generate_button.click(
173
+ fn=generate,
174
+ inputs=[input_video, input_audio, seed],
175
+ outputs=[video_output],
176
+ )
177
+
178
+ demo.queue().launch(share=False, show_api=False, show_error=True)