zejunyang commited on
Commit
8749423
·
1 Parent(s): 0da4ece
Files changed (2) hide show
  1. src/audio2vid.py +172 -172
  2. src/vid2vid.py +1 -1
src/audio2vid.py CHANGED
@@ -4,31 +4,31 @@ from datetime import datetime
4
  from pathlib import Path
5
  import numpy as np
6
  import cv2
7
- import torch
8
- import spaces
9
  from scipy.spatial.transform import Rotation as R
10
  from scipy.interpolate import interp1d
11
 
12
- from diffusers import AutoencoderKL, DDIMScheduler
13
- from einops import repeat
14
- from omegaconf import OmegaConf
15
- from PIL import Image
16
- from torchvision import transforms
17
- from transformers import CLIPVisionModelWithProjection
18
 
19
 
20
- from src.models.pose_guider import PoseGuider
21
- from src.models.unet_2d_condition import UNet2DConditionModel
22
- from src.models.unet_3d import UNet3DConditionModel
23
- from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
- from src.utils.util import save_videos_grid
25
 
26
- from src.audio_models.model import Audio2MeshModel
27
- from src.utils.audio_util import prepare_audio_feature
28
  from src.utils.mp_utils import LMKExtractor
29
- from src.utils.draw_util import FaceMeshVisualizer
30
- from src.utils.pose_util import project_points
31
- from src.utils.crop_face_single import crop_face
32
 
33
 
34
  def matrix_to_euler_and_translation(matrix):
@@ -92,169 +92,169 @@ def get_headpose_temp(input_video):
92
 
93
  return pose_arr_smooth
94
 
95
- @spaces.GPU
96
- def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
97
- fps = 30
98
- cfg = 3.5
99
 
100
- config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
101
 
102
- if config.weight_dtype == "fp16":
103
- weight_dtype = torch.float16
104
- else:
105
- weight_dtype = torch.float32
106
 
107
- audio_infer_config = OmegaConf.load(config.audio_inference_config)
108
- # prepare model
109
- a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
110
- a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
111
- a2m_model.cuda().eval()
112
-
113
- vae = AutoencoderKL.from_pretrained(
114
- config.pretrained_vae_path,
115
- ).to("cuda", dtype=weight_dtype)
116
-
117
- reference_unet = UNet2DConditionModel.from_pretrained(
118
- config.pretrained_base_model_path,
119
- subfolder="unet",
120
- ).to(dtype=weight_dtype, device="cuda")
121
-
122
- inference_config_path = config.inference_config
123
- infer_config = OmegaConf.load(inference_config_path)
124
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
125
- config.pretrained_base_model_path,
126
- config.motion_module_path,
127
- subfolder="unet",
128
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
129
- ).to(dtype=weight_dtype, device="cuda")
130
-
131
-
132
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
133
-
134
- image_enc = CLIPVisionModelWithProjection.from_pretrained(
135
- config.image_encoder_path
136
- ).to(dtype=weight_dtype, device="cuda")
137
-
138
- sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
139
- scheduler = DDIMScheduler(**sched_kwargs)
140
-
141
- generator = torch.manual_seed(seed)
142
-
143
- width, height = size, size
144
-
145
- # load pretrained weights
146
- denoising_unet.load_state_dict(
147
- torch.load(config.denoising_unet_path, map_location="cpu"),
148
- strict=False,
149
- )
150
- reference_unet.load_state_dict(
151
- torch.load(config.reference_unet_path, map_location="cpu"),
152
- )
153
- pose_guider.load_state_dict(
154
- torch.load(config.pose_guider_path, map_location="cpu"),
155
- )
156
-
157
- pipe = Pose2VideoPipeline(
158
- vae=vae,
159
- image_encoder=image_enc,
160
- reference_unet=reference_unet,
161
- denoising_unet=denoising_unet,
162
- pose_guider=pose_guider,
163
- scheduler=scheduler,
164
- )
165
- pipe = pipe.to("cuda", dtype=weight_dtype)
166
-
167
- date_str = datetime.now().strftime("%Y%m%d")
168
- time_str = datetime.now().strftime("%H%M")
169
- save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
170
-
171
- save_dir = Path(f"output/{date_str}/{save_dir_name}")
172
- save_dir.mkdir(exist_ok=True, parents=True)
173
-
174
- lmk_extractor = LMKExtractor()
175
- vis = FaceMeshVisualizer(forehead_edge=False)
176
-
177
- ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
178
- ref_image_np = crop_face(ref_image_np, lmk_extractor)
179
- if ref_image_np is None:
180
- return None, Image.fromarray(ref_img)
181
 
182
- ref_image_np = cv2.resize(ref_image_np, (size, size))
183
- ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
184
 
185
- face_result = lmk_extractor(ref_image_np)
186
- if face_result is None:
187
- return None, ref_image_pil
188
 
189
- lmks = face_result['lmks'].astype(np.float32)
190
- ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
191
 
