Upload 7 files
Browse files- scripts/app.py +51 -0
- scripts/data_preprocess.py +191 -0
- scripts/extract_meta_info_stage1.py +106 -0
- scripts/extract_meta_info_stage2.py +192 -0
- scripts/inference.py +376 -0
- scripts/train_stage1.py +793 -0
- scripts/train_stage2.py +991 -0
scripts/app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is a gradio web ui.
|
3 |
+
|
4 |
+
The script takes an image and an audio clip, and lets you configure all the
|
5 |
+
variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc.
|
6 |
+
|
7 |
+
Usage:
|
8 |
+
This script can be run from the command line with the following command:
|
9 |
+
|
10 |
+
python scripts/app.py
|
11 |
+
"""
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
from inference import inference_process
|
16 |
+
|
17 |
+
|
18 |
+
def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)):
|
19 |
+
"""
|
20 |
+
Create a gradio interface with the configs.
|
21 |
+
"""
|
22 |
+
_ = progress
|
23 |
+
config = {
|
24 |
+
'source_image': image,
|
25 |
+
'driving_audio': audio,
|
26 |
+
'pose_weight': pose_weight,
|
27 |
+
'face_weight': face_weight,
|
28 |
+
'lip_weight': lip_weight,
|
29 |
+
'face_expand_ratio': face_expand_ratio,
|
30 |
+
'config': 'configs/inference/default.yaml',
|
31 |
+
'checkpoint': None,
|
32 |
+
'output': ".cache/output.mp4"
|
33 |
+
}
|
34 |
+
args = argparse.Namespace()
|
35 |
+
for key, value in config.items():
|
36 |
+
setattr(args, key, value)
|
37 |
+
return inference_process(args)
|
38 |
+
|
39 |
+
app = gr.Interface(
|
40 |
+
fn=predict,
|
41 |
+
inputs=[
|
42 |
+
gr.Image(label="source image (no webp)", type="filepath", format="jpeg"),
|
43 |
+
gr.Audio(label="source audio", type="filepath"),
|
44 |
+
gr.Number(label="pose weight", value=1.0),
|
45 |
+
gr.Number(label="face weight", value=1.0),
|
46 |
+
gr.Number(label="lip weight", value=1.0),
|
47 |
+
gr.Number(label="face expand ratio", value=1.2),
|
48 |
+
],
|
49 |
+
outputs=[gr.Video()],
|
50 |
+
)
|
51 |
+
app.launch()
|
scripts/data_preprocess.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=W1203,W0718
|
2 |
+
"""
|
3 |
+
This module is used to process videos to prepare data for training. It utilizes various libraries and models
|
4 |
+
to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction.
|
5 |
+
The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism,
|
6 |
+
and rank for distributed processing.
|
7 |
+
|
8 |
+
Usage:
|
9 |
+
python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0
|
10 |
+
|
11 |
+
Example:
|
12 |
+
python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0
|
13 |
+
"""
|
14 |
+
import argparse
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
import cv2
|
21 |
+
import torch
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from hallo.datasets.audio_processor import AudioProcessor
|
25 |
+
from hallo.datasets.image_processor import ImageProcessorForDataProcessing
|
26 |
+
from hallo.utils.util import convert_video_to_images, extract_audio_from_videos
|
27 |
+
|
28 |
+
# Configure logging
|
29 |
+
logging.basicConfig(level=logging.INFO,
|
30 |
+
format='%(asctime)s - %(levelname)s - %(message)s')
|
31 |
+
|
32 |
+
|
33 |
+
def setup_directories(video_path: Path) -> dict:
|
34 |
+
"""
|
35 |
+
Setup directories for storing processed files.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
video_path (Path): Path to the video file.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
dict: A dictionary containing paths for various directories.
|
42 |
+
"""
|
43 |
+
base_dir = video_path.parent.parent
|
44 |
+
dirs = {
|
45 |
+
"face_mask": base_dir / "face_mask",
|
46 |
+
"sep_pose_mask": base_dir / "sep_pose_mask",
|
47 |
+
"sep_face_mask": base_dir / "sep_face_mask",
|
48 |
+
"sep_lip_mask": base_dir / "sep_lip_mask",
|
49 |
+
"face_emb": base_dir / "face_emb",
|
50 |
+
"audio_emb": base_dir / "audio_emb"
|
51 |
+
}
|
52 |
+
|
53 |
+
for path in dirs.values():
|
54 |
+
path.mkdir(parents=True, exist_ok=True)
|
55 |
+
|
56 |
+
return dirs
|
57 |
+
|
58 |
+
|
59 |
+
def process_single_video(video_path: Path,
|
60 |
+
output_dir: Path,
|
61 |
+
image_processor: ImageProcessorForDataProcessing,
|
62 |
+
audio_processor: AudioProcessor,
|
63 |
+
step: int) -> None:
|
64 |
+
"""
|
65 |
+
Process a single video file.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
video_path (Path): Path to the video file.
|
69 |
+
output_dir (Path): Directory to save the output.
|
70 |
+
image_processor (ImageProcessorForDataProcessing): Image processor object.
|
71 |
+
audio_processor (AudioProcessor): Audio processor object.
|
72 |
+
gpu_status (bool): Whether to use GPU for processing.
|
73 |
+
"""
|
74 |
+
assert video_path.exists(), f"Video path {video_path} does not exist"
|
75 |
+
dirs = setup_directories(video_path)
|
76 |
+
logging.info(f"Processing video: {video_path}")
|
77 |
+
|
78 |
+
try:
|
79 |
+
if step == 1:
|
80 |
+
images_output_dir = output_dir / 'images' / video_path.stem
|
81 |
+
images_output_dir.mkdir(parents=True, exist_ok=True)
|
82 |
+
images_output_dir = convert_video_to_images(
|
83 |
+
video_path, images_output_dir)
|
84 |
+
logging.info(f"Images saved to: {images_output_dir}")
|
85 |
+
|
86 |
+
audio_output_dir = output_dir / 'audios'
|
87 |
+
audio_output_dir.mkdir(parents=True, exist_ok=True)
|
88 |
+
audio_output_path = audio_output_dir / f'{video_path.stem}.wav'
|
89 |
+
audio_output_path = extract_audio_from_videos(
|
90 |
+
video_path, audio_output_path)
|
91 |
+
logging.info(f"Audio extracted to: {audio_output_path}")
|
92 |
+
|
93 |
+
face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess(
|
94 |
+
images_output_dir)
|
95 |
+
cv2.imwrite(
|
96 |
+
str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask)
|
97 |
+
cv2.imwrite(str(dirs["sep_pose_mask"] /
|
98 |
+
f"{video_path.stem}.png"), sep_pose_mask)
|
99 |
+
cv2.imwrite(str(dirs["sep_face_mask"] /
|
100 |
+
f"{video_path.stem}.png"), sep_face_mask)
|
101 |
+
cv2.imwrite(str(dirs["sep_lip_mask"] /
|
102 |
+
f"{video_path.stem}.png"), sep_lip_mask)
|
103 |
+
else:
|
104 |
+
images_dir = output_dir / "images" / video_path.stem
|
105 |
+
audio_path = output_dir / "audios" / f"{video_path.stem}.wav"
|
106 |
+
_, face_emb, _, _, _ = image_processor.preprocess(images_dir)
|
107 |
+
torch.save(face_emb, str(
|
108 |
+
dirs["face_emb"] / f"{video_path.stem}.pt"))
|
109 |
+
audio_emb, _ = audio_processor.preprocess(audio_path)
|
110 |
+
torch.save(audio_emb, str(
|
111 |
+
dirs["audio_emb"] / f"{video_path.stem}.pt"))
|
112 |
+
except Exception as e:
|
113 |
+
logging.error(f"Failed to process video {video_path}: {e}")
|
114 |
+
|
115 |
+
|
116 |
+
def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None:
|
117 |
+
"""
|
118 |
+
Process all videos in the input list.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
input_video_list (List[Path]): List of video paths to process.
|
122 |
+
output_dir (Path): Directory to save the output.
|
123 |
+
gpu_status (bool): Whether to use GPU for processing.
|
124 |
+
"""
|
125 |
+
face_analysis_model_path = "pretrained_models/face_analysis"
|
126 |
+
landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
|
127 |
+
audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx"
|
128 |
+
wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h'
|
129 |
+
|
130 |
+
audio_processor = AudioProcessor(
|
131 |
+
16000,
|
132 |
+
25,
|
133 |
+
wav2vec_model_path,
|
134 |
+
False,
|
135 |
+
os.path.dirname(audio_separator_model_file),
|
136 |
+
os.path.basename(audio_separator_model_file),
|
137 |
+
os.path.join(output_dir, "vocals"),
|
138 |
+
) if step==2 else None
|
139 |
+
|
140 |
+
image_processor = ImageProcessorForDataProcessing(
|
141 |
+
face_analysis_model_path, landmark_model_path, step)
|
142 |
+
|
143 |
+
for video_path in tqdm(input_video_list, desc="Processing videos"):
|
144 |
+
process_single_video(video_path, output_dir,
|
145 |
+
image_processor, audio_processor, step)
|
146 |
+
|
147 |
+
|
148 |
+
def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]:
|
149 |
+
"""
|
150 |
+
Get paths of videos to process, partitioned for parallel processing.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
source_dir (Path): Source directory containing videos.
|
154 |
+
parallelism (int): Level of parallelism.
|
155 |
+
rank (int): Rank for distributed processing.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
List[Path]: List of video paths to process.
|
159 |
+
"""
|
160 |
+
video_paths = [item for item in sorted(
|
161 |
+
source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4']
|
162 |
+
return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank]
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
parser = argparse.ArgumentParser(
|
167 |
+
description="Process videos to prepare data for training. Run this script twice with different GPU status parameters."
|
168 |
+
)
|
169 |
+
parser.add_argument("-i", "--input_dir", type=Path,
|
170 |
+
required=True, help="Directory containing videos")
|
171 |
+
parser.add_argument("-o", "--output_dir", type=Path,
|
172 |
+
help="Directory to save results, default is parent dir of input dir")
|
173 |
+
parser.add_argument("-s", "--step", type=int, default=1,
|
174 |
+
help="Specify data processing step 1 or 2, you should run 1 and 2 sequently")
|
175 |
+
parser.add_argument("-p", "--parallelism", default=1,
|
176 |
+
type=int, help="Level of parallelism")
|
177 |
+
parser.add_argument("-r", "--rank", default=0, type=int,
|
178 |
+
help="Rank for distributed processing")
|
179 |
+
|
180 |
+
args = parser.parse_args()
|
181 |
+
|
182 |
+
if args.output_dir is None:
|
183 |
+
args.output_dir = args.input_dir.parent
|
184 |
+
|
185 |
+
video_path_list = get_video_paths(
|
186 |
+
args.input_dir, args.parallelism, args.rank)
|
187 |
+
|
188 |
+
if not video_path_list:
|
189 |
+
logging.warning("No videos to process.")
|
190 |
+
else:
|
191 |
+
process_all_videos(video_path_list, args.output_dir, args.step)
|
scripts/extract_meta_info_stage1.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=R0801
|
2 |
+
"""
|
3 |
+
This module is used to extract meta information from video directories.
|
4 |
+
|
5 |
+
It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path`
|
6 |
+
specifies the path to the video directory, while the `dataset_name` specifies the name
|
7 |
+
of the dataset. The module then collects all the video folder paths, and for each video
|
8 |
+
folder, it checks if a mask path and a face embedding path exist. If they do, it appends
|
9 |
+
a dictionary containing the image path, mask path, and face embedding path to a list.
|
10 |
+
|
11 |
+
Finally, the module writes the list of dictionaries to a JSON file with the filename
|
12 |
+
constructed using the `dataset_name`.
|
13 |
+
|
14 |
+
Usage:
|
15 |
+
python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
from pathlib import Path
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
|
27 |
+
def collect_video_folder_paths(root_path: Path) -> list:
|
28 |
+
"""
|
29 |
+
Collect all video folder paths from the root path.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
root_path (Path): The root directory containing video folders.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
list: List of video folder paths.
|
36 |
+
"""
|
37 |
+
return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()]
|
38 |
+
|
39 |
+
|
40 |
+
def construct_meta_info(frames_dir_path: Path) -> dict:
|
41 |
+
"""
|
42 |
+
Construct meta information for a given frames directory.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
frames_dir_path (Path): The path to the frames directory.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist.
|
49 |
+
"""
|
50 |
+
mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png"
|
51 |
+
face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt"
|
52 |
+
|
53 |
+
if not os.path.exists(mask_path):
|
54 |
+
print(f"Mask path not found: {mask_path}")
|
55 |
+
return None
|
56 |
+
|
57 |
+
if torch.load(face_emb_path) is None:
|
58 |
+
print(f"Face emb is None: {face_emb_path}")
|
59 |
+
return None
|
60 |
+
|
61 |
+
return {
|
62 |
+
"image_path": str(frames_dir_path),
|
63 |
+
"mask_path": mask_path,
|
64 |
+
"face_emb": face_emb_path,
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
def main():
|
69 |
+
"""
|
70 |
+
Main function to extract meta info for training.
|
71 |
+
"""
|
72 |
+
parser = argparse.ArgumentParser()
|
73 |
+
parser.add_argument("-r", "--root_path", type=str,
|
74 |
+
required=True, help="Root path of the video directories")
|
75 |
+
parser.add_argument("-n", "--dataset_name", type=str,
|
76 |
+
required=True, help="Name of the dataset")
|
77 |
+
parser.add_argument("--meta_info_name", type=str,
|
78 |
+
help="Name of the meta information file")
|
79 |
+
|
80 |
+
args = parser.parse_args()
|
81 |
+
|
82 |
+
if args.meta_info_name is None:
|
83 |
+
args.meta_info_name = args.dataset_name
|
84 |
+
|
85 |
+
image_dir = Path(args.root_path) / "images"
|
86 |
+
output_dir = Path("./data")
|
87 |
+
output_dir.mkdir(exist_ok=True)
|
88 |
+
|
89 |
+
# Collect all video folder paths
|
90 |
+
frames_dir_paths = collect_video_folder_paths(image_dir)
|
91 |
+
|
92 |
+
meta_infos = []
|
93 |
+
for frames_dir_path in frames_dir_paths:
|
94 |
+
meta_info = construct_meta_info(frames_dir_path)
|
95 |
+
if meta_info:
|
96 |
+
meta_infos.append(meta_info)
|
97 |
+
|
98 |
+
output_file = output_dir / f"{args.meta_info_name}_stage1.json"
|
99 |
+
with output_file.open("w", encoding="utf-8") as f:
|
100 |
+
json.dump(meta_infos, f, indent=4)
|
101 |
+
|
102 |
+
print(f"Final data count: {len(meta_infos)}")
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
main()
|
scripts/extract_meta_info_stage2.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=R0801
|
2 |
+
"""
|
3 |
+
This module is used to extract meta information from video files and store them in a JSON file.
|
4 |
+
|
5 |
+
The script takes in command line arguments to specify the root path of the video files,
|
6 |
+
the dataset name, and the name of the meta information file. It then generates a list of
|
7 |
+
dictionaries containing the meta information for each video file and writes it to a JSON
|
8 |
+
file with the specified name.
|
9 |
+
|
10 |
+
The meta information includes the path to the video file, the mask path, the face mask
|
11 |
+
path, the face mask union path, the face mask gaussian path, the lip mask path, the lip
|
12 |
+
mask union path, the lip mask gaussian path, the separate mask border, the separate mask
|
13 |
+
face, the separate mask lip, the face embedding path, the audio path, the vocals embedding
|
14 |
+
base last path, the vocals embedding base all path, the vocals embedding base average
|
15 |
+
path, the vocals embedding large last path, the vocals embedding large all path, and the
|
16 |
+
vocals embedding large average path.
|
17 |
+
|
18 |
+
The script checks if the mask path exists before adding the information to the list.
|
19 |
+
|
20 |
+
Usage:
|
21 |
+
python tools/extract_meta_info_stage2.py --root_path <root_path> --dataset_name <dataset_name> --meta_info_name <meta_info_name>
|
22 |
+
|
23 |
+
Example:
|
24 |
+
python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info
|
25 |
+
"""
|
26 |
+
|
27 |
+
import argparse
|
28 |
+
import json
|
29 |
+
import os
|
30 |
+
from pathlib import Path
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from decord import VideoReader, cpu
|
34 |
+
from tqdm import tqdm
|
35 |
+
|
36 |
+
|
37 |
+
def get_video_paths(root_path: Path, extensions: list) -> list:
|
38 |
+
"""
|
39 |
+
Get a list of video paths from the root path with the specified extensions.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
root_path (Path): The root directory containing video files.
|
43 |
+
extensions (list): List of file extensions to include.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
list: List of video file paths.
|
47 |
+
"""
|
48 |
+
return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions]
|
49 |
+
|
50 |
+
|
51 |
+
def file_exists(file_path: str) -> bool:
|
52 |
+
"""
|
53 |
+
Check if a file exists.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
file_path (str): The path to the file.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
bool: True if the file exists, False otherwise.
|
60 |
+
"""
|
61 |
+
return os.path.exists(file_path)
|
62 |
+
|
63 |
+
|
64 |
+
def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str:
|
65 |
+
"""
|
66 |
+
Construct a new path by replacing the base directory and extension in the original path.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
video_path (str): The original video path.
|
70 |
+
base_dir (str): The base directory to be replaced.
|
71 |
+
new_dir (str): The new directory to replace the base directory.
|
72 |
+
new_ext (str): The new file extension.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
str: The constructed path.
|
76 |
+
"""
|
77 |
+
return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext)
|
78 |
+
|
79 |
+
|
80 |
+
def extract_meta_info(video_path: str) -> dict:
|
81 |
+
"""
|
82 |
+
Extract meta information for a given video file.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
video_path (str): The path to the video file.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
dict: A dictionary containing the meta information for the video.
|
89 |
+
"""
|
90 |
+
mask_path = construct_paths(
|
91 |
+
video_path, "videos", "face_mask", ".png")
|
92 |
+
sep_mask_border = construct_paths(
|
93 |
+
video_path, "videos", "sep_pose_mask", ".png")
|
94 |
+
sep_mask_face = construct_paths(
|
95 |
+
video_path, "videos", "sep_face_mask", ".png")
|
96 |
+
sep_mask_lip = construct_paths(
|
97 |
+
video_path, "videos", "sep_lip_mask", ".png")
|
98 |
+
face_emb_path = construct_paths(
|
99 |
+
video_path, "videos", "face_emb", ".pt")
|
100 |
+
audio_path = construct_paths(video_path, "videos", "audios", ".wav")
|
101 |
+
vocal_emb_base_all = construct_paths(
|
102 |
+
video_path, "videos", "audio_emb", ".pt")
|
103 |
+
|
104 |
+
assert_flag = True
|
105 |
+
|
106 |
+
if not file_exists(mask_path):
|
107 |
+
print(f"Mask path not found: {mask_path}")
|
108 |
+
assert_flag = False
|
109 |
+
if not file_exists(sep_mask_border):
|
110 |
+
print(f"Separate mask border not found: {sep_mask_border}")
|
111 |
+
assert_flag = False
|
112 |
+
if not file_exists(sep_mask_face):
|
113 |
+
print(f"Separate mask face not found: {sep_mask_face}")
|
114 |
+
assert_flag = False
|
115 |
+
if not file_exists(sep_mask_lip):
|
116 |
+
print(f"Separate mask lip not found: {sep_mask_lip}")
|
117 |
+
assert_flag = False
|
118 |
+
if not file_exists(face_emb_path):
|
119 |
+
print(f"Face embedding path not found: {face_emb_path}")
|
120 |
+
assert_flag = False
|
121 |
+
if not file_exists(audio_path):
|
122 |
+
print(f"Audio path not found: {audio_path}")
|
123 |
+
assert_flag = False
|
124 |
+
if not file_exists(vocal_emb_base_all):
|
125 |
+
print(f"Vocal embedding base all not found: {vocal_emb_base_all}")
|
126 |
+
assert_flag = False
|
127 |
+
|
128 |
+
video_frames = VideoReader(video_path, ctx=cpu(0))
|
129 |
+
audio_emb = torch.load(vocal_emb_base_all)
|
130 |
+
if abs(len(video_frames) - audio_emb.shape[0]) > 3:
|
131 |
+
print(f"Frame count mismatch for video: {video_path}")
|
132 |
+
assert_flag = False
|
133 |
+
|
134 |
+
face_emb = torch.load(face_emb_path)
|
135 |
+
if face_emb is None:
|
136 |
+
print(f"Face embedding is None for video: {video_path}")
|
137 |
+
assert_flag = False
|
138 |
+
|
139 |
+
del video_frames, audio_emb
|
140 |
+
|
141 |
+
if assert_flag:
|
142 |
+
return {
|
143 |
+
"video_path": str(video_path),
|
144 |
+
"mask_path": mask_path,
|
145 |
+
"sep_mask_border": sep_mask_border,
|
146 |
+
"sep_mask_face": sep_mask_face,
|
147 |
+
"sep_mask_lip": sep_mask_lip,
|
148 |
+
"face_emb_path": face_emb_path,
|
149 |
+
"audio_path": audio_path,
|
150 |
+
"vocals_emb_base_all": vocal_emb_base_all,
|
151 |
+
}
|
152 |
+
return None
|
153 |
+
|
154 |
+
|
155 |
+
def main():
|
156 |
+
"""
|
157 |
+
Main function to extract meta info for training.
|
158 |
+
"""
|
159 |
+
parser = argparse.ArgumentParser()
|
160 |
+
parser.add_argument("-r", "--root_path", type=str,
|
161 |
+
required=True, help="Root path of the video files")
|
162 |
+
parser.add_argument("-n", "--dataset_name", type=str,
|
163 |
+
required=True, help="Name of the dataset")
|
164 |
+
parser.add_argument("--meta_info_name", type=str,
|
165 |
+
help="Name of the meta information file")
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
|
169 |
+
if args.meta_info_name is None:
|
170 |
+
args.meta_info_name = args.dataset_name
|
171 |
+
|
172 |
+
video_dir = Path(args.root_path) / "videos"
|
173 |
+
video_paths = get_video_paths(video_dir, [".mp4"])
|
174 |
+
|
175 |
+
meta_infos = []
|
176 |
+
|
177 |
+
for video_path in tqdm(video_paths, desc="Extracting meta info"):
|
178 |
+
meta_info = extract_meta_info(video_path)
|
179 |
+
if meta_info:
|
180 |
+
meta_infos.append(meta_info)
|
181 |
+
|
182 |
+
print(f"Final data count: {len(meta_infos)}")
|
183 |
+
|
184 |
+
output_file = Path(f"./data/{args.meta_info_name}_stage2.json")
|
185 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
186 |
+
|
187 |
+
with output_file.open("w", encoding="utf-8") as f:
|
188 |
+
json.dump(meta_infos, f, indent=4)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
main()
|
scripts/inference.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=E1101
|
2 |
+
# scripts/inference.py
|
3 |
+
|
4 |
+
"""
|
5 |
+
This script contains the main inference pipeline for processing audio and image inputs to generate a video output.
|
6 |
+
|
7 |
+
The script imports necessary packages and classes, defines a neural network model,
|
8 |
+
and contains functions for processing audio embeddings and performing inference.
|
9 |
+
|
10 |
+
The main inference process is outlined in the following steps:
|
11 |
+
1. Initialize the configuration.
|
12 |
+
2. Set up runtime variables.
|
13 |
+
3. Prepare the input data for inference (source image, face mask, and face embeddings).
|
14 |
+
4. Process the audio embeddings.
|
15 |
+
5. Build and freeze the model and scheduler.
|
16 |
+
6. Run the inference loop and save the result.
|
17 |
+
|
18 |
+
Usage:
|
19 |
+
This script can be run from the command line with the following arguments:
|
20 |
+
- audio_path: Path to the audio file.
|
21 |
+
- image_path: Path to the source image.
|
22 |
+
- face_mask_path: Path to the face mask image.
|
23 |
+
- face_emb_path: Path to the face embeddings file.
|
24 |
+
- output_path: Path to save the output video.
|
25 |
+
|
26 |
+
Example:
|
27 |
+
python scripts/inference.py --audio_path audio.wav --image_path image.jpg
|
28 |
+
--face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4
|
29 |
+
"""
|
30 |
+
|
31 |
+
import argparse
|
32 |
+
import os
|
33 |
+
|
34 |
+
import torch
|
35 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
36 |
+
from omegaconf import OmegaConf
|
37 |
+
from torch import nn
|
38 |
+
|
39 |
+
from hallo.animate.face_animate import FaceAnimatePipeline
|
40 |
+
from hallo.datasets.audio_processor import AudioProcessor
|
41 |
+
from hallo.datasets.image_processor import ImageProcessor
|
42 |
+
from hallo.models.audio_proj import AudioProjModel
|
43 |
+
from hallo.models.face_locator import FaceLocator
|
44 |
+
from hallo.models.image_proj import ImageProjModel
|
45 |
+
from hallo.models.unet_2d_condition import UNet2DConditionModel
|
46 |
+
from hallo.models.unet_3d import UNet3DConditionModel
|
47 |
+
from hallo.utils.config import filter_non_none
|
48 |
+
from hallo.utils.util import tensor_to_video
|
49 |
+
|
50 |
+
|
51 |
+
class Net(nn.Module):
|
52 |
+
"""
|
53 |
+
The Net class combines all the necessary modules for the inference process.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
|
57 |
+
denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
|
58 |
+
face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
|
59 |
+
imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
|
60 |
+
audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
|
61 |
+
"""
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
reference_unet: UNet2DConditionModel,
|
65 |
+
denoising_unet: UNet3DConditionModel,
|
66 |
+
face_locator: FaceLocator,
|
67 |
+
imageproj,
|
68 |
+
audioproj,
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
self.reference_unet = reference_unet
|
72 |
+
self.denoising_unet = denoising_unet
|
73 |
+
self.face_locator = face_locator
|
74 |
+
self.imageproj = imageproj
|
75 |
+
self.audioproj = audioproj
|
76 |
+
|
77 |
+
def forward(self,):
|
78 |
+
"""
|
79 |
+
empty function to override abstract function of nn Module
|
80 |
+
"""
|
81 |
+
|
82 |
+
def get_modules(self):
|
83 |
+
"""
|
84 |
+
Simple method to avoid too-few-public-methods pylint error
|
85 |
+
"""
|
86 |
+
return {
|
87 |
+
"reference_unet": self.reference_unet,
|
88 |
+
"denoising_unet": self.denoising_unet,
|
89 |
+
"face_locator": self.face_locator,
|
90 |
+
"imageproj": self.imageproj,
|
91 |
+
"audioproj": self.audioproj,
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
def process_audio_emb(audio_emb):
|
96 |
+
"""
|
97 |
+
Process the audio embedding to concatenate with other tensors.
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
audio_emb (torch.Tensor): The audio embedding tensor to process.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
|
104 |
+
"""
|
105 |
+
concatenated_tensors = []
|
106 |
+
|
107 |
+
for i in range(audio_emb.shape[0]):
|
108 |
+
vectors_to_concat = [
|
109 |
+
audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
|
110 |
+
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
|
111 |
+
|
112 |
+
audio_emb = torch.stack(concatenated_tensors, dim=0)
|
113 |
+
|
114 |
+
return audio_emb
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
def inference_process(args: argparse.Namespace):
|
119 |
+
"""
|
120 |
+
Perform inference processing.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
args (argparse.Namespace): Command-line arguments.
|
124 |
+
|
125 |
+
This function initializes the configuration for the inference process. It sets up the necessary
|
126 |
+
modules and variables to prepare for the upcoming inference steps.
|
127 |
+
"""
|
128 |
+
# 1. init config
|
129 |
+
cli_args = filter_non_none(vars(args))
|
130 |
+
config = OmegaConf.load(args.config)
|
131 |
+
config = OmegaConf.merge(config, cli_args)
|
132 |
+
source_image_path = config.source_image
|
133 |
+
driving_audio_path = config.driving_audio
|
134 |
+
save_path = config.save_path
|
135 |
+
if not os.path.exists(save_path):
|
136 |
+
os.makedirs(save_path)
|
137 |
+
motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
|
138 |
+
|
139 |
+
# 2. runtime variables
|
140 |
+
device = torch.device(
|
141 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
142 |
+
if config.weight_dtype == "fp16":
|
143 |
+
weight_dtype = torch.float16
|
144 |
+
elif config.weight_dtype == "bf16":
|
145 |
+
weight_dtype = torch.bfloat16
|
146 |
+
elif config.weight_dtype == "fp32":
|
147 |
+
weight_dtype = torch.float32
|
148 |
+
else:
|
149 |
+
weight_dtype = torch.float32
|
150 |
+
|
151 |
+
# 3. prepare inference data
|
152 |
+
# 3.1 prepare source image, face mask, face embeddings
|
153 |
+
img_size = (config.data.source_image.width,
|
154 |
+
config.data.source_image.height)
|
155 |
+
clip_length = config.data.n_sample_frames
|
156 |
+
face_analysis_model_path = config.face_analysis.model_path
|
157 |
+
with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
|
158 |
+
source_image_pixels, \
|
159 |
+
source_image_face_region, \
|
160 |
+
source_image_face_emb, \
|
161 |
+
source_image_full_mask, \
|
162 |
+
source_image_face_mask, \
|
163 |
+
source_image_lip_mask = image_processor.preprocess(
|
164 |
+
source_image_path, save_path, config.face_expand_ratio)
|
165 |
+
|
166 |
+
# 3.2 prepare audio embeddings
|
167 |
+
sample_rate = config.data.driving_audio.sample_rate
|
168 |
+
assert sample_rate == 16000, "audio sample rate must be 16000"
|
169 |
+
fps = config.data.export_video.fps
|
170 |
+
wav2vec_model_path = config.wav2vec.model_path
|
171 |
+
wav2vec_only_last_features = config.wav2vec.features == "last"
|
172 |
+
audio_separator_model_file = config.audio_separator.model_path
|
173 |
+
with AudioProcessor(
|
174 |
+
sample_rate,
|
175 |
+
fps,
|
176 |
+
wav2vec_model_path,
|
177 |
+
wav2vec_only_last_features,
|
178 |
+
os.path.dirname(audio_separator_model_file),
|
179 |
+
os.path.basename(audio_separator_model_file),
|
180 |
+
os.path.join(save_path, "audio_preprocess")
|
181 |
+
) as audio_processor:
|
182 |
+
audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
|
183 |
+
|
184 |
+
# 4. build modules
|
185 |
+
sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
|
186 |
+
if config.enable_zero_snr:
|
187 |
+
sched_kwargs.update(
|
188 |
+
rescale_betas_zero_snr=True,
|
189 |
+
timestep_spacing="trailing",
|
190 |
+
prediction_type="v_prediction",
|
191 |
+
)
|
192 |
+
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
193 |
+
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
194 |
+
|
195 |
+
vae = AutoencoderKL.from_pretrained(config.vae.model_path)
|
196 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
197 |
+
config.base_model_path, subfolder="unet")
|
198 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
199 |
+
config.base_model_path,
|
200 |
+
config.motion_module_path,
|
201 |
+
subfolder="unet",
|
202 |
+
unet_additional_kwargs=OmegaConf.to_container(
|
203 |
+
config.unet_additional_kwargs),
|
204 |
+
use_landmark=False,
|
205 |
+
)
|
206 |
+
face_locator = FaceLocator(conditioning_embedding_channels=320)
|
207 |
+
image_proj = ImageProjModel(
|
208 |
+
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
209 |
+
clip_embeddings_dim=512,
|
210 |
+
clip_extra_context_tokens=4,
|
211 |
+
)
|
212 |
+
|
213 |
+
audio_proj = AudioProjModel(
|
214 |
+
seq_len=5,
|
215 |
+
blocks=12, # use 12 layers' hidden states of wav2vec
|
216 |
+
channels=768, # audio embedding channel
|
217 |
+
intermediate_dim=512,
|
218 |
+
output_dim=768,
|
219 |
+
context_tokens=32,
|
220 |
+
).to(device=device, dtype=weight_dtype)
|
221 |
+
|
222 |
+
audio_ckpt_dir = config.audio_ckpt_dir
|
223 |
+
|
224 |
+
|
225 |
+
# Freeze
|
226 |
+
vae.requires_grad_(False)
|
227 |
+
image_proj.requires_grad_(False)
|
228 |
+
reference_unet.requires_grad_(False)
|
229 |
+
denoising_unet.requires_grad_(False)
|
230 |
+
face_locator.requires_grad_(False)
|
231 |
+
audio_proj.requires_grad_(False)
|
232 |
+
|
233 |
+
reference_unet.enable_gradient_checkpointing()
|
234 |
+
denoising_unet.enable_gradient_checkpointing()
|
235 |
+
|
236 |
+
net = Net(
|
237 |
+
reference_unet,
|
238 |
+
denoising_unet,
|
239 |
+
face_locator,
|
240 |
+
image_proj,
|
241 |
+
audio_proj,
|
242 |
+
)
|
243 |
+
|
244 |
+
m,u = net.load_state_dict(
|
245 |
+
torch.load(
|
246 |
+
os.path.join(audio_ckpt_dir, "net.pth"),
|
247 |
+
map_location="cpu",
|
248 |
+
),
|
249 |
+
)
|
250 |
+
assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
|
251 |
+
print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
|
252 |
+
|
253 |
+
# 5. inference
|
254 |
+
pipeline = FaceAnimatePipeline(
|
255 |
+
vae=vae,
|
256 |
+
reference_unet=net.reference_unet,
|
257 |
+
denoising_unet=net.denoising_unet,
|
258 |
+
face_locator=net.face_locator,
|
259 |
+
scheduler=val_noise_scheduler,
|
260 |
+
image_proj=net.imageproj,
|
261 |
+
)
|
262 |
+
pipeline.to(device=device, dtype=weight_dtype)
|
263 |
+
|
264 |
+
audio_emb = process_audio_emb(audio_emb)
|
265 |
+
|
266 |
+
source_image_pixels = source_image_pixels.unsqueeze(0)
|
267 |
+
source_image_face_region = source_image_face_region.unsqueeze(0)
|
268 |
+
source_image_face_emb = source_image_face_emb.reshape(1, -1)
|
269 |
+
source_image_face_emb = torch.tensor(source_image_face_emb)
|
270 |
+
|
271 |
+
source_image_full_mask = [
|
272 |
+
(mask.repeat(clip_length, 1))
|
273 |
+
for mask in source_image_full_mask
|
274 |
+
]
|
275 |
+
source_image_face_mask = [
|
276 |
+
(mask.repeat(clip_length, 1))
|
277 |
+
for mask in source_image_face_mask
|
278 |
+
]
|
279 |
+
source_image_lip_mask = [
|
280 |
+
(mask.repeat(clip_length, 1))
|
281 |
+
for mask in source_image_lip_mask
|
282 |
+
]
|
283 |
+
|
284 |
+
|
285 |
+
times = audio_emb.shape[0] // clip_length
|
286 |
+
|
287 |
+
tensor_result = []
|
288 |
+
|
289 |
+
generator = torch.manual_seed(42)
|
290 |
+
|
291 |
+
for t in range(times):
|
292 |
+
print(f"[{t+1}/{times}]")
|
293 |
+
|
294 |
+
if len(tensor_result) == 0:
|
295 |
+
# The first iteration
|
296 |
+
motion_zeros = source_image_pixels.repeat(
|
297 |
+
config.data.n_motion_frames, 1, 1, 1)
|
298 |
+
motion_zeros = motion_zeros.to(
|
299 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
300 |
+
pixel_values_ref_img = torch.cat(
|
301 |
+
[source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
|
302 |
+
else:
|
303 |
+
motion_frames = tensor_result[-1][0]
|
304 |
+
motion_frames = motion_frames.permute(1, 0, 2, 3)
|
305 |
+
motion_frames = motion_frames[0-config.data.n_motion_frames:]
|
306 |
+
motion_frames = motion_frames * 2.0 - 1.0
|
307 |
+
motion_frames = motion_frames.to(
|
308 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
309 |
+
pixel_values_ref_img = torch.cat(
|
310 |
+
[source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
|
311 |
+
|
312 |
+
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
313 |
+
|
314 |
+
audio_tensor = audio_emb[
|
315 |
+
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
|
316 |
+
]
|
317 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
318 |
+
audio_tensor = audio_tensor.to(
|
319 |
+
device=net.audioproj.device, dtype=net.audioproj.dtype)
|
320 |
+
audio_tensor = net.audioproj(audio_tensor)
|
321 |
+
|
322 |
+
pipeline_output = pipeline(
|
323 |
+
ref_image=pixel_values_ref_img,
|
324 |
+
audio_tensor=audio_tensor,
|
325 |
+
face_emb=source_image_face_emb,
|
326 |
+
face_mask=source_image_face_region,
|
327 |
+
pixel_values_full_mask=source_image_full_mask,
|
328 |
+
pixel_values_face_mask=source_image_face_mask,
|
329 |
+
pixel_values_lip_mask=source_image_lip_mask,
|
330 |
+
width=img_size[0],
|
331 |
+
height=img_size[1],
|
332 |
+
video_length=clip_length,
|
333 |
+
num_inference_steps=config.inference_steps,
|
334 |
+
guidance_scale=config.cfg_scale,
|
335 |
+
generator=generator,
|
336 |
+
motion_scale=motion_scale,
|
337 |
+
)
|
338 |
+
|
339 |
+
tensor_result.append(pipeline_output.videos)
|
340 |
+
|
341 |
+
tensor_result = torch.cat(tensor_result, dim=2)
|
342 |
+
tensor_result = tensor_result.squeeze(0)
|
343 |
+
tensor_result = tensor_result[:, :audio_length]
|
344 |
+
|
345 |
+
output_file = config.output
|
346 |
+
# save the result after all iteration
|
347 |
+
tensor_to_video(tensor_result, output_file, driving_audio_path)
|
348 |
+
return output_file
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
parser = argparse.ArgumentParser()
|
353 |
+
|
354 |
+
parser.add_argument(
|
355 |
+
"-c", "--config", default="configs/inference/default.yaml")
|
356 |
+
parser.add_argument("--source_image", type=str, required=False,
|
357 |
+
help="source image")
|
358 |
+
parser.add_argument("--driving_audio", type=str, required=False,
|
359 |
+
help="driving audio")
|
360 |
+
parser.add_argument(
|
361 |
+
"--output", type=str, help="output video file name", default=".cache/output.mp4")
|
362 |
+
parser.add_argument(
|
363 |
+
"--pose_weight", type=float, help="weight of pose", required=False)
|
364 |
+
parser.add_argument(
|
365 |
+
"--face_weight", type=float, help="weight of face", required=False)
|
366 |
+
parser.add_argument(
|
367 |
+
"--lip_weight", type=float, help="weight of lip", required=False)
|
368 |
+
parser.add_argument(
|
369 |
+
"--face_expand_ratio", type=float, help="face region", required=False)
|
370 |
+
parser.add_argument(
|
371 |
+
"--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False)
|
372 |
+
|
373 |
+
|
374 |
+
command_line_args = parser.parse_args()
|
375 |
+
|
376 |
+
inference_process(command_line_args)
|
scripts/train_stage1.py
ADDED
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=E1101,C0415,W0718,R0801
|
2 |
+
# scripts/train_stage1.py
|
3 |
+
"""
|
4 |
+
This is the main training script for stage 1 of the project.
|
5 |
+
It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
|
6 |
+
|
7 |
+
The script includes the following classes and functions:
|
8 |
+
|
9 |
+
1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
|
10 |
+
and face masks as input and returns the denoised latents.
|
11 |
+
3. log_validation: A function that logs the validation information using the given VAE, image encoder,
|
12 |
+
network, scheduler, accelerator, width, height, and configuration.
|
13 |
+
4. train_stage1_process: A function that processes the training stage 1 using the given configuration.
|
14 |
+
|
15 |
+
The script also includes the necessary imports and a brief description of the purpose of the file.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import copy
|
20 |
+
import logging
|
21 |
+
import math
|
22 |
+
import os
|
23 |
+
import random
|
24 |
+
import warnings
|
25 |
+
from datetime import datetime
|
26 |
+
|
27 |
+
import cv2
|
28 |
+
import diffusers
|
29 |
+
import mlflow
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
import torch.nn.functional as F
|
33 |
+
import torch.utils.checkpoint
|
34 |
+
import transformers
|
35 |
+
from accelerate import Accelerator
|
36 |
+
from accelerate.logging import get_logger
|
37 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
38 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
39 |
+
from diffusers.optimization import get_scheduler
|
40 |
+
from diffusers.utils import check_min_version
|
41 |
+
from diffusers.utils.import_utils import is_xformers_available
|
42 |
+
from insightface.app import FaceAnalysis
|
43 |
+
from omegaconf import OmegaConf
|
44 |
+
from PIL import Image
|
45 |
+
from torch import nn
|
46 |
+
from tqdm.auto import tqdm
|
47 |
+
|
48 |
+
from hallo.animate.face_animate_static import StaticPipeline
|
49 |
+
from hallo.datasets.mask_image import FaceMaskDataset
|
50 |
+
from hallo.models.face_locator import FaceLocator
|
51 |
+
from hallo.models.image_proj import ImageProjModel
|
52 |
+
from hallo.models.mutual_self_attention import ReferenceAttentionControl
|
53 |
+
from hallo.models.unet_2d_condition import UNet2DConditionModel
|
54 |
+
from hallo.models.unet_3d import UNet3DConditionModel
|
55 |
+
from hallo.utils.util import (compute_snr, delete_additional_ckpt,
|
56 |
+
import_filename, init_output_dir,
|
57 |
+
load_checkpoint, move_final_checkpoint,
|
58 |
+
save_checkpoint, seed_everything)
|
59 |
+
|
60 |
+
warnings.filterwarnings("ignore")
|
61 |
+
|
62 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
63 |
+
check_min_version("0.10.0.dev0")
|
64 |
+
|
65 |
+
logger = get_logger(__name__, log_level="INFO")
|
66 |
+
|
67 |
+
|
68 |
+
class Net(nn.Module):
|
69 |
+
"""
|
70 |
+
The Net class defines a neural network model that combines a reference UNet2DConditionModel,
|
71 |
+
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
|
75 |
+
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
|
76 |
+
face_locator (FaceLocator): The face locator model used for face animation.
|
77 |
+
reference_control_writer: The reference control writer component.
|
78 |
+
reference_control_reader: The reference control reader component.
|
79 |
+
imageproj: The image projection model.
|
80 |
+
|
81 |
+
Forward method:
|
82 |
+
noisy_latents (torch.Tensor): The noisy latents tensor.
|
83 |
+
timesteps (torch.Tensor): The timesteps tensor.
|
84 |
+
ref_image_latents (torch.Tensor): The reference image latents tensor.
|
85 |
+
face_emb (torch.Tensor): The face embeddings tensor.
|
86 |
+
face_mask (torch.Tensor): The face mask tensor.
|
87 |
+
uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
torch.Tensor: The output tensor of the neural network model.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
reference_unet: UNet2DConditionModel,
|
96 |
+
denoising_unet: UNet3DConditionModel,
|
97 |
+
face_locator: FaceLocator,
|
98 |
+
reference_control_writer: ReferenceAttentionControl,
|
99 |
+
reference_control_reader: ReferenceAttentionControl,
|
100 |
+
imageproj: ImageProjModel,
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
self.reference_unet = reference_unet
|
104 |
+
self.denoising_unet = denoising_unet
|
105 |
+
self.face_locator = face_locator
|
106 |
+
self.reference_control_writer = reference_control_writer
|
107 |
+
self.reference_control_reader = reference_control_reader
|
108 |
+
self.imageproj = imageproj
|
109 |
+
|
110 |
+
def forward(
|
111 |
+
self,
|
112 |
+
noisy_latents,
|
113 |
+
timesteps,
|
114 |
+
ref_image_latents,
|
115 |
+
face_emb,
|
116 |
+
face_mask,
|
117 |
+
uncond_fwd: bool = False,
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
Forward pass of the model.
|
121 |
+
Args:
|
122 |
+
self (Net): The model instance.
|
123 |
+
noisy_latents (torch.Tensor): Noisy latents.
|
124 |
+
timesteps (torch.Tensor): Timesteps.
|
125 |
+
ref_image_latents (torch.Tensor): Reference image latents.
|
126 |
+
face_emb (torch.Tensor): Face embedding.
|
127 |
+
face_mask (torch.Tensor): Face mask.
|
128 |
+
uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
torch.Tensor: Model prediction.
|
132 |
+
"""
|
133 |
+
|
134 |
+
face_emb = self.imageproj(face_emb)
|
135 |
+
face_mask = face_mask.to(device="cuda")
|
136 |
+
face_mask_feature = self.face_locator(face_mask)
|
137 |
+
|
138 |
+
if not uncond_fwd:
|
139 |
+
ref_timesteps = torch.zeros_like(timesteps)
|
140 |
+
self.reference_unet(
|
141 |
+
ref_image_latents,
|
142 |
+
ref_timesteps,
|
143 |
+
encoder_hidden_states=face_emb,
|
144 |
+
return_dict=False,
|
145 |
+
)
|
146 |
+
self.reference_control_reader.update(self.reference_control_writer)
|
147 |
+
model_pred = self.denoising_unet(
|
148 |
+
noisy_latents,
|
149 |
+
timesteps,
|
150 |
+
mask_cond_fea=face_mask_feature,
|
151 |
+
encoder_hidden_states=face_emb,
|
152 |
+
).sample
|
153 |
+
|
154 |
+
return model_pred
|
155 |
+
|
156 |
+
|
157 |
+
def get_noise_scheduler(cfg: argparse.Namespace):
|
158 |
+
"""
|
159 |
+
Create noise scheduler for training
|
160 |
+
|
161 |
+
Args:
|
162 |
+
cfg (omegaconf.dictconfig.DictConfig): Configuration object.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
train noise scheduler and val noise scheduler
|
166 |
+
"""
|
167 |
+
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
|
168 |
+
if cfg.enable_zero_snr:
|
169 |
+
sched_kwargs.update(
|
170 |
+
rescale_betas_zero_snr=True,
|
171 |
+
timestep_spacing="trailing",
|
172 |
+
prediction_type="v_prediction",
|
173 |
+
)
|
174 |
+
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
175 |
+
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
176 |
+
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
177 |
+
|
178 |
+
return train_noise_scheduler, val_noise_scheduler
|
179 |
+
|
180 |
+
|
181 |
+
def log_validation(
|
182 |
+
vae,
|
183 |
+
net,
|
184 |
+
scheduler,
|
185 |
+
accelerator,
|
186 |
+
width,
|
187 |
+
height,
|
188 |
+
imageproj,
|
189 |
+
cfg,
|
190 |
+
save_dir,
|
191 |
+
global_step,
|
192 |
+
face_analysis_model_path,
|
193 |
+
):
|
194 |
+
"""
|
195 |
+
Log validation generation image.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
vae (nn.Module): Variational Autoencoder model.
|
199 |
+
net (Net): Main model.
|
200 |
+
scheduler (diffusers.SchedulerMixin): Noise scheduler.
|
201 |
+
accelerator (accelerate.Accelerator): Accelerator for training.
|
202 |
+
width (int): Width of the input images.
|
203 |
+
height (int): Height of the input images.
|
204 |
+
imageproj (nn.Module): Image projection model.
|
205 |
+
cfg (omegaconf.dictconfig.DictConfig): Configuration object.
|
206 |
+
save_dir (str): directory path to save log result.
|
207 |
+
global_step (int): Global step number.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
None
|
211 |
+
"""
|
212 |
+
logger.info("Running validation... ")
|
213 |
+
|
214 |
+
ori_net = accelerator.unwrap_model(net)
|
215 |
+
ori_net = copy.deepcopy(ori_net)
|
216 |
+
reference_unet = ori_net.reference_unet
|
217 |
+
denoising_unet = ori_net.denoising_unet
|
218 |
+
face_locator = ori_net.face_locator
|
219 |
+
|
220 |
+
generator = torch.manual_seed(42)
|
221 |
+
image_enc = FaceAnalysis(
|
222 |
+
name="",
|
223 |
+
root=face_analysis_model_path,
|
224 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
225 |
+
)
|
226 |
+
image_enc.prepare(ctx_id=0, det_size=(640, 640))
|
227 |
+
|
228 |
+
pipe = StaticPipeline(
|
229 |
+
vae=vae,
|
230 |
+
reference_unet=reference_unet,
|
231 |
+
denoising_unet=denoising_unet,
|
232 |
+
face_locator=face_locator,
|
233 |
+
scheduler=scheduler,
|
234 |
+
imageproj=imageproj,
|
235 |
+
)
|
236 |
+
|
237 |
+
pil_images = []
|
238 |
+
for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths):
|
239 |
+
# for mask_image_path in mask_image_paths:
|
240 |
+
mask_name = os.path.splitext(
|
241 |
+
os.path.basename(mask_image_path))[0]
|
242 |
+
ref_name = os.path.splitext(
|
243 |
+
os.path.basename(ref_image_path))[0]
|
244 |
+
ref_image_pil = Image.open(ref_image_path).convert("RGB")
|
245 |
+
mask_image_pil = Image.open(mask_image_path).convert("RGB")
|
246 |
+
|
247 |
+
# Prepare face embeds
|
248 |
+
face_info = image_enc.get(
|
249 |
+
cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR))
|
250 |
+
face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (
|
251 |
+
x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face
|
252 |
+
face_emb = torch.tensor(face_info['embedding'])
|
253 |
+
face_emb = face_emb.to(
|
254 |
+
imageproj.device, imageproj.dtype)
|
255 |
+
|
256 |
+
image = pipe(
|
257 |
+
ref_image_pil,
|
258 |
+
mask_image_pil,
|
259 |
+
width,
|
260 |
+
height,
|
261 |
+
20,
|
262 |
+
3.5,
|
263 |
+
face_emb,
|
264 |
+
generator=generator,
|
265 |
+
).images
|
266 |
+
image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512)
|
267 |
+
res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
|
268 |
+
# Save ref_image, src_image and the generated_image
|
269 |
+
w, h = res_image_pil.size
|
270 |
+
canvas = Image.new("RGB", (w * 3, h), "white")
|
271 |
+
ref_image_pil = ref_image_pil.resize((w, h))
|
272 |
+
mask_image_pil = mask_image_pil.resize((w, h))
|
273 |
+
canvas.paste(ref_image_pil, (0, 0))
|
274 |
+
canvas.paste(mask_image_pil, (w, 0))
|
275 |
+
canvas.paste(res_image_pil, (w * 2, 0))
|
276 |
+
|
277 |
+
out_file = os.path.join(
|
278 |
+
save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg"
|
279 |
+
)
|
280 |
+
canvas.save(out_file)
|
281 |
+
|
282 |
+
del pipe
|
283 |
+
del ori_net
|
284 |
+
torch.cuda.empty_cache()
|
285 |
+
|
286 |
+
return pil_images
|
287 |
+
|
288 |
+
|
289 |
+
def train_stage1_process(cfg: argparse.Namespace) -> None:
|
290 |
+
"""
|
291 |
+
Trains the model using the given configuration (cfg).
|
292 |
+
|
293 |
+
Args:
|
294 |
+
cfg (dict): The configuration dictionary containing the parameters for training.
|
295 |
+
|
296 |
+
Notes:
|
297 |
+
- This function trains the model using the given configuration.
|
298 |
+
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
|
299 |
+
- The training progress is logged and tracked using the accelerator.
|
300 |
+
- The trained model is saved after the training is completed.
|
301 |
+
"""
|
302 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
303 |
+
accelerator = Accelerator(
|
304 |
+
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
305 |
+
mixed_precision=cfg.solver.mixed_precision,
|
306 |
+
log_with="mlflow",
|
307 |
+
project_dir="./mlruns",
|
308 |
+
kwargs_handlers=[kwargs],
|
309 |
+
)
|
310 |
+
|
311 |
+
# Make one log on every process with the configuration for debugging.
|
312 |
+
logging.basicConfig(
|
313 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
314 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
315 |
+
level=logging.INFO,
|
316 |
+
)
|
317 |
+
|
318 |
+
logger.info(accelerator.state, main_process_only=False)
|
319 |
+
if accelerator.is_local_main_process:
|
320 |
+
transformers.utils.logging.set_verbosity_warning()
|
321 |
+
diffusers.utils.logging.set_verbosity_info()
|
322 |
+
else:
|
323 |
+
transformers.utils.logging.set_verbosity_error()
|
324 |
+
diffusers.utils.logging.set_verbosity_error()
|
325 |
+
|
326 |
+
# If passed along, set the training seed now.
|
327 |
+
if cfg.seed is not None:
|
328 |
+
seed_everything(cfg.seed)
|
329 |
+
|
330 |
+
# create output dir for training
|
331 |
+
exp_name = cfg.exp_name
|
332 |
+
save_dir = f"{cfg.output_dir}/{exp_name}"
|
333 |
+
checkpoint_dir = os.path.join(save_dir, "checkpoints")
|
334 |
+
module_dir = os.path.join(save_dir, "modules")
|
335 |
+
validation_dir = os.path.join(save_dir, "validation")
|
336 |
+
|
337 |
+
if accelerator.is_main_process:
|
338 |
+
init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
|
339 |
+
|
340 |
+
accelerator.wait_for_everyone()
|
341 |
+
|
342 |
+
# create model
|
343 |
+
if cfg.weight_dtype == "fp16":
|
344 |
+
weight_dtype = torch.float16
|
345 |
+
elif cfg.weight_dtype == "bf16":
|
346 |
+
weight_dtype = torch.bfloat16
|
347 |
+
elif cfg.weight_dtype == "fp32":
|
348 |
+
weight_dtype = torch.float32
|
349 |
+
else:
|
350 |
+
raise ValueError(
|
351 |
+
f"Do not support weight dtype: {cfg.weight_dtype} during training"
|
352 |
+
)
|
353 |
+
|
354 |
+
# create model
|
355 |
+
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
|
356 |
+
"cuda", dtype=weight_dtype
|
357 |
+
)
|
358 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
359 |
+
cfg.base_model_path,
|
360 |
+
subfolder="unet",
|
361 |
+
).to(device="cuda", dtype=weight_dtype)
|
362 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
363 |
+
cfg.base_model_path,
|
364 |
+
"",
|
365 |
+
subfolder="unet",
|
366 |
+
unet_additional_kwargs={
|
367 |
+
"use_motion_module": False,
|
368 |
+
"unet_use_temporal_attention": False,
|
369 |
+
},
|
370 |
+
use_landmark=False
|
371 |
+
).to(device="cuda", dtype=weight_dtype)
|
372 |
+
imageproj = ImageProjModel(
|
373 |
+
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
374 |
+
clip_embeddings_dim=512,
|
375 |
+
clip_extra_context_tokens=4,
|
376 |
+
).to(device="cuda", dtype=weight_dtype)
|
377 |
+
|
378 |
+
if cfg.face_locator_pretrained:
|
379 |
+
face_locator = FaceLocator(
|
380 |
+
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
|
381 |
+
).to(device="cuda", dtype=weight_dtype)
|
382 |
+
miss, _ = face_locator.load_state_dict(
|
383 |
+
cfg.face_state_dict_path, strict=False)
|
384 |
+
logger.info(f"Missing key for face locator: {len(miss)}")
|
385 |
+
else:
|
386 |
+
face_locator = FaceLocator(
|
387 |
+
conditioning_embedding_channels=320,
|
388 |
+
).to(device="cuda", dtype=weight_dtype)
|
389 |
+
# Freeze
|
390 |
+
vae.requires_grad_(False)
|
391 |
+
denoising_unet.requires_grad_(True)
|
392 |
+
reference_unet.requires_grad_(True)
|
393 |
+
imageproj.requires_grad_(True)
|
394 |
+
face_locator.requires_grad_(True)
|
395 |
+
|
396 |
+
reference_control_writer = ReferenceAttentionControl(
|
397 |
+
reference_unet,
|
398 |
+
do_classifier_free_guidance=False,
|
399 |
+
mode="write",
|
400 |
+
fusion_blocks="full",
|
401 |
+
)
|
402 |
+
reference_control_reader = ReferenceAttentionControl(
|
403 |
+
denoising_unet,
|
404 |
+
do_classifier_free_guidance=False,
|
405 |
+
mode="read",
|
406 |
+
fusion_blocks="full",
|
407 |
+
)
|
408 |
+
|
409 |
+
net = Net(
|
410 |
+
reference_unet,
|
411 |
+
denoising_unet,
|
412 |
+
face_locator,
|
413 |
+
reference_control_writer,
|
414 |
+
reference_control_reader,
|
415 |
+
imageproj,
|
416 |
+
).to(dtype=weight_dtype)
|
417 |
+
|
418 |
+
# get noise scheduler
|
419 |
+
train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
|
420 |
+
|
421 |
+
# init optimizer
|
422 |
+
if cfg.solver.enable_xformers_memory_efficient_attention:
|
423 |
+
if is_xformers_available():
|
424 |
+
reference_unet.enable_xformers_memory_efficient_attention()
|
425 |
+
denoising_unet.enable_xformers_memory_efficient_attention()
|
426 |
+
else:
|
427 |
+
raise ValueError(
|
428 |
+
"xformers is not available. Make sure it is installed correctly"
|
429 |
+
)
|
430 |
+
|
431 |
+
if cfg.solver.gradient_checkpointing:
|
432 |
+
reference_unet.enable_gradient_checkpointing()
|
433 |
+
denoising_unet.enable_gradient_checkpointing()
|
434 |
+
|
435 |
+
if cfg.solver.scale_lr:
|
436 |
+
learning_rate = (
|
437 |
+
cfg.solver.learning_rate
|
438 |
+
* cfg.solver.gradient_accumulation_steps
|
439 |
+
* cfg.data.train_bs
|
440 |
+
* accelerator.num_processes
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
learning_rate = cfg.solver.learning_rate
|
444 |
+
|
445 |
+
# Initialize the optimizer
|
446 |
+
if cfg.solver.use_8bit_adam:
|
447 |
+
try:
|
448 |
+
import bitsandbytes as bnb
|
449 |
+
except ImportError as exc:
|
450 |
+
raise ImportError(
|
451 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
452 |
+
) from exc
|
453 |
+
|
454 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
455 |
+
else:
|
456 |
+
optimizer_cls = torch.optim.AdamW
|
457 |
+
|
458 |
+
trainable_params = list(
|
459 |
+
filter(lambda p: p.requires_grad, net.parameters()))
|
460 |
+
optimizer = optimizer_cls(
|
461 |
+
trainable_params,
|
462 |
+
lr=learning_rate,
|
463 |
+
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
464 |
+
weight_decay=cfg.solver.adam_weight_decay,
|
465 |
+
eps=cfg.solver.adam_epsilon,
|
466 |
+
)
|
467 |
+
|
468 |
+
# init scheduler
|
469 |
+
lr_scheduler = get_scheduler(
|
470 |
+
cfg.solver.lr_scheduler,
|
471 |
+
optimizer=optimizer,
|
472 |
+
num_warmup_steps=cfg.solver.lr_warmup_steps
|
473 |
+
* cfg.solver.gradient_accumulation_steps,
|
474 |
+
num_training_steps=cfg.solver.max_train_steps
|
475 |
+
* cfg.solver.gradient_accumulation_steps,
|
476 |
+
)
|
477 |
+
|
478 |
+
# get data loader
|
479 |
+
train_dataset = FaceMaskDataset(
|
480 |
+
img_size=(cfg.data.train_width, cfg.data.train_height),
|
481 |
+
data_meta_paths=cfg.data.meta_paths,
|
482 |
+
sample_margin=cfg.data.sample_margin,
|
483 |
+
)
|
484 |
+
train_dataloader = torch.utils.data.DataLoader(
|
485 |
+
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
|
486 |
+
)
|
487 |
+
|
488 |
+
# Prepare everything with our `accelerator`.
|
489 |
+
(
|
490 |
+
net,
|
491 |
+
optimizer,
|
492 |
+
train_dataloader,
|
493 |
+
lr_scheduler,
|
494 |
+
) = accelerator.prepare(
|
495 |
+
net,
|
496 |
+
optimizer,
|
497 |
+
train_dataloader,
|
498 |
+
lr_scheduler,
|
499 |
+
)
|
500 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
501 |
+
num_update_steps_per_epoch = math.ceil(
|
502 |
+
len(train_dataloader) / cfg.solver.gradient_accumulation_steps
|
503 |
+
)
|
504 |
+
# Afterwards we recalculate our number of training epochs
|
505 |
+
num_train_epochs = math.ceil(
|
506 |
+
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
507 |
+
)
|
508 |
+
|
509 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
510 |
+
# The trackers initializes automatically on the main process.
|
511 |
+
if accelerator.is_main_process:
|
512 |
+
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
513 |
+
accelerator.init_trackers(
|
514 |
+
cfg.exp_name,
|
515 |
+
init_kwargs={"mlflow": {"run_name": run_time}},
|
516 |
+
)
|
517 |
+
# dump config file
|
518 |
+
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
|
519 |
+
|
520 |
+
logger.info(f"save config to {save_dir}")
|
521 |
+
OmegaConf.save(
|
522 |
+
cfg, os.path.join(save_dir, "config.yaml")
|
523 |
+
)
|
524 |
+
# Train!
|
525 |
+
total_batch_size = (
|
526 |
+
cfg.data.train_bs
|
527 |
+
* accelerator.num_processes
|
528 |
+
* cfg.solver.gradient_accumulation_steps
|
529 |
+
)
|
530 |
+
|
531 |
+
logger.info("***** Running training *****")
|
532 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
533 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
534 |
+
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
|
535 |
+
logger.info(
|
536 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
537 |
+
)
|
538 |
+
logger.info(
|
539 |
+
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
|
540 |
+
)
|
541 |
+
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
|
542 |
+
global_step = 0
|
543 |
+
first_epoch = 0
|
544 |
+
|
545 |
+
# load checkpoint
|
546 |
+
# Potentially load in the weights and states from a previous save
|
547 |
+
if cfg.resume_from_checkpoint:
|
548 |
+
logger.info(f"Loading checkpoint from {checkpoint_dir}")
|
549 |
+
global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
|
550 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
551 |
+
|
552 |
+
# Only show the progress bar once on each machine.
|
553 |
+
progress_bar = tqdm(
|
554 |
+
range(global_step, cfg.solver.max_train_steps),
|
555 |
+
disable=not accelerator.is_main_process,
|
556 |
+
)
|
557 |
+
progress_bar.set_description("Steps")
|
558 |
+
net.train()
|
559 |
+
for _ in range(first_epoch, num_train_epochs):
|
560 |
+
train_loss = 0.0
|
561 |
+
for _, batch in enumerate(train_dataloader):
|
562 |
+
with accelerator.accumulate(net):
|
563 |
+
# Convert videos to latent space
|
564 |
+
pixel_values = batch["img"].to(weight_dtype)
|
565 |
+
with torch.no_grad():
|
566 |
+
latents = vae.encode(pixel_values).latent_dist.sample()
|
567 |
+
latents = latents.unsqueeze(2) # (b, c, 1, h, w)
|
568 |
+
latents = latents * 0.18215
|
569 |
+
|
570 |
+
noise = torch.randn_like(latents)
|
571 |
+
if cfg.noise_offset > 0.0:
|
572 |
+
noise += cfg.noise_offset * torch.randn(
|
573 |
+
(noise.shape[0], noise.shape[1], 1, 1, 1),
|
574 |
+
device=noise.device,
|
575 |
+
)
|
576 |
+
|
577 |
+
bsz = latents.shape[0]
|
578 |
+
# Sample a random timestep for each video
|
579 |
+
timesteps = torch.randint(
|
580 |
+
0,
|
581 |
+
train_noise_scheduler.num_train_timesteps,
|
582 |
+
(bsz,),
|
583 |
+
device=latents.device,
|
584 |
+
)
|
585 |
+
timesteps = timesteps.long()
|
586 |
+
|
587 |
+
face_mask_img = batch["tgt_mask"]
|
588 |
+
face_mask_img = face_mask_img.unsqueeze(
|
589 |
+
2)
|
590 |
+
face_mask_img = face_mask_img.to(weight_dtype)
|
591 |
+
|
592 |
+
uncond_fwd = random.random() < cfg.uncond_ratio
|
593 |
+
face_emb_list = []
|
594 |
+
ref_image_list = []
|
595 |
+
for _, (ref_img, face_emb) in enumerate(
|
596 |
+
zip(batch["ref_img"], batch["face_emb"])
|
597 |
+
):
|
598 |
+
if uncond_fwd:
|
599 |
+
face_emb_list.append(torch.zeros_like(face_emb))
|
600 |
+
else:
|
601 |
+
face_emb_list.append(face_emb)
|
602 |
+
ref_image_list.append(ref_img)
|
603 |
+
|
604 |
+
with torch.no_grad():
|
605 |
+
ref_img = torch.stack(ref_image_list, dim=0).to(
|
606 |
+
dtype=vae.dtype, device=vae.device
|
607 |
+
)
|
608 |
+
ref_image_latents = vae.encode(
|
609 |
+
ref_img
|
610 |
+
).latent_dist.sample()
|
611 |
+
ref_image_latents = ref_image_latents * 0.18215
|
612 |
+
|
613 |
+
face_emb = torch.stack(face_emb_list, dim=0).to(
|
614 |
+
dtype=imageproj.dtype, device=imageproj.device
|
615 |
+
)
|
616 |
+
|
617 |
+
# add noise
|
618 |
+
noisy_latents = train_noise_scheduler.add_noise(
|
619 |
+
latents, noise, timesteps
|
620 |
+
)
|
621 |
+
|
622 |
+
# Get the target for loss depending on the prediction type
|
623 |
+
if train_noise_scheduler.prediction_type == "epsilon":
|
624 |
+
target = noise
|
625 |
+
elif train_noise_scheduler.prediction_type == "v_prediction":
|
626 |
+
target = train_noise_scheduler.get_velocity(
|
627 |
+
latents, noise, timesteps
|
628 |
+
)
|
629 |
+
else:
|
630 |
+
raise ValueError(
|
631 |
+
f"Unknown prediction type {train_noise_scheduler.prediction_type}"
|
632 |
+
)
|
633 |
+
model_pred = net(
|
634 |
+
noisy_latents,
|
635 |
+
timesteps,
|
636 |
+
ref_image_latents,
|
637 |
+
face_emb,
|
638 |
+
face_mask_img,
|
639 |
+
uncond_fwd,
|
640 |
+
)
|
641 |
+
|
642 |
+
if cfg.snr_gamma == 0:
|
643 |
+
loss = F.mse_loss(
|
644 |
+
model_pred.float(), target.float(), reduction="mean"
|
645 |
+
)
|
646 |
+
else:
|
647 |
+
snr = compute_snr(train_noise_scheduler, timesteps)
|
648 |
+
if train_noise_scheduler.config.prediction_type == "v_prediction":
|
649 |
+
# Velocity objective requires that we add one to SNR values before we divide by them.
|
650 |
+
snr = snr + 1
|
651 |
+
mse_loss_weights = (
|
652 |
+
torch.stack(
|
653 |
+
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
|
654 |
+
).min(dim=1)[0]
|
655 |
+
/ snr
|
656 |
+
)
|
657 |
+
loss = F.mse_loss(
|
658 |
+
model_pred.float(), target.float(), reduction="none"
|
659 |
+
)
|
660 |
+
loss = (
|
661 |
+
loss.mean(dim=list(range(1, len(loss.shape))))
|
662 |
+
* mse_loss_weights
|
663 |
+
)
|
664 |
+
loss = loss.mean()
|
665 |
+
|
666 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
667 |
+
avg_loss = accelerator.gather(
|
668 |
+
loss.repeat(cfg.data.train_bs)).mean()
|
669 |
+
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
|
670 |
+
|
671 |
+
# Backpropagate
|
672 |
+
accelerator.backward(loss)
|
673 |
+
if accelerator.sync_gradients:
|
674 |
+
accelerator.clip_grad_norm_(
|
675 |
+
trainable_params,
|
676 |
+
cfg.solver.max_grad_norm,
|
677 |
+
)
|
678 |
+
optimizer.step()
|
679 |
+
lr_scheduler.step()
|
680 |
+
optimizer.zero_grad()
|
681 |
+
|
682 |
+
if accelerator.sync_gradients:
|
683 |
+
reference_control_reader.clear()
|
684 |
+
reference_control_writer.clear()
|
685 |
+
progress_bar.update(1)
|
686 |
+
global_step += 1
|
687 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
688 |
+
train_loss = 0.0
|
689 |
+
if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps:
|
690 |
+
accelerator.wait_for_everyone()
|
691 |
+
save_path = os.path.join(
|
692 |
+
checkpoint_dir, f"checkpoint-{global_step}")
|
693 |
+
if accelerator.is_main_process:
|
694 |
+
delete_additional_ckpt(checkpoint_dir, 3)
|
695 |
+
accelerator.save_state(save_path)
|
696 |
+
accelerator.wait_for_everyone()
|
697 |
+
unwrap_net = accelerator.unwrap_model(net)
|
698 |
+
if accelerator.is_main_process:
|
699 |
+
save_checkpoint(
|
700 |
+
unwrap_net.reference_unet,
|
701 |
+
module_dir,
|
702 |
+
"reference_unet",
|
703 |
+
global_step,
|
704 |
+
total_limit=3,
|
705 |
+
)
|
706 |
+
save_checkpoint(
|
707 |
+
unwrap_net.imageproj,
|
708 |
+
module_dir,
|
709 |
+
"imageproj",
|
710 |
+
global_step,
|
711 |
+
total_limit=3,
|
712 |
+
)
|
713 |
+
save_checkpoint(
|
714 |
+
unwrap_net.denoising_unet,
|
715 |
+
module_dir,
|
716 |
+
"denoising_unet",
|
717 |
+
global_step,
|
718 |
+
total_limit=3,
|
719 |
+
)
|
720 |
+
save_checkpoint(
|
721 |
+
unwrap_net.face_locator,
|
722 |
+
module_dir,
|
723 |
+
"face_locator",
|
724 |
+
global_step,
|
725 |
+
total_limit=3,
|
726 |
+
)
|
727 |
+
|
728 |
+
if global_step % cfg.val.validation_steps == 0 or global_step == 1:
|
729 |
+
if accelerator.is_main_process:
|
730 |
+
generator = torch.Generator(device=accelerator.device)
|
731 |
+
generator.manual_seed(cfg.seed)
|
732 |
+
log_validation(
|
733 |
+
vae=vae,
|
734 |
+
net=net,
|
735 |
+
scheduler=val_noise_scheduler,
|
736 |
+
accelerator=accelerator,
|
737 |
+
width=cfg.data.train_width,
|
738 |
+
height=cfg.data.train_height,
|
739 |
+
imageproj=imageproj,
|
740 |
+
cfg=cfg,
|
741 |
+
save_dir=validation_dir,
|
742 |
+
global_step=global_step,
|
743 |
+
face_analysis_model_path=cfg.face_analysis_model_path
|
744 |
+
)
|
745 |
+
|
746 |
+
logs = {
|
747 |
+
"step_loss": loss.detach().item(),
|
748 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
749 |
+
}
|
750 |
+
progress_bar.set_postfix(**logs)
|
751 |
+
|
752 |
+
if global_step >= cfg.solver.max_train_steps:
|
753 |
+
# process final module weight for stage2
|
754 |
+
if accelerator.is_main_process:
|
755 |
+
move_final_checkpoint(save_dir, module_dir, "reference_unet")
|
756 |
+
move_final_checkpoint(save_dir, module_dir, "imageproj")
|
757 |
+
move_final_checkpoint(save_dir, module_dir, "denoising_unet")
|
758 |
+
move_final_checkpoint(save_dir, module_dir, "face_locator")
|
759 |
+
break
|
760 |
+
|
761 |
+
accelerator.wait_for_everyone()
|
762 |
+
accelerator.end_training()
|
763 |
+
|
764 |
+
|
765 |
+
def load_config(config_path: str) -> dict:
|
766 |
+
"""
|
767 |
+
Loads the configuration file.
|
768 |
+
|
769 |
+
Args:
|
770 |
+
config_path (str): Path to the configuration file.
|
771 |
+
|
772 |
+
Returns:
|
773 |
+
dict: The configuration dictionary.
|
774 |
+
"""
|
775 |
+
|
776 |
+
if config_path.endswith(".yaml"):
|
777 |
+
return OmegaConf.load(config_path)
|
778 |
+
if config_path.endswith(".py"):
|
779 |
+
return import_filename(config_path).cfg
|
780 |
+
raise ValueError("Unsupported format for config file")
|
781 |
+
|
782 |
+
|
783 |
+
if __name__ == "__main__":
|
784 |
+
parser = argparse.ArgumentParser()
|
785 |
+
parser.add_argument("--config", type=str,
|
786 |
+
default="./configs/train/stage1.yaml")
|
787 |
+
args = parser.parse_args()
|
788 |
+
|
789 |
+
try:
|
790 |
+
config = load_config(args.config)
|
791 |
+
train_stage1_process(config)
|
792 |
+
except Exception as e:
|
793 |
+
logging.error("Failed to execute the training process: %s", e)
|
scripts/train_stage2.py
ADDED
@@ -0,0 +1,991 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=E1101,C0415,W0718,R0801
|
2 |
+
# scripts/train_stage2.py
|
3 |
+
"""
|
4 |
+
This is the main training script for stage 2 of the project.
|
5 |
+
It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
|
6 |
+
|
7 |
+
The script includes the following classes and functions:
|
8 |
+
|
9 |
+
1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
|
10 |
+
and face masks as input and returns the denoised latents.
|
11 |
+
2. get_attention_mask: A function that rearranges the mask tensors to the required format.
|
12 |
+
3. get_noise_scheduler: A function that creates and returns the noise schedulers for training and validation.
|
13 |
+
4. process_audio_emb: A function that processes the audio embeddings to concatenate with other tensors.
|
14 |
+
5. log_validation: A function that logs the validation information using the given VAE, image encoder,
|
15 |
+
network, scheduler, accelerator, width, height, and configuration.
|
16 |
+
6. train_stage2_process: A function that processes the training stage 2 using the given configuration.
|
17 |
+
7. load_config: A function that loads the configuration file from the given path.
|
18 |
+
|
19 |
+
The script also includes the necessary imports and a brief description of the purpose of the file.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import argparse
|
23 |
+
import copy
|
24 |
+
import logging
|
25 |
+
import math
|
26 |
+
import os
|
27 |
+
import random
|
28 |
+
import time
|
29 |
+
import warnings
|
30 |
+
from datetime import datetime
|
31 |
+
from typing import List, Tuple
|
32 |
+
|
33 |
+
import diffusers
|
34 |
+
import mlflow
|
35 |
+
import torch
|
36 |
+
import torch.nn.functional as F
|
37 |
+
import torch.utils.checkpoint
|
38 |
+
import transformers
|
39 |
+
from accelerate import Accelerator
|
40 |
+
from accelerate.logging import get_logger
|
41 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
42 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
43 |
+
from diffusers.optimization import get_scheduler
|
44 |
+
from diffusers.utils import check_min_version
|
45 |
+
from diffusers.utils.import_utils import is_xformers_available
|
46 |
+
from einops import rearrange, repeat
|
47 |
+
from omegaconf import OmegaConf
|
48 |
+
from torch import nn
|
49 |
+
from tqdm.auto import tqdm
|
50 |
+
|
51 |
+
from hallo.animate.face_animate import FaceAnimatePipeline
|
52 |
+
from hallo.datasets.audio_processor import AudioProcessor
|
53 |
+
from hallo.datasets.image_processor import ImageProcessor
|
54 |
+
from hallo.datasets.talk_video import TalkingVideoDataset
|
55 |
+
from hallo.models.audio_proj import AudioProjModel
|
56 |
+
from hallo.models.face_locator import FaceLocator
|
57 |
+
from hallo.models.image_proj import ImageProjModel
|
58 |
+
from hallo.models.mutual_self_attention import ReferenceAttentionControl
|
59 |
+
from hallo.models.unet_2d_condition import UNet2DConditionModel
|
60 |
+
from hallo.models.unet_3d import UNet3DConditionModel
|
61 |
+
from hallo.utils.util import (compute_snr, delete_additional_ckpt,
|
62 |
+
import_filename, init_output_dir,
|
63 |
+
load_checkpoint, save_checkpoint,
|
64 |
+
seed_everything, tensor_to_video)
|
65 |
+
|
66 |
+
warnings.filterwarnings("ignore")
|
67 |
+
|
68 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
69 |
+
check_min_version("0.10.0.dev0")
|
70 |
+
|
71 |
+
logger = get_logger(__name__, log_level="INFO")
|
72 |
+
|
73 |
+
|
74 |
+
class Net(nn.Module):
|
75 |
+
"""
|
76 |
+
The Net class defines a neural network model that combines a reference UNet2DConditionModel,
|
77 |
+
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
|
81 |
+
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
|
82 |
+
face_locator (FaceLocator): The face locator model used for face animation.
|
83 |
+
reference_control_writer: The reference control writer component.
|
84 |
+
reference_control_reader: The reference control reader component.
|
85 |
+
imageproj: The image projection model.
|
86 |
+
audioproj: The audio projection model.
|
87 |
+
|
88 |
+
Forward method:
|
89 |
+
noisy_latents (torch.Tensor): The noisy latents tensor.
|
90 |
+
timesteps (torch.Tensor): The timesteps tensor.
|
91 |
+
ref_image_latents (torch.Tensor): The reference image latents tensor.
|
92 |
+
face_emb (torch.Tensor): The face embeddings tensor.
|
93 |
+
audio_emb (torch.Tensor): The audio embeddings tensor.
|
94 |
+
mask (torch.Tensor): Hard face mask for face locator.
|
95 |
+
full_mask (torch.Tensor): Pose Mask.
|
96 |
+
face_mask (torch.Tensor): Face Mask
|
97 |
+
lip_mask (torch.Tensor): Lip Mask
|
98 |
+
uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass.
|
99 |
+
uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
torch.Tensor: The output tensor of the neural network model.
|
103 |
+
"""
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
reference_unet: UNet2DConditionModel,
|
107 |
+
denoising_unet: UNet3DConditionModel,
|
108 |
+
face_locator: FaceLocator,
|
109 |
+
reference_control_writer,
|
110 |
+
reference_control_reader,
|
111 |
+
imageproj,
|
112 |
+
audioproj,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
self.reference_unet = reference_unet
|
116 |
+
self.denoising_unet = denoising_unet
|
117 |
+
self.face_locator = face_locator
|
118 |
+
self.reference_control_writer = reference_control_writer
|
119 |
+
self.reference_control_reader = reference_control_reader
|
120 |
+
self.imageproj = imageproj
|
121 |
+
self.audioproj = audioproj
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self,
|
125 |
+
noisy_latents: torch.Tensor,
|
126 |
+
timesteps: torch.Tensor,
|
127 |
+
ref_image_latents: torch.Tensor,
|
128 |
+
face_emb: torch.Tensor,
|
129 |
+
audio_emb: torch.Tensor,
|
130 |
+
mask: torch.Tensor,
|
131 |
+
full_mask: torch.Tensor,
|
132 |
+
face_mask: torch.Tensor,
|
133 |
+
lip_mask: torch.Tensor,
|
134 |
+
uncond_img_fwd: bool = False,
|
135 |
+
uncond_audio_fwd: bool = False,
|
136 |
+
):
|
137 |
+
"""
|
138 |
+
simple docstring to prevent pylint error
|
139 |
+
"""
|
140 |
+
face_emb = self.imageproj(face_emb)
|
141 |
+
mask = mask.to(device="cuda")
|
142 |
+
mask_feature = self.face_locator(mask)
|
143 |
+
audio_emb = audio_emb.to(
|
144 |
+
device=self.audioproj.device, dtype=self.audioproj.dtype)
|
145 |
+
audio_emb = self.audioproj(audio_emb)
|
146 |
+
|
147 |
+
# condition forward
|
148 |
+
if not uncond_img_fwd:
|
149 |
+
ref_timesteps = torch.zeros_like(timesteps)
|
150 |
+
ref_timesteps = repeat(
|
151 |
+
ref_timesteps,
|
152 |
+
"b -> (repeat b)",
|
153 |
+
repeat=ref_image_latents.size(0) // ref_timesteps.size(0),
|
154 |
+
)
|
155 |
+
self.reference_unet(
|
156 |
+
ref_image_latents,
|
157 |
+
ref_timesteps,
|
158 |
+
encoder_hidden_states=face_emb,
|
159 |
+
return_dict=False,
|
160 |
+
)
|
161 |
+
self.reference_control_reader.update(self.reference_control_writer)
|
162 |
+
|
163 |
+
if uncond_audio_fwd:
|
164 |
+
audio_emb = torch.zeros_like(audio_emb).to(
|
165 |
+
device=audio_emb.device, dtype=audio_emb.dtype
|
166 |
+
)
|
167 |
+
|
168 |
+
model_pred = self.denoising_unet(
|
169 |
+
noisy_latents,
|
170 |
+
timesteps,
|
171 |
+
mask_cond_fea=mask_feature,
|
172 |
+
encoder_hidden_states=face_emb,
|
173 |
+
audio_embedding=audio_emb,
|
174 |
+
full_mask=full_mask,
|
175 |
+
face_mask=face_mask,
|
176 |
+
lip_mask=lip_mask
|
177 |
+
).sample
|
178 |
+
|
179 |
+
return model_pred
|
180 |
+
|
181 |
+
|
182 |
+
def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
|
183 |
+
"""
|
184 |
+
Rearrange the mask tensors to the required format.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
mask (torch.Tensor): The input mask tensor.
|
188 |
+
weight_dtype (torch.dtype): The data type for the mask tensor.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
torch.Tensor: The rearranged mask tensor.
|
192 |
+
"""
|
193 |
+
if isinstance(mask, List):
|
194 |
+
_mask = []
|
195 |
+
for m in mask:
|
196 |
+
_mask.append(
|
197 |
+
rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype))
|
198 |
+
return _mask
|
199 |
+
mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype)
|
200 |
+
return mask
|
201 |
+
|
202 |
+
|
203 |
+
def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]:
|
204 |
+
"""
|
205 |
+
Create noise scheduler for training.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
cfg (argparse.Namespace): Configuration object.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler.
|
212 |
+
"""
|
213 |
+
|
214 |
+
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
|
215 |
+
if cfg.enable_zero_snr:
|
216 |
+
sched_kwargs.update(
|
217 |
+
rescale_betas_zero_snr=True,
|
218 |
+
timestep_spacing="trailing",
|
219 |
+
prediction_type="v_prediction",
|
220 |
+
)
|
221 |
+
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
222 |
+
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
223 |
+
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
224 |
+
|
225 |
+
return train_noise_scheduler, val_noise_scheduler
|
226 |
+
|
227 |
+
|
228 |
+
def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor:
|
229 |
+
"""
|
230 |
+
Process the audio embedding to concatenate with other tensors.
|
231 |
+
|
232 |
+
Parameters:
|
233 |
+
audio_emb (torch.Tensor): The audio embedding tensor to process.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
|
237 |
+
"""
|
238 |
+
concatenated_tensors = []
|
239 |
+
|
240 |
+
for i in range(audio_emb.shape[0]):
|
241 |
+
vectors_to_concat = [
|
242 |
+
audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)]
|
243 |
+
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
|
244 |
+
|
245 |
+
audio_emb = torch.stack(concatenated_tensors, dim=0)
|
246 |
+
|
247 |
+
return audio_emb
|
248 |
+
|
249 |
+
|
250 |
+
def log_validation(
|
251 |
+
accelerator: Accelerator,
|
252 |
+
vae: AutoencoderKL,
|
253 |
+
net: Net,
|
254 |
+
scheduler: DDIMScheduler,
|
255 |
+
width: int,
|
256 |
+
height: int,
|
257 |
+
clip_length: int = 24,
|
258 |
+
generator: torch.Generator = None,
|
259 |
+
cfg: dict = None,
|
260 |
+
save_dir: str = None,
|
261 |
+
global_step: int = 0,
|
262 |
+
times: int = None,
|
263 |
+
face_analysis_model_path: str = "",
|
264 |
+
) -> None:
|
265 |
+
"""
|
266 |
+
Log validation video during the training process.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
accelerator (Accelerator): The accelerator for distributed training.
|
270 |
+
vae (AutoencoderKL): The autoencoder model.
|
271 |
+
net (Net): The main neural network model.
|
272 |
+
scheduler (DDIMScheduler): The scheduler for noise.
|
273 |
+
width (int): The width of the input images.
|
274 |
+
height (int): The height of the input images.
|
275 |
+
clip_length (int): The length of the video clips. Defaults to 24.
|
276 |
+
generator (torch.Generator): The random number generator. Defaults to None.
|
277 |
+
cfg (dict): The configuration dictionary. Defaults to None.
|
278 |
+
save_dir (str): The directory to save validation results. Defaults to None.
|
279 |
+
global_step (int): The current global step in training. Defaults to 0.
|
280 |
+
times (int): The number of inference times. Defaults to None.
|
281 |
+
face_analysis_model_path (str): The path to the face analysis model. Defaults to "".
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
torch.Tensor: The tensor result of the validation.
|
285 |
+
"""
|
286 |
+
ori_net = accelerator.unwrap_model(net)
|
287 |
+
reference_unet = ori_net.reference_unet
|
288 |
+
denoising_unet = ori_net.denoising_unet
|
289 |
+
face_locator = ori_net.face_locator
|
290 |
+
imageproj = ori_net.imageproj
|
291 |
+
audioproj = ori_net.audioproj
|
292 |
+
|
293 |
+
generator = torch.manual_seed(42)
|
294 |
+
tmp_denoising_unet = copy.deepcopy(denoising_unet)
|
295 |
+
|
296 |
+
pipeline = FaceAnimatePipeline(
|
297 |
+
vae=vae,
|
298 |
+
reference_unet=reference_unet,
|
299 |
+
denoising_unet=tmp_denoising_unet,
|
300 |
+
face_locator=face_locator,
|
301 |
+
image_proj=imageproj,
|
302 |
+
scheduler=scheduler,
|
303 |
+
)
|
304 |
+
pipeline = pipeline.to("cuda")
|
305 |
+
|
306 |
+
image_processor = ImageProcessor((width, height), face_analysis_model_path)
|
307 |
+
audio_processor = AudioProcessor(
|
308 |
+
cfg.data.sample_rate,
|
309 |
+
cfg.data.fps,
|
310 |
+
cfg.wav2vec_config.model_path,
|
311 |
+
cfg.wav2vec_config.features == "last",
|
312 |
+
os.path.dirname(cfg.audio_separator.model_path),
|
313 |
+
os.path.basename(cfg.audio_separator.model_path),
|
314 |
+
os.path.join(save_dir, '.cache', "audio_preprocess")
|
315 |
+
)
|
316 |
+
|
317 |
+
for idx, ref_img_path in enumerate(cfg.ref_img_path):
|
318 |
+
audio_path = cfg.audio_path[idx]
|
319 |
+
source_image_pixels, \
|
320 |
+
source_image_face_region, \
|
321 |
+
source_image_face_emb, \
|
322 |
+
source_image_full_mask, \
|
323 |
+
source_image_face_mask, \
|
324 |
+
source_image_lip_mask = image_processor.preprocess(
|
325 |
+
ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio)
|
326 |
+
audio_emb, audio_length = audio_processor.preprocess(
|
327 |
+
audio_path, clip_length)
|
328 |
+
|
329 |
+
audio_emb = process_audio_emb(audio_emb)
|
330 |
+
|
331 |
+
source_image_pixels = source_image_pixels.unsqueeze(0)
|
332 |
+
source_image_face_region = source_image_face_region.unsqueeze(0)
|
333 |
+
source_image_face_emb = source_image_face_emb.reshape(1, -1)
|
334 |
+
source_image_face_emb = torch.tensor(source_image_face_emb)
|
335 |
+
|
336 |
+
source_image_full_mask = [
|
337 |
+
(mask.repeat(clip_length, 1))
|
338 |
+
for mask in source_image_full_mask
|
339 |
+
]
|
340 |
+
source_image_face_mask = [
|
341 |
+
(mask.repeat(clip_length, 1))
|
342 |
+
for mask in source_image_face_mask
|
343 |
+
]
|
344 |
+
source_image_lip_mask = [
|
345 |
+
(mask.repeat(clip_length, 1))
|
346 |
+
for mask in source_image_lip_mask
|
347 |
+
]
|
348 |
+
|
349 |
+
times = audio_emb.shape[0] // clip_length
|
350 |
+
tensor_result = []
|
351 |
+
generator = torch.manual_seed(42)
|
352 |
+
for t in range(times):
|
353 |
+
print(f"[{t+1}/{times}]")
|
354 |
+
|
355 |
+
if len(tensor_result) == 0:
|
356 |
+
# The first iteration
|
357 |
+
motion_zeros = source_image_pixels.repeat(
|
358 |
+
cfg.data.n_motion_frames, 1, 1, 1)
|
359 |
+
motion_zeros = motion_zeros.to(
|
360 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
361 |
+
pixel_values_ref_img = torch.cat(
|
362 |
+
[source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
|
363 |
+
else:
|
364 |
+
motion_frames = tensor_result[-1][0]
|
365 |
+
motion_frames = motion_frames.permute(1, 0, 2, 3)
|
366 |
+
motion_frames = motion_frames[0 - cfg.data.n_motion_frames:]
|
367 |
+
motion_frames = motion_frames * 2.0 - 1.0
|
368 |
+
motion_frames = motion_frames.to(
|
369 |
+
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
370 |
+
pixel_values_ref_img = torch.cat(
|
371 |
+
[source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
|
372 |
+
|
373 |
+
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
374 |
+
|
375 |
+
audio_tensor = audio_emb[
|
376 |
+
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
|
377 |
+
]
|
378 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
379 |
+
audio_tensor = audio_tensor.to(
|
380 |
+
device=audioproj.device, dtype=audioproj.dtype)
|
381 |
+
audio_tensor = audioproj(audio_tensor)
|
382 |
+
|
383 |
+
pipeline_output = pipeline(
|
384 |
+
ref_image=pixel_values_ref_img,
|
385 |
+
audio_tensor=audio_tensor,
|
386 |
+
face_emb=source_image_face_emb,
|
387 |
+
face_mask=source_image_face_region,
|
388 |
+
pixel_values_full_mask=source_image_full_mask,
|
389 |
+
pixel_values_face_mask=source_image_face_mask,
|
390 |
+
pixel_values_lip_mask=source_image_lip_mask,
|
391 |
+
width=cfg.data.train_width,
|
392 |
+
height=cfg.data.train_height,
|
393 |
+
video_length=clip_length,
|
394 |
+
num_inference_steps=cfg.inference_steps,
|
395 |
+
guidance_scale=cfg.cfg_scale,
|
396 |
+
generator=generator,
|
397 |
+
)
|
398 |
+
|
399 |
+
tensor_result.append(pipeline_output.videos)
|
400 |
+
|
401 |
+
tensor_result = torch.cat(tensor_result, dim=2)
|
402 |
+
tensor_result = tensor_result.squeeze(0)
|
403 |
+
tensor_result = tensor_result[:, :audio_length]
|
404 |
+
audio_name = os.path.basename(audio_path).split('.')[0]
|
405 |
+
ref_name = os.path.basename(ref_img_path).split('.')[0]
|
406 |
+
output_file = os.path.join(save_dir,f"{global_step}_{ref_name}_{audio_name}.mp4")
|
407 |
+
# save the result after all iteration
|
408 |
+
tensor_to_video(tensor_result, output_file, audio_path)
|
409 |
+
|
410 |
+
|
411 |
+
# clean up
|
412 |
+
del tmp_denoising_unet
|
413 |
+
del pipeline
|
414 |
+
del image_processor
|
415 |
+
del audio_processor
|
416 |
+
torch.cuda.empty_cache()
|
417 |
+
|
418 |
+
return tensor_result
|
419 |
+
|
420 |
+
|
421 |
+
def train_stage2_process(cfg: argparse.Namespace) -> None:
|
422 |
+
"""
|
423 |
+
Trains the model using the given configuration (cfg).
|
424 |
+
|
425 |
+
Args:
|
426 |
+
cfg (dict): The configuration dictionary containing the parameters for training.
|
427 |
+
|
428 |
+
Notes:
|
429 |
+
- This function trains the model using the given configuration.
|
430 |
+
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
|
431 |
+
- The training progress is logged and tracked using the accelerator.
|
432 |
+
- The trained model is saved after the training is completed.
|
433 |
+
"""
|
434 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
435 |
+
accelerator = Accelerator(
|
436 |
+
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
437 |
+
mixed_precision=cfg.solver.mixed_precision,
|
438 |
+
log_with="mlflow",
|
439 |
+
project_dir="./mlruns",
|
440 |
+
kwargs_handlers=[kwargs],
|
441 |
+
)
|
442 |
+
|
443 |
+
# Make one log on every process with the configuration for debugging.
|
444 |
+
logging.basicConfig(
|
445 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
446 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
447 |
+
level=logging.INFO,
|
448 |
+
)
|
449 |
+
logger.info(accelerator.state, main_process_only=False)
|
450 |
+
if accelerator.is_local_main_process:
|
451 |
+
transformers.utils.logging.set_verbosity_warning()
|
452 |
+
diffusers.utils.logging.set_verbosity_info()
|
453 |
+
else:
|
454 |
+
transformers.utils.logging.set_verbosity_error()
|
455 |
+
diffusers.utils.logging.set_verbosity_error()
|
456 |
+
|
457 |
+
# If passed along, set the training seed now.
|
458 |
+
if cfg.seed is not None:
|
459 |
+
seed_everything(cfg.seed)
|
460 |
+
|
461 |
+
# create output dir for training
|
462 |
+
exp_name = cfg.exp_name
|
463 |
+
save_dir = f"{cfg.output_dir}/{exp_name}"
|
464 |
+
checkpoint_dir = os.path.join(save_dir, "checkpoints")
|
465 |
+
module_dir = os.path.join(save_dir, "modules")
|
466 |
+
validation_dir = os.path.join(save_dir, "validation")
|
467 |
+
if accelerator.is_main_process:
|
468 |
+
init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
|
469 |
+
|
470 |
+
accelerator.wait_for_everyone()
|
471 |
+
|
472 |
+
if cfg.weight_dtype == "fp16":
|
473 |
+
weight_dtype = torch.float16
|
474 |
+
elif cfg.weight_dtype == "bf16":
|
475 |
+
weight_dtype = torch.bfloat16
|
476 |
+
elif cfg.weight_dtype == "fp32":
|
477 |
+
weight_dtype = torch.float32
|
478 |
+
else:
|
479 |
+
raise ValueError(
|
480 |
+
f"Do not support weight dtype: {cfg.weight_dtype} during training"
|
481 |
+
)
|
482 |
+
|
483 |
+
# Create Models
|
484 |
+
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
|
485 |
+
"cuda", dtype=weight_dtype
|
486 |
+
)
|
487 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
488 |
+
cfg.base_model_path,
|
489 |
+
subfolder="unet",
|
490 |
+
).to(device="cuda", dtype=weight_dtype)
|
491 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
492 |
+
cfg.base_model_path,
|
493 |
+
cfg.mm_path,
|
494 |
+
subfolder="unet",
|
495 |
+
unet_additional_kwargs=OmegaConf.to_container(
|
496 |
+
cfg.unet_additional_kwargs),
|
497 |
+
use_landmark=False
|
498 |
+
).to(device="cuda", dtype=weight_dtype)
|
499 |
+
imageproj = ImageProjModel(
|
500 |
+
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
501 |
+
clip_embeddings_dim=512,
|
502 |
+
clip_extra_context_tokens=4,
|
503 |
+
).to(device="cuda", dtype=weight_dtype)
|
504 |
+
face_locator = FaceLocator(
|
505 |
+
conditioning_embedding_channels=320,
|
506 |
+
).to(device="cuda", dtype=weight_dtype)
|
507 |
+
audioproj = AudioProjModel(
|
508 |
+
seq_len=5,
|
509 |
+
blocks=12,
|
510 |
+
channels=768,
|
511 |
+
intermediate_dim=512,
|
512 |
+
output_dim=768,
|
513 |
+
context_tokens=32,
|
514 |
+
).to(device="cuda", dtype=weight_dtype)
|
515 |
+
|
516 |
+
# load module weight from stage 1
|
517 |
+
stage1_ckpt_dir = cfg.stage1_ckpt_dir
|
518 |
+
denoising_unet.load_state_dict(
|
519 |
+
torch.load(
|
520 |
+
os.path.join(stage1_ckpt_dir, "denoising_unet.pth"),
|
521 |
+
map_location="cpu",
|
522 |
+
),
|
523 |
+
strict=False,
|
524 |
+
)
|
525 |
+
reference_unet.load_state_dict(
|
526 |
+
torch.load(
|
527 |
+
os.path.join(stage1_ckpt_dir, "reference_unet.pth"),
|
528 |
+
map_location="cpu",
|
529 |
+
),
|
530 |
+
strict=False,
|
531 |
+
)
|
532 |
+
face_locator.load_state_dict(
|
533 |
+
torch.load(
|
534 |
+
os.path.join(stage1_ckpt_dir, "face_locator.pth"),
|
535 |
+
map_location="cpu",
|
536 |
+
),
|
537 |
+
strict=False,
|
538 |
+
)
|
539 |
+
imageproj.load_state_dict(
|
540 |
+
torch.load(
|
541 |
+
os.path.join(stage1_ckpt_dir, "imageproj.pth"),
|
542 |
+
map_location="cpu",
|
543 |
+
),
|
544 |
+
strict=False,
|
545 |
+
)
|
546 |
+
|
547 |
+
# Freeze
|
548 |
+
vae.requires_grad_(False)
|
549 |
+
imageproj.requires_grad_(False)
|
550 |
+
reference_unet.requires_grad_(False)
|
551 |
+
denoising_unet.requires_grad_(False)
|
552 |
+
face_locator.requires_grad_(False)
|
553 |
+
audioproj.requires_grad_(True)
|
554 |
+
|
555 |
+
# Set motion module learnable
|
556 |
+
trainable_modules = cfg.trainable_para
|
557 |
+
for name, module in denoising_unet.named_modules():
|
558 |
+
if any(trainable_mod in name for trainable_mod in trainable_modules):
|
559 |
+
for params in module.parameters():
|
560 |
+
params.requires_grad_(True)
|
561 |
+
|
562 |
+
reference_control_writer = ReferenceAttentionControl(
|
563 |
+
reference_unet,
|
564 |
+
do_classifier_free_guidance=False,
|
565 |
+
mode="write",
|
566 |
+
fusion_blocks="full",
|
567 |
+
)
|
568 |
+
reference_control_reader = ReferenceAttentionControl(
|
569 |
+
denoising_unet,
|
570 |
+
do_classifier_free_guidance=False,
|
571 |
+
mode="read",
|
572 |
+
fusion_blocks="full",
|
573 |
+
)
|
574 |
+
|
575 |
+
net = Net(
|
576 |
+
reference_unet,
|
577 |
+
denoising_unet,
|
578 |
+
face_locator,
|
579 |
+
reference_control_writer,
|
580 |
+
reference_control_reader,
|
581 |
+
imageproj,
|
582 |
+
audioproj,
|
583 |
+
).to(dtype=weight_dtype)
|
584 |
+
|
585 |
+
# get noise scheduler
|
586 |
+
train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
|
587 |
+
|
588 |
+
if cfg.solver.enable_xformers_memory_efficient_attention:
|
589 |
+
if is_xformers_available():
|
590 |
+
reference_unet.enable_xformers_memory_efficient_attention()
|
591 |
+
denoising_unet.enable_xformers_memory_efficient_attention()
|
592 |
+
|
593 |
+
else:
|
594 |
+
raise ValueError(
|
595 |
+
"xformers is not available. Make sure it is installed correctly"
|
596 |
+
)
|
597 |
+
|
598 |
+
if cfg.solver.gradient_checkpointing:
|
599 |
+
reference_unet.enable_gradient_checkpointing()
|
600 |
+
denoising_unet.enable_gradient_checkpointing()
|
601 |
+
|
602 |
+
if cfg.solver.scale_lr:
|
603 |
+
learning_rate = (
|
604 |
+
cfg.solver.learning_rate
|
605 |
+
* cfg.solver.gradient_accumulation_steps
|
606 |
+
* cfg.data.train_bs
|
607 |
+
* accelerator.num_processes
|
608 |
+
)
|
609 |
+
else:
|
610 |
+
learning_rate = cfg.solver.learning_rate
|
611 |
+
|
612 |
+
# Initialize the optimizer
|
613 |
+
if cfg.solver.use_8bit_adam:
|
614 |
+
try:
|
615 |
+
import bitsandbytes as bnb
|
616 |
+
except ImportError as exc:
|
617 |
+
raise ImportError(
|
618 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
619 |
+
) from exc
|
620 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
621 |
+
else:
|
622 |
+
optimizer_cls = torch.optim.AdamW
|
623 |
+
|
624 |
+
trainable_params = list(
|
625 |
+
filter(lambda p: p.requires_grad, net.parameters()))
|
626 |
+
logger.info(f"Total trainable params {len(trainable_params)}")
|
627 |
+
optimizer = optimizer_cls(
|
628 |
+
trainable_params,
|
629 |
+
lr=learning_rate,
|
630 |
+
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
631 |
+
weight_decay=cfg.solver.adam_weight_decay,
|
632 |
+
eps=cfg.solver.adam_epsilon,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Scheduler
|
636 |
+
lr_scheduler = get_scheduler(
|
637 |
+
cfg.solver.lr_scheduler,
|
638 |
+
optimizer=optimizer,
|
639 |
+
num_warmup_steps=cfg.solver.lr_warmup_steps
|
640 |
+
* cfg.solver.gradient_accumulation_steps,
|
641 |
+
num_training_steps=cfg.solver.max_train_steps
|
642 |
+
* cfg.solver.gradient_accumulation_steps,
|
643 |
+
)
|
644 |
+
|
645 |
+
# get data loader
|
646 |
+
train_dataset = TalkingVideoDataset(
|
647 |
+
img_size=(cfg.data.train_width, cfg.data.train_height),
|
648 |
+
sample_rate=cfg.data.sample_rate,
|
649 |
+
n_sample_frames=cfg.data.n_sample_frames,
|
650 |
+
n_motion_frames=cfg.data.n_motion_frames,
|
651 |
+
audio_margin=cfg.data.audio_margin,
|
652 |
+
data_meta_paths=cfg.data.train_meta_paths,
|
653 |
+
wav2vec_cfg=cfg.wav2vec_config,
|
654 |
+
)
|
655 |
+
train_dataloader = torch.utils.data.DataLoader(
|
656 |
+
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16
|
657 |
+
)
|
658 |
+
|
659 |
+
# Prepare everything with our `accelerator`.
|
660 |
+
(
|
661 |
+
net,
|
662 |
+
optimizer,
|
663 |
+
train_dataloader,
|
664 |
+
lr_scheduler,
|
665 |
+
) = accelerator.prepare(
|
666 |
+
net,
|
667 |
+
optimizer,
|
668 |
+
train_dataloader,
|
669 |
+
lr_scheduler,
|
670 |
+
)
|
671 |
+
|
672 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
673 |
+
num_update_steps_per_epoch = math.ceil(
|
674 |
+
len(train_dataloader) / cfg.solver.gradient_accumulation_steps
|
675 |
+
)
|
676 |
+
# Afterwards we recalculate our number of training epochs
|
677 |
+
num_train_epochs = math.ceil(
|
678 |
+
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
679 |
+
)
|
680 |
+
|
681 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
682 |
+
# The trackers initializes automatically on the main process.
|
683 |
+
if accelerator.is_main_process:
|
684 |
+
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
685 |
+
accelerator.init_trackers(
|
686 |
+
exp_name,
|
687 |
+
init_kwargs={"mlflow": {"run_name": run_time}},
|
688 |
+
)
|
689 |
+
# dump config file
|
690 |
+
mlflow.log_dict(
|
691 |
+
OmegaConf.to_container(
|
692 |
+
cfg), "config.yaml"
|
693 |
+
)
|
694 |
+
logger.info(f"save config to {save_dir}")
|
695 |
+
OmegaConf.save(
|
696 |
+
cfg, os.path.join(save_dir, "config.yaml")
|
697 |
+
)
|
698 |
+
|
699 |
+
# Train!
|
700 |
+
total_batch_size = (
|
701 |
+
cfg.data.train_bs
|
702 |
+
* accelerator.num_processes
|
703 |
+
* cfg.solver.gradient_accumulation_steps
|
704 |
+
)
|
705 |
+
|
706 |
+
logger.info("***** Running training *****")
|
707 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
708 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
709 |
+
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
|
710 |
+
logger.info(
|
711 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
712 |
+
)
|
713 |
+
logger.info(
|
714 |
+
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
|
715 |
+
)
|
716 |
+
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
|
717 |
+
global_step = 0
|
718 |
+
first_epoch = 0
|
719 |
+
|
720 |
+
# # Potentially load in the weights and states from a previous save
|
721 |
+
if cfg.resume_from_checkpoint:
|
722 |
+
logger.info(f"Loading checkpoint from {checkpoint_dir}")
|
723 |
+
global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
|
724 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
725 |
+
|
726 |
+
# Only show the progress bar once on each machine.
|
727 |
+
progress_bar = tqdm(
|
728 |
+
range(global_step, cfg.solver.max_train_steps),
|
729 |
+
disable=not accelerator.is_local_main_process,
|
730 |
+
)
|
731 |
+
progress_bar.set_description("Steps")
|
732 |
+
|
733 |
+
for _ in range(first_epoch, num_train_epochs):
|
734 |
+
train_loss = 0.0
|
735 |
+
t_data_start = time.time()
|
736 |
+
for _, batch in enumerate(train_dataloader):
|
737 |
+
t_data = time.time() - t_data_start
|
738 |
+
with accelerator.accumulate(net):
|
739 |
+
# Convert videos to latent space
|
740 |
+
pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
|
741 |
+
|
742 |
+
pixel_values_face_mask = batch["pixel_values_face_mask"]
|
743 |
+
pixel_values_face_mask = get_attention_mask(
|
744 |
+
pixel_values_face_mask, weight_dtype
|
745 |
+
)
|
746 |
+
pixel_values_lip_mask = batch["pixel_values_lip_mask"]
|
747 |
+
pixel_values_lip_mask = get_attention_mask(
|
748 |
+
pixel_values_lip_mask, weight_dtype
|
749 |
+
)
|
750 |
+
pixel_values_full_mask = batch["pixel_values_full_mask"]
|
751 |
+
pixel_values_full_mask = get_attention_mask(
|
752 |
+
pixel_values_full_mask, weight_dtype
|
753 |
+
)
|
754 |
+
|
755 |
+
with torch.no_grad():
|
756 |
+
video_length = pixel_values_vid.shape[1]
|
757 |
+
pixel_values_vid = rearrange(
|
758 |
+
pixel_values_vid, "b f c h w -> (b f) c h w"
|
759 |
+
)
|
760 |
+
latents = vae.encode(pixel_values_vid).latent_dist.sample()
|
761 |
+
latents = rearrange(
|
762 |
+
latents, "(b f) c h w -> b c f h w", f=video_length
|
763 |
+
)
|
764 |
+
latents = latents * 0.18215
|
765 |
+
|
766 |
+
noise = torch.randn_like(latents)
|
767 |
+
if cfg.noise_offset > 0:
|
768 |
+
noise += cfg.noise_offset * torch.randn(
|
769 |
+
(latents.shape[0], latents.shape[1], 1, 1, 1),
|
770 |
+
device=latents.device,
|
771 |
+
)
|
772 |
+
|
773 |
+
bsz = latents.shape[0]
|
774 |
+
# Sample a random timestep for each video
|
775 |
+
timesteps = torch.randint(
|
776 |
+
0,
|
777 |
+
train_noise_scheduler.num_train_timesteps,
|
778 |
+
(bsz,),
|
779 |
+
device=latents.device,
|
780 |
+
)
|
781 |
+
timesteps = timesteps.long()
|
782 |
+
|
783 |
+
# mask for face locator
|
784 |
+
pixel_values_mask = (
|
785 |
+
batch["pixel_values_mask"].unsqueeze(
|
786 |
+
1).to(dtype=weight_dtype)
|
787 |
+
)
|
788 |
+
pixel_values_mask = repeat(
|
789 |
+
pixel_values_mask,
|
790 |
+
"b f c h w -> b (repeat f) c h w",
|
791 |
+
repeat=video_length,
|
792 |
+
)
|
793 |
+
pixel_values_mask = pixel_values_mask.transpose(
|
794 |
+
1, 2)
|
795 |
+
|
796 |
+
uncond_img_fwd = random.random() < cfg.uncond_img_ratio
|
797 |
+
uncond_audio_fwd = random.random() < cfg.uncond_audio_ratio
|
798 |
+
|
799 |
+
start_frame = random.random() < cfg.start_ratio
|
800 |
+
pixel_values_ref_img = batch["pixel_values_ref_img"].to(
|
801 |
+
dtype=weight_dtype
|
802 |
+
)
|
803 |
+
# initialize the motion frames as zero maps
|
804 |
+
if start_frame:
|
805 |
+
pixel_values_ref_img[:, 1:] = 0.0
|
806 |
+
|
807 |
+
ref_img_and_motion = rearrange(
|
808 |
+
pixel_values_ref_img, "b f c h w -> (b f) c h w"
|
809 |
+
)
|
810 |
+
|
811 |
+
with torch.no_grad():
|
812 |
+
ref_image_latents = vae.encode(
|
813 |
+
ref_img_and_motion
|
814 |
+
).latent_dist.sample()
|
815 |
+
ref_image_latents = ref_image_latents * 0.18215
|
816 |
+
image_prompt_embeds = batch["face_emb"].to(
|
817 |
+
dtype=imageproj.dtype, device=imageproj.device
|
818 |
+
)
|
819 |
+
|
820 |
+
# add noise
|
821 |
+
noisy_latents = train_noise_scheduler.add_noise(
|
822 |
+
latents, noise, timesteps
|
823 |
+
)
|
824 |
+
|
825 |
+
# Get the target for loss depending on the prediction type
|
826 |
+
if train_noise_scheduler.prediction_type == "epsilon":
|
827 |
+
target = noise
|
828 |
+
elif train_noise_scheduler.prediction_type == "v_prediction":
|
829 |
+
target = train_noise_scheduler.get_velocity(
|
830 |
+
latents, noise, timesteps
|
831 |
+
)
|
832 |
+
else:
|
833 |
+
raise ValueError(
|
834 |
+
f"Unknown prediction type {train_noise_scheduler.prediction_type}"
|
835 |
+
)
|
836 |
+
|
837 |
+
# ---- Forward!!! -----
|
838 |
+
model_pred = net(
|
839 |
+
noisy_latents=noisy_latents,
|
840 |
+
timesteps=timesteps,
|
841 |
+
ref_image_latents=ref_image_latents,
|
842 |
+
face_emb=image_prompt_embeds,
|
843 |
+
mask=pixel_values_mask,
|
844 |
+
full_mask=pixel_values_full_mask,
|
845 |
+
face_mask=pixel_values_face_mask,
|
846 |
+
lip_mask=pixel_values_lip_mask,
|
847 |
+
audio_emb=batch["audio_tensor"].to(
|
848 |
+
dtype=weight_dtype),
|
849 |
+
uncond_img_fwd=uncond_img_fwd,
|
850 |
+
uncond_audio_fwd=uncond_audio_fwd,
|
851 |
+
)
|
852 |
+
|
853 |
+
if cfg.snr_gamma == 0:
|
854 |
+
loss = F.mse_loss(
|
855 |
+
model_pred.float(),
|
856 |
+
target.float(),
|
857 |
+
reduction="mean",
|
858 |
+
)
|
859 |
+
else:
|
860 |
+
snr = compute_snr(train_noise_scheduler, timesteps)
|
861 |
+
if train_noise_scheduler.config.prediction_type == "v_prediction":
|
862 |
+
# Velocity objective requires that we add one to SNR values before we divide by them.
|
863 |
+
snr = snr + 1
|
864 |
+
mse_loss_weights = (
|
865 |
+
torch.stack(
|
866 |
+
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
|
867 |
+
).min(dim=1)[0]
|
868 |
+
/ snr
|
869 |
+
)
|
870 |
+
loss = F.mse_loss(
|
871 |
+
model_pred.float(),
|
872 |
+
target.float(),
|
873 |
+
reduction="mean",
|
874 |
+
)
|
875 |
+
loss = (
|
876 |
+
loss.mean(dim=list(range(1, len(loss.shape))))
|
877 |
+
* mse_loss_weights
|
878 |
+
).mean()
|
879 |
+
|
880 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
881 |
+
avg_loss = accelerator.gather(
|
882 |
+
loss.repeat(cfg.data.train_bs)).mean()
|
883 |
+
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
|
884 |
+
|
885 |
+
# Backpropagate
|
886 |
+
accelerator.backward(loss)
|
887 |
+
if accelerator.sync_gradients:
|
888 |
+
accelerator.clip_grad_norm_(
|
889 |
+
trainable_params,
|
890 |
+
cfg.solver.max_grad_norm,
|
891 |
+
)
|
892 |
+
optimizer.step()
|
893 |
+
lr_scheduler.step()
|
894 |
+
optimizer.zero_grad()
|
895 |
+
|
896 |
+
if accelerator.sync_gradients:
|
897 |
+
reference_control_reader.clear()
|
898 |
+
reference_control_writer.clear()
|
899 |
+
progress_bar.update(1)
|
900 |
+
global_step += 1
|
901 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
902 |
+
train_loss = 0.0
|
903 |
+
|
904 |
+
if global_step % cfg.val.validation_steps == 0 or global_step==1:
|
905 |
+
if accelerator.is_main_process:
|
906 |
+
generator = torch.Generator(device=accelerator.device)
|
907 |
+
generator.manual_seed(cfg.seed)
|
908 |
+
|
909 |
+
log_validation(
|
910 |
+
accelerator=accelerator,
|
911 |
+
vae=vae,
|
912 |
+
net=net,
|
913 |
+
scheduler=val_noise_scheduler,
|
914 |
+
width=cfg.data.train_width,
|
915 |
+
height=cfg.data.train_height,
|
916 |
+
clip_length=cfg.data.n_sample_frames,
|
917 |
+
cfg=cfg,
|
918 |
+
save_dir=validation_dir,
|
919 |
+
global_step=global_step,
|
920 |
+
times=cfg.single_inference_times if cfg.single_inference_times is not None else None,
|
921 |
+
face_analysis_model_path=cfg.face_analysis_model_path
|
922 |
+
)
|
923 |
+
|
924 |
+
logs = {
|
925 |
+
"step_loss": loss.detach().item(),
|
926 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
927 |
+
"td": f"{t_data:.2f}s",
|
928 |
+
}
|
929 |
+
t_data_start = time.time()
|
930 |
+
progress_bar.set_postfix(**logs)
|
931 |
+
|
932 |
+
if (
|
933 |
+
global_step % cfg.checkpointing_steps == 0
|
934 |
+
or global_step == cfg.solver.max_train_steps
|
935 |
+
):
|
936 |
+
# save model
|
937 |
+
save_path = os.path.join(
|
938 |
+
checkpoint_dir, f"checkpoint-{global_step}")
|
939 |
+
if accelerator.is_main_process:
|
940 |
+
delete_additional_ckpt(checkpoint_dir, 30)
|
941 |
+
accelerator.wait_for_everyone()
|
942 |
+
accelerator.save_state(save_path)
|
943 |
+
|
944 |
+
# save model weight
|
945 |
+
unwrap_net = accelerator.unwrap_model(net)
|
946 |
+
if accelerator.is_main_process:
|
947 |
+
save_checkpoint(
|
948 |
+
unwrap_net,
|
949 |
+
module_dir,
|
950 |
+
"net",
|
951 |
+
global_step,
|
952 |
+
total_limit=30,
|
953 |
+
)
|
954 |
+
if global_step >= cfg.solver.max_train_steps:
|
955 |
+
break
|
956 |
+
|
957 |
+
# Create the pipeline using the trained modules and save it.
|
958 |
+
accelerator.wait_for_everyone()
|
959 |
+
accelerator.end_training()
|
960 |
+
|
961 |
+
|
962 |
+
def load_config(config_path: str) -> dict:
|
963 |
+
"""
|
964 |
+
Loads the configuration file.
|
965 |
+
|
966 |
+
Args:
|
967 |
+
config_path (str): Path to the configuration file.
|
968 |
+
|
969 |
+
Returns:
|
970 |
+
dict: The configuration dictionary.
|
971 |
+
"""
|
972 |
+
|
973 |
+
if config_path.endswith(".yaml"):
|
974 |
+
return OmegaConf.load(config_path)
|
975 |
+
if config_path.endswith(".py"):
|
976 |
+
return import_filename(config_path).cfg
|
977 |
+
raise ValueError("Unsupported format for config file")
|
978 |
+
|
979 |
+
|
980 |
+
if __name__ == "__main__":
|
981 |
+
parser = argparse.ArgumentParser()
|
982 |
+
parser.add_argument(
|
983 |
+
"--config", type=str, default="./configs/train/stage2.yaml"
|
984 |
+
)
|
985 |
+
args = parser.parse_args()
|
986 |
+
|
987 |
+
try:
|
988 |
+
config = load_config(args.config)
|
989 |
+
train_stage2_process(config)
|
990 |
+
except Exception as e:
|
991 |
+
logging.error("Failed to execute the training process: %s", e)
|