192
- sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
193
- sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
194
- sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
195
-
196
- # inference
197
- pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
198
- pred = pred.squeeze().detach().cpu().numpy()
199
- pred = pred.reshape(pred.shape[0], -1, 3)
200
- pred = pred + face_result['lmks3d']
201
 
202
- if headpose_video is not None:
203
- pose_seq = get_headpose_temp(headpose_video)
204
- else:
205
- pose_seq = np.load(config['pose_temp'])
206
- mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
207
- cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
208
-
209
- # project 3D mesh to 2D landmark
210
- projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
211
-
212
- pose_images = []
213
- for i, verts in enumerate(projected_vertices):
214
- lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
215
- pose_images.append(lmk_img)
216
-
217
- pose_list = []
218
- pose_tensor_list = []
219
-
220
- pose_transform = transforms.Compose(
221
- [transforms.Resize((height, width)), transforms.ToTensor()]
222
- )
223
- args_L = len(pose_images) if length==0 or length > len(pose_images) else length
224
- args_L = min(args_L, 300)
225
- for pose_image_np in pose_images[: args_L]:
226
- pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
227
- pose_tensor_list.append(pose_transform(pose_image_pil))
228
- pose_image_np = cv2.resize(pose_image_np, (width, height))
229
- pose_list.append(pose_image_np)
230
 
231
- pose_list = np.array(pose_list)
232
 
233
- video_length = len(pose_tensor_list)
234
-
235
- video = pipe(
236
- ref_image_pil,
237
- pose_list,
238
- ref_pose,
239
- width,
240
- height,
241
- video_length,
242
- steps,
243
- cfg,
244
- generator=generator,
245
- ).videos
246
-
247
- save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
248
- save_videos_grid(
249
- video,
250
- save_path,
251
- n_rows=1,
252
- fps=fps,
253
- )
254
 
255
- stream = ffmpeg.input(save_path)
256
- audio = ffmpeg.input(input_audio)
257
- ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
258
- os.remove(save_path)
259
 
260
- return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
 
4
  from pathlib import Path
5
  import numpy as np
6
  import cv2
7
+ # import torch
8
+ # import spaces
9
  from scipy.spatial.transform import Rotation as R
10
  from scipy.interpolate import interp1d
11
 
12
+ # from diffusers import AutoencoderKL, DDIMScheduler
13
+ # from einops import repeat
14
+ # from omegaconf import OmegaConf
15
+ # from PIL import Image
16
+ # from torchvision import transforms
17
+ # from transformers import CLIPVisionModelWithProjection
18
 
19
 
20
+ # from src.models.pose_guider import PoseGuider
21
+ # from src.models.unet_2d_condition import UNet2DConditionModel
22
+ # from src.models.unet_3d import UNet3DConditionModel
23
+ # from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
+ # from src.utils.util import save_videos_grid
25
 
26
+ # from src.audio_models.model import Audio2MeshModel
27
+ # from src.utils.audio_util import prepare_audio_feature
28
  from src.utils.mp_utils import LMKExtractor
29
+ # from src.utils.draw_util import FaceMeshVisualizer
30
+ # from src.utils.pose_util import project_points
31
+ # from src.utils.crop_face_single import crop_face
32
 
33
 
34
  def matrix_to_euler_and_translation(matrix):
 
92
 
93
  return pose_arr_smooth
94
 
95
+ # @spaces.GPU(duration=150)
96
+ # def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
97
+ # fps = 30
98
+ # cfg = 3.5
99
 
100
+ # config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
101
 
102
+ # if config.weight_dtype == "fp16":
103
+ # weight_dtype = torch.float16
104
+ # else:
105
+ # weight_dtype = torch.float32
106
 
107
+ # audio_infer_config = OmegaConf.load(config.audio_inference_config)
108
+ # # prepare model
109
+ # a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
110
+ # a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
111
+ # a2m_model.cuda().eval()
112
+
113
+ # vae = AutoencoderKL.from_pretrained(
114
+ # config.pretrained_vae_path,
115
+ # ).to("cuda", dtype=weight_dtype)
116
+
117
+ # reference_unet = UNet2DConditionModel.from_pretrained(
118
+ # config.pretrained_base_model_path,
119
+ # subfolder="unet",
120
+ # ).to(dtype=weight_dtype, device="cuda")
121
+
122
+ # inference_config_path = config.inference_config
123
+ # infer_config = OmegaConf.load(inference_config_path)
124
+ # denoising_unet = UNet3DConditionModel.from_pretrained_2d(
125
+ # config.pretrained_base_model_path,
126
+ # config.motion_module_path,
127
+ # subfolder="unet",
128
+ # unet_additional_kwargs=infer_config.unet_additional_kwargs,
129
+ # ).to(dtype=weight_dtype, device="cuda")
130
+
131
+
132
+ # pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
133
+
134
+ # image_enc = CLIPVisionModelWithProjection.from_pretrained(
135
+ # config.image_encoder_path
136
+ # ).to(dtype=weight_dtype, device="cuda")
137
+
138
+ # sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
139
+ # scheduler = DDIMScheduler(**sched_kwargs)
140
+
141
+ # generator = torch.manual_seed(seed)
142
+
143
+ # width, height = size, size
144
+
145
+ # # load pretrained weights
146
+ # denoising_unet.load_state_dict(
147
+ # torch.load(config.denoising_unet_path, map_location="cpu"),
148
+ # strict=False,
149
+ # )
150
+ # reference_unet.load_state_dict(
151
+ # torch.load(config.reference_unet_path, map_location="cpu"),
152
+ # )
153
+ # pose_guider.load_state_dict(
154
+ # torch.load(config.pose_guider_path, map_location="cpu"),
155
+ # )
156
+
157
+ # pipe = Pose2VideoPipeline(
158
+ # vae=vae,
159
+ # image_encoder=image_enc,
160
+ # reference_unet=reference_unet,
161
+ # denoising_unet=denoising_unet,
162
+ # pose_guider=pose_guider,
163
+ # scheduler=scheduler,
164
+ # )
165
+ # pipe = pipe.to("cuda", dtype=weight_dtype)
166
+
167
+ # date_str = datetime.now().strftime("%Y%m%d")
168
+ # time_str = datetime.now().strftime("%H%M")
169
+ # save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
170
+
171
+ # save_dir = Path(f"output/{date_str}/{save_dir_name}")
172
+ # save_dir.mkdir(exist_ok=True, parents=True)
173
+
174
+ # lmk_extractor = LMKExtractor()
175
+ # vis = FaceMeshVisualizer(forehead_edge=False)
176
+
177
+ # ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
178
+ # ref_image_np = crop_face(ref_image_np, lmk_extractor)
179
+ # if ref_image_np is None:
180
+ # return None, Image.fromarray(ref_img)
181
 
182
+ # ref_image_np = cv2.resize(ref_image_np, (size, size))
183
+ # ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
184
 
185
+ # face_result = lmk_extractor(ref_image_np)
186
+ # if face_result is None:
187
+ # return None, ref_image_pil
188
 
189
+ # lmks = face_result['lmks'].astype(np.float32)
190
+ # ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
191
 
192
+ # sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
193
+ # sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
194
+ # sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
195
+
196
+ # # inference
197
+ # pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
198
+ # pred = pred.squeeze().detach().cpu().numpy()
199
+ # pred = pred.reshape(pred.shape[0], -1, 3)
200
+ # pred = pred + face_result['lmks3d']
201
 
202
+ # if headpose_video is not None:
203
+ # pose_seq = get_headpose_temp(headpose_video)
204
+ # else:
205
+ # pose_seq = np.load(config['pose_temp'])
206
+ # mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
207
+ # cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
208
+
209
+ # # project 3D mesh to 2D landmark
210
+ # projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
211
+
212
+ # pose_images = []
213
+ # for i, verts in enumerate(projected_vertices):
214
+ # lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
215
+ # pose_images.append(lmk_img)
216
+
217
+ # pose_list = []
218
+ # pose_tensor_list = []
219
+
220
+ # pose_transform = transforms.Compose(
221
+ # [transforms.Resize((height, width)), transforms.ToTensor()]
222
+ # )
223
+ # args_L = len(pose_images) if length==0 or length > len(pose_images) else length
224
+ # args_L = min(args_L, 300)
225
+ # for pose_image_np in pose_images[: args_L]:
226
+ # pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
227
+ # pose_tensor_list.append(pose_transform(pose_image_pil))
228
+ # pose_image_np = cv2.resize(pose_image_np, (width, height))
229
+ # pose_list.append(pose_image_np)
230
 
231
+ # pose_list = np.array(pose_list)
232
 
233
+ # video_length = len(pose_tensor_list)
234
+
235
+ # video = pipe(
236
+ # ref_image_pil,
237
+ # pose_list,
238
+ # ref_pose,
239
+ # width,
240
+ # height,
241
+ # video_length,
242
+ # steps,
243
+ # cfg,
244
+ # generator=generator,
245
+ # ).videos
246
+
247
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
248
+ # save_videos_grid(
249
+ # video,
250
+ # save_path,
251
+ # n_rows=1,
252
+ # fps=fps,
253
+ # )
254
 
255
+ # stream = ffmpeg.input(save_path)
256
+ # audio = ffmpeg.input(input_audio)
257
+ # ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
258
+ # os.remove(save_path)
259
 
260
+ # return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
src/vid2vid.py CHANGED
@@ -26,7 +26,7 @@ from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_t
26
  from src.audio2vid import smooth_pose_seq
27
  from src.utils.crop_face_single import crop_face
28
 
29
- @spaces.GPU
30
  def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
31
  cfg = 3.5
32
 
 
26
  from src.audio2vid import smooth_pose_seq
27
  from src.utils.crop_face_single import crop_face
28
 
29
+ # @spaces.GPU(duration=150)
30
  def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
31
  cfg = 3.5
32