diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..27756e25cf7793a445f77d2f81ee37499fb156eb
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..6e8a9589424f44b5af9a9bfd51285d33dc9e1da5 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.task filter=lfs diff=lfs merge=lfs -text
+*.mat filter=lfs diff=lfs merge=lfs -text
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..32a9c071f21dd7615b21ff6d692fa02f40626a2e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,36 @@
+
+## Introduction
+
+This repo provides the inference Gradio demo for **Hybrid (Trajectory + Landmark)** Control of MOFA-Video.
+
+## Environment Setup
+
+```
+cd MOFA-Hybrid
+conda create -n mofa python==3.10
+conda activate mofa
+pip install -r requirements.txt
+pip install opencv-python-headless
+pip install "git+https://github.com/facebookresearch/pytorch3d.git"
+```
+
+**IMPORTANT:** Gradio Version of **4.5.0** should be used since other versions may cause errors.
+
+
+## Checkpoints Download
+1. Download the checkpoint of CMP from [here](https://huggingface.co/MyNiuuu/MOFA-Video-Hybrid/blob/main/models/cmp/experiments/semiauto_annot/resnet50_vip%2Bmpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar) and put it into `./models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints`.
+
+2. Downloading the necessary pretrained checkpoints from [huggingface](https://huggingface.co/MyNiuuu/MOFA-Video-Hybrid). It is recommended to directly using git lfs to clone the [huggingface repo](https://huggingface.co/MyNiuuu/MOFA-Video-Hybrid). The checkpoints should be orgnized as `./ckpt_tree.md` (they will be automatically organized if you use git lfs to clone the [huggingface repo](https://huggingface.co/MyNiuuu/MOFA-Video-Hybrid)).
+
+
+## Run Gradio Demo
+
+### Using audio to animate the facial part
+
+`python run_gradio_audio_driven.py`
+
+### Using refernce video to animate the facial part
+
+`python run_gradio_audio_driven.py`
+
+**IMPORTANT:** Please refer to the instructions on the gradio interface during the inference process.
\ No newline at end of file
diff --git a/aniportrait/.DS_Store b/aniportrait/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..c34099dd3c099c3da89663afea31152c2a477eca
Binary files /dev/null and b/aniportrait/.DS_Store differ
diff --git a/aniportrait/audio2ldmk.py b/aniportrait/audio2ldmk.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a4ff41987743161dd65f3a536e2e7dbc7c7c65d
--- /dev/null
+++ b/aniportrait/audio2ldmk.py
@@ -0,0 +1,310 @@
+import argparse
+import os
+# import ffmpeg
+import random
+import numpy as np
+import cv2
+import torch
+import torchvision
+from omegaconf import OmegaConf
+from PIL import Image
+
+from src.audio_models.model import Audio2MeshModel
+from src.audio_models.pose_model import Audio2PoseModel
+from src.utils.audio_util import prepare_audio_feature
+from src.utils.mp_utils import LMKExtractor
+from src.utils.pose_util import project_points, smooth_pose_seq
+
+
+PARTS = [
+ ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)),
+ ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)),
+ ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)),
+ ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)),
+ ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)),
+ ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)),
+ ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)),
+ ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)),
+ ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)),
+ ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)),
+ ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)),
+ ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)),
+ ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)),
+ ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)),
+ ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)),
+]
+
+
+def draw_landmarks(keypoints, h, w):
+
+ image = np.zeros((h, w, 3))
+
+ for name, indices, color in PARTS:
+ # 选择当前部分的关键点
+ indices = np.array(indices) - 1
+ current_part_keypoints = keypoints[indices]
+
+ # 绘制关键点
+ # for point in current_part_keypoints:
+ # x, y = point
+ # image[y, x, :] = color
+
+ # 绘制连接线
+ for i in range(len(indices) - 1):
+ x1, y1 = current_part_keypoints[i]
+ x2, y2 = current_part_keypoints[i + 1]
+ cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2)
+
+ return image
+
+
+
+def convert_ldmk_to_68(mediapipe_ldmk):
+ return np.stack([
+ # face coutour
+ mediapipe_ldmk[:, 234],
+ mediapipe_ldmk[:, 93],
+ mediapipe_ldmk[:, 132],
+ mediapipe_ldmk[:, 58],
+ mediapipe_ldmk[:, 172],
+ mediapipe_ldmk[:, 136],
+ mediapipe_ldmk[:, 150],
+ mediapipe_ldmk[:, 176],
+ mediapipe_ldmk[:, 152],
+ mediapipe_ldmk[:, 400],
+ mediapipe_ldmk[:, 379],
+ mediapipe_ldmk[:, 365],
+ mediapipe_ldmk[:, 397],
+ mediapipe_ldmk[:, 288],
+ mediapipe_ldmk[:, 361],
+ mediapipe_ldmk[:, 323],
+ mediapipe_ldmk[:, 454],
+ # right eyebrow
+ mediapipe_ldmk[:, 70],
+ mediapipe_ldmk[:, 63],
+ mediapipe_ldmk[:, 105],
+ mediapipe_ldmk[:, 66],
+ mediapipe_ldmk[:, 107],
+ # left eyebrow
+ mediapipe_ldmk[:, 336],
+ mediapipe_ldmk[:, 296],
+ mediapipe_ldmk[:, 334],
+ mediapipe_ldmk[:, 293],
+ mediapipe_ldmk[:, 300],
+ # nose
+ mediapipe_ldmk[:, 168],
+ mediapipe_ldmk[:, 6],
+ mediapipe_ldmk[:, 195],
+ mediapipe_ldmk[:, 4],
+ # nose down
+ mediapipe_ldmk[:, 239],
+ mediapipe_ldmk[:, 241],
+ mediapipe_ldmk[:, 19],
+ mediapipe_ldmk[:, 461],
+ mediapipe_ldmk[:, 459],
+ # right eye
+ mediapipe_ldmk[:, 33],
+ mediapipe_ldmk[:, 160],
+ mediapipe_ldmk[:, 158],
+ mediapipe_ldmk[:, 133],
+ mediapipe_ldmk[:, 153],
+ mediapipe_ldmk[:, 144],
+ # left eye
+ mediapipe_ldmk[:, 362],
+ mediapipe_ldmk[:, 385],
+ mediapipe_ldmk[:, 387],
+ mediapipe_ldmk[:, 263],
+ mediapipe_ldmk[:, 373],
+ mediapipe_ldmk[:, 380],
+ # outer lips
+ mediapipe_ldmk[:, 61],
+ mediapipe_ldmk[:, 40],
+ mediapipe_ldmk[:, 37],
+ mediapipe_ldmk[:, 0],
+ mediapipe_ldmk[:, 267],
+ mediapipe_ldmk[:, 270],
+ mediapipe_ldmk[:, 291],
+ mediapipe_ldmk[:, 321],
+ mediapipe_ldmk[:, 314],
+ mediapipe_ldmk[:, 17],
+ mediapipe_ldmk[:, 84],
+ mediapipe_ldmk[:, 91],
+ # inner lips
+ mediapipe_ldmk[:, 78],
+ mediapipe_ldmk[:, 81],
+ mediapipe_ldmk[:, 13],
+ mediapipe_ldmk[:, 311],
+ mediapipe_ldmk[:, 308],
+ mediapipe_ldmk[:, 402],
+ mediapipe_ldmk[:, 14],
+ mediapipe_ldmk[:, 178],
+], axis=1)
+
+
+
+# def parse_args():
+# parser = argparse.ArgumentParser()
+# parser.add_argument("--config", type=str, default='./configs/prompts/animation_audio.yaml')
+# parser.add_argument("-W", type=int, default=512)
+# parser.add_argument("-H", type=int, default=512)
+# parser.add_argument("-L", type=int)
+# parser.add_argument("--seed", type=int, default=42)
+# parser.add_argument("--cfg", type=float, default=3.5)
+# parser.add_argument("--steps", type=int, default=25)
+# parser.add_argument("--fps", type=int, default=30)
+# parser.add_argument("-acc", "--accelerate", action='store_true')
+# parser.add_argument("--fi_step", type=int, default=3)
+# args = parser.parse_args()
+
+# return args
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ref_image_path", type=str, required=True)
+ parser.add_argument("--audio_path", type=str, required=True)
+ parser.add_argument("--save_dir", type=str, required=True)
+ parser.add_argument("--fps", type=int, default=25)
+ parser.add_argument("--sr", type=int, default=16000)
+ args = parser.parse_args()
+
+ return args
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+def main():
+ args = parse_args()
+
+ config = OmegaConf.load('aniportrait/configs/config.yaml')
+
+ set_seed(42)
+
+ # if config.weight_dtype == "fp16":
+ # weight_dtype = torch.float16
+ # else:
+ # weight_dtype = torch.float32
+
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
+ # prepare model
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
+ a2m_model.cuda().eval()
+
+ a2p_model = Audio2PoseModel(audio_infer_config['a2p_model'])
+ a2p_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2p_ckpt']), strict=False)
+ a2p_model.cuda().eval()
+
+ lmk_extractor = LMKExtractor()
+
+ ref_image_path = args.ref_image_path
+ audio_path = args.audio_path
+ save_dir = args.save_dir
+
+ ref_image_pil = Image.open(ref_image_path).convert("RGB")
+ ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)
+ height, width, _ = ref_image_np.shape
+
+ face_result = lmk_extractor(ref_image_np)
+ assert face_result is not None, "No face detected."
+ lmks = face_result['lmks'].astype(np.float32)
+ lmks[:, 0] *= width
+ lmks[:, 1] *= height
+
+ # print(lmks.shape)
+
+ # assert False
+
+ sample = prepare_audio_feature(audio_path, fps=args.fps, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
+ sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
+
+ # print(sample['audio_feature'].shape)
+
+ # inference
+ pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
+ pred = pred.squeeze().detach().cpu().numpy()
+ pred = pred.reshape(pred.shape[0], -1, 3)
+
+ pred = pred + face_result['lmks3d']
+
+ # print(pred.shape)
+
+ # assert False
+
+ id_seed = 42
+ id_seed = torch.LongTensor([id_seed]).cuda()
+
+ # Currently, only inference up to a maximum length of 10 seconds is supported.
+ chunk_duration = 5 # 5 seconds
+ chunk_size = args.sr * chunk_duration
+
+
+ audio_chunks = list(sample['audio_feature'].split(chunk_size, dim=1))
+ seq_len_list = [chunk_duration*args.fps] * (len(audio_chunks) - 1) + [sample['seq_len'] % (chunk_duration*args.fps)]
+ audio_chunks[-2] = torch.cat((audio_chunks[-2], audio_chunks[-1]), dim=1)
+ seq_len_list[-2] = seq_len_list[-2] + seq_len_list[-1]
+ del audio_chunks[-1]
+ del seq_len_list[-1]
+
+ # assert False
+
+ pose_seq = []
+ for audio, seq_len in zip(audio_chunks, seq_len_list):
+ pose_seq_chunk = a2p_model.infer(audio, seq_len, id_seed)
+ pose_seq_chunk = pose_seq_chunk.squeeze().detach().cpu().numpy()
+ pose_seq_chunk[:, :3] *= 0.5
+ pose_seq.append(pose_seq_chunk)
+
+ pose_seq = np.concatenate(pose_seq, 0)
+ pose_seq = smooth_pose_seq(pose_seq, 7)
+
+ # project 3D mesh to 2D landmark
+ projected_vertices = project_points(pred, face_result['trans_mat'], pose_seq, [height, width])
+ projected_vertices = np.concatenate([lmks[:468, :2][None, :], projected_vertices], axis=0)
+ projected_vertices = convert_ldmk_to_68(projected_vertices)
+
+ # print(projected_vertices.shape)
+
+ pose_images = []
+ for i in range(projected_vertices.shape[0]):
+ pose_img = draw_landmarks(projected_vertices[i], height, width)
+ pose_images.append(pose_img)
+ pose_images = np.array(pose_images)
+
+ # print(pose_images.shape)
+
+ ref_image_np = cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)
+ ref_imgs = np.stack([ref_image_np]*(pose_images.shape[0]), axis=0)
+
+ all_np = np.concatenate([ref_imgs, pose_images], axis=2)
+
+ # print(projected_vertices.shape)
+
+ os.makedirs(save_dir, exist_ok=True)
+
+ np.save(os.path.join(save_dir, 'landmarks.npy'), projected_vertices)
+
+ torchvision.io.write_video(os.path.join(save_dir, 'landmarks.mp4'), all_np, fps=args.fps, video_codec='h264', options={'crf': '10'})
+
+ # stream = ffmpeg.input(os.path.join(save_dir, 'landmarks.mp4'))
+ # audio = ffmpeg.input(args.audio_path)
+ # ffmpeg.output(stream.video, audio.audio, os.path.join(save_dir, 'landmarks_audio.mp4'), vcodec='copy', acodec='aac').run()
+
+
+
+
+
+
+
+if __name__ == "__main__":
+ main()
+
\ No newline at end of file
diff --git a/aniportrait/configs/config.yaml b/aniportrait/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1be3d696e3f5acc118053ebeff327ff603abb53b
--- /dev/null
+++ b/aniportrait/configs/config.yaml
@@ -0,0 +1,12 @@
+pretrained_base_model_path: 'ckpts/aniportrait/stable-diffusion-v1-5'
+pretrained_vae_path: 'ckpts/aniportrait/sd-vae-ft-mse'
+image_encoder_path: 'ckpts/aniportrait/image_encoder'
+
+denoising_unet_path: "ckpts/aniportrait/denoising_unet.pth"
+reference_unet_path: "ckpts/aniportrait/reference_unet.pth"
+pose_guider_path: "ckpts/aniportrait/pose_guider.pth"
+motion_module_path: "ckpts/aniportrait/motion_module.pth"
+
+audio_inference_config: "aniportrait/configs/inference_audio.yaml"
+inference_config: "aniportrait/configs/inference_v2.yaml"
+weight_dtype: 'fp16'
diff --git a/aniportrait/configs/inference_audio.yaml b/aniportrait/configs/inference_audio.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c63499327424dc58f81660cd7d7fbc3c9b1a4b39
--- /dev/null
+++ b/aniportrait/configs/inference_audio.yaml
@@ -0,0 +1,17 @@
+a2m_model:
+ out_dim: 1404
+ latent_dim: 512
+ model_path: ckpts/aniportrait/wav2vec2-base-960h
+ only_last_fetures: True
+ from_pretrained: True
+
+a2p_model:
+ out_dim: 6
+ latent_dim: 512
+ model_path: ckpts/aniportrait/wav2vec2-base-960h
+ only_last_fetures: True
+ from_pretrained: True
+
+pretrained_model:
+ a2m_ckpt: ckpts/aniportrait/audio2mesh.pt
+ a2p_ckpt: ckpts/aniportrait/audio2pose.pt
diff --git a/aniportrait/configs/inference_v2.yaml b/aniportrait/configs/inference_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d613dca2d2e48a41295a89f47b5a82fd7032dba5
--- /dev/null
+++ b/aniportrait/configs/inference_v2.yaml
@@ -0,0 +1,35 @@
+unet_additional_kwargs:
+ use_inflated_groupnorm: true
+ unet_use_cross_frame_attention: false
+ unet_use_temporal_attention: false
+ use_motion_module: true
+ motion_module_resolutions:
+ - 1
+ - 2
+ - 4
+ - 8
+ motion_module_mid_block: true
+ motion_module_decoder_only: false
+ motion_module_type: Vanilla
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types:
+ - Temporal_Self
+ - Temporal_Self
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 32
+ temporal_attention_dim_div: 1
+
+noise_scheduler_kwargs:
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "linear"
+ clip_sample: false
+ steps_offset: 1
+ ### Zero-SNR params
+ prediction_type: "v_prediction"
+ rescale_betas_zero_snr: True
+ timestep_spacing: "trailing"
+
+sampler: DDIM
\ No newline at end of file
diff --git a/aniportrait/src/.DS_Store b/aniportrait/src/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..8126747bcf280f1d8dbd91bb5cddc97b93191099
Binary files /dev/null and b/aniportrait/src/.DS_Store differ
diff --git a/aniportrait/src/audio_models/mish.py b/aniportrait/src/audio_models/mish.py
new file mode 100644
index 0000000000000000000000000000000000000000..607b95d33edd40bb53f93682bdcd9e0ff31ffbe4
--- /dev/null
+++ b/aniportrait/src/audio_models/mish.py
@@ -0,0 +1,51 @@
+"""
+Applies the mish function element-wise:
+mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+@torch.jit.script
+def mish(input):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+ See additional documentation for mish class.
+ """
+ return input * torch.tanh(F.softplus(input))
+
+class Mish(nn.Module):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+
+ Shape:
+ - Input: (N, *) where * means, any number of additional
+ dimensions
+ - Output: (N, *), same shape as the input
+
+ Examples:
+ >>> m = Mish()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
+ """
+
+ def __init__(self):
+ """
+ Init method.
+ """
+ super().__init__()
+
+ def forward(self, input):
+ """
+ Forward pass of the function.
+ """
+ if torch.__version__ >= "1.9":
+ return F.mish(input)
+ else:
+ return mish(input)
\ No newline at end of file
diff --git a/aniportrait/src/audio_models/model.py b/aniportrait/src/audio_models/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..54040ee238677d2fc70b4f34dd78c191c13ed874
--- /dev/null
+++ b/aniportrait/src/audio_models/model.py
@@ -0,0 +1,71 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import Wav2Vec2Config
+
+from .torch_utils import get_mask_from_lengths
+from .wav2vec2 import Wav2Vec2Model
+
+
+class Audio2MeshModel(nn.Module):
+ def __init__(
+ self,
+ config
+ ):
+ super().__init__()
+ out_dim = config['out_dim']
+ latent_dim = config['latent_dim']
+ model_path = config['model_path']
+ only_last_fetures = config['only_last_fetures']
+ from_pretrained = config['from_pretrained']
+
+ self._only_last_features = only_last_fetures
+
+ self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
+ if from_pretrained:
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
+ else:
+ self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
+ self.audio_encoder.feature_extractor._freeze_parameters()
+
+ hidden_size = self.audio_encoder_config.hidden_size
+
+ self.in_fn = nn.Linear(hidden_size, latent_dim)
+
+ self.out_fn = nn.Linear(latent_dim, out_dim)
+ nn.init.constant_(self.out_fn.weight, 0)
+ nn.init.constant_(self.out_fn.bias, 0)
+
+ def forward(self, audio, label, audio_len=None):
+ attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None
+
+ seq_len = label.shape[1]
+
+ embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True,
+ attention_mask=attention_mask)
+
+ if self._only_last_features:
+ hidden_states = embeddings.last_hidden_state
+ else:
+ hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
+
+ layer_in = self.in_fn(hidden_states)
+ out = self.out_fn(layer_in)
+
+ return out, None
+
+ def infer(self, input_value, seq_len):
+ embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
+
+ if self._only_last_features:
+ hidden_states = embeddings.last_hidden_state
+ else:
+ hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
+
+ layer_in = self.in_fn(hidden_states)
+ out = self.out_fn(layer_in)
+
+ return out
+
+
diff --git a/aniportrait/src/audio_models/pose_model.py b/aniportrait/src/audio_models/pose_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f72f7477d45ac82c7fb277902cec612773c14257
--- /dev/null
+++ b/aniportrait/src/audio_models/pose_model.py
@@ -0,0 +1,125 @@
+import os
+import math
+import torch
+import torch.nn as nn
+from transformers import Wav2Vec2Config
+
+from .torch_utils import get_mask_from_lengths
+from .wav2vec2 import Wav2Vec2Model
+
+
+def init_biased_mask(n_head, max_seq_len, period):
+ def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = (2**(-2**-(math.log2(n)-3)))
+ ratio = start
+ return [start*ratio**i for i in range(n)]
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2**math.floor(math.log2(n))
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
+ slopes = torch.Tensor(get_slopes(n_head))
+ bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period)
+ bias = - torch.flip(bias,dims=[0])
+ alibi = torch.zeros(max_seq_len, max_seq_len)
+ for i in range(max_seq_len):
+ alibi[i, :i+1] = bias[-(i+1):]
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
+ mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
+ mask = mask.unsqueeze(0) + alibi
+ return mask
+
+
+def enc_dec_mask(device, T, S):
+ mask = torch.ones(T, S)
+ for i in range(T):
+ mask[i, i] = 0
+ return (mask==1).to(device=device)
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, max_len=600):
+ super(PositionalEncoding, self).__init__()
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return x
+
+
+class Audio2PoseModel(nn.Module):
+ def __init__(
+ self,
+ config
+ ):
+
+ super().__init__()
+
+ latent_dim = config['latent_dim']
+ model_path = config['model_path']
+ only_last_fetures = config['only_last_fetures']
+ from_pretrained = config['from_pretrained']
+ out_dim = config['out_dim']
+
+ self.out_dim = out_dim
+
+ self._only_last_features = only_last_fetures
+
+ self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True)
+ if from_pretrained:
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True)
+ else:
+ self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config)
+ self.audio_encoder.feature_extractor._freeze_parameters()
+
+ hidden_size = self.audio_encoder_config.hidden_size
+
+ self.pose_map = nn.Linear(out_dim, latent_dim)
+ self.in_fn = nn.Linear(hidden_size, latent_dim)
+
+ self.PPE = PositionalEncoding(latent_dim)
+ self.biased_mask = init_biased_mask(n_head = 8, max_seq_len = 600, period=1)
+ decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8, dim_feedforward=2*latent_dim, batch_first=True)
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=8)
+ self.pose_map_r = nn.Linear(latent_dim, out_dim)
+
+ self.id_embed = nn.Embedding(100, latent_dim) # 100 ids
+
+
+ def infer(self, input_value, seq_len, id_seed=None):
+ embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True)
+
+ if self._only_last_features:
+ hidden_states = embeddings.last_hidden_state
+ else:
+ hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states)
+
+ hidden_states = self.in_fn(hidden_states)
+
+ id_embedding = self.id_embed(id_seed).unsqueeze(1)
+
+ init_pose = torch.zeros([hidden_states.shape[0], 1, self.out_dim]).to(hidden_states.device)
+ for i in range(seq_len):
+ if i==0:
+ pose_emb = self.pose_map(init_pose)
+ pose_input = self.PPE(pose_emb)
+ else:
+ pose_input = self.PPE(pose_emb)
+
+ pose_input = pose_input + id_embedding
+ tgt_mask = self.biased_mask[:, :pose_input.shape[1], :pose_input.shape[1]].clone().detach().to(hidden_states.device)
+ memory_mask = enc_dec_mask(hidden_states.device, pose_input.shape[1], hidden_states.shape[1])
+ pose_out = self.transformer_decoder(pose_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask)
+ pose_out = self.pose_map_r(pose_out)
+ new_output = self.pose_map(pose_out[:,-1,:]).unsqueeze(1)
+ pose_emb = torch.cat((pose_emb, new_output), 1)
+ return pose_out
+
diff --git a/aniportrait/src/audio_models/torch_utils.py b/aniportrait/src/audio_models/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a91940405797c55cf685e8fa8f669adbc6089067
--- /dev/null
+++ b/aniportrait/src/audio_models/torch_utils.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn.functional as F
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ lengths = lengths.to(torch.long)
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
+ mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
+
+ return mask
+
+
+def linear_interpolation(features, seq_len):
+ features = features.transpose(1, 2)
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
+ return output_features.transpose(1, 2)
+
+
+if __name__ == "__main__":
+ import numpy as np
+ mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6])))
+ import pdb; pdb.set_trace()
\ No newline at end of file
diff --git a/aniportrait/src/audio_models/wav2vec2.py b/aniportrait/src/audio_models/wav2vec2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec9c2b93454d47f6820b53c511e70208710e408
--- /dev/null
+++ b/aniportrait/src/audio_models/wav2vec2.py
@@ -0,0 +1,125 @@
+from transformers import Wav2Vec2Config, Wav2Vec2Model
+from transformers.modeling_outputs import BaseModelOutput
+
+from .torch_utils import linear_interpolation
+
+# the implementation of Wav2Vec2Model is borrowed from
+# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+# initialize our encoder with the pre-trained wav2vec 2.0 weights.
+class Wav2Vec2Model(Wav2Vec2Model):
+ def __init__(self, config: Wav2Vec2Config):
+ super().__init__(config)
+
+ def forward(
+ self,
+ input_values,
+ seq_len,
+ attention_mask=None,
+ mask_time_indices=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ self.config.output_attentions = True
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, ) + encoder_outputs[1:]
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+ def feature_extract(
+ self,
+ input_values,
+ seq_len,
+ ):
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
+
+ return extract_features
+
+ def encode(
+ self,
+ extract_features,
+ attention_mask=None,
+ mask_time_indices=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ self.config.output_attentions = True
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, ) + encoder_outputs[1:]
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/aniportrait/src/utils/audio_util.py b/aniportrait/src/utils/audio_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b42c5c8c25ae86f41ac87439223f822f55bb5c0
--- /dev/null
+++ b/aniportrait/src/utils/audio_util.py
@@ -0,0 +1,30 @@
+import os
+import math
+
+import librosa
+import numpy as np
+from transformers import Wav2Vec2FeatureExtractor
+
+
+class DataProcessor:
+ def __init__(self, sampling_rate, wav2vec_model_path):
+ self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
+ self._sampling_rate = sampling_rate
+
+ def extract_feature(self, audio_path):
+ speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate)
+ input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values)
+ return input_value
+
+
+def prepare_audio_feature(wav_file, fps=25, sampling_rate=16000, wav2vec_model_path=None):
+ data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path)
+
+ input_value = data_preprocessor.extract_feature(wav_file)
+ seq_len = math.ceil(len(input_value)/sampling_rate*fps)
+ return {
+ "audio_feature": input_value,
+ "seq_len": seq_len
+ }
+
+
diff --git a/aniportrait/src/utils/draw_util.py b/aniportrait/src/utils/draw_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..96dbf49426c1f60dbf617f0970536ea7d37187f5
--- /dev/null
+++ b/aniportrait/src/utils/draw_util.py
@@ -0,0 +1,149 @@
+import cv2
+import mediapipe as mp
+import numpy as np
+from mediapipe.framework.formats import landmark_pb2
+
+class FaceMeshVisualizer:
+ def __init__(self, forehead_edge=False):
+ self.mp_drawing = mp.solutions.drawing_utils
+ mp_face_mesh = mp.solutions.face_mesh
+ self.mp_face_mesh = mp_face_mesh
+ self.forehead_edge = forehead_edge
+
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
+ f_thick = 2
+ f_rad = 1
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
+ head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad)
+
+ FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)]
+ FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)]
+
+ FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)]
+ FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)]
+
+ FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)]
+ FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)]
+
+ FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)]
+ FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)]
+
+ FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)]
+
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
+ face_connection_spec = {}
+ if self.forehead_edge:
+ for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+ else:
+ for edge in FACEMESH_CUSTOM_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
+ face_connection_spec[edge] = left_eye_draw
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
+ face_connection_spec[edge] = left_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
+ # face_connection_spec[edge] = left_iris_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
+ face_connection_spec[edge] = right_eye_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ face_connection_spec[edge] = right_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ # face_connection_spec[edge] = right_iris_draw
+ # for edge in mp_face_mesh.FACEMESH_LIPS:
+ # face_connection_spec[edge] = mouth_draw
+
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_obl
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_obr
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_ibl
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_ibr
+ for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_otl
+ for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_otr
+ for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_itl
+ for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_itr
+
+
+ iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
+
+ self.face_connection_spec = face_connection_spec
+ def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
+ if len(image.shape) != 3:
+ raise ValueError("Input image must be H,W,C.")
+ image_rows, image_cols, image_channels = image.shape
+ if image_channels != 3: # BGR channels
+ raise ValueError('Input image must contain three channel bgr data.')
+ for idx, landmark in enumerate(landmark_list.landmark):
+ if (
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
+ (landmark.HasField('presence') and landmark.presence < 0.5)
+ ):
+ continue
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
+ continue
+ image_x = int(image_cols*landmark.x)
+ image_y = int(image_rows*landmark.y)
+ draw_color = None
+ if isinstance(drawing_spec, Mapping):
+ if drawing_spec.get(idx) is None:
+ continue
+ else:
+ draw_color = drawing_spec[idx].color
+ elif isinstance(drawing_spec, DrawingSpec):
+ draw_color = drawing_spec.color
+ image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
+
+
+
+ def draw_landmarks(self, image_size, keypoints, normed=False):
+ ini_size = [512, 512]
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for i in range(keypoints.shape[0]):
+ landmark = new_landmarks.landmark.add()
+ if normed:
+ landmark.x = keypoints[i, 0]
+ landmark.y = keypoints[i, 1]
+ else:
+ landmark.x = keypoints[i, 0] / image_size[0]
+ landmark.y = keypoints[i, 1] / image_size[1]
+ landmark.z = 1.0
+
+ self.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=new_landmarks,
+ connections=self.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.face_connection_spec
+ )
+ # draw_pupils(image, face_landmarks, iris_landmark_spec, 2)
+ image = cv2.resize(image, (image_size[0], image_size[1]))
+
+ return image
+
diff --git a/aniportrait/src/utils/face_landmark.py b/aniportrait/src/utils/face_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6580cb2cded9dcfeab46b0d50c8931ed6256669
--- /dev/null
+++ b/aniportrait/src/utils/face_landmark.py
@@ -0,0 +1,3305 @@
+# Copyright 2023 The MediaPipe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MediaPipe face landmarker task."""
+
+import dataclasses
+import enum
+from typing import Callable, Mapping, Optional, List
+
+import numpy as np
+
+from mediapipe.framework.formats import classification_pb2
+from mediapipe.framework.formats import landmark_pb2
+from mediapipe.framework.formats import matrix_data_pb2
+from mediapipe.python import packet_creator
+from mediapipe.python import packet_getter
+from mediapipe.python._framework_bindings import image as image_module
+from mediapipe.python._framework_bindings import packet as packet_module
+# pylint: disable=unused-import
+from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2
+# pylint: enable=unused-import
+from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2
+from mediapipe.tasks.python.components.containers import category as category_module
+from mediapipe.tasks.python.components.containers import landmark as landmark_module
+from mediapipe.tasks.python.core import base_options as base_options_module
+from mediapipe.tasks.python.core import task_info as task_info_module
+from mediapipe.tasks.python.core.optional_dependencies import doc_controls
+from mediapipe.tasks.python.vision.core import base_vision_task_api
+from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
+from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
+
+_BaseOptions = base_options_module.BaseOptions
+_FaceLandmarkerGraphOptionsProto = (
+ face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions
+)
+_LayoutEnum = matrix_data_pb2.MatrixData.Layout
+_RunningMode = running_mode_module.VisionTaskRunningMode
+_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
+_TaskInfo = task_info_module.TaskInfo
+
+_IMAGE_IN_STREAM_NAME = 'image_in'
+_IMAGE_OUT_STREAM_NAME = 'image_out'
+_IMAGE_TAG = 'IMAGE'
+_NORM_RECT_STREAM_NAME = 'norm_rect_in'
+_NORM_RECT_TAG = 'NORM_RECT'
+_NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks'
+_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS'
+_BLENDSHAPES_STREAM_NAME = 'blendshapes'
+_BLENDSHAPES_TAG = 'BLENDSHAPES'
+_FACE_GEOMETRY_STREAM_NAME = 'face_geometry'
+_FACE_GEOMETRY_TAG = 'FACE_GEOMETRY'
+_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph'
+_MICRO_SECONDS_PER_MILLISECOND = 1000
+
+
+class Blendshapes(enum.IntEnum):
+ """The 52 blendshape coefficients."""
+
+ NEUTRAL = 0
+ BROW_DOWN_LEFT = 1
+ BROW_DOWN_RIGHT = 2
+ BROW_INNER_UP = 3
+ BROW_OUTER_UP_LEFT = 4
+ BROW_OUTER_UP_RIGHT = 5
+ CHEEK_PUFF = 6
+ CHEEK_SQUINT_LEFT = 7
+ CHEEK_SQUINT_RIGHT = 8
+ EYE_BLINK_LEFT = 9
+ EYE_BLINK_RIGHT = 10
+ EYE_LOOK_DOWN_LEFT = 11
+ EYE_LOOK_DOWN_RIGHT = 12
+ EYE_LOOK_IN_LEFT = 13
+ EYE_LOOK_IN_RIGHT = 14
+ EYE_LOOK_OUT_LEFT = 15
+ EYE_LOOK_OUT_RIGHT = 16
+ EYE_LOOK_UP_LEFT = 17
+ EYE_LOOK_UP_RIGHT = 18
+ EYE_SQUINT_LEFT = 19
+ EYE_SQUINT_RIGHT = 20
+ EYE_WIDE_LEFT = 21
+ EYE_WIDE_RIGHT = 22
+ JAW_FORWARD = 23
+ JAW_LEFT = 24
+ JAW_OPEN = 25
+ JAW_RIGHT = 26
+ MOUTH_CLOSE = 27
+ MOUTH_DIMPLE_LEFT = 28
+ MOUTH_DIMPLE_RIGHT = 29
+ MOUTH_FROWN_LEFT = 30
+ MOUTH_FROWN_RIGHT = 31
+ MOUTH_FUNNEL = 32
+ MOUTH_LEFT = 33
+ MOUTH_LOWER_DOWN_LEFT = 34
+ MOUTH_LOWER_DOWN_RIGHT = 35
+ MOUTH_PRESS_LEFT = 36
+ MOUTH_PRESS_RIGHT = 37
+ MOUTH_PUCKER = 38
+ MOUTH_RIGHT = 39
+ MOUTH_ROLL_LOWER = 40
+ MOUTH_ROLL_UPPER = 41
+ MOUTH_SHRUG_LOWER = 42
+ MOUTH_SHRUG_UPPER = 43
+ MOUTH_SMILE_LEFT = 44
+ MOUTH_SMILE_RIGHT = 45
+ MOUTH_STRETCH_LEFT = 46
+ MOUTH_STRETCH_RIGHT = 47
+ MOUTH_UPPER_UP_LEFT = 48
+ MOUTH_UPPER_UP_RIGHT = 49
+ NOSE_SNEER_LEFT = 50
+ NOSE_SNEER_RIGHT = 51
+
+
+class FaceLandmarksConnections:
+ """The connections between face landmarks."""
+
+ @dataclasses.dataclass
+ class Connection:
+ """The connection class for face landmarks."""
+
+ start: int
+ end: int
+
+ FACE_LANDMARKS_LIPS: List[Connection] = [
+ Connection(61, 146),
+ Connection(146, 91),
+ Connection(91, 181),
+ Connection(181, 84),
+ Connection(84, 17),
+ Connection(17, 314),
+ Connection(314, 405),
+ Connection(405, 321),
+ Connection(321, 375),
+ Connection(375, 291),
+ Connection(61, 185),
+ Connection(185, 40),
+ Connection(40, 39),
+ Connection(39, 37),
+ Connection(37, 0),
+ Connection(0, 267),
+ Connection(267, 269),
+ Connection(269, 270),
+ Connection(270, 409),
+ Connection(409, 291),
+ Connection(78, 95),
+ Connection(95, 88),
+ Connection(88, 178),
+ Connection(178, 87),
+ Connection(87, 14),
+ Connection(14, 317),
+ Connection(317, 402),
+ Connection(402, 318),
+ Connection(318, 324),
+ Connection(324, 308),
+ Connection(78, 191),
+ Connection(191, 80),
+ Connection(80, 81),
+ Connection(81, 82),
+ Connection(82, 13),
+ Connection(13, 312),
+ Connection(312, 311),
+ Connection(311, 310),
+ Connection(310, 415),
+ Connection(415, 308),
+ ]
+
+ FACE_LANDMARKS_LEFT_EYE: List[Connection] = [
+ Connection(263, 249),
+ Connection(249, 390),
+ Connection(390, 373),
+ Connection(373, 374),
+ Connection(374, 380),
+ Connection(380, 381),
+ Connection(381, 382),
+ Connection(382, 362),
+ Connection(263, 466),
+ Connection(466, 388),
+ Connection(388, 387),
+ Connection(387, 386),
+ Connection(386, 385),
+ Connection(385, 384),
+ Connection(384, 398),
+ Connection(398, 362),
+ ]
+
+ FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [
+ Connection(276, 283),
+ Connection(283, 282),
+ Connection(282, 295),
+ Connection(295, 285),
+ Connection(300, 293),
+ Connection(293, 334),
+ Connection(334, 296),
+ Connection(296, 336),
+ ]
+
+ FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [
+ Connection(474, 475),
+ Connection(475, 476),
+ Connection(476, 477),
+ Connection(477, 474),
+ ]
+
+ FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [
+ Connection(33, 7),
+ Connection(7, 163),
+ Connection(163, 144),
+ Connection(144, 145),
+ Connection(145, 153),
+ Connection(153, 154),
+ Connection(154, 155),
+ Connection(155, 133),
+ Connection(33, 246),
+ Connection(246, 161),
+ Connection(161, 160),
+ Connection(160, 159),
+ Connection(159, 158),
+ Connection(158, 157),
+ Connection(157, 173),
+ Connection(173, 133),
+ ]
+
+ FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [
+ Connection(46, 53),
+ Connection(53, 52),
+ Connection(52, 65),
+ Connection(65, 55),
+ Connection(70, 63),
+ Connection(63, 105),
+ Connection(105, 66),
+ Connection(66, 107),
+ ]
+
+ FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [
+ Connection(469, 470),
+ Connection(470, 471),
+ Connection(471, 472),
+ Connection(472, 469),
+ ]
+
+ FACE_LANDMARKS_FACE_OVAL: List[Connection] = [
+ Connection(10, 338),
+ Connection(338, 297),
+ Connection(297, 332),
+ Connection(332, 284),
+ Connection(284, 251),
+ Connection(251, 389),
+ Connection(389, 356),
+ Connection(356, 454),
+ Connection(454, 323),
+ Connection(323, 361),
+ Connection(361, 288),
+ Connection(288, 397),
+ Connection(397, 365),
+ Connection(365, 379),
+ Connection(379, 378),
+ Connection(378, 400),
+ Connection(400, 377),
+ Connection(377, 152),
+ Connection(152, 148),
+ Connection(148, 176),
+ Connection(176, 149),
+ Connection(149, 150),
+ Connection(150, 136),
+ Connection(136, 172),
+ Connection(172, 58),
+ Connection(58, 132),
+ Connection(132, 93),
+ Connection(93, 234),
+ Connection(234, 127),
+ Connection(127, 162),
+ Connection(162, 21),
+ Connection(21, 54),
+ Connection(54, 103),
+ Connection(103, 67),
+ Connection(67, 109),
+ Connection(109, 10),
+ ]
+
+ FACE_LANDMARKS_CONTOURS: List[Connection] = (
+ FACE_LANDMARKS_LIPS
+ + FACE_LANDMARKS_LEFT_EYE
+ + FACE_LANDMARKS_LEFT_EYEBROW
+ + FACE_LANDMARKS_RIGHT_EYE
+ + FACE_LANDMARKS_RIGHT_EYEBROW
+ + FACE_LANDMARKS_FACE_OVAL
+ )
+
+ FACE_LANDMARKS_TESSELATION: List[Connection] = [
+ Connection(127, 34),
+ Connection(34, 139),
+ Connection(139, 127),
+ Connection(11, 0),
+ Connection(0, 37),
+ Connection(37, 11),
+ Connection(232, 231),
+ Connection(231, 120),
+ Connection(120, 232),
+ Connection(72, 37),
+ Connection(37, 39),
+ Connection(39, 72),
+ Connection(128, 121),
+ Connection(121, 47),
+ Connection(47, 128),
+ Connection(232, 121),
+ Connection(121, 128),
+ Connection(128, 232),
+ Connection(104, 69),
+ Connection(69, 67),
+ Connection(67, 104),
+ Connection(175, 171),
+ Connection(171, 148),
+ Connection(148, 175),
+ Connection(118, 50),
+ Connection(50, 101),
+ Connection(101, 118),
+ Connection(73, 39),
+ Connection(39, 40),
+ Connection(40, 73),
+ Connection(9, 151),
+ Connection(151, 108),
+ Connection(108, 9),
+ Connection(48, 115),
+ Connection(115, 131),
+ Connection(131, 48),
+ Connection(194, 204),
+ Connection(204, 211),
+ Connection(211, 194),
+ Connection(74, 40),
+ Connection(40, 185),
+ Connection(185, 74),
+ Connection(80, 42),
+ Connection(42, 183),
+ Connection(183, 80),
+ Connection(40, 92),
+ Connection(92, 186),
+ Connection(186, 40),
+ Connection(230, 229),
+ Connection(229, 118),
+ Connection(118, 230),
+ Connection(202, 212),
+ Connection(212, 214),
+ Connection(214, 202),
+ Connection(83, 18),
+ Connection(18, 17),
+ Connection(17, 83),
+ Connection(76, 61),
+ Connection(61, 146),
+ Connection(146, 76),
+ Connection(160, 29),
+ Connection(29, 30),
+ Connection(30, 160),
+ Connection(56, 157),
+ Connection(157, 173),
+ Connection(173, 56),
+ Connection(106, 204),
+ Connection(204, 194),
+ Connection(194, 106),
+ Connection(135, 214),
+ Connection(214, 192),
+ Connection(192, 135),
+ Connection(203, 165),
+ Connection(165, 98),
+ Connection(98, 203),
+ Connection(21, 71),
+ Connection(71, 68),
+ Connection(68, 21),
+ Connection(51, 45),
+ Connection(45, 4),
+ Connection(4, 51),
+ Connection(144, 24),
+ Connection(24, 23),
+ Connection(23, 144),
+ Connection(77, 146),
+ Connection(146, 91),
+ Connection(91, 77),
+ Connection(205, 50),
+ Connection(50, 187),
+ Connection(187, 205),
+ Connection(201, 200),
+ Connection(200, 18),
+ Connection(18, 201),
+ Connection(91, 106),
+ Connection(106, 182),
+ Connection(182, 91),
+ Connection(90, 91),
+ Connection(91, 181),
+ Connection(181, 90),
+ Connection(85, 84),
+ Connection(84, 17),
+ Connection(17, 85),
+ Connection(206, 203),
+ Connection(203, 36),
+ Connection(36, 206),
+ Connection(148, 171),
+ Connection(171, 140),
+ Connection(140, 148),
+ Connection(92, 40),
+ Connection(40, 39),
+ Connection(39, 92),
+ Connection(193, 189),
+ Connection(189, 244),
+ Connection(244, 193),
+ Connection(159, 158),
+ Connection(158, 28),
+ Connection(28, 159),
+ Connection(247, 246),
+ Connection(246, 161),
+ Connection(161, 247),
+ Connection(236, 3),
+ Connection(3, 196),
+ Connection(196, 236),
+ Connection(54, 68),
+ Connection(68, 104),
+ Connection(104, 54),
+ Connection(193, 168),
+ Connection(168, 8),
+ Connection(8, 193),
+ Connection(117, 228),
+ Connection(228, 31),
+ Connection(31, 117),
+ Connection(189, 193),
+ Connection(193, 55),
+ Connection(55, 189),
+ Connection(98, 97),
+ Connection(97, 99),
+ Connection(99, 98),
+ Connection(126, 47),
+ Connection(47, 100),
+ Connection(100, 126),
+ Connection(166, 79),
+ Connection(79, 218),
+ Connection(218, 166),
+ Connection(155, 154),
+ Connection(154, 26),
+ Connection(26, 155),
+ Connection(209, 49),
+ Connection(49, 131),
+ Connection(131, 209),
+ Connection(135, 136),
+ Connection(136, 150),
+ Connection(150, 135),
+ Connection(47, 126),
+ Connection(126, 217),
+ Connection(217, 47),
+ Connection(223, 52),
+ Connection(52, 53),
+ Connection(53, 223),
+ Connection(45, 51),
+ Connection(51, 134),
+ Connection(134, 45),
+ Connection(211, 170),
+ Connection(170, 140),
+ Connection(140, 211),
+ Connection(67, 69),
+ Connection(69, 108),
+ Connection(108, 67),
+ Connection(43, 106),
+ Connection(106, 91),
+ Connection(91, 43),
+ Connection(230, 119),
+ Connection(119, 120),
+ Connection(120, 230),
+ Connection(226, 130),
+ Connection(130, 247),
+ Connection(247, 226),
+ Connection(63, 53),
+ Connection(53, 52),
+ Connection(52, 63),
+ Connection(238, 20),
+ Connection(20, 242),
+ Connection(242, 238),
+ Connection(46, 70),
+ Connection(70, 156),
+ Connection(156, 46),
+ Connection(78, 62),
+ Connection(62, 96),
+ Connection(96, 78),
+ Connection(46, 53),
+ Connection(53, 63),
+ Connection(63, 46),
+ Connection(143, 34),
+ Connection(34, 227),
+ Connection(227, 143),
+ Connection(123, 117),
+ Connection(117, 111),
+ Connection(111, 123),
+ Connection(44, 125),
+ Connection(125, 19),
+ Connection(19, 44),
+ Connection(236, 134),
+ Connection(134, 51),
+ Connection(51, 236),
+ Connection(216, 206),
+ Connection(206, 205),
+ Connection(205, 216),
+ Connection(154, 153),
+ Connection(153, 22),
+ Connection(22, 154),
+ Connection(39, 37),
+ Connection(37, 167),
+ Connection(167, 39),
+ Connection(200, 201),
+ Connection(201, 208),
+ Connection(208, 200),
+ Connection(36, 142),
+ Connection(142, 100),
+ Connection(100, 36),
+ Connection(57, 212),
+ Connection(212, 202),
+ Connection(202, 57),
+ Connection(20, 60),
+ Connection(60, 99),
+ Connection(99, 20),
+ Connection(28, 158),
+ Connection(158, 157),
+ Connection(157, 28),
+ Connection(35, 226),
+ Connection(226, 113),
+ Connection(113, 35),
+ Connection(160, 159),
+ Connection(159, 27),
+ Connection(27, 160),
+ Connection(204, 202),
+ Connection(202, 210),
+ Connection(210, 204),
+ Connection(113, 225),
+ Connection(225, 46),
+ Connection(46, 113),
+ Connection(43, 202),
+ Connection(202, 204),
+ Connection(204, 43),
+ Connection(62, 76),
+ Connection(76, 77),
+ Connection(77, 62),
+ Connection(137, 123),
+ Connection(123, 116),
+ Connection(116, 137),
+ Connection(41, 38),
+ Connection(38, 72),
+ Connection(72, 41),
+ Connection(203, 129),
+ Connection(129, 142),
+ Connection(142, 203),
+ Connection(64, 98),
+ Connection(98, 240),
+ Connection(240, 64),
+ Connection(49, 102),
+ Connection(102, 64),
+ Connection(64, 49),
+ Connection(41, 73),
+ Connection(73, 74),
+ Connection(74, 41),
+ Connection(212, 216),
+ Connection(216, 207),
+ Connection(207, 212),
+ Connection(42, 74),
+ Connection(74, 184),
+ Connection(184, 42),
+ Connection(169, 170),
+ Connection(170, 211),
+ Connection(211, 169),
+ Connection(170, 149),
+ Connection(149, 176),
+ Connection(176, 170),
+ Connection(105, 66),
+ Connection(66, 69),
+ Connection(69, 105),
+ Connection(122, 6),
+ Connection(6, 168),
+ Connection(168, 122),
+ Connection(123, 147),
+ Connection(147, 187),
+ Connection(187, 123),
+ Connection(96, 77),
+ Connection(77, 90),
+ Connection(90, 96),
+ Connection(65, 55),
+ Connection(55, 107),
+ Connection(107, 65),
+ Connection(89, 90),
+ Connection(90, 180),
+ Connection(180, 89),
+ Connection(101, 100),
+ Connection(100, 120),
+ Connection(120, 101),
+ Connection(63, 105),
+ Connection(105, 104),
+ Connection(104, 63),
+ Connection(93, 137),
+ Connection(137, 227),
+ Connection(227, 93),
+ Connection(15, 86),
+ Connection(86, 85),
+ Connection(85, 15),
+ Connection(129, 102),
+ Connection(102, 49),
+ Connection(49, 129),
+ Connection(14, 87),
+ Connection(87, 86),
+ Connection(86, 14),
+ Connection(55, 8),
+ Connection(8, 9),
+ Connection(9, 55),
+ Connection(100, 47),
+ Connection(47, 121),
+ Connection(121, 100),
+ Connection(145, 23),
+ Connection(23, 22),
+ Connection(22, 145),
+ Connection(88, 89),
+ Connection(89, 179),
+ Connection(179, 88),
+ Connection(6, 122),
+ Connection(122, 196),
+ Connection(196, 6),
+ Connection(88, 95),
+ Connection(95, 96),
+ Connection(96, 88),
+ Connection(138, 172),
+ Connection(172, 136),
+ Connection(136, 138),
+ Connection(215, 58),
+ Connection(58, 172),
+ Connection(172, 215),
+ Connection(115, 48),
+ Connection(48, 219),
+ Connection(219, 115),
+ Connection(42, 80),
+ Connection(80, 81),
+ Connection(81, 42),
+ Connection(195, 3),
+ Connection(3, 51),
+ Connection(51, 195),
+ Connection(43, 146),
+ Connection(146, 61),
+ Connection(61, 43),
+ Connection(171, 175),
+ Connection(175, 199),
+ Connection(199, 171),
+ Connection(81, 82),
+ Connection(82, 38),
+ Connection(38, 81),
+ Connection(53, 46),
+ Connection(46, 225),
+ Connection(225, 53),
+ Connection(144, 163),
+ Connection(163, 110),
+ Connection(110, 144),
+ Connection(52, 65),
+ Connection(65, 66),
+ Connection(66, 52),
+ Connection(229, 228),
+ Connection(228, 117),
+ Connection(117, 229),
+ Connection(34, 127),
+ Connection(127, 234),
+ Connection(234, 34),
+ Connection(107, 108),
+ Connection(108, 69),
+ Connection(69, 107),
+ Connection(109, 108),
+ Connection(108, 151),
+ Connection(151, 109),
+ Connection(48, 64),
+ Connection(64, 235),
+ Connection(235, 48),
+ Connection(62, 78),
+ Connection(78, 191),
+ Connection(191, 62),
+ Connection(129, 209),
+ Connection(209, 126),
+ Connection(126, 129),
+ Connection(111, 35),
+ Connection(35, 143),
+ Connection(143, 111),
+ Connection(117, 123),
+ Connection(123, 50),
+ Connection(50, 117),
+ Connection(222, 65),
+ Connection(65, 52),
+ Connection(52, 222),
+ Connection(19, 125),
+ Connection(125, 141),
+ Connection(141, 19),
+ Connection(221, 55),
+ Connection(55, 65),
+ Connection(65, 221),
+ Connection(3, 195),
+ Connection(195, 197),
+ Connection(197, 3),
+ Connection(25, 7),
+ Connection(7, 33),
+ Connection(33, 25),
+ Connection(220, 237),
+ Connection(237, 44),
+ Connection(44, 220),
+ Connection(70, 71),
+ Connection(71, 139),
+ Connection(139, 70),
+ Connection(122, 193),
+ Connection(193, 245),
+ Connection(245, 122),
+ Connection(247, 130),
+ Connection(130, 33),
+ Connection(33, 247),
+ Connection(71, 21),
+ Connection(21, 162),
+ Connection(162, 71),
+ Connection(170, 169),
+ Connection(169, 150),
+ Connection(150, 170),
+ Connection(188, 174),
+ Connection(174, 196),
+ Connection(196, 188),
+ Connection(216, 186),
+ Connection(186, 92),
+ Connection(92, 216),
+ Connection(2, 97),
+ Connection(97, 167),
+ Connection(167, 2),
+ Connection(141, 125),
+ Connection(125, 241),
+ Connection(241, 141),
+ Connection(164, 167),
+ Connection(167, 37),
+ Connection(37, 164),
+ Connection(72, 38),
+ Connection(38, 12),
+ Connection(12, 72),
+ Connection(38, 82),
+ Connection(82, 13),
+ Connection(13, 38),
+ Connection(63, 68),
+ Connection(68, 71),
+ Connection(71, 63),
+ Connection(226, 35),
+ Connection(35, 111),
+ Connection(111, 226),
+ Connection(101, 50),
+ Connection(50, 205),
+ Connection(205, 101),
+ Connection(206, 92),
+ Connection(92, 165),
+ Connection(165, 206),
+ Connection(209, 198),
+ Connection(198, 217),
+ Connection(217, 209),
+ Connection(165, 167),
+ Connection(167, 97),
+ Connection(97, 165),
+ Connection(220, 115),
+ Connection(115, 218),
+ Connection(218, 220),
+ Connection(133, 112),
+ Connection(112, 243),
+ Connection(243, 133),
+ Connection(239, 238),
+ Connection(238, 241),
+ Connection(241, 239),
+ Connection(214, 135),
+ Connection(135, 169),
+ Connection(169, 214),
+ Connection(190, 173),
+ Connection(173, 133),
+ Connection(133, 190),
+ Connection(171, 208),
+ Connection(208, 32),
+ Connection(32, 171),
+ Connection(125, 44),
+ Connection(44, 237),
+ Connection(237, 125),
+ Connection(86, 87),
+ Connection(87, 178),
+ Connection(178, 86),
+ Connection(85, 86),
+ Connection(86, 179),
+ Connection(179, 85),
+ Connection(84, 85),
+ Connection(85, 180),
+ Connection(180, 84),
+ Connection(83, 84),
+ Connection(84, 181),
+ Connection(181, 83),
+ Connection(201, 83),
+ Connection(83, 182),
+ Connection(182, 201),
+ Connection(137, 93),
+ Connection(93, 132),
+ Connection(132, 137),
+ Connection(76, 62),
+ Connection(62, 183),
+ Connection(183, 76),
+ Connection(61, 76),
+ Connection(76, 184),
+ Connection(184, 61),
+ Connection(57, 61),
+ Connection(61, 185),
+ Connection(185, 57),
+ Connection(212, 57),
+ Connection(57, 186),
+ Connection(186, 212),
+ Connection(214, 207),
+ Connection(207, 187),
+ Connection(187, 214),
+ Connection(34, 143),
+ Connection(143, 156),
+ Connection(156, 34),
+ Connection(79, 239),
+ Connection(239, 237),
+ Connection(237, 79),
+ Connection(123, 137),
+ Connection(137, 177),
+ Connection(177, 123),
+ Connection(44, 1),
+ Connection(1, 4),
+ Connection(4, 44),
+ Connection(201, 194),
+ Connection(194, 32),
+ Connection(32, 201),
+ Connection(64, 102),
+ Connection(102, 129),
+ Connection(129, 64),
+ Connection(213, 215),
+ Connection(215, 138),
+ Connection(138, 213),
+ Connection(59, 166),
+ Connection(166, 219),
+ Connection(219, 59),
+ Connection(242, 99),
+ Connection(99, 97),
+ Connection(97, 242),
+ Connection(2, 94),
+ Connection(94, 141),
+ Connection(141, 2),
+ Connection(75, 59),
+ Connection(59, 235),
+ Connection(235, 75),
+ Connection(24, 110),
+ Connection(110, 228),
+ Connection(228, 24),
+ Connection(25, 130),
+ Connection(130, 226),
+ Connection(226, 25),
+ Connection(23, 24),
+ Connection(24, 229),
+ Connection(229, 23),
+ Connection(22, 23),
+ Connection(23, 230),
+ Connection(230, 22),
+ Connection(26, 22),
+ Connection(22, 231),
+ Connection(231, 26),
+ Connection(112, 26),
+ Connection(26, 232),
+ Connection(232, 112),
+ Connection(189, 190),
+ Connection(190, 243),
+ Connection(243, 189),
+ Connection(221, 56),
+ Connection(56, 190),
+ Connection(190, 221),
+ Connection(28, 56),
+ Connection(56, 221),
+ Connection(221, 28),
+ Connection(27, 28),
+ Connection(28, 222),
+ Connection(222, 27),
+ Connection(29, 27),
+ Connection(27, 223),
+ Connection(223, 29),
+ Connection(30, 29),
+ Connection(29, 224),
+ Connection(224, 30),
+ Connection(247, 30),
+ Connection(30, 225),
+ Connection(225, 247),
+ Connection(238, 79),
+ Connection(79, 20),
+ Connection(20, 238),
+ Connection(166, 59),
+ Connection(59, 75),
+ Connection(75, 166),
+ Connection(60, 75),
+ Connection(75, 240),
+ Connection(240, 60),
+ Connection(147, 177),
+ Connection(177, 215),
+ Connection(215, 147),
+ Connection(20, 79),
+ Connection(79, 166),
+ Connection(166, 20),
+ Connection(187, 147),
+ Connection(147, 213),
+ Connection(213, 187),
+ Connection(112, 233),
+ Connection(233, 244),
+ Connection(244, 112),
+ Connection(233, 128),
+ Connection(128, 245),
+ Connection(245, 233),
+ Connection(128, 114),
+ Connection(114, 188),
+ Connection(188, 128),
+ Connection(114, 217),
+ Connection(217, 174),
+ Connection(174, 114),
+ Connection(131, 115),
+ Connection(115, 220),
+ Connection(220, 131),
+ Connection(217, 198),
+ Connection(198, 236),
+ Connection(236, 217),
+ Connection(198, 131),
+ Connection(131, 134),
+ Connection(134, 198),
+ Connection(177, 132),
+ Connection(132, 58),
+ Connection(58, 177),
+ Connection(143, 35),
+ Connection(35, 124),
+ Connection(124, 143),
+ Connection(110, 163),
+ Connection(163, 7),
+ Connection(7, 110),
+ Connection(228, 110),
+ Connection(110, 25),
+ Connection(25, 228),
+ Connection(356, 389),
+ Connection(389, 368),
+ Connection(368, 356),
+ Connection(11, 302),
+ Connection(302, 267),
+ Connection(267, 11),
+ Connection(452, 350),
+ Connection(350, 349),
+ Connection(349, 452),
+ Connection(302, 303),
+ Connection(303, 269),
+ Connection(269, 302),
+ Connection(357, 343),
+ Connection(343, 277),
+ Connection(277, 357),
+ Connection(452, 453),
+ Connection(453, 357),
+ Connection(357, 452),
+ Connection(333, 332),
+ Connection(332, 297),
+ Connection(297, 333),
+ Connection(175, 152),
+ Connection(152, 377),
+ Connection(377, 175),
+ Connection(347, 348),
+ Connection(348, 330),
+ Connection(330, 347),
+ Connection(303, 304),
+ Connection(304, 270),
+ Connection(270, 303),
+ Connection(9, 336),
+ Connection(336, 337),
+ Connection(337, 9),
+ Connection(278, 279),
+ Connection(279, 360),
+ Connection(360, 278),
+ Connection(418, 262),
+ Connection(262, 431),
+ Connection(431, 418),
+ Connection(304, 408),
+ Connection(408, 409),
+ Connection(409, 304),
+ Connection(310, 415),
+ Connection(415, 407),
+ Connection(407, 310),
+ Connection(270, 409),
+ Connection(409, 410),
+ Connection(410, 270),
+ Connection(450, 348),
+ Connection(348, 347),
+ Connection(347, 450),
+ Connection(422, 430),
+ Connection(430, 434),
+ Connection(434, 422),
+ Connection(313, 314),
+ Connection(314, 17),
+ Connection(17, 313),
+ Connection(306, 307),
+ Connection(307, 375),
+ Connection(375, 306),
+ Connection(387, 388),
+ Connection(388, 260),
+ Connection(260, 387),
+ Connection(286, 414),
+ Connection(414, 398),
+ Connection(398, 286),
+ Connection(335, 406),
+ Connection(406, 418),
+ Connection(418, 335),
+ Connection(364, 367),
+ Connection(367, 416),
+ Connection(416, 364),
+ Connection(423, 358),
+ Connection(358, 327),
+ Connection(327, 423),
+ Connection(251, 284),
+ Connection(284, 298),
+ Connection(298, 251),
+ Connection(281, 5),
+ Connection(5, 4),
+ Connection(4, 281),
+ Connection(373, 374),
+ Connection(374, 253),
+ Connection(253, 373),
+ Connection(307, 320),
+ Connection(320, 321),
+ Connection(321, 307),
+ Connection(425, 427),
+ Connection(427, 411),
+ Connection(411, 425),
+ Connection(421, 313),
+ Connection(313, 18),
+ Connection(18, 421),
+ Connection(321, 405),
+ Connection(405, 406),
+ Connection(406, 321),
+ Connection(320, 404),
+ Connection(404, 405),
+ Connection(405, 320),
+ Connection(315, 16),
+ Connection(16, 17),
+ Connection(17, 315),
+ Connection(426, 425),
+ Connection(425, 266),
+ Connection(266, 426),
+ Connection(377, 400),
+ Connection(400, 369),
+ Connection(369, 377),
+ Connection(322, 391),
+ Connection(391, 269),
+ Connection(269, 322),
+ Connection(417, 465),
+ Connection(465, 464),
+ Connection(464, 417),
+ Connection(386, 257),
+ Connection(257, 258),
+ Connection(258, 386),
+ Connection(466, 260),
+ Connection(260, 388),
+ Connection(388, 466),
+ Connection(456, 399),
+ Connection(399, 419),
+ Connection(419, 456),
+ Connection(284, 332),
+ Connection(332, 333),
+ Connection(333, 284),
+ Connection(417, 285),
+ Connection(285, 8),
+ Connection(8, 417),
+ Connection(346, 340),
+ Connection(340, 261),
+ Connection(261, 346),
+ Connection(413, 441),
+ Connection(441, 285),
+ Connection(285, 413),
+ Connection(327, 460),
+ Connection(460, 328),
+ Connection(328, 327),
+ Connection(355, 371),
+ Connection(371, 329),
+ Connection(329, 355),
+ Connection(392, 439),
+ Connection(439, 438),
+ Connection(438, 392),
+ Connection(382, 341),
+ Connection(341, 256),
+ Connection(256, 382),
+ Connection(429, 420),
+ Connection(420, 360),
+ Connection(360, 429),
+ Connection(364, 394),
+ Connection(394, 379),
+ Connection(379, 364),
+ Connection(277, 343),
+ Connection(343, 437),
+ Connection(437, 277),
+ Connection(443, 444),
+ Connection(444, 283),
+ Connection(283, 443),
+ Connection(275, 440),
+ Connection(440, 363),
+ Connection(363, 275),
+ Connection(431, 262),
+ Connection(262, 369),
+ Connection(369, 431),
+ Connection(297, 338),
+ Connection(338, 337),
+ Connection(337, 297),
+ Connection(273, 375),
+ Connection(375, 321),
+ Connection(321, 273),
+ Connection(450, 451),
+ Connection(451, 349),
+ Connection(349, 450),
+ Connection(446, 342),
+ Connection(342, 467),
+ Connection(467, 446),
+ Connection(293, 334),
+ Connection(334, 282),
+ Connection(282, 293),
+ Connection(458, 461),
+ Connection(461, 462),
+ Connection(462, 458),
+ Connection(276, 353),
+ Connection(353, 383),
+ Connection(383, 276),
+ Connection(308, 324),
+ Connection(324, 325),
+ Connection(325, 308),
+ Connection(276, 300),
+ Connection(300, 293),
+ Connection(293, 276),
+ Connection(372, 345),
+ Connection(345, 447),
+ Connection(447, 372),
+ Connection(352, 345),
+ Connection(345, 340),
+ Connection(340, 352),
+ Connection(274, 1),
+ Connection(1, 19),
+ Connection(19, 274),
+ Connection(456, 248),
+ Connection(248, 281),
+ Connection(281, 456),
+ Connection(436, 427),
+ Connection(427, 425),
+ Connection(425, 436),
+ Connection(381, 256),
+ Connection(256, 252),
+ Connection(252, 381),
+ Connection(269, 391),
+ Connection(391, 393),
+ Connection(393, 269),
+ Connection(200, 199),
+ Connection(199, 428),
+ Connection(428, 200),
+ Connection(266, 330),
+ Connection(330, 329),
+ Connection(329, 266),
+ Connection(287, 273),
+ Connection(273, 422),
+ Connection(422, 287),
+ Connection(250, 462),
+ Connection(462, 328),
+ Connection(328, 250),
+ Connection(258, 286),
+ Connection(286, 384),
+ Connection(384, 258),
+ Connection(265, 353),
+ Connection(353, 342),
+ Connection(342, 265),
+ Connection(387, 259),
+ Connection(259, 257),
+ Connection(257, 387),
+ Connection(424, 431),
+ Connection(431, 430),
+ Connection(430, 424),
+ Connection(342, 353),
+ Connection(353, 276),
+ Connection(276, 342),
+ Connection(273, 335),
+ Connection(335, 424),
+ Connection(424, 273),
+ Connection(292, 325),
+ Connection(325, 307),
+ Connection(307, 292),
+ Connection(366, 447),
+ Connection(447, 345),
+ Connection(345, 366),
+ Connection(271, 303),
+ Connection(303, 302),
+ Connection(302, 271),
+ Connection(423, 266),
+ Connection(266, 371),
+ Connection(371, 423),
+ Connection(294, 455),
+ Connection(455, 460),
+ Connection(460, 294),
+ Connection(279, 278),
+ Connection(278, 294),
+ Connection(294, 279),
+ Connection(271, 272),
+ Connection(272, 304),
+ Connection(304, 271),
+ Connection(432, 434),
+ Connection(434, 427),
+ Connection(427, 432),
+ Connection(272, 407),
+ Connection(407, 408),
+ Connection(408, 272),
+ Connection(394, 430),
+ Connection(430, 431),
+ Connection(431, 394),
+ Connection(395, 369),
+ Connection(369, 400),
+ Connection(400, 395),
+ Connection(334, 333),
+ Connection(333, 299),
+ Connection(299, 334),
+ Connection(351, 417),
+ Connection(417, 168),
+ Connection(168, 351),
+ Connection(352, 280),
+ Connection(280, 411),
+ Connection(411, 352),
+ Connection(325, 319),
+ Connection(319, 320),
+ Connection(320, 325),
+ Connection(295, 296),
+ Connection(296, 336),
+ Connection(336, 295),
+ Connection(319, 403),
+ Connection(403, 404),
+ Connection(404, 319),
+ Connection(330, 348),
+ Connection(348, 349),
+ Connection(349, 330),
+ Connection(293, 298),
+ Connection(298, 333),
+ Connection(333, 293),
+ Connection(323, 454),
+ Connection(454, 447),
+ Connection(447, 323),
+ Connection(15, 16),
+ Connection(16, 315),
+ Connection(315, 15),
+ Connection(358, 429),
+ Connection(429, 279),
+ Connection(279, 358),
+ Connection(14, 15),
+ Connection(15, 316),
+ Connection(316, 14),
+ Connection(285, 336),
+ Connection(336, 9),
+ Connection(9, 285),
+ Connection(329, 349),
+ Connection(349, 350),
+ Connection(350, 329),
+ Connection(374, 380),
+ Connection(380, 252),
+ Connection(252, 374),
+ Connection(318, 402),
+ Connection(402, 403),
+ Connection(403, 318),
+ Connection(6, 197),
+ Connection(197, 419),
+ Connection(419, 6),
+ Connection(318, 319),
+ Connection(319, 325),
+ Connection(325, 318),
+ Connection(367, 364),
+ Connection(364, 365),
+ Connection(365, 367),
+ Connection(435, 367),
+ Connection(367, 397),
+ Connection(397, 435),
+ Connection(344, 438),
+ Connection(438, 439),
+ Connection(439, 344),
+ Connection(272, 271),
+ Connection(271, 311),
+ Connection(311, 272),
+ Connection(195, 5),
+ Connection(5, 281),
+ Connection(281, 195),
+ Connection(273, 287),
+ Connection(287, 291),
+ Connection(291, 273),
+ Connection(396, 428),
+ Connection(428, 199),
+ Connection(199, 396),
+ Connection(311, 271),
+ Connection(271, 268),
+ Connection(268, 311),
+ Connection(283, 444),
+ Connection(444, 445),
+ Connection(445, 283),
+ Connection(373, 254),
+ Connection(254, 339),
+ Connection(339, 373),
+ Connection(282, 334),
+ Connection(334, 296),
+ Connection(296, 282),
+ Connection(449, 347),
+ Connection(347, 346),
+ Connection(346, 449),
+ Connection(264, 447),
+ Connection(447, 454),
+ Connection(454, 264),
+ Connection(336, 296),
+ Connection(296, 299),
+ Connection(299, 336),
+ Connection(338, 10),
+ Connection(10, 151),
+ Connection(151, 338),
+ Connection(278, 439),
+ Connection(439, 455),
+ Connection(455, 278),
+ Connection(292, 407),
+ Connection(407, 415),
+ Connection(415, 292),
+ Connection(358, 371),
+ Connection(371, 355),
+ Connection(355, 358),
+ Connection(340, 345),
+ Connection(345, 372),
+ Connection(372, 340),
+ Connection(346, 347),
+ Connection(347, 280),
+ Connection(280, 346),
+ Connection(442, 443),
+ Connection(443, 282),
+ Connection(282, 442),
+ Connection(19, 94),
+ Connection(94, 370),
+ Connection(370, 19),
+ Connection(441, 442),
+ Connection(442, 295),
+ Connection(295, 441),
+ Connection(248, 419),
+ Connection(419, 197),
+ Connection(197, 248),
+ Connection(263, 255),
+ Connection(255, 359),
+ Connection(359, 263),
+ Connection(440, 275),
+ Connection(275, 274),
+ Connection(274, 440),
+ Connection(300, 383),
+ Connection(383, 368),
+ Connection(368, 300),
+ Connection(351, 412),
+ Connection(412, 465),
+ Connection(465, 351),
+ Connection(263, 467),
+ Connection(467, 466),
+ Connection(466, 263),
+ Connection(301, 368),
+ Connection(368, 389),
+ Connection(389, 301),
+ Connection(395, 378),
+ Connection(378, 379),
+ Connection(379, 395),
+ Connection(412, 351),
+ Connection(351, 419),
+ Connection(419, 412),
+ Connection(436, 426),
+ Connection(426, 322),
+ Connection(322, 436),
+ Connection(2, 164),
+ Connection(164, 393),
+ Connection(393, 2),
+ Connection(370, 462),
+ Connection(462, 461),
+ Connection(461, 370),
+ Connection(164, 0),
+ Connection(0, 267),
+ Connection(267, 164),
+ Connection(302, 11),
+ Connection(11, 12),
+ Connection(12, 302),
+ Connection(268, 12),
+ Connection(12, 13),
+ Connection(13, 268),
+ Connection(293, 300),
+ Connection(300, 301),
+ Connection(301, 293),
+ Connection(446, 261),
+ Connection(261, 340),
+ Connection(340, 446),
+ Connection(330, 266),
+ Connection(266, 425),
+ Connection(425, 330),
+ Connection(426, 423),
+ Connection(423, 391),
+ Connection(391, 426),
+ Connection(429, 355),
+ Connection(355, 437),
+ Connection(437, 429),
+ Connection(391, 327),
+ Connection(327, 326),
+ Connection(326, 391),
+ Connection(440, 457),
+ Connection(457, 438),
+ Connection(438, 440),
+ Connection(341, 382),
+ Connection(382, 362),
+ Connection(362, 341),
+ Connection(459, 457),
+ Connection(457, 461),
+ Connection(461, 459),
+ Connection(434, 430),
+ Connection(430, 394),
+ Connection(394, 434),
+ Connection(414, 463),
+ Connection(463, 362),
+ Connection(362, 414),
+ Connection(396, 369),
+ Connection(369, 262),
+ Connection(262, 396),
+ Connection(354, 461),
+ Connection(461, 457),
+ Connection(457, 354),
+ Connection(316, 403),
+ Connection(403, 402),
+ Connection(402, 316),
+ Connection(315, 404),
+ Connection(404, 403),
+ Connection(403, 315),
+ Connection(314, 405),
+ Connection(405, 404),
+ Connection(404, 314),
+ Connection(313, 406),
+ Connection(406, 405),
+ Connection(405, 313),
+ Connection(421, 418),
+ Connection(418, 406),
+ Connection(406, 421),
+ Connection(366, 401),
+ Connection(401, 361),
+ Connection(361, 366),
+ Connection(306, 408),
+ Connection(408, 407),
+ Connection(407, 306),
+ Connection(291, 409),
+ Connection(409, 408),
+ Connection(408, 291),
+ Connection(287, 410),
+ Connection(410, 409),
+ Connection(409, 287),
+ Connection(432, 436),
+ Connection(436, 410),
+ Connection(410, 432),
+ Connection(434, 416),
+ Connection(416, 411),
+ Connection(411, 434),
+ Connection(264, 368),
+ Connection(368, 383),
+ Connection(383, 264),
+ Connection(309, 438),
+ Connection(438, 457),
+ Connection(457, 309),
+ Connection(352, 376),
+ Connection(376, 401),
+ Connection(401, 352),
+ Connection(274, 275),
+ Connection(275, 4),
+ Connection(4, 274),
+ Connection(421, 428),
+ Connection(428, 262),
+ Connection(262, 421),
+ Connection(294, 327),
+ Connection(327, 358),
+ Connection(358, 294),
+ Connection(433, 416),
+ Connection(416, 367),
+ Connection(367, 433),
+ Connection(289, 455),
+ Connection(455, 439),
+ Connection(439, 289),
+ Connection(462, 370),
+ Connection(370, 326),
+ Connection(326, 462),
+ Connection(2, 326),
+ Connection(326, 370),
+ Connection(370, 2),
+ Connection(305, 460),
+ Connection(460, 455),
+ Connection(455, 305),
+ Connection(254, 449),
+ Connection(449, 448),
+ Connection(448, 254),
+ Connection(255, 261),
+ Connection(261, 446),
+ Connection(446, 255),
+ Connection(253, 450),
+ Connection(450, 449),
+ Connection(449, 253),
+ Connection(252, 451),
+ Connection(451, 450),
+ Connection(450, 252),
+ Connection(256, 452),
+ Connection(452, 451),
+ Connection(451, 256),
+ Connection(341, 453),
+ Connection(453, 452),
+ Connection(452, 341),
+ Connection(413, 464),
+ Connection(464, 463),
+ Connection(463, 413),
+ Connection(441, 413),
+ Connection(413, 414),
+ Connection(414, 441),
+ Connection(258, 442),
+ Connection(442, 441),
+ Connection(441, 258),
+ Connection(257, 443),
+ Connection(443, 442),
+ Connection(442, 257),
+ Connection(259, 444),
+ Connection(444, 443),
+ Connection(443, 259),
+ Connection(260, 445),
+ Connection(445, 444),
+ Connection(444, 260),
+ Connection(467, 342),
+ Connection(342, 445),
+ Connection(445, 467),
+ Connection(459, 458),
+ Connection(458, 250),
+ Connection(250, 459),
+ Connection(289, 392),
+ Connection(392, 290),
+ Connection(290, 289),
+ Connection(290, 328),
+ Connection(328, 460),
+ Connection(460, 290),
+ Connection(376, 433),
+ Connection(433, 435),
+ Connection(435, 376),
+ Connection(250, 290),
+ Connection(290, 392),
+ Connection(392, 250),
+ Connection(411, 416),
+ Connection(416, 433),
+ Connection(433, 411),
+ Connection(341, 463),
+ Connection(463, 464),
+ Connection(464, 341),
+ Connection(453, 464),
+ Connection(464, 465),
+ Connection(465, 453),
+ Connection(357, 465),
+ Connection(465, 412),
+ Connection(412, 357),
+ Connection(343, 412),
+ Connection(412, 399),
+ Connection(399, 343),
+ Connection(360, 363),
+ Connection(363, 440),
+ Connection(440, 360),
+ Connection(437, 399),
+ Connection(399, 456),
+ Connection(456, 437),
+ Connection(420, 456),
+ Connection(456, 363),
+ Connection(363, 420),
+ Connection(401, 435),
+ Connection(435, 288),
+ Connection(288, 401),
+ Connection(372, 383),
+ Connection(383, 353),
+ Connection(353, 372),
+ Connection(339, 255),
+ Connection(255, 249),
+ Connection(249, 339),
+ Connection(448, 261),
+ Connection(261, 255),
+ Connection(255, 448),
+ Connection(133, 243),
+ Connection(243, 190),
+ Connection(190, 133),
+ Connection(133, 155),
+ Connection(155, 112),
+ Connection(112, 133),
+ Connection(33, 246),
+ Connection(246, 247),
+ Connection(247, 33),
+ Connection(33, 130),
+ Connection(130, 25),
+ Connection(25, 33),
+ Connection(398, 384),
+ Connection(384, 286),
+ Connection(286, 398),
+ Connection(362, 398),
+ Connection(398, 414),
+ Connection(414, 362),
+ Connection(362, 463),
+ Connection(463, 341),
+ Connection(341, 362),
+ Connection(263, 359),
+ Connection(359, 467),
+ Connection(467, 263),
+ Connection(263, 249),
+ Connection(249, 255),
+ Connection(255, 263),
+ Connection(466, 467),
+ Connection(467, 260),
+ Connection(260, 466),
+ Connection(75, 60),
+ Connection(60, 166),
+ Connection(166, 75),
+ Connection(238, 239),
+ Connection(239, 79),
+ Connection(79, 238),
+ Connection(162, 127),
+ Connection(127, 139),
+ Connection(139, 162),
+ Connection(72, 11),
+ Connection(11, 37),
+ Connection(37, 72),
+ Connection(121, 232),
+ Connection(232, 120),
+ Connection(120, 121),
+ Connection(73, 72),
+ Connection(72, 39),
+ Connection(39, 73),
+ Connection(114, 128),
+ Connection(128, 47),
+ Connection(47, 114),
+ Connection(233, 232),
+ Connection(232, 128),
+ Connection(128, 233),
+ Connection(103, 104),
+ Connection(104, 67),
+ Connection(67, 103),
+ Connection(152, 175),
+ Connection(175, 148),
+ Connection(148, 152),
+ Connection(119, 118),
+ Connection(118, 101),
+ Connection(101, 119),
+ Connection(74, 73),
+ Connection(73, 40),
+ Connection(40, 74),
+ Connection(107, 9),
+ Connection(9, 108),
+ Connection(108, 107),
+ Connection(49, 48),
+ Connection(48, 131),
+ Connection(131, 49),
+ Connection(32, 194),
+ Connection(194, 211),
+ Connection(211, 32),
+ Connection(184, 74),
+ Connection(74, 185),
+ Connection(185, 184),
+ Connection(191, 80),
+ Connection(80, 183),
+ Connection(183, 191),
+ Connection(185, 40),
+ Connection(40, 186),
+ Connection(186, 185),
+ Connection(119, 230),
+ Connection(230, 118),
+ Connection(118, 119),
+ Connection(210, 202),
+ Connection(202, 214),
+ Connection(214, 210),
+ Connection(84, 83),
+ Connection(83, 17),
+ Connection(17, 84),
+ Connection(77, 76),
+ Connection(76, 146),
+ Connection(146, 77),
+ Connection(161, 160),
+ Connection(160, 30),
+ Connection(30, 161),
+ Connection(190, 56),
+ Connection(56, 173),
+ Connection(173, 190),
+ Connection(182, 106),
+ Connection(106, 194),
+ Connection(194, 182),
+ Connection(138, 135),
+ Connection(135, 192),
+ Connection(192, 138),
+ Connection(129, 203),
+ Connection(203, 98),
+ Connection(98, 129),
+ Connection(54, 21),
+ Connection(21, 68),
+ Connection(68, 54),
+ Connection(5, 51),
+ Connection(51, 4),
+ Connection(4, 5),
+ Connection(145, 144),
+ Connection(144, 23),
+ Connection(23, 145),
+ Connection(90, 77),
+ Connection(77, 91),
+ Connection(91, 90),
+ Connection(207, 205),
+ Connection(205, 187),
+ Connection(187, 207),
+ Connection(83, 201),
+ Connection(201, 18),
+ Connection(18, 83),
+ Connection(181, 91),
+ Connection(91, 182),
+ Connection(182, 181),
+ Connection(180, 90),
+ Connection(90, 181),
+ Connection(181, 180),
+ Connection(16, 85),
+ Connection(85, 17),
+ Connection(17, 16),
+ Connection(205, 206),
+ Connection(206, 36),
+ Connection(36, 205),
+ Connection(176, 148),
+ Connection(148, 140),
+ Connection(140, 176),
+ Connection(165, 92),
+ Connection(92, 39),
+ Connection(39, 165),
+ Connection(245, 193),
+ Connection(193, 244),
+ Connection(244, 245),
+ Connection(27, 159),
+ Connection(159, 28),
+ Connection(28, 27),
+ Connection(30, 247),
+ Connection(247, 161),
+ Connection(161, 30),
+ Connection(174, 236),
+ Connection(236, 196),
+ Connection(196, 174),
+ Connection(103, 54),
+ Connection(54, 104),
+ Connection(104, 103),
+ Connection(55, 193),
+ Connection(193, 8),
+ Connection(8, 55),
+ Connection(111, 117),
+ Connection(117, 31),
+ Connection(31, 111),
+ Connection(221, 189),
+ Connection(189, 55),
+ Connection(55, 221),
+ Connection(240, 98),
+ Connection(98, 99),
+ Connection(99, 240),
+ Connection(142, 126),
+ Connection(126, 100),
+ Connection(100, 142),
+ Connection(219, 166),
+ Connection(166, 218),
+ Connection(218, 219),
+ Connection(112, 155),
+ Connection(155, 26),
+ Connection(26, 112),
+ Connection(198, 209),
+ Connection(209, 131),
+ Connection(131, 198),
+ Connection(169, 135),
+ Connection(135, 150),
+ Connection(150, 169),
+ Connection(114, 47),
+ Connection(47, 217),
+ Connection(217, 114),
+ Connection(224, 223),
+ Connection(223, 53),
+ Connection(53, 224),
+ Connection(220, 45),
+ Connection(45, 134),
+ Connection(134, 220),
+ Connection(32, 211),
+ Connection(211, 140),
+ Connection(140, 32),
+ Connection(109, 67),
+ Connection(67, 108),
+ Connection(108, 109),
+ Connection(146, 43),
+ Connection(43, 91),
+ Connection(91, 146),
+ Connection(231, 230),
+ Connection(230, 120),
+ Connection(120, 231),
+ Connection(113, 226),
+ Connection(226, 247),
+ Connection(247, 113),
+ Connection(105, 63),
+ Connection(63, 52),
+ Connection(52, 105),
+ Connection(241, 238),
+ Connection(238, 242),
+ Connection(242, 241),
+ Connection(124, 46),
+ Connection(46, 156),
+ Connection(156, 124),
+ Connection(95, 78),
+ Connection(78, 96),
+ Connection(96, 95),
+ Connection(70, 46),
+ Connection(46, 63),
+ Connection(63, 70),
+ Connection(116, 143),
+ Connection(143, 227),
+ Connection(227, 116),
+ Connection(116, 123),
+ Connection(123, 111),
+ Connection(111, 116),
+ Connection(1, 44),
+ Connection(44, 19),
+ Connection(19, 1),
+ Connection(3, 236),
+ Connection(236, 51),
+ Connection(51, 3),
+ Connection(207, 216),
+ Connection(216, 205),
+ Connection(205, 207),
+ Connection(26, 154),
+ Connection(154, 22),
+ Connection(22, 26),
+ Connection(165, 39),
+ Connection(39, 167),
+ Connection(167, 165),
+ Connection(199, 200),
+ Connection(200, 208),
+ Connection(208, 199),
+ Connection(101, 36),
+ Connection(36, 100),
+ Connection(100, 101),
+ Connection(43, 57),
+ Connection(57, 202),
+ Connection(202, 43),
+ Connection(242, 20),
+ Connection(20, 99),
+ Connection(99, 242),
+ Connection(56, 28),
+ Connection(28, 157),
+ Connection(157, 56),
+ Connection(124, 35),
+ Connection(35, 113),
+ Connection(113, 124),
+ Connection(29, 160),
+ Connection(160, 27),
+ Connection(27, 29),
+ Connection(211, 204),
+ Connection(204, 210),
+ Connection(210, 211),
+ Connection(124, 113),
+ Connection(113, 46),
+ Connection(46, 124),
+ Connection(106, 43),
+ Connection(43, 204),
+ Connection(204, 106),
+ Connection(96, 62),
+ Connection(62, 77),
+ Connection(77, 96),
+ Connection(227, 137),
+ Connection(137, 116),
+ Connection(116, 227),
+ Connection(73, 41),
+ Connection(41, 72),
+ Connection(72, 73),
+ Connection(36, 203),
+ Connection(203, 142),
+ Connection(142, 36),
+ Connection(235, 64),
+ Connection(64, 240),
+ Connection(240, 235),
+ Connection(48, 49),
+ Connection(49, 64),
+ Connection(64, 48),
+ Connection(42, 41),
+ Connection(41, 74),
+ Connection(74, 42),
+ Connection(214, 212),
+ Connection(212, 207),
+ Connection(207, 214),
+ Connection(183, 42),
+ Connection(42, 184),
+ Connection(184, 183),
+ Connection(210, 169),
+ Connection(169, 211),
+ Connection(211, 210),
+ Connection(140, 170),
+ Connection(170, 176),
+ Connection(176, 140),
+ Connection(104, 105),
+ Connection(105, 69),
+ Connection(69, 104),
+ Connection(193, 122),
+ Connection(122, 168),
+ Connection(168, 193),
+ Connection(50, 123),
+ Connection(123, 187),
+ Connection(187, 50),
+ Connection(89, 96),
+ Connection(96, 90),
+ Connection(90, 89),
+ Connection(66, 65),
+ Connection(65, 107),
+ Connection(107, 66),
+ Connection(179, 89),
+ Connection(89, 180),
+ Connection(180, 179),
+ Connection(119, 101),
+ Connection(101, 120),
+ Connection(120, 119),
+ Connection(68, 63),
+ Connection(63, 104),
+ Connection(104, 68),
+ Connection(234, 93),
+ Connection(93, 227),
+ Connection(227, 234),
+ Connection(16, 15),
+ Connection(15, 85),
+ Connection(85, 16),
+ Connection(209, 129),
+ Connection(129, 49),
+ Connection(49, 209),
+ Connection(15, 14),
+ Connection(14, 86),
+ Connection(86, 15),
+ Connection(107, 55),
+ Connection(55, 9),
+ Connection(9, 107),
+ Connection(120, 100),
+ Connection(100, 121),
+ Connection(121, 120),
+ Connection(153, 145),
+ Connection(145, 22),
+ Connection(22, 153),
+ Connection(178, 88),
+ Connection(88, 179),
+ Connection(179, 178),
+ Connection(197, 6),
+ Connection(6, 196),
+ Connection(196, 197),
+ Connection(89, 88),
+ Connection(88, 96),
+ Connection(96, 89),
+ Connection(135, 138),
+ Connection(138, 136),
+ Connection(136, 135),
+ Connection(138, 215),
+ Connection(215, 172),
+ Connection(172, 138),
+ Connection(218, 115),
+ Connection(115, 219),
+ Connection(219, 218),
+ Connection(41, 42),
+ Connection(42, 81),
+ Connection(81, 41),
+ Connection(5, 195),
+ Connection(195, 51),
+ Connection(51, 5),
+ Connection(57, 43),
+ Connection(43, 61),
+ Connection(61, 57),
+ Connection(208, 171),
+ Connection(171, 199),
+ Connection(199, 208),
+ Connection(41, 81),
+ Connection(81, 38),
+ Connection(38, 41),
+ Connection(224, 53),
+ Connection(53, 225),
+ Connection(225, 224),
+ Connection(24, 144),
+ Connection(144, 110),
+ Connection(110, 24),
+ Connection(105, 52),
+ Connection(52, 66),
+ Connection(66, 105),
+ Connection(118, 229),
+ Connection(229, 117),
+ Connection(117, 118),
+ Connection(227, 34),
+ Connection(34, 234),
+ Connection(234, 227),
+ Connection(66, 107),
+ Connection(107, 69),
+ Connection(69, 66),
+ Connection(10, 109),
+ Connection(109, 151),
+ Connection(151, 10),
+ Connection(219, 48),
+ Connection(48, 235),
+ Connection(235, 219),
+ Connection(183, 62),
+ Connection(62, 191),
+ Connection(191, 183),
+ Connection(142, 129),
+ Connection(129, 126),
+ Connection(126, 142),
+ Connection(116, 111),
+ Connection(111, 143),
+ Connection(143, 116),
+ Connection(118, 117),
+ Connection(117, 50),
+ Connection(50, 118),
+ Connection(223, 222),
+ Connection(222, 52),
+ Connection(52, 223),
+ Connection(94, 19),
+ Connection(19, 141),
+ Connection(141, 94),
+ Connection(222, 221),
+ Connection(221, 65),
+ Connection(65, 222),
+ Connection(196, 3),
+ Connection(3, 197),
+ Connection(197, 196),
+ Connection(45, 220),
+ Connection(220, 44),
+ Connection(44, 45),
+ Connection(156, 70),
+ Connection(70, 139),
+ Connection(139, 156),
+ Connection(188, 122),
+ Connection(122, 245),
+ Connection(245, 188),
+ Connection(139, 71),
+ Connection(71, 162),
+ Connection(162, 139),
+ Connection(149, 170),
+ Connection(170, 150),
+ Connection(150, 149),
+ Connection(122, 188),
+ Connection(188, 196),
+ Connection(196, 122),
+ Connection(206, 216),
+ Connection(216, 92),
+ Connection(92, 206),
+ Connection(164, 2),
+ Connection(2, 167),
+ Connection(167, 164),
+ Connection(242, 141),
+ Connection(141, 241),
+ Connection(241, 242),
+ Connection(0, 164),
+ Connection(164, 37),
+ Connection(37, 0),
+ Connection(11, 72),
+ Connection(72, 12),
+ Connection(12, 11),
+ Connection(12, 38),
+ Connection(38, 13),
+ Connection(13, 12),
+ Connection(70, 63),
+ Connection(63, 71),
+ Connection(71, 70),
+ Connection(31, 226),
+ Connection(226, 111),
+ Connection(111, 31),
+ Connection(36, 101),
+ Connection(101, 205),
+ Connection(205, 36),
+ Connection(203, 206),
+ Connection(206, 165),
+ Connection(165, 203),
+ Connection(126, 209),
+ Connection(209, 217),
+ Connection(217, 126),
+ Connection(98, 165),
+ Connection(165, 97),
+ Connection(97, 98),
+ Connection(237, 220),
+ Connection(220, 218),
+ Connection(218, 237),
+ Connection(237, 239),
+ Connection(239, 241),
+ Connection(241, 237),
+ Connection(210, 214),
+ Connection(214, 169),
+ Connection(169, 210),
+ Connection(140, 171),
+ Connection(171, 32),
+ Connection(32, 140),
+ Connection(241, 125),
+ Connection(125, 237),
+ Connection(237, 241),
+ Connection(179, 86),
+ Connection(86, 178),
+ Connection(178, 179),
+ Connection(180, 85),
+ Connection(85, 179),
+ Connection(179, 180),
+ Connection(181, 84),
+ Connection(84, 180),
+ Connection(180, 181),
+ Connection(182, 83),
+ Connection(83, 181),
+ Connection(181, 182),
+ Connection(194, 201),
+ Connection(201, 182),
+ Connection(182, 194),
+ Connection(177, 137),
+ Connection(137, 132),
+ Connection(132, 177),
+ Connection(184, 76),
+ Connection(76, 183),
+ Connection(183, 184),
+ Connection(185, 61),
+ Connection(61, 184),
+ Connection(184, 185),
+ Connection(186, 57),
+ Connection(57, 185),
+ Connection(185, 186),
+ Connection(216, 212),
+ Connection(212, 186),
+ Connection(186, 216),
+ Connection(192, 214),
+ Connection(214, 187),
+ Connection(187, 192),
+ Connection(139, 34),
+ Connection(34, 156),
+ Connection(156, 139),
+ Connection(218, 79),
+ Connection(79, 237),
+ Connection(237, 218),
+ Connection(147, 123),
+ Connection(123, 177),
+ Connection(177, 147),
+ Connection(45, 44),
+ Connection(44, 4),
+ Connection(4, 45),
+ Connection(208, 201),
+ Connection(201, 32),
+ Connection(32, 208),
+ Connection(98, 64),
+ Connection(64, 129),
+ Connection(129, 98),
+ Connection(192, 213),
+ Connection(213, 138),
+ Connection(138, 192),
+ Connection(235, 59),
+ Connection(59, 219),
+ Connection(219, 235),
+ Connection(141, 242),
+ Connection(242, 97),
+ Connection(97, 141),
+ Connection(97, 2),
+ Connection(2, 141),
+ Connection(141, 97),
+ Connection(240, 75),
+ Connection(75, 235),
+ Connection(235, 240),
+ Connection(229, 24),
+ Connection(24, 228),
+ Connection(228, 229),
+ Connection(31, 25),
+ Connection(25, 226),
+ Connection(226, 31),
+ Connection(230, 23),
+ Connection(23, 229),
+ Connection(229, 230),
+ Connection(231, 22),
+ Connection(22, 230),
+ Connection(230, 231),
+ Connection(232, 26),
+ Connection(26, 231),
+ Connection(231, 232),
+ Connection(233, 112),
+ Connection(112, 232),
+ Connection(232, 233),
+ Connection(244, 189),
+ Connection(189, 243),
+ Connection(243, 244),
+ Connection(189, 221),
+ Connection(221, 190),
+ Connection(190, 189),
+ Connection(222, 28),
+ Connection(28, 221),
+ Connection(221, 222),
+ Connection(223, 27),
+ Connection(27, 222),
+ Connection(222, 223),
+ Connection(224, 29),
+ Connection(29, 223),
+ Connection(223, 224),
+ Connection(225, 30),
+ Connection(30, 224),
+ Connection(224, 225),
+ Connection(113, 247),
+ Connection(247, 225),
+ Connection(225, 113),
+ Connection(99, 60),
+ Connection(60, 240),
+ Connection(240, 99),
+ Connection(213, 147),
+ Connection(147, 215),
+ Connection(215, 213),
+ Connection(60, 20),
+ Connection(20, 166),
+ Connection(166, 60),
+ Connection(192, 187),
+ Connection(187, 213),
+ Connection(213, 192),
+ Connection(243, 112),
+ Connection(112, 244),
+ Connection(244, 243),
+ Connection(244, 233),
+ Connection(233, 245),
+ Connection(245, 244),
+ Connection(245, 128),
+ Connection(128, 188),
+ Connection(188, 245),
+ Connection(188, 114),
+ Connection(114, 174),
+ Connection(174, 188),
+ Connection(134, 131),
+ Connection(131, 220),
+ Connection(220, 134),
+ Connection(174, 217),
+ Connection(217, 236),
+ Connection(236, 174),
+ Connection(236, 198),
+ Connection(198, 134),
+ Connection(134, 236),
+ Connection(215, 177),
+ Connection(177, 58),
+ Connection(58, 215),
+ Connection(156, 143),
+ Connection(143, 124),
+ Connection(124, 156),
+ Connection(25, 110),
+ Connection(110, 7),
+ Connection(7, 25),
+ Connection(31, 228),
+ Connection(228, 25),
+ Connection(25, 31),
+ Connection(264, 356),
+ Connection(356, 368),
+ Connection(368, 264),
+ Connection(0, 11),
+ Connection(11, 267),
+ Connection(267, 0),
+ Connection(451, 452),
+ Connection(452, 349),
+ Connection(349, 451),
+ Connection(267, 302),
+ Connection(302, 269),
+ Connection(269, 267),
+ Connection(350, 357),
+ Connection(357, 277),
+ Connection(277, 350),
+ Connection(350, 452),
+ Connection(452, 357),
+ Connection(357, 350),
+ Connection(299, 333),
+ Connection(333, 297),
+ Connection(297, 299),
+ Connection(396, 175),
+ Connection(175, 377),
+ Connection(377, 396),
+ Connection(280, 347),
+ Connection(347, 330),
+ Connection(330, 280),
+ Connection(269, 303),
+ Connection(303, 270),
+ Connection(270, 269),
+ Connection(151, 9),
+ Connection(9, 337),
+ Connection(337, 151),
+ Connection(344, 278),
+ Connection(278, 360),
+ Connection(360, 344),
+ Connection(424, 418),
+ Connection(418, 431),
+ Connection(431, 424),
+ Connection(270, 304),
+ Connection(304, 409),
+ Connection(409, 270),
+ Connection(272, 310),
+ Connection(310, 407),
+ Connection(407, 272),
+ Connection(322, 270),
+ Connection(270, 410),
+ Connection(410, 322),
+ Connection(449, 450),
+ Connection(450, 347),
+ Connection(347, 449),
+ Connection(432, 422),
+ Connection(422, 434),
+ Connection(434, 432),
+ Connection(18, 313),
+ Connection(313, 17),
+ Connection(17, 18),
+ Connection(291, 306),
+ Connection(306, 375),
+ Connection(375, 291),
+ Connection(259, 387),
+ Connection(387, 260),
+ Connection(260, 259),
+ Connection(424, 335),
+ Connection(335, 418),
+ Connection(418, 424),
+ Connection(434, 364),
+ Connection(364, 416),
+ Connection(416, 434),
+ Connection(391, 423),
+ Connection(423, 327),
+ Connection(327, 391),
+ Connection(301, 251),
+ Connection(251, 298),
+ Connection(298, 301),
+ Connection(275, 281),
+ Connection(281, 4),
+ Connection(4, 275),
+ Connection(254, 373),
+ Connection(373, 253),
+ Connection(253, 254),
+ Connection(375, 307),
+ Connection(307, 321),
+ Connection(321, 375),
+ Connection(280, 425),
+ Connection(425, 411),
+ Connection(411, 280),
+ Connection(200, 421),
+ Connection(421, 18),
+ Connection(18, 200),
+ Connection(335, 321),
+ Connection(321, 406),
+ Connection(406, 335),
+ Connection(321, 320),
+ Connection(320, 405),
+ Connection(405, 321),
+ Connection(314, 315),
+ Connection(315, 17),
+ Connection(17, 314),
+ Connection(423, 426),
+ Connection(426, 266),
+ Connection(266, 423),
+ Connection(396, 377),
+ Connection(377, 369),
+ Connection(369, 396),
+ Connection(270, 322),
+ Connection(322, 269),
+ Connection(269, 270),
+ Connection(413, 417),
+ Connection(417, 464),
+ Connection(464, 413),
+ Connection(385, 386),
+ Connection(386, 258),
+ Connection(258, 385),
+ Connection(248, 456),
+ Connection(456, 419),
+ Connection(419, 248),
+ Connection(298, 284),
+ Connection(284, 333),
+ Connection(333, 298),
+ Connection(168, 417),
+ Connection(417, 8),
+ Connection(8, 168),
+ Connection(448, 346),
+ Connection(346, 261),
+ Connection(261, 448),
+ Connection(417, 413),
+ Connection(413, 285),
+ Connection(285, 417),
+ Connection(326, 327),
+ Connection(327, 328),
+ Connection(328, 326),
+ Connection(277, 355),
+ Connection(355, 329),
+ Connection(329, 277),
+ Connection(309, 392),
+ Connection(392, 438),
+ Connection(438, 309),
+ Connection(381, 382),
+ Connection(382, 256),
+ Connection(256, 381),
+ Connection(279, 429),
+ Connection(429, 360),
+ Connection(360, 279),
+ Connection(365, 364),
+ Connection(364, 379),
+ Connection(379, 365),
+ Connection(355, 277),
+ Connection(277, 437),
+ Connection(437, 355),
+ Connection(282, 443),
+ Connection(443, 283),
+ Connection(283, 282),
+ Connection(281, 275),
+ Connection(275, 363),
+ Connection(363, 281),
+ Connection(395, 431),
+ Connection(431, 369),
+ Connection(369, 395),
+ Connection(299, 297),
+ Connection(297, 337),
+ Connection(337, 299),
+ Connection(335, 273),
+ Connection(273, 321),
+ Connection(321, 335),
+ Connection(348, 450),
+ Connection(450, 349),
+ Connection(349, 348),
+ Connection(359, 446),
+ Connection(446, 467),
+ Connection(467, 359),
+ Connection(283, 293),
+ Connection(293, 282),
+ Connection(282, 283),
+ Connection(250, 458),
+ Connection(458, 462),
+ Connection(462, 250),
+ Connection(300, 276),
+ Connection(276, 383),
+ Connection(383, 300),
+ Connection(292, 308),
+ Connection(308, 325),
+ Connection(325, 292),
+ Connection(283, 276),
+ Connection(276, 293),
+ Connection(293, 283),
+ Connection(264, 372),
+ Connection(372, 447),
+ Connection(447, 264),
+ Connection(346, 352),
+ Connection(352, 340),
+ Connection(340, 346),
+ Connection(354, 274),
+ Connection(274, 19),
+ Connection(19, 354),
+ Connection(363, 456),
+ Connection(456, 281),
+ Connection(281, 363),
+ Connection(426, 436),
+ Connection(436, 425),
+ Connection(425, 426),
+ Connection(380, 381),
+ Connection(381, 252),
+ Connection(252, 380),
+ Connection(267, 269),
+ Connection(269, 393),
+ Connection(393, 267),
+ Connection(421, 200),
+ Connection(200, 428),
+ Connection(428, 421),
+ Connection(371, 266),
+ Connection(266, 329),
+ Connection(329, 371),
+ Connection(432, 287),
+ Connection(287, 422),
+ Connection(422, 432),
+ Connection(290, 250),
+ Connection(250, 328),
+ Connection(328, 290),
+ Connection(385, 258),
+ Connection(258, 384),
+ Connection(384, 385),
+ Connection(446, 265),
+ Connection(265, 342),
+ Connection(342, 446),
+ Connection(386, 387),
+ Connection(387, 257),
+ Connection(257, 386),
+ Connection(422, 424),
+ Connection(424, 430),
+ Connection(430, 422),
+ Connection(445, 342),
+ Connection(342, 276),
+ Connection(276, 445),
+ Connection(422, 273),
+ Connection(273, 424),
+ Connection(424, 422),
+ Connection(306, 292),
+ Connection(292, 307),
+ Connection(307, 306),
+ Connection(352, 366),
+ Connection(366, 345),
+ Connection(345, 352),
+ Connection(268, 271),
+ Connection(271, 302),
+ Connection(302, 268),
+ Connection(358, 423),
+ Connection(423, 371),
+ Connection(371, 358),
+ Connection(327, 294),
+ Connection(294, 460),
+ Connection(460, 327),
+ Connection(331, 279),
+ Connection(279, 294),
+ Connection(294, 331),
+ Connection(303, 271),
+ Connection(271, 304),
+ Connection(304, 303),
+ Connection(436, 432),
+ Connection(432, 427),
+ Connection(427, 436),
+ Connection(304, 272),
+ Connection(272, 408),
+ Connection(408, 304),
+ Connection(395, 394),
+ Connection(394, 431),
+ Connection(431, 395),
+ Connection(378, 395),
+ Connection(395, 400),
+ Connection(400, 378),
+ Connection(296, 334),
+ Connection(334, 299),
+ Connection(299, 296),
+ Connection(6, 351),
+ Connection(351, 168),
+ Connection(168, 6),
+ Connection(376, 352),
+ Connection(352, 411),
+ Connection(411, 376),
+ Connection(307, 325),
+ Connection(325, 320),
+ Connection(320, 307),
+ Connection(285, 295),
+ Connection(295, 336),
+ Connection(336, 285),
+ Connection(320, 319),
+ Connection(319, 404),
+ Connection(404, 320),
+ Connection(329, 330),
+ Connection(330, 349),
+ Connection(349, 329),
+ Connection(334, 293),
+ Connection(293, 333),
+ Connection(333, 334),
+ Connection(366, 323),
+ Connection(323, 447),
+ Connection(447, 366),
+ Connection(316, 15),
+ Connection(15, 315),
+ Connection(315, 316),
+ Connection(331, 358),
+ Connection(358, 279),
+ Connection(279, 331),
+ Connection(317, 14),
+ Connection(14, 316),
+ Connection(316, 317),
+ Connection(8, 285),
+ Connection(285, 9),
+ Connection(9, 8),
+ Connection(277, 329),
+ Connection(329, 350),
+ Connection(350, 277),
+ Connection(253, 374),
+ Connection(374, 252),
+ Connection(252, 253),
+ Connection(319, 318),
+ Connection(318, 403),
+ Connection(403, 319),
+ Connection(351, 6),
+ Connection(6, 419),
+ Connection(419, 351),
+ Connection(324, 318),
+ Connection(318, 325),
+ Connection(325, 324),
+ Connection(397, 367),
+ Connection(367, 365),
+ Connection(365, 397),
+ Connection(288, 435),
+ Connection(435, 397),
+ Connection(397, 288),
+ Connection(278, 344),
+ Connection(344, 439),
+ Connection(439, 278),
+ Connection(310, 272),
+ Connection(272, 311),
+ Connection(311, 310),
+ Connection(248, 195),
+ Connection(195, 281),
+ Connection(281, 248),
+ Connection(375, 273),
+ Connection(273, 291),
+ Connection(291, 375),
+ Connection(175, 396),
+ Connection(396, 199),
+ Connection(199, 175),
+ Connection(312, 311),
+ Connection(311, 268),
+ Connection(268, 312),
+ Connection(276, 283),
+ Connection(283, 445),
+ Connection(445, 276),
+ Connection(390, 373),
+ Connection(373, 339),
+ Connection(339, 390),
+ Connection(295, 282),
+ Connection(282, 296),
+ Connection(296, 295),
+ Connection(448, 449),
+ Connection(449, 346),
+ Connection(346, 448),
+ Connection(356, 264),
+ Connection(264, 454),
+ Connection(454, 356),
+ Connection(337, 336),
+ Connection(336, 299),
+ Connection(299, 337),
+ Connection(337, 338),
+ Connection(338, 151),
+ Connection(151, 337),
+ Connection(294, 278),
+ Connection(278, 455),
+ Connection(455, 294),
+ Connection(308, 292),
+ Connection(292, 415),
+ Connection(415, 308),
+ Connection(429, 358),
+ Connection(358, 355),
+ Connection(355, 429),
+ Connection(265, 340),
+ Connection(340, 372),
+ Connection(372, 265),
+ Connection(352, 346),
+ Connection(346, 280),
+ Connection(280, 352),
+ Connection(295, 442),
+ Connection(442, 282),
+ Connection(282, 295),
+ Connection(354, 19),
+ Connection(19, 370),
+ Connection(370, 354),
+ Connection(285, 441),
+ Connection(441, 295),
+ Connection(295, 285),
+ Connection(195, 248),
+ Connection(248, 197),
+ Connection(197, 195),
+ Connection(457, 440),
+ Connection(440, 274),
+ Connection(274, 457),
+ Connection(301, 300),
+ Connection(300, 368),
+ Connection(368, 301),
+ Connection(417, 351),
+ Connection(351, 465),
+ Connection(465, 417),
+ Connection(251, 301),
+ Connection(301, 389),
+ Connection(389, 251),
+ Connection(394, 395),
+ Connection(395, 379),
+ Connection(379, 394),
+ Connection(399, 412),
+ Connection(412, 419),
+ Connection(419, 399),
+ Connection(410, 436),
+ Connection(436, 322),
+ Connection(322, 410),
+ Connection(326, 2),
+ Connection(2, 393),
+ Connection(393, 326),
+ Connection(354, 370),
+ Connection(370, 461),
+ Connection(461, 354),
+ Connection(393, 164),
+ Connection(164, 267),
+ Connection(267, 393),
+ Connection(268, 302),
+ Connection(302, 12),
+ Connection(12, 268),
+ Connection(312, 268),
+ Connection(268, 13),
+ Connection(13, 312),
+ Connection(298, 293),
+ Connection(293, 301),
+ Connection(301, 298),
+ Connection(265, 446),
+ Connection(446, 340),
+ Connection(340, 265),
+ Connection(280, 330),
+ Connection(330, 425),
+ Connection(425, 280),
+ Connection(322, 426),
+ Connection(426, 391),
+ Connection(391, 322),
+ Connection(420, 429),
+ Connection(429, 437),
+ Connection(437, 420),
+ Connection(393, 391),
+ Connection(391, 326),
+ Connection(326, 393),
+ Connection(344, 440),
+ Connection(440, 438),
+ Connection(438, 344),
+ Connection(458, 459),
+ Connection(459, 461),
+ Connection(461, 458),
+ Connection(364, 434),
+ Connection(434, 394),
+ Connection(394, 364),
+ Connection(428, 396),
+ Connection(396, 262),
+ Connection(262, 428),
+ Connection(274, 354),
+ Connection(354, 457),
+ Connection(457, 274),
+ Connection(317, 316),
+ Connection(316, 402),
+ Connection(402, 317),
+ Connection(316, 315),
+ Connection(315, 403),
+ Connection(403, 316),
+ Connection(315, 314),
+ Connection(314, 404),
+ Connection(404, 315),
+ Connection(314, 313),
+ Connection(313, 405),
+ Connection(405, 314),
+ Connection(313, 421),
+ Connection(421, 406),
+ Connection(406, 313),
+ Connection(323, 366),
+ Connection(366, 361),
+ Connection(361, 323),
+ Connection(292, 306),
+ Connection(306, 407),
+ Connection(407, 292),
+ Connection(306, 291),
+ Connection(291, 408),
+ Connection(408, 306),
+ Connection(291, 287),
+ Connection(287, 409),
+ Connection(409, 291),
+ Connection(287, 432),
+ Connection(432, 410),
+ Connection(410, 287),
+ Connection(427, 434),
+ Connection(434, 411),
+ Connection(411, 427),
+ Connection(372, 264),
+ Connection(264, 383),
+ Connection(383, 372),
+ Connection(459, 309),
+ Connection(309, 457),
+ Connection(457, 459),
+ Connection(366, 352),
+ Connection(352, 401),
+ Connection(401, 366),
+ Connection(1, 274),
+ Connection(274, 4),
+ Connection(4, 1),
+ Connection(418, 421),
+ Connection(421, 262),
+ Connection(262, 418),
+ Connection(331, 294),
+ Connection(294, 358),
+ Connection(358, 331),
+ Connection(435, 433),
+ Connection(433, 367),
+ Connection(367, 435),
+ Connection(392, 289),
+ Connection(289, 439),
+ Connection(439, 392),
+ Connection(328, 462),
+ Connection(462, 326),
+ Connection(326, 328),
+ Connection(94, 2),
+ Connection(2, 370),
+ Connection(370, 94),
+ Connection(289, 305),
+ Connection(305, 455),
+ Connection(455, 289),
+ Connection(339, 254),
+ Connection(254, 448),
+ Connection(448, 339),
+ Connection(359, 255),
+ Connection(255, 446),
+ Connection(446, 359),
+ Connection(254, 253),
+ Connection(253, 449),
+ Connection(449, 254),
+ Connection(253, 252),
+ Connection(252, 450),
+ Connection(450, 253),
+ Connection(252, 256),
+ Connection(256, 451),
+ Connection(451, 252),
+ Connection(256, 341),
+ Connection(341, 452),
+ Connection(452, 256),
+ Connection(414, 413),
+ Connection(413, 463),
+ Connection(463, 414),
+ Connection(286, 441),
+ Connection(441, 414),
+ Connection(414, 286),
+ Connection(286, 258),
+ Connection(258, 441),
+ Connection(441, 286),
+ Connection(258, 257),
+ Connection(257, 442),
+ Connection(442, 258),
+ Connection(257, 259),
+ Connection(259, 443),
+ Connection(443, 257),
+ Connection(259, 260),
+ Connection(260, 444),
+ Connection(444, 259),
+ Connection(260, 467),
+ Connection(467, 445),
+ Connection(445, 260),
+ Connection(309, 459),
+ Connection(459, 250),
+ Connection(250, 309),
+ Connection(305, 289),
+ Connection(289, 290),
+ Connection(290, 305),
+ Connection(305, 290),
+ Connection(290, 460),
+ Connection(460, 305),
+ Connection(401, 376),
+ Connection(376, 435),
+ Connection(435, 401),
+ Connection(309, 250),
+ Connection(250, 392),
+ Connection(392, 309),
+ Connection(376, 411),
+ Connection(411, 433),
+ Connection(433, 376),
+ Connection(453, 341),
+ Connection(341, 464),
+ Connection(464, 453),
+ Connection(357, 453),
+ Connection(453, 465),
+ Connection(465, 357),
+ Connection(343, 357),
+ Connection(357, 412),
+ Connection(412, 343),
+ Connection(437, 343),
+ Connection(343, 399),
+ Connection(399, 437),
+ Connection(344, 360),
+ Connection(360, 440),
+ Connection(440, 344),
+ Connection(420, 437),
+ Connection(437, 456),
+ Connection(456, 420),
+ Connection(360, 420),
+ Connection(420, 363),
+ Connection(363, 360),
+ Connection(361, 401),
+ Connection(401, 288),
+ Connection(288, 361),
+ Connection(265, 372),
+ Connection(372, 353),
+ Connection(353, 265),
+ Connection(390, 339),
+ Connection(339, 249),
+ Connection(249, 390),
+ Connection(339, 448),
+ Connection(448, 255),
+ Connection(255, 339),
+ ]
+
+
+@dataclasses.dataclass
+class FaceLandmarkerResult:
+ """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image.
+
+ Attributes:
+ face_landmarks: Detected face landmarks in normalized image coordinates.
+ face_blendshapes: Optional face blendshapes results.
+ facial_transformation_matrixes: Optional facial transformation matrix.
+ """
+
+ face_landmarks: List[List[landmark_module.NormalizedLandmark]]
+ face_blendshapes: List[List[category_module.Category]]
+ facial_transformation_matrixes: List[np.ndarray]
+
+
+def _build_landmarker_result(
+ output_packets: Mapping[str, packet_module.Packet]
+) -> FaceLandmarkerResult:
+ """Constructs a `FaceLandmarkerResult` from output packets."""
+ face_landmarks_proto_list = packet_getter.get_proto_list(
+ output_packets[_NORM_LANDMARKS_STREAM_NAME]
+ )
+
+ face_landmarks_results = []
+ for proto in face_landmarks_proto_list:
+ face_landmarks = landmark_pb2.NormalizedLandmarkList()
+ face_landmarks.MergeFrom(proto)
+ face_landmarks_list = []
+ for face_landmark in face_landmarks.landmark:
+ face_landmarks_list.append(
+ landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
+ )
+ face_landmarks_results.append(face_landmarks_list)
+
+ face_blendshapes_results = []
+ if _BLENDSHAPES_STREAM_NAME in output_packets:
+ face_blendshapes_proto_list = packet_getter.get_proto_list(
+ output_packets[_BLENDSHAPES_STREAM_NAME]
+ )
+ for proto in face_blendshapes_proto_list:
+ face_blendshapes_categories = []
+ face_blendshapes_classifications = classification_pb2.ClassificationList()
+ face_blendshapes_classifications.MergeFrom(proto)
+ for face_blendshapes in face_blendshapes_classifications.classification:
+ face_blendshapes_categories.append(
+ category_module.Category(
+ index=face_blendshapes.index,
+ score=face_blendshapes.score,
+ display_name=face_blendshapes.display_name,
+ category_name=face_blendshapes.label,
+ )
+ )
+ face_blendshapes_results.append(face_blendshapes_categories)
+
+ facial_transformation_matrixes_results = []
+ if _FACE_GEOMETRY_STREAM_NAME in output_packets:
+ facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
+ output_packets[_FACE_GEOMETRY_STREAM_NAME]
+ )
+ for proto in facial_transformation_matrixes_proto_list:
+ if hasattr(proto, 'pose_transform_matrix'):
+ matrix_data = matrix_data_pb2.MatrixData()
+ matrix_data.MergeFrom(proto.pose_transform_matrix)
+ matrix = np.array(matrix_data.packed_data)
+ matrix = matrix.reshape((matrix_data.rows, matrix_data.cols))
+ matrix = (
+ matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T
+ )
+ facial_transformation_matrixes_results.append(matrix)
+
+ return FaceLandmarkerResult(
+ face_landmarks_results,
+ face_blendshapes_results,
+ facial_transformation_matrixes_results,
+ )
+
+def _build_landmarker_result2(
+ output_packets: Mapping[str, packet_module.Packet]
+) -> FaceLandmarkerResult:
+ """Constructs a `FaceLandmarkerResult` from output packets."""
+ face_landmarks_proto_list = packet_getter.get_proto_list(
+ output_packets[_NORM_LANDMARKS_STREAM_NAME]
+ )
+
+ face_landmarks_results = []
+ for proto in face_landmarks_proto_list:
+ face_landmarks = landmark_pb2.NormalizedLandmarkList()
+ face_landmarks.MergeFrom(proto)
+ face_landmarks_list = []
+ for face_landmark in face_landmarks.landmark:
+ face_landmarks_list.append(
+ landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
+ )
+ face_landmarks_results.append(face_landmarks_list)
+
+ face_blendshapes_results = []
+ if _BLENDSHAPES_STREAM_NAME in output_packets:
+ face_blendshapes_proto_list = packet_getter.get_proto_list(
+ output_packets[_BLENDSHAPES_STREAM_NAME]
+ )
+ for proto in face_blendshapes_proto_list:
+ face_blendshapes_categories = []
+ face_blendshapes_classifications = classification_pb2.ClassificationList()
+ face_blendshapes_classifications.MergeFrom(proto)
+ for face_blendshapes in face_blendshapes_classifications.classification:
+ face_blendshapes_categories.append(
+ category_module.Category(
+ index=face_blendshapes.index,
+ score=face_blendshapes.score,
+ display_name=face_blendshapes.display_name,
+ category_name=face_blendshapes.label,
+ )
+ )
+ face_blendshapes_results.append(face_blendshapes_categories)
+
+ facial_transformation_matrixes_results = []
+ if _FACE_GEOMETRY_STREAM_NAME in output_packets:
+ facial_transformation_matrixes_proto_list = packet_getter.get_proto_list(
+ output_packets[_FACE_GEOMETRY_STREAM_NAME]
+ )
+ for proto in facial_transformation_matrixes_proto_list:
+ if hasattr(proto, 'pose_transform_matrix'):
+ matrix_data = matrix_data_pb2.MatrixData()
+ matrix_data.MergeFrom(proto.pose_transform_matrix)
+ matrix = np.array(matrix_data.packed_data)
+ matrix = matrix.reshape((matrix_data.rows, matrix_data.cols))
+ matrix = (
+ matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T
+ )
+ facial_transformation_matrixes_results.append(matrix)
+
+ return FaceLandmarkerResult(
+ face_landmarks_results,
+ face_blendshapes_results,
+ facial_transformation_matrixes_results,
+ ), facial_transformation_matrixes_proto_list[0].mesh
+
+@dataclasses.dataclass
+class FaceLandmarkerOptions:
+ """Options for the face landmarker task.
+
+ Attributes:
+ base_options: Base options for the face landmarker task.
+ running_mode: The running mode of the task. Default to the image mode.
+ FaceLandmarker has three running modes: 1) The image mode for detecting
+ face landmarks on single image inputs. 2) The video mode for detecting
+ face landmarks on the decoded frames of a video. 3) The live stream mode
+ for detecting face landmarks on the live stream of input data, such as
+ from camera. In this mode, the "result_callback" below must be specified
+ to receive the detection results asynchronously.
+ num_faces: The maximum number of faces that can be detected by the
+ FaceLandmarker.
+ min_face_detection_confidence: The minimum confidence score for the face
+ detection to be considered successful.
+ min_face_presence_confidence: The minimum confidence score of face presence
+ score in the face landmark detection.
+ min_tracking_confidence: The minimum confidence score for the face tracking
+ to be considered successful.
+ output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes
+ classification. Face blendshapes are used for rendering the 3D face model.
+ output_facial_transformation_matrixes: Whether FaceLandmarker outputs facial
+ transformation_matrix. Facial transformation matrix is used to transform
+ the face landmarks in canonical face to the detected face, so that users
+ can apply face effects on the detected landmarks.
+ result_callback: The user-defined result callback for processing live stream
+ data. The result callback should only be specified when the running mode
+ is set to the live stream mode.
+ """
+
+ base_options: _BaseOptions
+ running_mode: _RunningMode = _RunningMode.IMAGE
+ num_faces: int = 1
+ min_face_detection_confidence: float = 0.5
+ min_face_presence_confidence: float = 0.5
+ min_tracking_confidence: float = 0.5
+ output_face_blendshapes: bool = False
+ output_facial_transformation_matrixes: bool = False
+ result_callback: Optional[
+ Callable[[FaceLandmarkerResult, image_module.Image, int], None]
+ ] = None
+
+ @doc_controls.do_not_generate_docs
+ def to_pb2(self) -> _FaceLandmarkerGraphOptionsProto:
+ """Generates an FaceLandmarkerGraphOptions protobuf object."""
+ base_options_proto = self.base_options.to_pb2()
+ base_options_proto.use_stream_mode = (
+ False if self.running_mode == _RunningMode.IMAGE else True
+ )
+
+ # Initialize the face landmarker options from base options.
+ face_landmarker_options_proto = _FaceLandmarkerGraphOptionsProto(
+ base_options=base_options_proto
+ )
+
+ # Configure face detector options.
+ face_landmarker_options_proto.face_detector_graph_options.num_faces = (
+ self.num_faces
+ )
+ face_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = (
+ self.min_face_detection_confidence
+ )
+
+ # Configure face landmark detector options.
+ face_landmarker_options_proto.min_tracking_confidence = (
+ self.min_tracking_confidence
+ )
+ face_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = (
+ self.min_face_detection_confidence
+ )
+ return face_landmarker_options_proto
+
+
+class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi):
+ """Class that performs face landmarks detection on images."""
+
+ @classmethod
+ def create_from_model_path(cls, model_path: str) -> 'FaceLandmarker':
+ """Creates an `FaceLandmarker` object from a TensorFlow Lite model and the default `FaceLandmarkerOptions`.
+
+ Note that the created `FaceLandmarker` instance is in image mode, for
+ detecting face landmarks on single image inputs.
+
+ Args:
+ model_path: Path to the model.
+
+ Returns:
+ `FaceLandmarker` object that's created from the model file and the
+ default `FaceLandmarkerOptions`.
+
+ Raises:
+ ValueError: If failed to create `FaceLandmarker` object from the
+ provided file such as invalid file path.
+ RuntimeError: If other types of error occurred.
+ """
+ base_options = _BaseOptions(model_asset_path=model_path)
+ options = FaceLandmarkerOptions(
+ base_options=base_options, running_mode=_RunningMode.IMAGE
+ )
+ return cls.create_from_options(options)
+
+ @classmethod
+ def create_from_options(
+ cls, options: FaceLandmarkerOptions
+ ) -> 'FaceLandmarker':
+ """Creates the `FaceLandmarker` object from face landmarker options.
+
+ Args:
+ options: Options for the face landmarker task.
+
+ Returns:
+ `FaceLandmarker` object that's created from `options`.
+
+ Raises:
+ ValueError: If failed to create `FaceLandmarker` object from
+ `FaceLandmarkerOptions` such as missing the model.
+ RuntimeError: If other types of error occurred.
+ """
+
+ def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
+ if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
+ return
+
+ image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
+ if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
+ return
+
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
+ empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME]
+ options.result_callback(
+ FaceLandmarkerResult([], [], []),
+ image,
+ empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
+ )
+ return
+
+ face_landmarks_result = _build_landmarker_result(output_packets)
+ timestamp = output_packets[_NORM_LANDMARKS_STREAM_NAME].timestamp
+ options.result_callback(
+ face_landmarks_result,
+ image,
+ timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
+ )
+
+ output_streams = [
+ ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]),
+ ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
+ ]
+
+ if options.output_face_blendshapes:
+ output_streams.append(
+ ':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME])
+ )
+ if options.output_facial_transformation_matrixes:
+ output_streams.append(
+ ':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME])
+ )
+
+ task_info = _TaskInfo(
+ task_graph=_TASK_GRAPH_NAME,
+ input_streams=[
+ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
+ ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
+ ],
+ output_streams=output_streams,
+ task_options=options,
+ )
+ return cls(
+ task_info.generate_graph_config(
+ enable_flow_limiting=options.running_mode
+ == _RunningMode.LIVE_STREAM
+ ),
+ options.running_mode,
+ packets_callback if options.result_callback else None,
+ )
+
+ def detect(
+ self,
+ image: image_module.Image,
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
+ ) -> FaceLandmarkerResult:
+ """Performs face landmarks detection on the given image.
+
+ Only use this method when the FaceLandmarker is created with the image
+ running mode.
+
+ The image can be of any size with format RGB or RGBA.
+ TODO: Describes how the input image will be preprocessed after the yuv
+ support is implemented.
+
+ Args:
+ image: MediaPipe Image.
+ image_processing_options: Options for image processing.
+
+ Returns:
+ The face landmarks detection results.
+
+ Raises:
+ ValueError: If any of the input arguments is invalid.
+ RuntimeError: If face landmarker detection failed to run.
+ """
+
+ normalized_rect = self.convert_to_normalized_rect(
+ image_processing_options, image, roi_allowed=False
+ )
+ output_packets = self._process_image_data({
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
+ normalized_rect.to_pb2()
+ ),
+ })
+
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
+ return FaceLandmarkerResult([], [], [])
+
+ return _build_landmarker_result2(output_packets)
+
+ def detect_for_video(
+ self,
+ image: image_module.Image,
+ timestamp_ms: int,
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
+ ):
+ """Performs face landmarks detection on the provided video frame.
+
+ Only use this method when the FaceLandmarker is created with the video
+ running mode.
+
+ Only use this method when the FaceLandmarker is created with the video
+ running mode. It's required to provide the video frame's timestamp (in
+ milliseconds) along with the video frame. The input timestamps should be
+ monotonically increasing for adjacent calls of this method.
+
+ Args:
+ image: MediaPipe Image.
+ timestamp_ms: The timestamp of the input video frame in milliseconds.
+ image_processing_options: Options for image processing.
+
+ Returns:
+ The face landmarks detection results.
+
+ Raises:
+ ValueError: If any of the input arguments is invalid.
+ RuntimeError: If face landmarker detection failed to run.
+ """
+ normalized_rect = self.convert_to_normalized_rect(
+ image_processing_options, image, roi_allowed=False
+ )
+ output_packets = self._process_video_data({
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
+ timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
+ ),
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
+ normalized_rect.to_pb2()
+ ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
+ })
+
+ if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
+ return FaceLandmarkerResult([], [], [])
+
+ return _build_landmarker_result2(output_packets)
+
+ def detect_async(
+ self,
+ image: image_module.Image,
+ timestamp_ms: int,
+ image_processing_options: Optional[_ImageProcessingOptions] = None,
+ ) -> None:
+ """Sends live image data to perform face landmarks detection.
+
+ The results will be available via the "result_callback" provided in the
+ FaceLandmarkerOptions. Only use this method when the FaceLandmarker is
+ created with the live stream running mode.
+
+ Only use this method when the FaceLandmarker is created with the live
+ stream running mode. The input timestamps should be monotonically increasing
+ for adjacent calls of this method. This method will return immediately after
+ the input image is accepted. The results will be available via the
+ `result_callback` provided in the `FaceLandmarkerOptions`. The
+ `detect_async` method is designed to process live stream data such as
+ camera input. To lower the overall latency, face landmarker may drop the
+ input images if needed. In other words, it's not guaranteed to have output
+ per input image.
+
+ The `result_callback` provides:
+ - The face landmarks detection results.
+ - The input image that the face landmarker runs on.
+ - The input timestamp in milliseconds.
+
+ Args:
+ image: MediaPipe Image.
+ timestamp_ms: The timestamp of the input image in milliseconds.
+ image_processing_options: Options for image processing.
+
+ Raises:
+ ValueError: If the current input timestamp is smaller than what the
+ face landmarker has already processed.
+ """
+ normalized_rect = self.convert_to_normalized_rect(
+ image_processing_options, image, roi_allowed=False
+ )
+ self._send_live_stream_data({
+ _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
+ timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
+ ),
+ _NORM_RECT_STREAM_NAME: packet_creator.create_proto(
+ normalized_rect.to_pb2()
+ ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
+ })
\ No newline at end of file
diff --git a/aniportrait/src/utils/frame_interpolation.py b/aniportrait/src/utils/frame_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ae04817ffef8aaf8a980cfd27c728dc496eeae4
--- /dev/null
+++ b/aniportrait/src/utils/frame_interpolation.py
@@ -0,0 +1,69 @@
+# Adapted from https://github.com/dajes/frame-interpolation-pytorch
+import os
+import cv2
+import numpy as np
+import torch
+import bisect
+import shutil
+import pdb
+from tqdm import tqdm
+
+def init_frame_interpolation_model():
+ print("Initializing frame interpolation model")
+ checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
+
+ model = torch.jit.load(checkpoint_name, map_location='cpu')
+ model.eval()
+ model = model.half()
+ model = model.to(device="cuda")
+ return model
+
+
+def batch_images_interpolation_tool(input_tensor, model, inter_frames=1):
+
+ video_tensor = []
+ frame_num = input_tensor.shape[2] # bs, channel, frame, height, width
+
+ for idx in tqdm(range(frame_num-1)):
+ image1 = input_tensor[:,:,idx]
+ image2 = input_tensor[:,:,idx+1]
+
+ results = [image1, image2]
+
+ inter_frames = int(inter_frames)
+ idxes = [0, inter_frames + 1]
+ remains = list(range(1, inter_frames + 1))
+
+ splits = torch.linspace(0, 1, inter_frames + 2)
+
+ for _ in range(len(remains)):
+ starts = splits[idxes[:-1]]
+ ends = splits[idxes[1:]]
+ distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
+ matrix = torch.argmin(distances).item()
+ start_i, step = np.unravel_index(matrix, distances.shape)
+ end_i = start_i + 1
+
+ x0 = results[start_i]
+ x1 = results[end_i]
+
+ x0 = x0.half()
+ x1 = x1.half()
+ x0 = x0.cuda()
+ x1 = x1.cuda()
+
+ dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
+
+ with torch.no_grad():
+ prediction = model(x0, x1, dt)
+ insert_position = bisect.bisect_left(idxes, remains[step])
+ idxes.insert(insert_position, remains[step])
+ results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
+ del remains[step]
+
+ for sub_idx in range(len(results)-1):
+ video_tensor.append(results[sub_idx].unsqueeze(2))
+
+ video_tensor.append(input_tensor[:,:,-1].unsqueeze(2))
+ video_tensor = torch.cat(video_tensor, dim=2)
+ return video_tensor
\ No newline at end of file
diff --git a/aniportrait/src/utils/mp_models/blaze_face_short_range.tflite b/aniportrait/src/utils/mp_models/blaze_face_short_range.tflite
new file mode 100644
index 0000000000000000000000000000000000000000..2645898ee18d8bf53746df830303779c9deabc7d
--- /dev/null
+++ b/aniportrait/src/utils/mp_models/blaze_face_short_range.tflite
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
+size 229746
diff --git a/aniportrait/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task b/aniportrait/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task
new file mode 100644
index 0000000000000000000000000000000000000000..fedb14de6d2b6708a56c04ae259783e23404c1aa
--- /dev/null
+++ b/aniportrait/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
+size 3758596
diff --git a/aniportrait/src/utils/mp_models/pose_landmarker_heavy.task b/aniportrait/src/utils/mp_models/pose_landmarker_heavy.task
new file mode 100644
index 0000000000000000000000000000000000000000..5f2c1e254fe2d104606a9031b20b266863d014a6
--- /dev/null
+++ b/aniportrait/src/utils/mp_models/pose_landmarker_heavy.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64437af838a65d18e5ba7a0d39b465540069bc8aae8308de3e318aad31fcbc7b
+size 30664242
diff --git a/aniportrait/src/utils/mp_utils.py b/aniportrait/src/utils/mp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb4385128d1d188fc9023ee6c3df51fb8be39e2f
--- /dev/null
+++ b/aniportrait/src/utils/mp_utils.py
@@ -0,0 +1,95 @@
+import os
+import numpy as np
+import cv2
+import time
+from tqdm import tqdm
+import multiprocessing
+import glob
+
+import mediapipe as mp
+from mediapipe import solutions
+from mediapipe.framework.formats import landmark_pb2
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+from . import face_landmark
+
+CUR_DIR = os.path.dirname(__file__)
+
+
+class LMKExtractor():
+ def __init__(self, FPS=25):
+ # Create an FaceLandmarker object.
+ self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE
+ base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task'))
+ base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU
+ options = vision.FaceLandmarkerOptions(base_options=base_options,
+ running_mode=self.mode,
+ output_face_blendshapes=True,
+ output_facial_transformation_matrixes=True,
+ num_faces=1)
+ self.detector = face_landmark.FaceLandmarker.create_from_options(options)
+ self.last_ts = 0
+ self.frame_ms = int(1000 / FPS)
+
+ det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite'))
+ det_options = vision.FaceDetectorOptions(base_options=det_base_options)
+ self.det_detector = vision.FaceDetector.create_from_options(det_options)
+
+
+ def __call__(self, img):
+ frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
+ t0 = time.time()
+ if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO:
+ det_result = self.det_detector.detect(image)
+ if len(det_result.detections) != 1:
+ return None
+ self.last_ts += self.frame_ms
+ try:
+ detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts)
+ except:
+ return None
+ elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE:
+ # det_result = self.det_detector.detect(image)
+
+ # if len(det_result.detections) != 1:
+ # return None
+ try:
+ detection_result, mesh3d = self.detector.detect(image)
+ except:
+ return None
+
+
+ bs_list = detection_result.face_blendshapes
+ if len(bs_list) == 1:
+ bs = bs_list[0]
+ bs_values = []
+ for index in range(len(bs)):
+ bs_values.append(bs[index].score)
+ bs_values = bs_values[1:] # remove neutral
+ trans_mat = detection_result.facial_transformation_matrixes[0]
+ face_landmarks_list = detection_result.face_landmarks
+ face_landmarks = face_landmarks_list[0]
+ lmks = []
+ for index in range(len(face_landmarks)):
+ x = face_landmarks[index].x
+ y = face_landmarks[index].y
+ z = face_landmarks[index].z
+ lmks.append([x, y, z])
+ lmks = np.array(lmks)
+
+ lmks3d = np.array(mesh3d.vertex_buffer)
+ lmks3d = lmks3d.reshape(-1, 5)[:, :3]
+ mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1
+
+ return {
+ "lmks": lmks,
+ 'lmks3d': lmks3d,
+ "trans_mat": trans_mat,
+ 'faces': mp_tris,
+ "bs": bs_values
+ }
+ else:
+ # print('multiple faces in the image: {}'.format(img_path))
+ return None
+
\ No newline at end of file
diff --git a/aniportrait/src/utils/pose_util.py b/aniportrait/src/utils/pose_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09d07f20d404fbdfb6d444e2896df641ccc364c
--- /dev/null
+++ b/aniportrait/src/utils/pose_util.py
@@ -0,0 +1,89 @@
+import math
+
+import numpy as np
+from scipy.spatial.transform import Rotation as R
+
+
+def create_perspective_matrix(aspect_ratio):
+ kDegreesToRadians = np.pi / 180.
+ near = 1
+ far = 10000
+ perspective_matrix = np.zeros(16, dtype=np.float32)
+
+ # Standard perspective projection matrix calculations.
+ f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.)
+
+ denom = 1.0 / (near - far)
+ perspective_matrix[0] = f / aspect_ratio
+ perspective_matrix[5] = f
+ perspective_matrix[10] = (near + far) * denom
+ perspective_matrix[11] = -1.
+ perspective_matrix[14] = 1. * far * near * denom
+
+ # If the environment's origin point location is in the top left corner,
+ # then skip additional flip along Y-axis is required to render correctly.
+
+ perspective_matrix[5] *= -1.
+ return perspective_matrix
+
+
+def project_points(points_3d, transformation_matrix, pose_vectors, image_shape):
+ P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
+ L, N, _ = points_3d.shape
+ projected_points = np.zeros((L, N, 2))
+ for i in range(L):
+ points_3d_frame = points_3d[i]
+ ones = np.ones((points_3d_frame.shape[0], 1))
+ points_3d_homogeneous = np.hstack([points_3d_frame, ones])
+ transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P
+ projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
+ projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
+ projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
+ projected_points[i] = projected_points_frame
+ return projected_points
+
+
+def project_points_with_trans(points_3d, transformation_matrix, image_shape):
+ P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T
+ L, N, _ = points_3d.shape
+ projected_points = np.zeros((L, N, 2))
+ for i in range(L):
+ points_3d_frame = points_3d[i]
+ ones = np.ones((points_3d_frame.shape[0], 1))
+ points_3d_homogeneous = np.hstack([points_3d_frame, ones])
+ transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P
+ projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1
+ projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1]
+ projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0]
+ projected_points[i] = projected_points_frame
+ return projected_points
+
+
+def euler_and_translation_to_matrix(euler_angles, translation_vector):
+ rotation = R.from_euler('xyz', euler_angles, degrees=True)
+ rotation_matrix = rotation.as_matrix()
+
+ matrix = np.eye(4)
+ matrix[:3, :3] = rotation_matrix
+ matrix[:3, 3] = translation_vector
+
+ return matrix
+
+
+def matrix_to_euler_and_translation(matrix):
+ rotation_matrix = matrix[:3, :3]
+ translation_vector = matrix[:3, 3]
+ rotation = R.from_matrix(rotation_matrix)
+ euler_angles = rotation.as_euler('xyz', degrees=True)
+ return euler_angles, translation_vector
+
+
+def smooth_pose_seq(pose_seq, window_size=5):
+ smoothed_pose_seq = np.zeros_like(pose_seq)
+
+ for i in range(len(pose_seq)):
+ start = max(0, i - window_size // 2)
+ end = min(len(pose_seq), i + window_size // 2 + 1)
+ smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0)
+
+ return smoothed_pose_seq
\ No newline at end of file
diff --git a/aniportrait/src/utils/util.py b/aniportrait/src/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e473082ae134df241868fdb4ef1ffc61129cf10
--- /dev/null
+++ b/aniportrait/src/utils/util.py
@@ -0,0 +1,181 @@
+import importlib
+import os
+import os.path as osp
+import shutil
+import sys
+import cv2
+from pathlib import Path
+
+import av
+import numpy as np
+import torch
+import torchvision
+from einops import rearrange
+from PIL import Image
+
+
+def seed_everything(seed):
+ import random
+
+ import numpy as np
+
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed % (2**32))
+ random.seed(seed)
+
+
+def import_filename(filename):
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def delete_additional_ckpt(base_path, num_keep):
+ dirs = []
+ for d in os.listdir(base_path):
+ if d.startswith("checkpoint-"):
+ dirs.append(d)
+ num_tot = len(dirs)
+ if num_tot <= num_keep:
+ return
+ # ensure ckpt is sorted and delete the ealier!
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
+ for d in del_dirs:
+ path_to_dir = osp.join(base_path, d)
+ if osp.exists(path_to_dir):
+ shutil.rmtree(path_to_dir)
+
+
+def save_videos_from_pil(pil_images, path, fps=8):
+ import av
+
+ save_fmt = Path(path).suffix
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ width, height = pil_images[0].size
+
+ if save_fmt == ".mp4":
+ codec = "libx264"
+ container = av.open(path, "w")
+ stream = container.add_stream(codec, rate=fps)
+
+ stream.width = width
+ stream.height = height
+
+ for pil_image in pil_images:
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
+ av_frame = av.VideoFrame.from_image(pil_image)
+ container.mux(stream.encode(av_frame))
+ container.mux(stream.encode())
+ container.close()
+
+ elif save_fmt == ".gif":
+ pil_images[0].save(
+ fp=path,
+ format="GIF",
+ append_images=pil_images[1:],
+ save_all=True,
+ duration=(1 / fps * 1000),
+ loop=0,
+ )
+ else:
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
+
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ height, width = videos.shape[-2:]
+ outputs = []
+
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ x = Image.fromarray(x)
+
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ save_videos_from_pil(outputs, path, fps)
+
+
+def read_frames(video_path):
+ container = av.open(video_path)
+
+ video_stream = next(s for s in container.streams if s.type == "video")
+ frames = []
+ for packet in container.demux(video_stream):
+ for frame in packet.decode():
+ image = Image.frombytes(
+ "RGB",
+ (frame.width, frame.height),
+ frame.to_rgb().to_ndarray(),
+ )
+ frames.append(image)
+
+ return frames
+
+
+def get_fps(video_path):
+ container = av.open(video_path)
+ video_stream = next(s for s in container.streams if s.type == "video")
+ fps = video_stream.average_rate
+ container.close()
+ return fps
+
+def crop_face(img, lmk_extractor, expand=1.5):
+ result = lmk_extractor(img) # cv2 BGR
+
+ if result is None:
+ return None
+
+ H, W, _ = img.shape
+ lmks = result['lmks']
+ lmks[:, 0] *= W
+ lmks[:, 1] *= H
+
+ x_min = np.min(lmks[:, 0])
+ x_max = np.max(lmks[:, 0])
+ y_min = np.min(lmks[:, 1])
+ y_max = np.max(lmks[:, 1])
+
+ width = x_max - x_min
+ height = y_max - y_min
+
+ if width*height >= W*H*0.15:
+ if W == H:
+ return img
+ size = min(H, W)
+ offset = int((max(H, W) - size)/2)
+ if size == H:
+ return img[:, offset:-offset]
+ else:
+ return img[offset:-offset, :]
+ else:
+ center_x = x_min + width / 2
+ center_y = y_min + height / 2
+
+ width *= expand
+ height *= expand
+
+ size = max(width, height)
+
+ x_min = int(center_x - size / 2)
+ x_max = int(center_x + size / 2)
+ y_min = int(center_y - size / 2)
+ y_max = int(center_y + size / 2)
+
+ top = max(0, -y_min)
+ bottom = max(0, y_max - img.shape[0])
+ left = max(0, -x_min)
+ right = max(0, x_max - img.shape[1])
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
+
+ cropped_img = img[y_min + top:y_max + top, x_min + left:x_max + left]
+
+ return cropped_img
\ No newline at end of file
diff --git a/ckpt_tree.md b/ckpt_tree.md
new file mode 100644
index 0000000000000000000000000000000000000000..5018c23169e4fc414546ee8b439fa8696a5b61fc
--- /dev/null
+++ b/ckpt_tree.md
@@ -0,0 +1,108 @@
+
+```
+|-- ckpts
+| |-- aniportrait
+| | `-- motion_module.pth
+| | `-- audio2mesh.pt
+| | `-- film_net_fp16.pt
+| | |-- sd-vae-ft-mse
+| | | `-- diffusion_pytorch_model.safetensors
+| | | `-- config.json
+| | | `-- diffusion_pytorch_model.bin
+| | `-- denoising_unet.pth
+| | `-- audio2pose.pt
+| | `-- pose_guider.pth
+| | |-- sd-image-variations-diffusers
+| | | `-- v1-montage.jpg
+| | | |-- scheduler
+| | | | `-- scheduler_config.json
+| | | `-- README.md
+| | | `-- model_index.json
+| | | |-- unet
+| | | | `-- config.json
+| | | | `-- diffusion_pytorch_model.bin
+| | | |-- feature_extractor
+| | | | `-- preprocessor_config.json
+| | | `-- v2-montage.jpg
+| | | |-- vae
+| | | | `-- config.json
+| | | | `-- diffusion_pytorch_model.bin
+| | | `-- alias-montage.jpg
+| | | `-- inputs.jpg
+| | | |-- safety_checker
+| | | | `-- pytorch_model.bin
+| | | | `-- config.json
+| | | `-- earring.jpg
+| | | `-- default-montage.jpg
+| | |-- image_encoder
+| | | `-- pytorch_model.bin
+| | | `-- config.json
+| | |-- stable-diffusion-v1-5
+| | | `-- model_index.json
+| | | `-- v1-inference.yaml
+| | | |-- unet
+| | | | `-- config.json
+| | | | `-- diffusion_pytorch_model.bin
+| | | |-- feature_extractor
+| | | | `-- preprocessor_config.json
+| | `-- reference_unet.pth
+| | |-- wav2vec2-base-960h
+| | | `-- pytorch_model.bin
+| | | `-- README.md
+| | | `-- vocab.json
+| | | `-- config.json
+| | | `-- tf_model.h5
+| | | `-- tokenizer_config.json
+| | | `-- model.safetensors
+| | | `-- special_tokens_map.json
+| | | `-- preprocessor_config.json
+| | | `-- feature_extractor_config.json
+| |-- mofa
+| | |-- traj_controlnet
+| | | `-- diffusion_pytorch_model.safetensors
+| | | `-- config.json
+| | |-- stable-video-diffusion-img2vid-xt-1-1
+| | | |-- scheduler
+| | | | `-- scheduler_config.json
+| | | `-- README.md
+| | | `-- model_index.json
+| | | |-- unet
+| | | | `-- diffusion_pytorch_model.fp16.safetensors
+| | | | `-- config.json
+| | | |-- feature_extractor
+| | | | `-- preprocessor_config.json
+| | | |-- vae
+| | | | `-- diffusion_pytorch_model.fp16.safetensors
+| | | | `-- config.json
+| | | `-- LICENSE
+| | | `-- svd11.webp
+| | | |-- image_encoder
+| | | | `-- config.json
+| | | | `-- model.fp16.safetensors
+| | |-- ldmk_controlnet
+| | | `-- diffusion_pytorch_model.safetensors
+| | | `-- config.json
+| |-- sad_talker
+| | `-- SadTalker_V0.0.2_256.safetensors
+| | |-- hub
+| | `-- mapping_00229-model.pth.tar
+| | |-- BFM_Fitting
+| | | `-- select_vertex_id.mat
+| | | `-- facemodel_info.mat
+| | | `-- BFM_exp_idx.mat
+| | | `-- BFM_model_front.mat
+| | | `-- 01_MorphableModel.mat
+| | | `-- similarity_Lm3D_all.mat
+| | | `-- BFM_front_idx.mat
+| | | `-- Exp_Pca.bin
+| | | `-- std_exp.txt
+| | `-- SadTalker_V0.0.2_512.safetensors
+| | `-- similarity_Lm3D_all.mat
+| | `-- epoch_00190_iteration_000400000_checkpoint.pt
+| | `-- mapping_00109-model.pth.tar
+| |-- gfpgan
+| | `-- alignment_WFLW_4HG.pth
+| | `-- parsing_parsenet.pth
+| | `-- detection_Resnet50_Final.pth
+
+```
\ No newline at end of file
diff --git a/ckpts/.DS_Store b/ckpts/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..a20c84b82d1c5a4a302648a3806452d21574ec83
Binary files /dev/null and b/ckpts/.DS_Store differ
diff --git a/ckpts/aniportrait/.DS_Store b/ckpts/aniportrait/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..cbbf868ddb99054a099a5e3c3c8206f3f3d5c242
Binary files /dev/null and b/ckpts/aniportrait/.DS_Store differ
diff --git a/ckpts/aniportrait/audio2mesh.pt b/ckpts/aniportrait/audio2mesh.pt
new file mode 100644
index 0000000000000000000000000000000000000000..33ee1a9327b42015e1fc9c59160714c6fa7ec53e
--- /dev/null
+++ b/ckpts/aniportrait/audio2mesh.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04996bebdad780a33642b0046036dae5d3c6db76e8f4ef5860e551fb9a1f0a1a
+size 382031763
diff --git a/ckpts/aniportrait/audio2pose.pt b/ckpts/aniportrait/audio2pose.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5b48c188e92df76f04a08e7e8ff405f182c5d286
--- /dev/null
+++ b/ckpts/aniportrait/audio2pose.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e61ffd104f8d1fe40476de1a8df9050559315976c830e5fdead1c31d1c5661f3
+size 481586148
diff --git a/ckpts/aniportrait/denoising_unet.pth b/ckpts/aniportrait/denoising_unet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..306233ea507d0193ea35013413146c5981ca8fe5
--- /dev/null
+++ b/ckpts/aniportrait/denoising_unet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddc4990e0d33dd5393190e8609fd7b32bfc0b5c386763624a3bff8038e0c054c
+size 3438374981
diff --git a/ckpts/aniportrait/film_net_fp16.pt b/ckpts/aniportrait/film_net_fp16.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6ce95f925363b2c2b8bf6d9de92fb4ed5fa0b540
--- /dev/null
+++ b/ckpts/aniportrait/film_net_fp16.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0c8314674a6ba97787584fb04d59df9c6051ad5b735c89704f60801eece34d1
+size 69032330
diff --git a/ckpts/aniportrait/image_encoder/config.json b/ckpts/aniportrait/image_encoder/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..251e37d8a59724357a8887da1716fad7b791b9c0
--- /dev/null
+++ b/ckpts/aniportrait/image_encoder/config.json
@@ -0,0 +1,23 @@
+{
+ "_name_or_path": "/home/jpinkney/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/ca6f97f838ae1b5bf764f31363a21f388f4d8f3e/image_encoder",
+ "architectures": [
+ "CLIPVisionModelWithProjection"
+ ],
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768,
+ "torch_dtype": "float32",
+ "transformers_version": "4.25.1"
+}
diff --git a/ckpts/aniportrait/image_encoder/pytorch_model.bin b/ckpts/aniportrait/image_encoder/pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..167893f2790c143ffda7de008d70cf000136ceed
--- /dev/null
+++ b/ckpts/aniportrait/image_encoder/pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89d2aa29b5fdf64f3ad4f45fb4227ea98bc45156bbae673b85be1af7783dbabb
+size 1215993967
diff --git a/ckpts/aniportrait/motion_module.pth b/ckpts/aniportrait/motion_module.pth
new file mode 100644
index 0000000000000000000000000000000000000000..ca85f934edd01d6aed46356bf9e643f85bbb4146
--- /dev/null
+++ b/ckpts/aniportrait/motion_module.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:954aee616a81f143e0316d210445d1933cca05a397c760661b7046738c4c1f06
+size 1817900817
diff --git a/ckpts/aniportrait/pose_guider.pth b/ckpts/aniportrait/pose_guider.pth
new file mode 100644
index 0000000000000000000000000000000000000000..cfebaa337d6e835bce672d7a40aa52552f425eb7
--- /dev/null
+++ b/ckpts/aniportrait/pose_guider.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b6e8a6c72efcbb4ed49421baeb2c218b6047e94f5f0c90b554748019f757e64f
+size 670182863
diff --git a/ckpts/aniportrait/reference_unet.pth b/ckpts/aniportrait/reference_unet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3f500bc04cd16dae0213002112577741371ab97f
--- /dev/null
+++ b/ckpts/aniportrait/reference_unet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e3603f7d44917ca3330a4a4be81b9578b308e9b9e3398fd6b8a1c4c86c474bd
+size 3438324501
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/README.md b/ckpts/aniportrait/sd-image-variations-diffusers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..285325aee404f7c98122462b54f2ca3d53da9feb
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/README.md
@@ -0,0 +1,226 @@
+---
+thumbnail: "https://repository-images.githubusercontent.com/523487884/fdb03a69-8353-4387-b5fc-0d85f888a63f"
+datasets:
+- ChristophSchuhmann/improved_aesthetics_6plus
+license: creativeml-openrail-m
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- image-to-image
+---
+
+# Stable Diffusion Image Variations Model Card
+
+📣 V2 model released, and blurriness issues fixed! 📣
+
+🧨🎉 Image Variations is now natively supported in 🤗 Diffusers! 🎉🧨
+
+![](https://raw.githubusercontent.com/justinpinkney/stable-diffusion/main/assets/im-vars-thin.jpg)
+
+## Version 2
+
+This version of Stable Diffusion has been fine tuned from [CompVis/stable-diffusion-v1-4-original](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) to accept CLIP image embedding rather than text embeddings. This allows the creation of "image variations" similar to DALLE-2 using Stable Diffusion. This version of the weights has been ported to huggingface Diffusers, to use this with the Diffusers library requires the [Lambda Diffusers repo](https://github.com/LambdaLabsML/lambda-diffusers).
+
+This model was trained in two stages and longer than the original variations model and gives better image quality and better CLIP rated similarity compared to the original version
+
+See training details and v1 vs v2 comparison below.
+
+
+## Example
+
+Make sure you are using a version of Diffusers >=0.8.0 (for older version see the old instructions at the bottom of this model card)
+
+```python
+from diffusers import StableDiffusionImageVariationPipeline
+from PIL import Image
+
+device = "cuda:0"
+sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
+ "lambdalabs/sd-image-variations-diffusers",
+ revision="v2.0",
+ )
+sd_pipe = sd_pipe.to(device)
+
+im = Image.open("path/to/image.jpg")
+tform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Resize(
+ (224, 224),
+ interpolation=transforms.InterpolationMode.BICUBIC,
+ antialias=False,
+ ),
+ transforms.Normalize(
+ [0.48145466, 0.4578275, 0.40821073],
+ [0.26862954, 0.26130258, 0.27577711]),
+])
+inp = tform(im).to(device).unsqueeze(0)
+
+out = sd_pipe(inp, guidance_scale=3)
+out["images"][0].save("result.jpg")
+```
+
+### The importance of resizing correctly... (or not)
+
+Note that due a bit of an oversight during training, the model expects resized images without anti-aliasing. This turns out to make a big difference and is important to do the resizing the same way during inference. When passing a PIL image to the Diffusers pipeline antialiasing will be applied during resize, so it's better to input a tensor which you have prepared manually according to the transfrom in the example above!
+
+Here are examples of images generated without (top) and with (bottom) anti-aliasing during resize. (Input is [this image](https://github.com/SHI-Labs/Versatile-Diffusion/blob/master/assets/ghibli.jpg))
+
+![](alias-montage.jpg)
+
+![](default-montage.jpg)
+
+### V1 vs V2
+
+Here's an example of V1 vs V2, version two was trained more carefully and for longer, see the details below. V2-top vs V1-bottom
+
+![](v2-montage.jpg)
+
+![](v1-montage.jpg)
+
+Input images:
+
+![](inputs.jpg)
+
+One important thing to note is that due to the longer training V2 appears to have memorised some common images from the training data, e.g. now the previous example of the Girl with a Pearl Earring almosts perfectly reproduce the original rather than creating variations. You can always use v1 by specifiying `revision="v1.0"`.
+
+v2 output for girl with a pearl earing as input (guidance scale=3)
+
+![](earring.jpg)
+
+# Training
+
+
+**Training Procedure**
+This model is fine tuned from Stable Diffusion v1-3 where the text encoder has been replaced with an image encoder. The training procedure is the same as for Stable Diffusion except for the fact that images are encoded through a ViT-L/14 image-encoder including the final projection layer to the CLIP shared embedding space. The model was trained on LAION improved aesthetics 6plus.
+
+- **Hardware:** 8 x A100-40GB GPUs (provided by [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
+- **Optimizer:** AdamW
+
+- **Stage 1** - Fine tune only CrossAttention layer weights from Stable Diffusion v1.4 model
+ - **Steps**: 46,000
+ - **Batch:** batch size=4, GPUs=8, Gradient Accumulations=4. Total batch size=128
+ - **Learning rate:** warmup to 1e-5 for 10,000 steps and then kept constant
+
+- **Stage 2** - Resume from Stage 1 training the whole unet
+ - **Steps**: 50,000
+ - **Batch:** batch size=4, GPUs=8, Gradient Accumulations=5. Total batch size=160
+ - **Learning rate:** warmup to 1e-5 for 5,000 steps and then kept constant
+
+
+Training was done using a [modified version of the original Stable Diffusion training code](https://github.com/justinpinkney/stable-diffusion).
+
+
+# Uses
+_The following section is adapted from the [Stable Diffusion model card](https://huggingface.co/CompVis/stable-diffusion-v1-4)_
+
+## Direct Use
+The model is intended for research purposes only. Possible research areas and
+tasks include
+
+- Safe deployment of models which have the potential to generate harmful content.
+- Probing and understanding the limitations and biases of generative models.
+- Generation of artworks and use in design and other artistic processes.
+- Applications in educational or creative tools.
+- Research on generative models.
+
+Excluded uses are described below.
+
+ ### Misuse, Malicious Use, and Out-of-Scope Use
+
+The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
+
+#### Out-of-Scope Use
+The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
+
+#### Misuse and Malicious Use
+Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
+
+- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
+- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
+- Impersonating individuals without their consent.
+- Sexual content without consent of the people who might see it.
+- Mis- and disinformation
+- Representations of egregious violence and gore
+- Sharing of copyrighted or licensed material in violation of its terms of use.
+- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
+
+## Limitations and Bias
+
+### Limitations
+
+- The model does not achieve perfect photorealism
+- The model cannot render legible text
+- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
+- Faces and people in general may not be generated properly.
+- The model was trained mainly with English captions and will not work as well in other languages.
+- The autoencoding part of the model is lossy
+- The model was trained on a large-scale dataset
+ [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
+ and is not fit for product use without additional safety mechanisms and
+ considerations.
+- No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
+ The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
+
+### Bias
+
+While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
+Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
+which consists of images that are primarily limited to English descriptions.
+Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
+This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
+ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
+
+### Safety Module
+
+The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
+This checker works by checking model outputs against known hard-coded NSFW concepts.
+The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
+Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPModel` *after generation* of the images.
+The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
+
+
+## Old instructions
+
+If you are using a diffusers version <0.8.0 there is no `StableDiffusionImageVariationPipeline`,
+in this case you need to use an older revision (`2ddbd90b14bc5892c19925b15185e561bc8e5d0a`) in conjunction with the lambda-diffusers repo:
+
+
+First clone [Lambda Diffusers](https://github.com/LambdaLabsML/lambda-diffusers) and install any requirements (in a virtual environment in the example below):
+
+```bash
+git clone https://github.com/LambdaLabsML/lambda-diffusers.git
+cd lambda-diffusers
+python -m venv .venv
+source .venv/bin/activate
+pip install -r requirements.txt
+```
+
+Then run the following python code:
+
+```python
+from pathlib import Path
+from lambda_diffusers import StableDiffusionImageEmbedPipeline
+from PIL import Image
+import torch
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+pipe = StableDiffusionImageEmbedPipeline.from_pretrained(
+"lambdalabs/sd-image-variations-diffusers",
+revision="2ddbd90b14bc5892c19925b15185e561bc8e5d0a",
+)
+pipe = pipe.to(device)
+
+im = Image.open("your/input/image/here.jpg")
+num_samples = 4
+image = pipe(num_samples*[im], guidance_scale=3.0)
+image = image["sample"]
+
+base_path = Path("outputs/im2im")
+base_path.mkdir(exist_ok=True, parents=True)
+for idx, im in enumerate(image):
+ im.save(base_path/f"{idx:06}.jpg")
+```
+
+
+
+*This model card was written by: Justin Pinkney and is based on the [Stable Diffusion model card](https://huggingface.co/CompVis/stable-diffusion-v1-4).*
\ No newline at end of file
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/alias-montage.jpg b/ckpts/aniportrait/sd-image-variations-diffusers/alias-montage.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dcd30e8cc5d8c7d8185dccdcd794fbcfee0ab9f9
Binary files /dev/null and b/ckpts/aniportrait/sd-image-variations-diffusers/alias-montage.jpg differ
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/default-montage.jpg b/ckpts/aniportrait/sd-image-variations-diffusers/default-montage.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7de51409118909289694002f64040985267f9d6b
Binary files /dev/null and b/ckpts/aniportrait/sd-image-variations-diffusers/default-montage.jpg differ
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/earring.jpg b/ckpts/aniportrait/sd-image-variations-diffusers/earring.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0484498780bc00dc1f55e16b3c4471a9e68b9e58
Binary files /dev/null and b/ckpts/aniportrait/sd-image-variations-diffusers/earring.jpg differ
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json b/ckpts/aniportrait/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0d9d33b883843d1b370da781f3943051067e1b2c
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json
@@ -0,0 +1,28 @@
+{
+ "crop_size": {
+ "height": 224,
+ "width": 224
+ },
+ "do_center_crop": true,
+ "do_convert_rgb": true,
+ "do_normalize": true,
+ "do_rescale": true,
+ "do_resize": true,
+ "feature_extractor_type": "CLIPFeatureExtractor",
+ "image_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "image_processor_type": "CLIPImageProcessor",
+ "image_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "resample": 3,
+ "rescale_factor": 0.00392156862745098,
+ "size": {
+ "shortest_edge": 224
+ }
+}
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/inputs.jpg b/ckpts/aniportrait/sd-image-variations-diffusers/inputs.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3aa3f2c9d4cc99c599fb90a7f46573172c571079
Binary files /dev/null and b/ckpts/aniportrait/sd-image-variations-diffusers/inputs.jpg differ
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/model_index.json b/ckpts/aniportrait/sd-image-variations-diffusers/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..be345ca18731424f472499e6bac22758435b13e0
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/model_index.json
@@ -0,0 +1,29 @@
+{
+ "_class_name": "StableDiffusionImageVariationPipeline",
+ "_diffusers_version": "0.9.0",
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "image_encoder": [
+ "transformers",
+ "CLIPVisionModelWithProjection"
+ ],
+ "requires_safety_checker": true,
+ "safety_checker": [
+ "stable_diffusion",
+ "StableDiffusionSafetyChecker"
+ ],
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/safety_checker/config.json b/ckpts/aniportrait/sd-image-variations-diffusers/safety_checker/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..d3d60c31e4b8019cb51f8874e7fe6aa0079123d4
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/safety_checker/config.json
@@ -0,0 +1,181 @@
+{
+ "_commit_hash": "ca6f97f838ae1b5bf764f31363a21f388f4d8f3e",
+ "_name_or_path": "/home/jpinkney/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/ca6f97f838ae1b5bf764f31363a21f388f4d8f3e/safety_checker",
+ "architectures": [
+ "StableDiffusionSafetyChecker"
+ ],
+ "initializer_factor": 1.0,
+ "logit_scale_init_value": 2.6592,
+ "model_type": "clip",
+ "projection_dim": 768,
+ "text_config": {
+ "_name_or_path": "",
+ "add_cross_attention": false,
+ "architectures": null,
+ "attention_dropout": 0.0,
+ "bad_words_ids": null,
+ "begin_suppress_tokens": null,
+ "bos_token_id": 0,
+ "chunk_size_feed_forward": 0,
+ "cross_attention_hidden_size": null,
+ "decoder_start_token_id": null,
+ "diversity_penalty": 0.0,
+ "do_sample": false,
+ "dropout": 0.0,
+ "early_stopping": false,
+ "encoder_no_repeat_ngram_size": 0,
+ "eos_token_id": 2,
+ "exponential_decay_length_penalty": null,
+ "finetuning_task": null,
+ "forced_bos_token_id": null,
+ "forced_eos_token_id": null,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 768,
+ "id2label": {
+ "0": "LABEL_0",
+ "1": "LABEL_1"
+ },
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "is_decoder": false,
+ "is_encoder_decoder": false,
+ "label2id": {
+ "LABEL_0": 0,
+ "LABEL_1": 1
+ },
+ "layer_norm_eps": 1e-05,
+ "length_penalty": 1.0,
+ "max_length": 20,
+ "max_position_embeddings": 77,
+ "min_length": 0,
+ "model_type": "clip_text_model",
+ "no_repeat_ngram_size": 0,
+ "num_attention_heads": 12,
+ "num_beam_groups": 1,
+ "num_beams": 1,
+ "num_hidden_layers": 12,
+ "num_return_sequences": 1,
+ "output_attentions": false,
+ "output_hidden_states": false,
+ "output_scores": false,
+ "pad_token_id": 1,
+ "prefix": null,
+ "problem_type": null,
+ "projection_dim": 512,
+ "pruned_heads": {},
+ "remove_invalid_values": false,
+ "repetition_penalty": 1.0,
+ "return_dict": true,
+ "return_dict_in_generate": false,
+ "sep_token_id": null,
+ "suppress_tokens": null,
+ "task_specific_params": null,
+ "temperature": 1.0,
+ "tf_legacy_loss": false,
+ "tie_encoder_decoder": false,
+ "tie_word_embeddings": true,
+ "tokenizer_class": null,
+ "top_k": 50,
+ "top_p": 1.0,
+ "torch_dtype": null,
+ "torchscript": false,
+ "transformers_version": "4.25.1",
+ "typical_p": 1.0,
+ "use_bfloat16": false,
+ "vocab_size": 49408
+ },
+ "text_config_dict": {
+ "hidden_size": 768,
+ "intermediate_size": 3072,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12
+ },
+ "torch_dtype": "float32",
+ "transformers_version": null,
+ "vision_config": {
+ "_name_or_path": "",
+ "add_cross_attention": false,
+ "architectures": null,
+ "attention_dropout": 0.0,
+ "bad_words_ids": null,
+ "begin_suppress_tokens": null,
+ "bos_token_id": null,
+ "chunk_size_feed_forward": 0,
+ "cross_attention_hidden_size": null,
+ "decoder_start_token_id": null,
+ "diversity_penalty": 0.0,
+ "do_sample": false,
+ "dropout": 0.0,
+ "early_stopping": false,
+ "encoder_no_repeat_ngram_size": 0,
+ "eos_token_id": null,
+ "exponential_decay_length_penalty": null,
+ "finetuning_task": null,
+ "forced_bos_token_id": null,
+ "forced_eos_token_id": null,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "id2label": {
+ "0": "LABEL_0",
+ "1": "LABEL_1"
+ },
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "is_decoder": false,
+ "is_encoder_decoder": false,
+ "label2id": {
+ "LABEL_0": 0,
+ "LABEL_1": 1
+ },
+ "layer_norm_eps": 1e-05,
+ "length_penalty": 1.0,
+ "max_length": 20,
+ "min_length": 0,
+ "model_type": "clip_vision_model",
+ "no_repeat_ngram_size": 0,
+ "num_attention_heads": 16,
+ "num_beam_groups": 1,
+ "num_beams": 1,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "num_return_sequences": 1,
+ "output_attentions": false,
+ "output_hidden_states": false,
+ "output_scores": false,
+ "pad_token_id": null,
+ "patch_size": 14,
+ "prefix": null,
+ "problem_type": null,
+ "projection_dim": 512,
+ "pruned_heads": {},
+ "remove_invalid_values": false,
+ "repetition_penalty": 1.0,
+ "return_dict": true,
+ "return_dict_in_generate": false,
+ "sep_token_id": null,
+ "suppress_tokens": null,
+ "task_specific_params": null,
+ "temperature": 1.0,
+ "tf_legacy_loss": false,
+ "tie_encoder_decoder": false,
+ "tie_word_embeddings": true,
+ "tokenizer_class": null,
+ "top_k": 50,
+ "top_p": 1.0,
+ "torch_dtype": null,
+ "torchscript": false,
+ "transformers_version": "4.25.1",
+ "typical_p": 1.0,
+ "use_bfloat16": false
+ },
+ "vision_config_dict": {
+ "hidden_size": 1024,
+ "intermediate_size": 4096,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 24,
+ "patch_size": 14
+ }
+}
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/safety_checker/pytorch_model.bin b/ckpts/aniportrait/sd-image-variations-diffusers/safety_checker/pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..faa7c232cc5d9bce74eb4889572f683a773e96bf
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/safety_checker/pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974
+size 1216061799
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/scheduler/scheduler_config.json b/ckpts/aniportrait/sd-image-variations-diffusers/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..42016462bfec03db9be66539bc0814df83d48db0
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/scheduler/scheduler_config.json
@@ -0,0 +1,13 @@
+{
+ "_class_name": "PNDMScheduler",
+ "_diffusers_version": "0.9.0",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "num_train_timesteps": 1000,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "trained_betas": null
+}
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/unet/config.json b/ckpts/aniportrait/sd-image-variations-diffusers/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..22c3fb762f46de16b6123feb536df65d673f3588
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/unet/config.json
@@ -0,0 +1,40 @@
+{
+ "_class_name": "UNet2DConditionModel",
+ "_diffusers_version": "0.9.0",
+ "act_fn": "silu",
+ "attention_head_dim": 8,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "center_input_sample": false,
+ "cross_attention_dim": 768,
+ "down_block_types": [
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "dual_cross_attention": false,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_class_embeds": null,
+ "only_cross_attention": false,
+ "out_channels": 4,
+ "sample_size": 64,
+ "up_block_types": [
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D"
+ ],
+ "use_linear_projection": false
+}
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin b/ckpts/aniportrait/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..e803f2b7a1e2cb72216971e779f3baec4c267a63
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee23e3368e4e7c0e4ef636ed61923609c97fcaa583f8bb416e3e0986d4a0cfc6
+size 3438354725
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/v1-montage.jpg b/ckpts/aniportrait/sd-image-variations-diffusers/v1-montage.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..709f40a2ec6c3463f60f764bd6b53f0510a46d1b
Binary files /dev/null and b/ckpts/aniportrait/sd-image-variations-diffusers/v1-montage.jpg differ
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/v2-montage.jpg b/ckpts/aniportrait/sd-image-variations-diffusers/v2-montage.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a6155b28eac735c650dac4722e4f582b4866d3fd
Binary files /dev/null and b/ckpts/aniportrait/sd-image-variations-diffusers/v2-montage.jpg differ
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/vae/config.json b/ckpts/aniportrait/sd-image-variations-diffusers/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..400385391aecb3543364ccc925627866b151e10e
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/vae/config.json
@@ -0,0 +1,30 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.9.0",
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ]
+}
diff --git a/ckpts/aniportrait/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin b/ckpts/aniportrait/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..ba36f34d64ad3be997b7cab94b0b9acd61272851
--- /dev/null
+++ b/ckpts/aniportrait/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc
+size 334707217
diff --git a/ckpts/aniportrait/sd-vae-ft-mse/config.json b/ckpts/aniportrait/sd-vae-ft-mse/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0db26717579be63eb0ddbf15b43faa43700dfe5a
--- /dev/null
+++ b/ckpts/aniportrait/sd-vae-ft-mse/config.json
@@ -0,0 +1,29 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.4.2",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ]
+}
diff --git a/ckpts/aniportrait/sd-vae-ft-mse/diffusion_pytorch_model.bin b/ckpts/aniportrait/sd-vae-ft-mse/diffusion_pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..ba36f34d64ad3be997b7cab94b0b9acd61272851
--- /dev/null
+++ b/ckpts/aniportrait/sd-vae-ft-mse/diffusion_pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc
+size 334707217
diff --git a/ckpts/aniportrait/sd-vae-ft-mse/diffusion_pytorch_model.safetensors b/ckpts/aniportrait/sd-vae-ft-mse/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..90464d67ac7303d0ee4696334df13da130a948ea
--- /dev/null
+++ b/ckpts/aniportrait/sd-vae-ft-mse/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1d993488569e928462932c8c38a0760b874d166399b14414135bd9c42df5815
+size 334643276
diff --git a/ckpts/aniportrait/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json b/ckpts/aniportrait/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5294955ff7801083f720b34b55d0f1f51313c5c5
--- /dev/null
+++ b/ckpts/aniportrait/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json
@@ -0,0 +1,20 @@
+{
+ "crop_size": 224,
+ "do_center_crop": true,
+ "do_convert_rgb": true,
+ "do_normalize": true,
+ "do_resize": true,
+ "feature_extractor_type": "CLIPFeatureExtractor",
+ "image_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "image_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "resample": 3,
+ "size": 224
+}
diff --git a/ckpts/aniportrait/stable-diffusion-v1-5/model_index.json b/ckpts/aniportrait/stable-diffusion-v1-5/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..daf7e2e2dfc64fb437a2b44525667111b00cb9fc
--- /dev/null
+++ b/ckpts/aniportrait/stable-diffusion-v1-5/model_index.json
@@ -0,0 +1,32 @@
+{
+ "_class_name": "StableDiffusionPipeline",
+ "_diffusers_version": "0.6.0",
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "safety_checker": [
+ "stable_diffusion",
+ "StableDiffusionSafetyChecker"
+ ],
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ "text_encoder": [
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "tokenizer": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/ckpts/aniportrait/stable-diffusion-v1-5/unet/config.json b/ckpts/aniportrait/stable-diffusion-v1-5/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..1a02ee8abc93e840ffbcb2d68b66ccbcb74b3ab3
--- /dev/null
+++ b/ckpts/aniportrait/stable-diffusion-v1-5/unet/config.json
@@ -0,0 +1,36 @@
+{
+ "_class_name": "UNet2DConditionModel",
+ "_diffusers_version": "0.6.0",
+ "act_fn": "silu",
+ "attention_head_dim": 8,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "center_input_sample": false,
+ "cross_attention_dim": 768,
+ "down_block_types": [
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "out_channels": 4,
+ "sample_size": 64,
+ "up_block_types": [
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D"
+ ]
+}
diff --git a/ckpts/aniportrait/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin b/ckpts/aniportrait/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..f1ffb48de7efbabc851a260efde560d49621a9bc
--- /dev/null
+++ b/ckpts/aniportrait/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4
+size 3438354725
diff --git a/ckpts/aniportrait/stable-diffusion-v1-5/v1-inference.yaml b/ckpts/aniportrait/stable-diffusion-v1-5/v1-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4effe569e897369918625f9d8be5603a0e6a0d6
--- /dev/null
+++ b/ckpts/aniportrait/stable-diffusion-v1-5/v1-inference.yaml
@@ -0,0 +1,70 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/README.md b/ckpts/aniportrait/wav2vec2-base-960h/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c7fe2047d7ac9b9816848c657b2a492ee95b264b
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/README.md
@@ -0,0 +1,128 @@
+---
+language: en
+datasets:
+- librispeech_asr
+tags:
+- audio
+- automatic-speech-recognition
+- hf-asr-leaderboard
+license: apache-2.0
+widget:
+- example_title: Librispeech sample 1
+ src: https://cdn-media.huggingface.co/speech_samples/sample1.flac
+- example_title: Librispeech sample 2
+ src: https://cdn-media.huggingface.co/speech_samples/sample2.flac
+model-index:
+- name: wav2vec2-base-960h
+ results:
+ - task:
+ name: Automatic Speech Recognition
+ type: automatic-speech-recognition
+ dataset:
+ name: LibriSpeech (clean)
+ type: librispeech_asr
+ config: clean
+ split: test
+ args:
+ language: en
+ metrics:
+ - name: Test WER
+ type: wer
+ value: 3.4
+ - task:
+ name: Automatic Speech Recognition
+ type: automatic-speech-recognition
+ dataset:
+ name: LibriSpeech (other)
+ type: librispeech_asr
+ config: other
+ split: test
+ args:
+ language: en
+ metrics:
+ - name: Test WER
+ type: wer
+ value: 8.6
+---
+
+# Wav2Vec2-Base-960h
+
+[Facebook's Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
+
+The base model pretrained and fine-tuned on 960 hours of Librispeech on 16kHz sampled speech audio. When using the model
+make sure that your speech input is also sampled at 16Khz.
+
+[Paper](https://arxiv.org/abs/2006.11477)
+
+Authors: Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
+
+**Abstract**
+
+We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks the speech input in the latent space and solves a contrastive task defined over a quantization of the latent representations which are jointly learned. Experiments using all labeled data of Librispeech achieve 1.8/3.3 WER on the clean/other test sets. When lowering the amount of labeled data to one hour, wav2vec 2.0 outperforms the previous state of the art on the 100 hour subset while using 100 times less labeled data. Using just ten minutes of labeled data and pre-training on 53k hours of unlabeled data still achieves 4.8/8.2 WER. This demonstrates the feasibility of speech recognition with limited amounts of labeled data.
+
+The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
+
+
+# Usage
+
+To transcribe audio files the model can be used as a standalone acoustic model as follows:
+
+```python
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
+ from datasets import load_dataset
+ import torch
+
+ # load model and tokenizer
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
+
+ # load dummy dataset and read soundfiles
+ ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
+
+ # tokenize
+ input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values # Batch size 1
+
+ # retrieve logits
+ logits = model(input_values).logits
+
+ # take argmax and decode
+ predicted_ids = torch.argmax(logits, dim=-1)
+ transcription = processor.batch_decode(predicted_ids)
+ ```
+
+ ## Evaluation
+
+ This code snippet shows how to evaluate **facebook/wav2vec2-base-960h** on LibriSpeech's "clean" and "other" test data.
+
+```python
+from datasets import load_dataset
+from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
+import torch
+from jiwer import wer
+
+
+librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
+
+model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
+processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
+
+def map_to_pred(batch):
+ input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
+ with torch.no_grad():
+ logits = model(input_values.to("cuda")).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ transcription = processor.batch_decode(predicted_ids)
+ batch["transcription"] = transcription
+ return batch
+
+result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
+
+print("WER:", wer(result["text"], result["transcription"]))
+```
+
+*Result (WER)*:
+
+| "clean" | "other" |
+|---|---|
+| 3.4 | 8.6 |
\ No newline at end of file
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/config.json b/ckpts/aniportrait/wav2vec2-base-960h/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..8ca9cc7496e145e37d09cec17d0c3bf9b8523c8e
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/config.json
@@ -0,0 +1,77 @@
+{
+ "_name_or_path": "facebook/wav2vec2-base-960h",
+ "activation_dropout": 0.1,
+ "apply_spec_augment": true,
+ "architectures": [
+ "Wav2Vec2ForCTC"
+ ],
+ "attention_dropout": 0.1,
+ "bos_token_id": 1,
+ "codevector_dim": 256,
+ "contrastive_logits_temperature": 0.1,
+ "conv_bias": false,
+ "conv_dim": [
+ 512,
+ 512,
+ 512,
+ 512,
+ 512,
+ 512,
+ 512
+ ],
+ "conv_kernel": [
+ 10,
+ 3,
+ 3,
+ 3,
+ 3,
+ 2,
+ 2
+ ],
+ "conv_stride": [
+ 5,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2
+ ],
+ "ctc_loss_reduction": "sum",
+ "ctc_zero_infinity": false,
+ "diversity_loss_weight": 0.1,
+ "do_stable_layer_norm": false,
+ "eos_token_id": 2,
+ "feat_extract_activation": "gelu",
+ "feat_extract_dropout": 0.0,
+ "feat_extract_norm": "group",
+ "feat_proj_dropout": 0.1,
+ "feat_quantizer_dropout": 0.0,
+ "final_dropout": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout": 0.1,
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-05,
+ "layerdrop": 0.1,
+ "mask_feature_length": 10,
+ "mask_feature_prob": 0.0,
+ "mask_time_length": 10,
+ "mask_time_prob": 0.05,
+ "model_type": "wav2vec2",
+ "num_attention_heads": 12,
+ "num_codevector_groups": 2,
+ "num_codevectors_per_group": 320,
+ "num_conv_pos_embedding_groups": 16,
+ "num_conv_pos_embeddings": 128,
+ "num_feat_extract_layers": 7,
+ "num_hidden_layers": 12,
+ "num_negatives": 100,
+ "pad_token_id": 0,
+ "proj_codevector_dim": 256,
+ "transformers_version": "4.7.0.dev0",
+ "vocab_size": 32
+}
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/feature_extractor_config.json b/ckpts/aniportrait/wav2vec2-base-960h/feature_extractor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..52fdd74dc06f40033506e402269fbde5e7adc21d
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/feature_extractor_config.json
@@ -0,0 +1,8 @@
+{
+ "do_normalize": true,
+ "feature_dim": 1,
+ "padding_side": "right",
+ "padding_value": 0.0,
+ "return_attention_mask": false,
+ "sampling_rate": 16000
+}
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/model.safetensors b/ckpts/aniportrait/wav2vec2-base-960h/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..942562678fb28df86c055027c18216fa2a7cb5dd
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8aa76ab2243c81747a1f832954586bc566090c83a0ac167df6f31f0fa917d74a
+size 377607901
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/preprocessor_config.json b/ckpts/aniportrait/wav2vec2-base-960h/preprocessor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3f24dc078fcba55ee1d417a413847ead40c093a3
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/preprocessor_config.json
@@ -0,0 +1,8 @@
+{
+ "do_normalize": true,
+ "feature_size": 1,
+ "padding_side": "right",
+ "padding_value": 0.0,
+ "return_attention_mask": false,
+ "sampling_rate": 16000
+}
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/pytorch_model.bin b/ckpts/aniportrait/wav2vec2-base-960h/pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..d630db45384aa007f54a9a1b37da83c5a208f4cf
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c34f9827b034a1b9141dbf6f652f8a60eda61cdf5771c9e05bfa99033c92cd96
+size 377667514
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/special_tokens_map.json b/ckpts/aniportrait/wav2vec2-base-960h/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..25bc39604f72700b3b8e10bd69bb2f227157edd1
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/special_tokens_map.json
@@ -0,0 +1 @@
+{"bos_token": "", "eos_token": "", "unk_token": "", "pad_token": ""}
\ No newline at end of file
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/tf_model.h5 b/ckpts/aniportrait/wav2vec2-base-960h/tf_model.h5
new file mode 100644
index 0000000000000000000000000000000000000000..e6d1d69dc1ac70461fd5754f00e9d1d9626bd400
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/tf_model.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:412742825972a6e2e877255ccd8b3416e618df15a7f1e5e4f736aa3632ce33b5
+size 377840624
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/tokenizer_config.json b/ckpts/aniportrait/wav2vec2-base-960h/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..978a15a96dbb2d23e2afbc70137cae6c5ce38c8d
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/tokenizer_config.json
@@ -0,0 +1 @@
+{"unk_token": "", "bos_token": "", "eos_token": "", "pad_token": "", "do_lower_case": false, "return_attention_mask": false, "do_normalize": true}
\ No newline at end of file
diff --git a/ckpts/aniportrait/wav2vec2-base-960h/vocab.json b/ckpts/aniportrait/wav2vec2-base-960h/vocab.json
new file mode 100644
index 0000000000000000000000000000000000000000..88181b954aa14df68be9b444b3c36585f3078c0a
--- /dev/null
+++ b/ckpts/aniportrait/wav2vec2-base-960h/vocab.json
@@ -0,0 +1 @@
+{"": 0, "": 1, "": 2, "": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}
\ No newline at end of file
diff --git a/ckpts/gfpgan/alignment_WFLW_4HG.pth b/ckpts/gfpgan/alignment_WFLW_4HG.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3cfeef20123eb2e74b35a4319c2111ef65783c34
--- /dev/null
+++ b/ckpts/gfpgan/alignment_WFLW_4HG.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbfd137307a4c7debd5c283b9b0ce539466cee417ac0a155e184d857f9f2899c
+size 193670248
diff --git a/ckpts/gfpgan/detection_Resnet50_Final.pth b/ckpts/gfpgan/detection_Resnet50_Final.pth
new file mode 100644
index 0000000000000000000000000000000000000000..16546738ce0a00a9fd47585e0fc52744d31cc117
--- /dev/null
+++ b/ckpts/gfpgan/detection_Resnet50_Final.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
+size 109497761
diff --git a/ckpts/gfpgan/parsing_parsenet.pth b/ckpts/gfpgan/parsing_parsenet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1ac2efc50360a79c9905dbac57d9d99cbfbe863c
--- /dev/null
+++ b/ckpts/gfpgan/parsing_parsenet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
+size 85331193
diff --git a/ckpts/mofa/.DS_Store b/ckpts/mofa/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..6ce9da2b97cc1829b1208320f7a2587778c6a792
Binary files /dev/null and b/ckpts/mofa/.DS_Store differ
diff --git a/ckpts/mofa/ldmk_controlnet/config.json b/ckpts/mofa/ldmk_controlnet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a38d0e54404b9e6419452bd30763b5a7214b63c0
--- /dev/null
+++ b/ckpts/mofa/ldmk_controlnet/config.json
@@ -0,0 +1,45 @@
+{
+ "_class_name": "FlowControlNet",
+ "_diffusers_version": "0.25.1",
+ "_name_or_path": "/apdcephfs_cq10/share_1290939/myniu/mofa_world_log/train_svdxt_forward_ctrlnet_occlusion_fixcmp_face_hdtfocclusion_8gpu_vfhq_preldmk_512_14_slim_vfhqtune/checkpoint-100000/controlnet",
+ "addition_time_embed_dim": 256,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "conditioning_channels": 3,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "cross_attention_dim": 1024,
+ "down_block_types": [
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal"
+ ],
+ "in_channels": 8,
+ "layers_per_block": 2,
+ "num_attention_heads": [
+ 5,
+ 10,
+ 10,
+ 20
+ ],
+ "num_frames": 25,
+ "out_channels": 4,
+ "projection_class_embeddings_input_dim": 768,
+ "sample_size": null,
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal"
+ ]
+}
diff --git a/ckpts/mofa/ldmk_controlnet/diffusion_pytorch_model.safetensors b/ckpts/mofa/ldmk_controlnet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..e66bc386cab44912fce1a7c736309e5b0b1065fe
--- /dev/null
+++ b/ckpts/mofa/ldmk_controlnet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:652909df9e6dede5b2bb415cd7863d6a804e04a29af80151af4c580111e1a71b
+size 2889913940
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/LICENSE b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5775e3e1d74fbb92d51edf7929ca73d5c70718f8
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/LICENSE
@@ -0,0 +1,42 @@
+STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
+Dated: February 2, 2024
+
+By clicking “I Accept” below or by using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software Products or Derivative Works through this License, and you must immediately cease using the Software Products or Derivative Works. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products or Derivative Works on behalf of your employer or other entity.
+
+"Agreement" means this Stable Non-Commercial Research Community License Agreement.
+
+“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
+
+"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
+
+“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
+
+"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
+
+“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
+
+“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
+
+"Stability AI" or "we" means Stability AI Ltd. and its affiliates.
+
+
+"Software" means Stability AI’s proprietary software made available under this Agreement.
+
+“Software Products” means the Models, Software and Documentation, individually or in any combination.
+
+
+
+1. License Rights and Redistribution.
+a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
+b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
+c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
+2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
+3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
+4. Intellectual Property.
+a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
+b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
+c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
+5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
+
+6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
+principles.
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/README.md b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e18fc2ae239239d609004468fe14507f997bbece
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/README.md
@@ -0,0 +1,128 @@
+---
+pipeline_tag: image-to-video
+license: other
+license_name: stable-video-diffusion-1-1-nc-community
+license_link: LICENSE
+extra_gated_prompt: >-
+ STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
+ Dated: February 2, 2024
+
+ By clicking “I Accept” below or by using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software Products or Derivative Works through this License, and you must immediately cease using the Software Products or Derivative Works. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products or Derivative Works on behalf of your employer or other entity.
+
+ "Agreement" means this Stable Non-Commercial Research Community License Agreement.
+
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
+
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
+
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
+
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
+
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
+
+ “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
+
+ "Stability AI" or "we" means Stability AI Ltd. and its affiliates.
+
+
+ "Software" means Stability AI’s proprietary software made available under this Agreement.
+
+ “Software Products” means the Models, Software and Documentation, individually or in any combination.
+
+
+
+ 1. License Rights and Redistribution.
+ a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
+ b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
+ c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
+ 4. Intellectual Property.
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
+ b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
+ c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
+ 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
+
+ 6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
+ principles.
+
+extra_gated_description: Stable Video Diffusion 1.1 License Agreement
+extra_gated_button_content: Submit
+extra_gated_fields:
+ Name: text
+ Company Name (if applicable): text
+ Email: text
+ Other Comments: text
+ By clicking here, you accept the License agreement, and will use the Software Products and Derivative Works for non-commercial or research purposes only: checkbox
+ By clicking here, you agree to sharing with Stability AI the information contained within this form and that Stability AI can contact you for the purposes of marketing our products and services: checkbox
+---
+
+# Stable Video Diffusion 1.1 Image-to-Video Model Card
+
+
+![row01](svd11.webp)
+Stable Video Diffusion (SVD) 1.1 Image-to-Video is a diffusion model that takes in a still image as a conditioning frame, and generates a video from it.
+
+## Model Details
+
+### Model Description
+
+(SVD 1.1) Image-to-Video is a latent diffusion model trained to generate short video clips from an image conditioning.
+
+This model was trained to generate 25 frames at resolution 1024x576 given a context frame of the same size, finetuned from [SVD Image-to-Video [25 frames]](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt).
+
+Fine tuning was performed with fixed conditioning at 6FPS and Motion Bucket Id 127 to improve the consistency of outputs without the need to adjust hyper parameters. These conditions are still adjustable and have not been removed. Performance outside of the fixed conditioning settings may vary compared to SVD 1.0.
+
+
+- **Developed by:** Stability AI
+- **Funded by:** Stability AI
+- **Model type:** Generative image-to-video model
+- **Finetuned from model:** SVD Image-to-Video [25 frames]
+
+### Model Sources
+
+For research purposes, we recommend our `generative-models` Github repository (https://github.com/Stability-AI/generative-models),
+which implements the most popular diffusion frameworks (both training and inference).
+
+- **Repository:** https://github.com/Stability-AI/generative-models
+- **Paper:** https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets
+
+## Uses
+
+### Direct Use
+
+The model is intended for research purposes only. Possible research areas and tasks include
+
+- Research on generative models.
+- Safe deployment of models which have the potential to generate harmful content.
+- Probing and understanding the limitations and biases of generative models.
+- Generation of artworks and use in design and other artistic processes.
+- Applications in educational or creative tools.
+
+Excluded uses are described below.
+
+### Out-of-Scope Use
+
+The model was not trained to be factual or true representations of people or events,
+and therefore using the model to generate such content is out-of-scope for the abilities of this model.
+The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
+
+## Limitations and Bias
+
+### Limitations
+- The generated videos are rather short (<= 4sec), and the model does not achieve perfect photorealism.
+- The model may generate videos without motion, or very slow camera pans.
+- The model cannot be controlled through text.
+- The model cannot render legible text.
+- Faces and people in general may not be generated properly.
+- The autoencoding part of the model is lossy.
+
+
+### Recommendations
+
+The model is intended for research purposes only.
+
+## How to Get Started with the Model
+
+Check out https://github.com/Stability-AI/generative-models
\ No newline at end of file
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/feature_extractor/preprocessor_config.json b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/feature_extractor/preprocessor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0d9d33b883843d1b370da781f3943051067e1b2c
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/feature_extractor/preprocessor_config.json
@@ -0,0 +1,28 @@
+{
+ "crop_size": {
+ "height": 224,
+ "width": 224
+ },
+ "do_center_crop": true,
+ "do_convert_rgb": true,
+ "do_normalize": true,
+ "do_rescale": true,
+ "do_resize": true,
+ "feature_extractor_type": "CLIPFeatureExtractor",
+ "image_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "image_processor_type": "CLIPImageProcessor",
+ "image_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "resample": 3,
+ "rescale_factor": 0.00392156862745098,
+ "size": {
+ "shortest_edge": 224
+ }
+}
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/image_encoder/config.json b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/image_encoder/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..1f5518ccff586593a40f4eaf0e75c066dca54bec
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/image_encoder/config.json
@@ -0,0 +1,23 @@
+{
+ "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/image_encoder",
+ "architectures": [
+ "CLIPVisionModelWithProjection"
+ ],
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "gelu",
+ "hidden_size": 1280,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 5120,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 32,
+ "patch_size": 14,
+ "projection_dim": 1024,
+ "torch_dtype": "float16",
+ "transformers_version": "4.34.0.dev0"
+}
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/image_encoder/model.fp16.safetensors b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/image_encoder/model.fp16.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..29848b356a3062d1da3b773a82d62de5e49a8f89
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/image_encoder/model.fp16.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae616c24393dd1854372b0639e5541666f7521cbe219669255e865cb7f89466a
+size 1264217240
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/model_index.json b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..814cc99f8674db1df84d0fff0d4e5535e745d328
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/model_index.json
@@ -0,0 +1,25 @@
+{
+ "_class_name": "StableVideoDiffusionPipeline",
+ "_diffusers_version": "0.24.0.dev0",
+ "_name_or_path": "diffusers/svd-xt",
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "image_encoder": [
+ "transformers",
+ "CLIPVisionModelWithProjection"
+ ],
+ "scheduler": [
+ "diffusers",
+ "EulerDiscreteScheduler"
+ ],
+ "unet": [
+ "diffusers",
+ "UNetSpatioTemporalConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKLTemporalDecoder"
+ ]
+}
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/scheduler/scheduler_config.json b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..05ea60ddb0d95607f9306e020e4ea355664d0275
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/scheduler/scheduler_config.json
@@ -0,0 +1,20 @@
+{
+ "_class_name": "EulerDiscreteScheduler",
+ "_diffusers_version": "0.24.0.dev0",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "interpolation_type": "linear",
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "set_alpha_to_one": false,
+ "sigma_max": 700.0,
+ "sigma_min": 0.002,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "timestep_spacing": "leading",
+ "timestep_type": "continuous",
+ "trained_betas": null,
+ "use_karras_sigmas": true
+}
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/unet/config.json b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..2a30c09f6764459c04d7dc10bf5b4bbf1e5ebc73
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/unet/config.json
@@ -0,0 +1,38 @@
+{
+ "_class_name": "UNetSpatioTemporalConditionModel",
+ "_diffusers_version": "0.24.0.dev0",
+ "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet",
+ "addition_time_embed_dim": 256,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "cross_attention_dim": 1024,
+ "down_block_types": [
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal"
+ ],
+ "in_channels": 8,
+ "layers_per_block": 2,
+ "num_attention_heads": [
+ 5,
+ 10,
+ 20,
+ 20
+ ],
+ "num_frames": 25,
+ "out_channels": 4,
+ "projection_class_embeddings_input_dim": 768,
+ "sample_size": 96,
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal"
+ ]
+}
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/unet/diffusion_pytorch_model.fp16.safetensors b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/unet/diffusion_pytorch_model.fp16.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..7772c11e0769adb8d700238beb80c0312d2e7abf
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/unet/diffusion_pytorch_model.fp16.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcc7032525a903a6c284ccb3e8cb6f9d87597b71130d8b9820925085e2215d17
+size 3049435836
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/vae/config.json b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..7c27c35b4e6ab0e705d46306f60d36839b680c03
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/vae/config.json
@@ -0,0 +1,24 @@
+{
+ "_class_name": "AutoencoderKLTemporalDecoder",
+ "_diffusers_version": "0.24.0.dev0",
+ "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/vae",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "out_channels": 3,
+ "sample_size": 768,
+ "scaling_factor": 0.18215
+}
diff --git a/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/vae/diffusion_pytorch_model.fp16.safetensors b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/vae/diffusion_pytorch_model.fp16.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..76c1b086f4ca2c1060a3ba7c8dea11d522a67584
--- /dev/null
+++ b/ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1/vae/diffusion_pytorch_model.fp16.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af602cd0eb4ad6086ec94fbf1438dfb1be5ec9ac03fd0215640854e90d6463a3
+size 195531910
diff --git a/ckpts/mofa/traj_controlnet/config.json b/ckpts/mofa/traj_controlnet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..304d16bfe4ab918d82e654ff5fc32d167d75b48b
--- /dev/null
+++ b/ckpts/mofa/traj_controlnet/config.json
@@ -0,0 +1,45 @@
+{
+ "_class_name": "FlowControlNet",
+ "_diffusers_version": "0.25.1",
+ "_name_or_path": "mofa_ckpts/traj_controlnet",
+ "addition_time_embed_dim": 256,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "conditioning_channels": 3,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "cross_attention_dim": 1024,
+ "down_block_types": [
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal"
+ ],
+ "in_channels": 8,
+ "layers_per_block": 2,
+ "num_attention_heads": [
+ 5,
+ 10,
+ 10,
+ 20
+ ],
+ "num_frames": 25,
+ "out_channels": 4,
+ "projection_class_embeddings_input_dim": 768,
+ "sample_size": null,
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal"
+ ]
+}
diff --git a/ckpts/mofa/traj_controlnet/diffusion_pytorch_model.safetensors b/ckpts/mofa/traj_controlnet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3b575b15b495831af7a1970337ea33cf883f6d58
--- /dev/null
+++ b/ckpts/mofa/traj_controlnet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1303192a1e72d071e15e7eb37fd1ea15f6424aaf2cd6b6b1e1bb3b1e9e75d37e
+size 2777345452
diff --git a/ckpts/sad_talker/.DS_Store b/ckpts/sad_talker/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..e8853c9d7fd091248ce01d10b9443bcdfbee5f74
Binary files /dev/null and b/ckpts/sad_talker/.DS_Store differ
diff --git a/ckpts/sad_talker/BFM_Fitting/01_MorphableModel.mat b/ckpts/sad_talker/BFM_Fitting/01_MorphableModel.mat
new file mode 100644
index 0000000000000000000000000000000000000000..f251485b55d35adac0ad4f1622a47d7a39a1502c
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/01_MorphableModel.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
+size 240875364
diff --git a/ckpts/sad_talker/BFM_Fitting/BFM_exp_idx.mat b/ckpts/sad_talker/BFM_Fitting/BFM_exp_idx.mat
new file mode 100644
index 0000000000000000000000000000000000000000..5b214a5f8afbc038e6959f7f72141e448e89fb3b
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/BFM_exp_idx.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62752a2cab3eea148569fb07e367e03535b4ee04aa71ea1a9aed36486d26c612
+size 91931
diff --git a/ckpts/sad_talker/BFM_Fitting/BFM_front_idx.mat b/ckpts/sad_talker/BFM_Fitting/BFM_front_idx.mat
new file mode 100644
index 0000000000000000000000000000000000000000..29d82e79f8b2558a5bf1956ab9e1261d49c2c8dd
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/BFM_front_idx.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d285dd018563113496127df9c364800183172adb4d3e802f726085dab66b087
+size 44880
diff --git a/ckpts/sad_talker/BFM_Fitting/BFM_model_front.mat b/ckpts/sad_talker/BFM_Fitting/BFM_model_front.mat
new file mode 100644
index 0000000000000000000000000000000000000000..4676a8a48dc2a0f03f4db2879cf25307f06bfc63
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/BFM_model_front.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e60c2a15c24a9270f1bb50505d134aa984e16483535342b540ea7763db92c2c
+size 127170280
diff --git a/ckpts/sad_talker/BFM_Fitting/Exp_Pca.bin b/ckpts/sad_talker/BFM_Fitting/Exp_Pca.bin
new file mode 100644
index 0000000000000000000000000000000000000000..3c1785e6abc52b13e54a573f9f3ebc099915b1e0
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/Exp_Pca.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726
+size 51086404
diff --git a/ckpts/sad_talker/BFM_Fitting/facemodel_info.mat b/ckpts/sad_talker/BFM_Fitting/facemodel_info.mat
new file mode 100644
index 0000000000000000000000000000000000000000..c2e0a3521fc040e59e07fc09384fc140234f006f
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/facemodel_info.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:529398f76619ae7e22f43c25dd60a2473bcc2bcc8c894fd9c613c68624ce1c04
+size 738861
diff --git a/ckpts/sad_talker/BFM_Fitting/select_vertex_id.mat b/ckpts/sad_talker/BFM_Fitting/select_vertex_id.mat
new file mode 100644
index 0000000000000000000000000000000000000000..feadeff96a0b8e0619461f64a9bdc9e761b14c80
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/select_vertex_id.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6877a7d634330f25bf1e81bc062b6507ee53ea183838e471fa21b613048fa36b
+size 62299
diff --git a/ckpts/sad_talker/BFM_Fitting/similarity_Lm3D_all.mat b/ckpts/sad_talker/BFM_Fitting/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..9f5b0bd4ecffb926128a29cb1bbf9d9081c3d4e7
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/similarity_Lm3D_all.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a
+size 994
diff --git a/ckpts/sad_talker/BFM_Fitting/std_exp.txt b/ckpts/sad_talker/BFM_Fitting/std_exp.txt
new file mode 100644
index 0000000000000000000000000000000000000000..767b8de4ea1ca78b6f22b98ff2dee4fa345500bb
--- /dev/null
+++ b/ckpts/sad_talker/BFM_Fitting/std_exp.txt
@@ -0,0 +1 @@
+453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19
\ No newline at end of file
diff --git a/ckpts/sad_talker/SadTalker_V0.0.2_256.safetensors b/ckpts/sad_talker/SadTalker_V0.0.2_256.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..1d0eb9787332ff6c6a603c0e79ebada49010270a
--- /dev/null
+++ b/ckpts/sad_talker/SadTalker_V0.0.2_256.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c211f5d6de003516bf1bbda9f47049a4c9c99133b1ab565c6961e5af16477bff
+size 725066984
diff --git a/ckpts/sad_talker/SadTalker_V0.0.2_512.safetensors b/ckpts/sad_talker/SadTalker_V0.0.2_512.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..356c229c4a1424e157e4b22686ba72fd3daa8e81
--- /dev/null
+++ b/ckpts/sad_talker/SadTalker_V0.0.2_512.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e063f7ff5258240bdb0f7690783a7b1374e6a4a81ce8fa33456f4cd49694340
+size 725066984
diff --git a/ckpts/sad_talker/epoch_00190_iteration_000400000_checkpoint.pt b/ckpts/sad_talker/epoch_00190_iteration_000400000_checkpoint.pt
new file mode 100644
index 0000000000000000000000000000000000000000..f5258b8314f176fb9d5646d9c2a955e08180610a
--- /dev/null
+++ b/ckpts/sad_talker/epoch_00190_iteration_000400000_checkpoint.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41220d2973c0ba2eab6e8f17ed00711aef5a0d76d19808f885dc0e3251df2e80
+size 180424655
diff --git a/ckpts/sad_talker/mapping_00109-model.pth.tar b/ckpts/sad_talker/mapping_00109-model.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..009c3190f5d903c56a2fb0a085d605dc782a83c9
--- /dev/null
+++ b/ckpts/sad_talker/mapping_00109-model.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84a8642468a3fcfdd9ab6be955267043116c2bec2284686a5262f1eaf017f64c
+size 155779231
diff --git a/ckpts/sad_talker/mapping_00229-model.pth.tar b/ckpts/sad_talker/mapping_00229-model.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..6400233ae3fa5ff9426800ef761fd6c830bc0cd7
--- /dev/null
+++ b/ckpts/sad_talker/mapping_00229-model.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
+size 155521183
diff --git a/ckpts/sad_talker/similarity_Lm3D_all.mat b/ckpts/sad_talker/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..9f5b0bd4ecffb926128a29cb1bbf9d9081c3d4e7
--- /dev/null
+++ b/ckpts/sad_talker/similarity_Lm3D_all.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a
+size 994
diff --git a/expression.mat b/expression.mat
new file mode 100644
index 0000000000000000000000000000000000000000..bf4d3c687be74adda57b4096cf05e279b9bf72ec
--- /dev/null
+++ b/expression.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93e9d69eb46e866ed5cbb569ed2bdb3813254720fb0cb745d5b56181faf9aec5
+size 1456
diff --git a/models/.DS_Store b/models/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..087053b36a558287fc3b017d1f8393a4e004363e
Binary files /dev/null and b/models/.DS_Store differ
diff --git a/models/cmp/.DS_Store b/models/cmp/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..ff76e5324d7be1a0c28204838fcff4fc1fb09c7a
Binary files /dev/null and b/models/cmp/.DS_Store differ
diff --git a/models/cmp/experiments/.DS_Store b/models/cmp/experiments/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..9359ef6cd967551d5476ebf1ce12f2e1a233d898
Binary files /dev/null and b/models/cmp/experiments/.DS_Store differ
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2944c4056bc0683c9a94bc20017c1056352356e1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml
@@ -0,0 +1,59 @@
+model:
+ arch: CMP
+ total_iter: 140000
+ lr_steps: [80000, 120000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: alexnet_fcn_32x
+ sparse_encoder: shallownet32x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 12
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.000025
+ nms_ks: 81
+ max_num_guide: 150
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ - data/youtube9000/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9296127836bea77efc5d6d28ccf363c6e8adbf91
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 70000
+ lr_steps: [40000, 60000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: alexnet_fcn_32x
+ sparse_encoder: shallownet32x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 12
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00015625
+ nms_ks: 41
+ max_num_guide: 150
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6e8751ff794627d37449771734bb2fe1521f527a
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 140000
+ lr_steps: [80000, 120000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: alexnet_fcn_32x
+ sparse_encoder: shallownet32x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 12
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00015625
+ nms_ks: 41
+ max_num_guide: 150
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..17cec90cef2555c6b7dd5acfe3b938c9be451346
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..94b7cccac61566afb3eef7924a6d8b56027b2d13
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..330d14c459f8549ea81f956c3497e13ddf68aed0
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..140bfae1f0543e2b186f06f3dfc7a934c0aeccf1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5dd44cf7642837242711326eb413b950c384dd26
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml
@@ -0,0 +1,61 @@
+model:
+ arch: CMP
+ total_iter: 70000
+ lr_steps: [40000, 60000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 10
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [320, 320]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00015625
+ nms_ks: 15
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ - data/youtube9000/lists/train.txt
+ - data/VIP/lists/train.txt
+ - data/MPII/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1a453c27947f570320609a61fde9c862819842bc
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 42000
+ lr_steps: [24000, 36000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 16
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 333
+ crop_size: [256, 256]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00005632
+ nms_ks: 49
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..47ba5c8c0d6f63247b7fcf6ac18554f4cddb0eac
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 42000
+ lr_steps: [24000, 36000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 10
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [320, 320]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00003629
+ nms_ks: 67
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..a15fde53bc352803ac906bb48f7ec6f08f55f817
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd3a385e227c29f89b5c7c6f4c89d356f6022fa7fcfc71ab1bd40e9833048dd6
+size 228465722
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fc56f53ce2088872c5f6987a0f1a44dabaf76f9d
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml
@@ -0,0 +1,59 @@
+model:
+ arch: CMP
+ total_iter: 42000
+ lr_steps: [24000, 36000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderSkipLayer
+ skip_layer: True
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 8
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 5.74e-5
+ nms_ks: 41
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/VIP/lists/train.txt
+ - data/MPII/lists/train.txt
+ val_source:
+ - data/VIP/lists/randval.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 5000
+ save_freq: 5000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: True
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..17cec90cef2555c6b7dd5acfe3b938c9be451346
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..94b7cccac61566afb3eef7924a6d8b56027b2d13
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..330d14c459f8549ea81f956c3497e13ddf68aed0
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..140bfae1f0543e2b186f06f3dfc7a934c0aeccf1
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..aef377a6e02a61de710eb8a72769ede93ce897e7
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition \
+ -n8 --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/losses.py b/models/cmp/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..b562ff841da3e8508ddf5e1264de382fb510376d
--- /dev/null
+++ b/models/cmp/losses.py
@@ -0,0 +1,536 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import random
+import math
+
+def MultiChannelSoftBinaryCrossEntropy(input, target, reduction='mean'):
+ '''
+ input: N x 38 x H x W --> 19N x 2 x H x W
+ target: N x 19 x H x W --> 19N x 1 x H x W
+ '''
+ input = input.view(-1, 2, input.size(2), input.size(3))
+ target = target.view(-1, 1, input.size(2), input.size(3))
+
+ logsoftmax = nn.LogSoftmax(dim=1)
+ if reduction == 'mean':
+ return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
+ else:
+ return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))
+
+class EdgeAwareLoss():
+ def __init__(self, nc=2, loss_type="L1", reduction='mean'):
+ assert loss_type in ['L1', 'BCE'], "Undefined loss type: {}".format(loss_type)
+ self.nc = nc
+ self.loss_type = loss_type
+ self.kernelx = Variable(torch.Tensor([[1,0,-1],[2,0,-2],[1,0,-1]]).cuda())
+ self.kernelx = self.kernelx.repeat(nc,1,1,1)
+ self.kernely = Variable(torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).cuda())
+ self.kernely = self.kernely.repeat(nc,1,1,1)
+ self.bias = Variable(torch.zeros(nc).cuda())
+ self.reduction = reduction
+ if loss_type == 'L1':
+ self.loss = nn.SmoothL1Loss(reduction=reduction)
+ elif loss_type == 'BCE':
+ self.loss = self.bce2d
+
+ def bce2d(self, input, target):
+ assert not target.requires_grad
+ beta = 1 - torch.mean(target)
+ weights = 1 - beta + (2 * beta - 1) * target
+ loss = nn.functional.binary_cross_entropy(input, target, weights, reduction=self.reduction)
+ return loss
+
+ def get_edge(self, var):
+ assert var.size(1) == self.nc, \
+ "input size at dim 1 should be consistent with nc, {} vs {}".format(var.size(1), self.nc)
+ outputx = nn.functional.conv2d(var, self.kernelx, bias=self.bias, padding=1, groups=self.nc)
+ outputy = nn.functional.conv2d(var, self.kernely, bias=self.bias, padding=1, groups=self.nc)
+ eps=1e-05
+ return torch.sqrt(outputx.pow(2) + outputy.pow(2) + eps).mean(dim=1, keepdim=True)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ target_edge = self.get_edge(target)
+ if self.loss_type == 'L1':
+ return self.loss(self.get_edge(input), target_edge)
+ elif self.loss_type == 'BCE':
+ raise NotImplemented
+ #target_edge = torch.sign(target_edge - 0.1)
+ #pred = self.get_edge(nn.functional.sigmoid(input))
+ #return self.loss(pred, target_edge)
+
+def KLD(mean, logvar):
+ return -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
+
+class DiscreteLoss(nn.Module):
+ def __init__(self, nbins, fmax):
+ super().__init__()
+ self.loss = nn.CrossEntropyLoss()
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ quantized_target = torch.floor((target + self.fmax) / self.step)
+ return quantized_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ if input.shape[2] != size[0] or input.shape[3] != size[1]:
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ target = self.tobin(target)
+ assert input.size(1) == self.nbins * 2
+ # print(target.shape)
+ # print(input.shape)
+ # print(torch.max(target))
+ target[target>=99]=98 # odd bugs of the training loss. We have [0 ~ 99] in GT flow, but nbins = 99
+ return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...])
+
+class MultiDiscreteLoss():
+ def __init__(self, nbins=19, fmax=47.5, reduction='mean', xy_weight=(1., 1.), quantize_strategy='linear'):
+ self.loss = nn.CrossEntropyLoss(reduction=reduction)
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.x_weight, self.y_weight = xy_weight
+ self.quantize_strategy = quantize_strategy
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ if self.quantize_strategy == "linear":
+ quantized_target = torch.floor((target + self.fmax) / self.step)
+ elif self.quantize_strategy == "quadratic":
+ ind = target.data > 0
+ quantized_target = target.clone()
+ quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.)
+ quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.)
+ return quantized_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ target = self.tobin(target)
+ if isinstance(input, list):
+ input = [nn.functional.interpolate(ip, size=size, mode="bilinear", align_corners=True) for ip in input]
+ return sum([self.x_weight * self.loss(input[k][:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[k][:,self.nbins:,...], target[:,1,...]) for k in range(len(input))]) / float(len(input))
+ else:
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ return self.x_weight * self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[:,self.nbins:,...], target[:,1,...])
+
+class MultiL1Loss():
+ def __init__(self, reduction='mean'):
+ self.loss = nn.SmoothL1Loss(reduction=reduction)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ if isinstance(input, list):
+ input = [nn.functional.interpolate(ip, size=size, mode="bilinear", align_corners=True) for ip in input]
+ return sum([self.loss(input[k], target) for k in range(len(input))]) / float(len(input))
+ else:
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ return self.loss(input, target)
+
+class MultiMSELoss():
+ def __init__(self):
+ self.loss = nn.MSELoss()
+
+ def __call__(self, predicts, targets):
+ loss = 0
+ for predict, target in zip(predicts, targets):
+ loss += self.loss(predict, target)
+ return loss
+
+class JointDiscreteLoss():
+ def __init__(self, nbins=19, fmax=47.5, reduction='mean', quantize_strategy='linear'):
+ self.loss = nn.CrossEntropyLoss(reduction=reduction)
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.quantize_strategy = quantize_strategy
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ if self.quantize_strategy == "linear":
+ quantized_target = torch.floor((target + self.fmax) / self.step)
+ elif self.quantize_strategy == "quadratic":
+ ind = target.data > 0
+ quantized_target = target.clone()
+ quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.)
+ quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.)
+ else:
+ raise Exception("No such quantize strategy: {}".format(self.quantize_strategy))
+ joint_target = quantized_target[:,0,:,:] * self.nbins + quantized_target[:,1,:,:]
+ return joint_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ target = self.tobin(target)
+ assert input.size(1) == self.nbins ** 2
+ return self.loss(input, target)
+
+class PolarDiscreteLoss():
+ def __init__(self, abins=30, rbins=20, fmax=50., reduction='mean', ar_weight=(1., 1.), quantize_strategy='linear'):
+ self.loss = nn.CrossEntropyLoss(reduction=reduction)
+ self.fmax = fmax
+ self.rbins = rbins
+ self.abins = abins
+ self.a_weight, self.r_weight = ar_weight
+ self.quantize_strategy = quantize_strategy
+
+ def tobin(self, target):
+ indxneg = target.data[:,0,:,:] < 0
+ eps = torch.zeros(target.data[:,0,:,:].size()).cuda()
+ epsind = target.data[:,0,:,:] == 0
+ eps[epsind] += 1e-5
+ angle = torch.atan(target.data[:,1,:,:] / (target.data[:,0,:,:] + eps))
+ angle[indxneg] += np.pi
+ angle += np.pi / 2 # 0 to 2pi
+ angle = torch.clamp(angle, 0, 2 * np.pi - 1e-3)
+ radius = torch.sqrt(target.data[:,0,:,:] ** 2 + target.data[:,1,:,:] ** 2)
+ radius = torch.clamp(radius, 0, self.fmax - 1e-3)
+ quantized_angle = torch.floor(self.abins * angle / (2 * np.pi))
+ if self.quantize_strategy == 'linear':
+ quantized_radius = torch.floor(self.rbins * radius / self.fmax)
+ elif self.quantize_strategy == 'quadratic':
+ quantized_radius = torch.floor(self.rbins * torch.sqrt(radius / self.fmax))
+ else:
+ raise Exception("No such quantize strategy: {}".format(self.quantize_strategy))
+ quantized_target = torch.autograd.Variable(torch.cat([torch.unsqueeze(quantized_angle, 1), torch.unsqueeze(quantized_radius, 1)], dim=1))
+ return quantized_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ target = self.tobin(target)
+ assert (target >= 0).all() and (target[:,0,:,:] < self.abins).all() and (target[:,1,:,:] < self.rbins).all()
+ return self.a_weight * self.loss(input[:,:self.abins,...], target[:,0,...]) + self.r_weight * self.loss(input[:,self.abins:,...], target[:,1,...])
+
+class WeightedDiscreteLoss():
+ def __init__(self, nbins=19, fmax=47.5, reduction='mean'):
+ self.loss = CrossEntropy2d(reduction=reduction)
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.weight = np.ones((nbins), dtype=np.float32)
+ self.weight[int(self.fmax / self.step)] = 0.01
+ self.weight = torch.from_numpy(self.weight).cuda()
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ return torch.floor((target + self.fmax) / self.step).type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ target = self.tobin(target)
+ assert (target >= 0).all() and (target < self.nbins).all()
+ return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...], self.weight)
+
+
+class CrossEntropy2d(nn.Module):
+ def __init__(self, reduction='mean', ignore_label=-1):
+ super(CrossEntropy2d, self).__init__()
+ self.ignore_label = ignore_label
+ self.reduction = reduction
+
+ def forward(self, predict, target, weight=None):
+ """
+ Args:
+ predict:(n, c, h, w)
+ target:(n, h, w)
+ weight (Tensor, optional): a manual rescaling weight given to each class.
+ If given, has to be a Tensor of size "nclasses"
+ """
+ assert not target.requires_grad
+ assert predict.dim() == 4
+ assert target.dim() == 3
+ assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
+ assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
+ assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
+ n, c, h, w = predict.size()
+ target_mask = (target >= 0) * (target != self.ignore_label)
+ target = target[target_mask]
+ predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
+ predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
+ loss = F.cross_entropy(predict, target, weight=weight, reduction=self.reduction)
+ return loss
+
+#class CrossPixelSimilarityLoss():
+# '''
+# Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+# '''
+# def __init__(self, sigma=0.0036, sampling_size=512):
+# self.sigma = sigma
+# self.sampling_size = sampling_size
+# self.epsilon = 1.0e-15
+# self.embed_norm = True # loss does not decrease no matter it is true or false.
+#
+# def __call__(self, embeddings, flows):
+# '''
+# embedding: Variable Nx256xHxW (not hyper-column)
+# flows: Variable Nx2xHxW
+# '''
+# assert flows.size(1) == 2
+#
+# # flow normalization
+# positive_mask = (flows > 0)
+# flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+# flows[positive_mask] = -flows[positive_mask]
+#
+# # embedding normalization
+# if self.embed_norm:
+# embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)
+#
+# # Spatially random sampling (512 samples)
+# flows_flatten = flows.view(flows.shape[0], 2, -1)
+# random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+# flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+#
+# # K_f
+# k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -
+# torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+# keepdim=False) ** 2
+# exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+#
+#
+# # mask
+# eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+# mask = torch.ones_like(exp_k_f) - eye
+#
+# # S_f
+# masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+# s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+#
+# # K_theta
+# embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)
+# embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)
+# embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True)
+# k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm))
+# exp_k_theta = torch.exp(k_theta)
+#
+# # S_theta
+# masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye
+# s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+#
+# # loss
+# loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+#
+# return loss
+
+class CrossPixelSimilarityLoss():
+ '''
+ Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+ '''
+ def __init__(self, sigma=0.01, sampling_size=512):
+ self.sigma = sigma
+ self.sampling_size = sampling_size
+ self.epsilon = 1.0e-15
+ self.embed_norm = True # loss does not decrease no matter it is true or false.
+
+ def __call__(self, embeddings, flows):
+ '''
+ embedding: Variable Nx256xHxW (not hyper-column)
+ flows: Variable Nx2xHxW
+ '''
+ assert flows.size(1) == 2
+
+ # flow normalization
+ positive_mask = (flows > 0)
+ flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+ flows[positive_mask] = -flows[positive_mask]
+
+ # embedding normalization
+ if self.embed_norm:
+ embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)
+
+ # Spatially random sampling (512 samples)
+ flows_flatten = flows.view(flows.shape[0], 2, -1)
+ random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+ flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+
+ # K_f
+ k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -
+ torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+ keepdim=False) ** 2
+ exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+
+
+ # mask
+ eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+ mask = torch.ones_like(exp_k_f) - eye
+
+ # S_f
+ masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+ s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+
+ # K_theta
+ embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)
+ embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)
+ embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True)
+ k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm))
+ exp_k_theta = torch.exp(k_theta)
+
+ # S_theta
+ masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye
+ s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+
+ # loss
+ loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+
+ return loss
+
+
+class CrossPixelSimilarityFullLoss():
+ '''
+ Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+ '''
+ def __init__(self, sigma=0.01):
+ self.sigma = sigma
+ self.epsilon = 1.0e-15
+ self.embed_norm = True # loss does not decrease no matter it is true or false.
+
+ def __call__(self, embeddings, flows):
+ '''
+ embedding: Variable Nx256xHxW (not hyper-column)
+ flows: Variable Nx2xHxW
+ '''
+ assert flows.size(1) == 2
+
+ # downsample flow
+ factor = flows.shape[2] // embeddings.shape[2]
+ flows = nn.functional.avg_pool2d(flows, factor, factor)
+ assert flows.shape[2] == embeddings.shape[2]
+
+ # flow normalization
+ positive_mask = (flows > 0)
+ flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+ flows[positive_mask] = -flows[positive_mask]
+
+ # embedding normalization
+ if self.embed_norm:
+ embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)
+
+ # Spatially random sampling (512 samples)
+ flows_flatten = flows.view(flows.shape[0], 2, -1)
+ #random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+ #flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+
+ # K_f
+ k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_flatten, dim=-1).permute(0, 3, 2, 1) -
+ torch.unsqueeze(flows_flatten, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+ keepdim=False) ** 2
+ exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+
+
+ # mask
+ eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+ mask = torch.ones_like(exp_k_f) - eye
+
+ # S_f
+ masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+ s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+
+ # K_theta
+ embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)
+ #embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)
+ embeddings_flatten_norm = torch.norm(embeddings_flatten, p=2, dim=1, keepdim=True)
+ k_theta = 0.25 * (torch.matmul(embeddings_flatten.permute(0, 2, 1), embeddings_flatten)) / (self.epsilon + torch.matmul(embeddings_flatten_norm.permute(0, 2, 1), embeddings_flatten_norm))
+ exp_k_theta = torch.exp(k_theta)
+
+ # S_theta
+ masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye
+ s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+
+ # loss
+ loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+
+ return loss
+
+
+def get_column(embeddings, index, full_size):
+ col = []
+ for embd in embeddings:
+ ind = (index.float() / full_size * embd.size(2)).long()
+ col.append(torch.index_select(embd.view(embd.shape[0], embd.shape[1], -1), 2, ind))
+ return torch.cat(col, dim=1) # N x coldim x sparsenum
+
+class CrossPixelSimilarityColumnLoss(nn.Module):
+ '''
+ Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+ '''
+ def __init__(self, sigma=0.0036, sampling_size=512):
+ super(CrossPixelSimilarityColumnLoss, self).__init__()
+ self.sigma = sigma
+ self.sampling_size = sampling_size
+ self.epsilon = 1.0e-15
+ self.embed_norm = True # loss does not decrease no matter it is true or false.
+ self.mlp = nn.Sequential(
+ nn.Linear(96 + 96 + 384 + 256 + 4096, 256),
+ nn.ReLU(inplace=True),
+ nn.Linear(256, 16))
+
+ def forward(self, feats, flows):
+ '''
+ embedding: Variable Nx256xHxW (not hyper-column)
+ flows: Variable Nx2xHxW
+ '''
+ assert flows.size(1) == 2
+
+ # flow normalization
+ positive_mask = (flows > 0)
+ flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+ flows[positive_mask] = -flows[positive_mask]
+
+ # Spatially random sampling (512 samples)
+ flows_flatten = flows.view(flows.shape[0], 2, -1)
+ random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+ flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+
+ # K_f
+ k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -
+ torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+ keepdim=False) ** 2
+ exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+
+
+ # mask
+ eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+ mask = torch.ones_like(exp_k_f) - eye
+
+ # S_f
+ masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+ s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+
+
+ # column
+ column = get_column(feats, random_locations, flows.shape[2])
+ embedding = self.mlp(column)
+ # K_theta
+ embedding_norm = torch.norm(embedding, p=2, dim=1, keepdim=True)
+ k_theta = 0.25 * (torch.matmul(embedding.permute(0, 2, 1), embedding)) / (self.epsilon + torch.matmul(embedding_norm.permute(0, 2, 1), embedding_norm))
+ exp_k_theta = torch.exp(k_theta)
+
+ # S_theta
+ masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye
+ s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+
+ # loss
+ loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+
+ return loss
+
+
+def print_info(name, var):
+ print(name, var.size(), torch.max(var).data.cpu()[0], torch.min(var).data.cpu()[0], torch.mean(var).data.cpu()[0])
+
+
+def MaskL1Loss(input, target, mask):
+ input_size = input.size()
+ res = torch.sum(torch.abs(input * mask - target * mask))
+ total = torch.sum(mask).item()
+ if total > 0:
+ res = res / (total * input_size[1])
+ return res
diff --git a/models/cmp/models/.DS_Store b/models/cmp/models/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..778b39036de751baf323736bc69406f0202af4ea
Binary files /dev/null and b/models/cmp/models/.DS_Store differ
diff --git a/models/cmp/models/__init__.py b/models/cmp/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..168e1aae54937d9c23f4f40eae871b7fd73dc5c8
--- /dev/null
+++ b/models/cmp/models/__init__.py
@@ -0,0 +1,4 @@
+from .single_stage_model import *
+from .cmp import *
+from . import modules
+from . import backbone
diff --git a/models/cmp/models/backbone/__init__.py b/models/cmp/models/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eea305c40902faaf491f9ed7ca70a56c0b9ae7fb
--- /dev/null
+++ b/models/cmp/models/backbone/__init__.py
@@ -0,0 +1,2 @@
+from .resnet import *
+from .alexnet import *
diff --git a/models/cmp/models/backbone/alexnet.py b/models/cmp/models/backbone/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ac39a8d8af096e9363854a2e7d720623ecd73e
--- /dev/null
+++ b/models/cmp/models/backbone/alexnet.py
@@ -0,0 +1,83 @@
+import torch.nn as nn
+import math
+
+class AlexNetBN_FCN(nn.Module):
+
+ def __init__(self, output_dim=256, stride=[4, 2, 2, 2], dilation=[1, 1], padding=[1, 1]):
+ super(AlexNetBN_FCN, self).__init__()
+ BN = nn.BatchNorm2d
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, 96, kernel_size=11, stride=stride[0], padding=5),
+ BN(96),
+ nn.ReLU(inplace=True))
+ self.pool1 = nn.MaxPool2d(kernel_size=3, stride=stride[1], padding=1)
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(96, 256, kernel_size=5, padding=2),
+ BN(256),
+ nn.ReLU(inplace=True))
+ self.pool2 = nn.MaxPool2d(kernel_size=3, stride=stride[2], padding=1)
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(256, 384, kernel_size=3, padding=1),
+ BN(384),
+ nn.ReLU(inplace=True))
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(384, 384, kernel_size=3, padding=padding[0], dilation=dilation[0]),
+ BN(384),
+ nn.ReLU(inplace=True))
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(384, 256, kernel_size=3, padding=padding[1], dilation=dilation[1]),
+ BN(256),
+ nn.ReLU(inplace=True))
+ self.pool5 = nn.MaxPool2d(kernel_size=3, stride=stride[3], padding=1)
+
+ self.fc6 = nn.Sequential(
+ nn.Conv2d(256, 4096, kernel_size=3, stride=1, padding=1),
+ BN(4096),
+ nn.ReLU(inplace=True))
+ self.drop6 = nn.Dropout(0.5)
+ self.fc7 = nn.Sequential(
+ nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0),
+ BN(4096),
+ nn.ReLU(inplace=True))
+ self.drop7 = nn.Dropout(0.5)
+ self.conv8 = nn.Conv2d(4096, output_dim, kernel_size=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(self, x, ret_feat=False):
+ if ret_feat:
+ raise NotImplemented
+ x = self.conv1(x)
+ x = self.pool1(x)
+ x = self.conv2(x)
+ x = self.pool2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.conv5(x)
+ x = self.pool5(x)
+ x = self.fc6(x)
+ x = self.drop6(x)
+ x = self.fc7(x)
+ x = self.drop7(x)
+ x = self.conv8(x)
+ return x
+
+def alexnet_fcn_32x(output_dim, pretrained=False, **kwargs):
+ assert pretrained == False
+ model = AlexNetBN_FCN(output_dim=output_dim, **kwargs)
+ return model
+
+def alexnet_fcn_8x(output_dim, use_ppm=False, pretrained=False, **kwargs):
+ assert pretrained == False
+ model = AlexNetBN_FCN(output_dim=output_dim, stride=[2, 2, 2, 1], **kwargs)
+ return model
diff --git a/models/cmp/models/backbone/resnet.py b/models/cmp/models/backbone/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..126ef386b7abba0ff9d09b2f051494ada0cfab30
--- /dev/null
+++ b/models/cmp/models/backbone/resnet.py
@@ -0,0 +1,201 @@
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+
+BN = None
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BN(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = BN(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BN(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = BN(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BN(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, output_dim, block, layers):
+
+ global BN
+
+ BN = nn.BatchNorm2d
+
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = BN(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ self.conv5 = nn.Conv2d(2048, output_dim, kernel_size=1)
+
+ ## dilation
+ for n, m in self.layer3.named_modules():
+ if 'conv2' in n:
+ m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
+ elif 'downsample.0' in n:
+ m.stride = (1, 1)
+ for n, m in self.layer4.named_modules():
+ if 'conv2' in n:
+ m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
+ elif 'downsample.0' in n:
+ m.stride = (1, 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BN(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, img, ret_feat=False):
+ x = self.conv1(img) # 1/2
+ x = self.bn1(x)
+ conv1 = self.relu(x) # 1/2
+ pool1 = self.maxpool(conv1) # 1/4
+
+ layer1 = self.layer1(pool1) # 1/4
+ layer2 = self.layer2(layer1) # 1/8
+ layer3 = self.layer3(layer2) # 1/8
+ layer4 = self.layer4(layer3) # 1/8
+ out = self.conv5(layer4)
+
+ if ret_feat:
+ return out, [img, conv1, layer1] # 3, 64, 256
+ else:
+ return out
+
+def resnet18(output_dim, pretrained=False):
+ model = ResNet(output_dim, BasicBlock, [2, 2, 2, 2])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
+ return model
+
+
+def resnet34(output_dim, pretrained=False):
+ model = ResNet(output_dim, BasicBlock, [3, 4, 6, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
+ return model
+
+
+def resnet50(output_dim, pretrained=False):
+ model = ResNet(output_dim, Bottleneck, [3, 4, 6, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
+ return model
+
+def resnet101(output_dim, pretrained=False):
+ model = ResNet(output_dim, Bottleneck, [3, 4, 23, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
+ return model
+
+
+def resnet152(output_dim, pretrained=False):
+ model = ResNet(output_dim, Bottleneck, [3, 8, 36, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
+ return model
diff --git a/models/cmp/models/cmp.py b/models/cmp/models/cmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..11987b4b7c14a2a2e7a2ad01b34e84f7f77bc03f
--- /dev/null
+++ b/models/cmp/models/cmp.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+
+import models.cmp.losses as losses
+import models.cmp.utils as utils
+
+from . import SingleStageModel
+
+class CMP(SingleStageModel):
+
+ def __init__(self, params, dist_model=False):
+ super(CMP, self).__init__(params, dist_model)
+ model_params = params['module']
+
+ # define loss
+ if model_params['flow_criterion'] == 'L1':
+ self.flow_criterion = nn.SmoothL1Loss()
+ elif model_params['flow_criterion'] == 'L2':
+ self.flow_criterion = nn.MSELoss()
+ elif model_params['flow_criterion'] == 'DiscreteLoss':
+ self.flow_criterion = losses.DiscreteLoss(
+ nbins=model_params['nbins'], fmax=model_params['fmax'])
+ else:
+ raise Exception("No such flow loss: {}".format(model_params['flow_criterion']))
+
+ self.fuser = utils.Fuser(nbins=model_params['nbins'],
+ fmax=model_params['fmax'])
+ self.model_params = model_params
+
+ def eval(self, ret_loss=True):
+ with torch.no_grad():
+ cmp_output = self.model(self.image_input, self.sparse_input)
+ if self.model_params['flow_criterion'] == "DiscreteLoss":
+ self.flow = self.fuser.convert_flow(cmp_output)
+ else:
+ self.flow = cmp_output
+ if self.flow.shape[2] != self.image_input.shape[2]:
+ self.flow = nn.functional.interpolate(
+ self.flow, size=self.image_input.shape[2:4],
+ mode="bilinear", align_corners=True)
+
+ ret_tensors = {
+ 'flow_tensors': [self.flow, self.flow_target],
+ 'common_tensors': [],
+ 'rgb_tensors': []} # except for image_input
+
+ if ret_loss:
+ if cmp_output.shape[2] != self.flow_target.shape[2]:
+ cmp_output = nn.functional.interpolate(
+ cmp_output, size=self.flow_target.shape[2:4],
+ mode="bilinear", align_corners=True)
+ loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size
+ return ret_tensors, {'loss_flow': loss_flow}
+ else:
+ return ret_tensors
+
+ def step(self):
+ cmp_output = self.model(self.image_input, self.sparse_input)
+ loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size
+ self.optim.zero_grad()
+ loss_flow.backward()
+ utils.average_gradients(self.model)
+ self.optim.step()
+ return {'loss_flow': loss_flow}
diff --git a/models/cmp/models/modules/__init__.py b/models/cmp/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff11cb76475f299be4ce9641182686866e00f99
--- /dev/null
+++ b/models/cmp/models/modules/__init__.py
@@ -0,0 +1,6 @@
+from .warp import *
+from .others import *
+from .shallownet import *
+from .decoder import *
+from .cmp import *
+
diff --git a/models/cmp/models/modules/cmp.py b/models/cmp/models/modules/cmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c5130353000c6f971425a37c2588e45d8710664
--- /dev/null
+++ b/models/cmp/models/modules/cmp.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+import models.cmp.models as models
+
+
+class CMP(nn.Module):
+
+ def __init__(self, params):
+ super(CMP, self).__init__()
+ img_enc_dim = params['img_enc_dim']
+ sparse_enc_dim = params['sparse_enc_dim']
+ output_dim = params['output_dim']
+ pretrained = params['pretrained_image_encoder']
+ decoder_combo = params['decoder_combo']
+ self.skip_layer = params['skip_layer']
+ if self.skip_layer:
+ assert params['flow_decoder'] == "MotionDecoderSkipLayer"
+
+ self.image_encoder = models.backbone.__dict__[params['image_encoder']](
+ img_enc_dim, pretrained)
+ self.flow_encoder = models.modules.__dict__[params['sparse_encoder']](
+ sparse_enc_dim)
+ self.flow_decoder = models.modules.__dict__[params['flow_decoder']](
+ input_dim=img_enc_dim+sparse_enc_dim,
+ output_dim=output_dim, combo=decoder_combo)
+
+ def forward(self, image, sparse):
+ sparse_enc = self.flow_encoder(sparse)
+ if self.skip_layer:
+ img_enc, skip_feat = self.image_encoder(image, ret_feat=True)
+ flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1), skip_feat)
+ else:
+ img_enc = self.image_encoder(image)
+ flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1))
+ return flow_dec
+
+
diff --git a/models/cmp/models/modules/decoder.py b/models/cmp/models/modules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f1c0e395f55f4e348410c98d2d37a13441d7139
--- /dev/null
+++ b/models/cmp/models/modules/decoder.py
@@ -0,0 +1,358 @@
+import torch
+import torch.nn as nn
+import math
+
+class MotionDecoderPlain(nn.Module):
+
+ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4]):
+ super(MotionDecoderPlain, self).__init__()
+ BN = nn.BatchNorm2d
+
+ self.combo = combo
+ for c in combo:
+ assert c in [1,2,4,8], "invalid combo: {}".format(combo)
+
+ if 1 in combo:
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ if 2 in combo:
+ self.decoder2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ if 4 in combo:
+ self.decoder4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ if 8 in combo:
+ self.decoder8 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=8, stride=8),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.head = nn.Conv2d(128 * len(self.combo), output_dim, kernel_size=1, padding=0)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+
+ cat_list = []
+ if 1 in self.combo:
+ x1 = self.decoder1(x)
+ cat_list.append(x1)
+ if 2 in self.combo:
+ x2 = nn.functional.interpolate(
+ self.decoder2(x), size=(x.size(2), x.size(3)),
+ mode="bilinear", align_corners=True)
+ cat_list.append(x2)
+ if 4 in self.combo:
+ x4 = nn.functional.interpolate(
+ self.decoder4(x), size=(x.size(2), x.size(3)),
+ mode="bilinear", align_corners=True)
+ cat_list.append(x4)
+ if 8 in self.combo:
+ x8 = nn.functional.interpolate(
+ self.decoder8(x), size=(x.size(2), x.size(3)),
+ mode="bilinear", align_corners=True)
+ cat_list.append(x8)
+
+ cat = torch.cat(cat_list, dim=1)
+ flow = self.head(cat)
+ return flow
+
+
+class MotionDecoderSkipLayer(nn.Module):
+
+ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
+ super(MotionDecoderSkipLayer, self).__init__()
+
+ BN = nn.BatchNorm2d
+
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder8 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=8, stride=8),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.fusion8 = nn.Sequential(
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
+ BN(256),
+ nn.ReLU(inplace=True))
+
+ self.skipconv4 = nn.Sequential(
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+ self.fusion4 = nn.Sequential(
+ nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.skipconv2 = nn.Sequential(
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
+ BN(32),
+ nn.ReLU(inplace=True))
+ self.fusion2 = nn.Sequential(
+ nn.Conv2d(128 + 32, 64, kernel_size=3, padding=1),
+ BN(64),
+ nn.ReLU(inplace=True))
+
+ self.head = nn.Conv2d(64, output_dim, kernel_size=1, padding=0)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x, skip_feat):
+ layer1, layer2, layer4 = skip_feat
+
+ x1 = self.decoder1(x)
+ x2 = nn.functional.interpolate(
+ self.decoder2(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x4 = nn.functional.interpolate(
+ self.decoder4(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x8 = nn.functional.interpolate(
+ self.decoder8(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ cat = torch.cat([x1, x2, x4, x8], dim=1)
+ f8 = self.fusion8(cat)
+
+ f8_up = nn.functional.interpolate(
+ f8, size=(layer4.size(2), layer4.size(3)),
+ mode="bilinear", align_corners=True)
+ f4 = self.fusion4(torch.cat([f8_up, self.skipconv4(layer4)], dim=1))
+
+ f4_up = nn.functional.interpolate(
+ f4, size=(layer2.size(2), layer2.size(3)),
+ mode="bilinear", align_corners=True)
+ f2 = self.fusion2(torch.cat([f4_up, self.skipconv2(layer2)], dim=1))
+
+ flow = self.head(f2)
+ return flow
+
+
+class MotionDecoderFlowNet(nn.Module):
+
+ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
+ super(MotionDecoderFlowNet, self).__init__()
+ global BN
+
+ BN = nn.BatchNorm2d
+
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder8 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=8, stride=8),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.fusion8 = nn.Sequential(
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
+ BN(256),
+ nn.ReLU(inplace=True))
+
+ # flownet head
+ self.predict_flow8 = predict_flow(256, output_dim)
+ self.predict_flow4 = predict_flow(384 + output_dim, output_dim)
+ self.predict_flow2 = predict_flow(192 + output_dim, output_dim)
+ self.predict_flow1 = predict_flow(67 + output_dim, output_dim)
+
+ self.upsampled_flow8_to_4 = nn.ConvTranspose2d(
+ output_dim, output_dim, 4, 2, 1, bias=False)
+ self.upsampled_flow4_to_2 = nn.ConvTranspose2d(
+ output_dim, output_dim, 4, 2, 1, bias=False)
+ self.upsampled_flow2_to_1 = nn.ConvTranspose2d(
+ output_dim, output_dim, 4, 2, 1, bias=False)
+
+ self.deconv8 = deconv(256, 128)
+ self.deconv4 = deconv(384 + output_dim, 128)
+ self.deconv2 = deconv(192 + output_dim, 64)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x, skip_feat):
+ layer1, layer2, layer4 = skip_feat # 3, 64, 256
+
+ # propagation nets
+ x1 = self.decoder1(x)
+ x2 = nn.functional.interpolate(
+ self.decoder2(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x4 = nn.functional.interpolate(
+ self.decoder4(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x8 = nn.functional.interpolate(
+ self.decoder8(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ cat = torch.cat([x1, x2, x4, x8], dim=1)
+ feat8 = self.fusion8(cat) # 256
+
+ # flownet head
+ flow8 = self.predict_flow8(feat8)
+ flow8_up = self.upsampled_flow8_to_4(flow8)
+ out_deconv8 = self.deconv8(feat8) # 128
+
+ concat4 = torch.cat((layer4, out_deconv8, flow8_up), dim=1) # 394 + out
+ flow4 = self.predict_flow4(concat4)
+ flow4_up = self.upsampled_flow4_to_2(flow4)
+ out_deconv4 = self.deconv4(concat4) # 128
+
+ concat2 = torch.cat((layer2, out_deconv4, flow4_up), dim=1) # 192 + out
+ flow2 = self.predict_flow2(concat2)
+ flow2_up = self.upsampled_flow2_to_1(flow2)
+ out_deconv2 = self.deconv2(concat2) # 64
+
+ concat1 = torch.cat((layer1, out_deconv2, flow2_up), dim=1) # 67 + out
+ flow1 = self.predict_flow1(concat1)
+
+ return [flow1, flow2, flow4, flow8]
+
+
+def predict_flow(in_planes, out_planes):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
+ stride=1, padding=1, bias=True)
+
+
+def deconv(in_planes, out_planes):
+ return nn.Sequential(
+ nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4,
+ stride=2, padding=1, bias=True),
+ nn.LeakyReLU(0.1, inplace=True)
+ )
+
+
diff --git a/models/cmp/models/modules/others.py b/models/cmp/models/modules/others.py
new file mode 100644
index 0000000000000000000000000000000000000000..591ce94f7d10db49fb3209d4d74a4a973f9a6cf5
--- /dev/null
+++ b/models/cmp/models/modules/others.py
@@ -0,0 +1,11 @@
+import torch.nn as nn
+
+class FixModule(nn.Module):
+
+ def __init__(self, m):
+ super(FixModule, self).__init__()
+ self.module = m
+
+ def forward(self, *args, **kwargs):
+ return self.module(*args, **kwargs)
+
diff --git a/models/cmp/models/modules/shallownet.py b/models/cmp/models/modules/shallownet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37fedd26b5096e34c0e6303f69e54b3d58c39b4
--- /dev/null
+++ b/models/cmp/models/modules/shallownet.py
@@ -0,0 +1,49 @@
+import torch.nn as nn
+import math
+
+class ShallowNet(nn.Module):
+
+ def __init__(self, input_dim=4, output_dim=16, stride=[2, 2, 2]):
+ super(ShallowNet, self).__init__()
+ global BN
+
+ BN = nn.BatchNorm2d
+
+ self.features = nn.Sequential(
+ nn.Conv2d(input_dim, 16, kernel_size=5, stride=stride[0], padding=2),
+ nn.BatchNorm2d(16),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=stride[1], stride=stride[1]),
+ nn.Conv2d(16, output_dim, kernel_size=3, padding=1),
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(inplace=True),
+ nn.AvgPool2d(kernel_size=stride[2], stride=stride[2]),
+ )
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.features(x)
+ return x
+
+
+def shallownet8x(output_dim):
+ model = ShallowNet(output_dim=output_dim, stride=[2,2,2])
+ return model
+
+def shallownet32x(output_dim, **kwargs):
+ model = ShallowNet(output_dim=output_dim, stride=[2,2,8])
+ return model
+
+
+
diff --git a/models/cmp/models/modules/warp.py b/models/cmp/models/modules/warp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d32dc5db787345c9d2622fa6f65d463dd78ef8ba
--- /dev/null
+++ b/models/cmp/models/modules/warp.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+
+class WarpingLayerBWFlow(nn.Module):
+
+ def __init__(self):
+ super(WarpingLayerBWFlow, self).__init__()
+
+ def forward(self, image, flow):
+ flow_for_grip = torch.zeros_like(flow)
+ flow_for_grip[:,0,:,:] = flow[:,0,:,:] / ((flow.size(3) - 1.0) / 2.0)
+ flow_for_grip[:,1,:,:] = flow[:,1,:,:] / ((flow.size(2) - 1.0) / 2.0)
+
+ torchHorizontal = torch.linspace(
+ -1.0, 1.0, image.size(3)).view(
+ 1, 1, 1, image.size(3)).expand(
+ image.size(0), 1, image.size(2), image.size(3))
+ torchVertical = torch.linspace(
+ -1.0, 1.0, image.size(2)).view(
+ 1, 1, image.size(2), 1).expand(
+ image.size(0), 1, image.size(2), image.size(3))
+ grid = torch.cat([torchHorizontal, torchVertical], 1).cuda()
+
+ grid = (grid + flow_for_grip).permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(image, grid)
+
+
+class WarpingLayerFWFlow(nn.Module):
+
+ def __init__(self):
+ super(WarpingLayerFWFlow, self).__init__()
+ self.initialized = False
+
+ def forward(self, image, flow, ret_mask = False):
+ n, h, w = image.size(0), image.size(2), image.size(3)
+
+ if not self.initialized or n != self.meshx.shape[0] or h * w != self.meshx.shape[1]:
+ self.meshx = torch.arange(w).view(1, 1, w).expand(
+ n, h, w).contiguous().view(n, -1).cuda()
+ self.meshy = torch.arange(h).view(1, h, 1).expand(
+ n, h, w).contiguous().view(n, -1).cuda()
+ self.warped_image = torch.zeros((n, 3, h, w), dtype=torch.float32).cuda()
+ if ret_mask:
+ self.hole_mask = torch.ones((n, 1, h, w), dtype=torch.float32).cuda()
+ self.initialized = True
+
+ v = (flow[:,0,:,:] ** 2 + flow[:,1,:,:] ** 2).view(n, -1)
+ _, sortidx = torch.sort(v, dim=1)
+
+ warped_meshx = self.meshx + flow[:,0,:,:].long().view(n, -1)
+ warped_meshy = self.meshy + flow[:,1,:,:].long().view(n, -1)
+
+ warped_meshx = torch.clamp(warped_meshx, 0, w - 1)
+ warped_meshy = torch.clamp(warped_meshy, 0, h - 1)
+
+ self.warped_image.zero_()
+ if ret_mask:
+ self.hole_mask.fill_(1.)
+ for i in range(n):
+ for c in range(3):
+ ind = sortidx[i]
+ self.warped_image[i,c,warped_meshy[i][ind],warped_meshx[i][ind]] = image[i,c,self.meshy[i][ind],self.meshx[i][ind]]
+ if ret_mask:
+ self.hole_mask[i,0,warped_meshy[i],warped_meshx[i]] = 0.
+ if ret_mask:
+ return self.warped_image, self.hole_mask
+ else:
+ return self.warped_image
diff --git a/models/cmp/models/single_stage_model.py b/models/cmp/models/single_stage_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f5d4ab7ccffba99f72612ad6c77bf4dc3f2521
--- /dev/null
+++ b/models/cmp/models/single_stage_model.py
@@ -0,0 +1,72 @@
+import os
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+import models.cmp.models as models
+import models.cmp.utils as utils
+
+
+class SingleStageModel(object):
+
+ def __init__(self, params, dist_model=False):
+ model_params = params['module']
+ self.model = models.modules.__dict__[params['module']['arch']](model_params)
+ utils.init_weights(self.model, init_type='xavier')
+ self.model.cuda()
+ if dist_model:
+ self.model = utils.DistModule(self.model)
+ self.world_size = dist.get_world_size()
+ else:
+ self.model = models.modules.FixModule(self.model)
+ self.world_size = 1
+
+ if params['optim'] == 'SGD':
+ self.optim = torch.optim.SGD(
+ self.model.parameters(), lr=params['lr'],
+ momentum=0.9, weight_decay=0.0001)
+ elif params['optim'] == 'Adam':
+ self.optim = torch.optim.Adam(
+ self.model.parameters(), lr=params['lr'],
+ betas=(params['beta1'], 0.999))
+ else:
+ raise Exception("No such optimizer: {}".format(params['optim']))
+
+ cudnn.benchmark = True
+
+ def set_input(self, image_input, sparse_input, flow_target=None, rgb_target=None):
+ self.image_input = image_input
+ self.sparse_input = sparse_input
+ self.flow_target = flow_target
+ self.rgb_target = rgb_target
+
+ def eval(self, ret_loss=True):
+ pass
+
+ def step(self):
+ pass
+
+ def load_state(self, path, Iter, resume=False):
+ path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter))
+
+ if resume:
+ utils.load_state(path, self.model, self.optim)
+ else:
+ utils.load_state(path, self.model)
+
+ def load_pretrain(self, load_path):
+ utils.load_state(load_path, self.model)
+
+ def save_state(self, path, Iter):
+ path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter))
+
+ torch.save({
+ 'step': Iter,
+ 'state_dict': self.model.state_dict(),
+ 'optimizer': self.optim.state_dict()}, path)
+
+ def switch_to(self, phase):
+ if phase == 'train':
+ self.model.train()
+ else:
+ self.model.eval()
diff --git a/models/cmp/utils/__init__.py b/models/cmp/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29be9c14049e0540b324db3fc65eedf1b492358e
--- /dev/null
+++ b/models/cmp/utils/__init__.py
@@ -0,0 +1,6 @@
+from .common_utils import *
+from .data_utils import *
+from .distributed_utils import *
+from .visualize_utils import *
+from .scheduler import *
+from . import flowlib
diff --git a/models/cmp/utils/common_utils.py b/models/cmp/utils/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a3862068c32b5094b7ee3caa045d250c1d63264
--- /dev/null
+++ b/models/cmp/utils/common_utils.py
@@ -0,0 +1,118 @@
+import os
+import logging
+import numpy as np
+
+import torch
+from torch.nn import init
+
+def init_weights(net, init_type='normal', init_gain=0.02):
+ """Initialize network weights.
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ net.apply(init_func) # apply the initialization function
+
+def create_logger(name, log_file, level=logging.INFO):
+ l = logging.getLogger(name)
+ formatter = logging.Formatter('[%(asctime)s] %(message)s')
+ fh = logging.FileHandler(log_file)
+ fh.setFormatter(formatter)
+ sh = logging.StreamHandler()
+ sh.setFormatter(formatter)
+ l.setLevel(level)
+ l.addHandler(fh)
+ l.addHandler(sh)
+ return l
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self, length=0):
+ self.length = length
+ self.reset()
+
+ def reset(self):
+ if self.length > 0:
+ self.history = []
+ else:
+ self.count = 0
+ self.sum = 0.0
+ self.val = 0.0
+ self.avg = 0.0
+
+ def update(self, val):
+ if self.length > 0:
+ self.history.append(val)
+ if len(self.history) > self.length:
+ del self.history[0]
+
+ self.val = self.history[-1]
+ self.avg = np.mean(self.history)
+ else:
+ self.val = val
+ self.sum += val
+ self.count += 1
+ self.avg = self.sum / self.count
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0, keepdims=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+def load_state(path, model, optimizer=None):
+ def map_func(storage, location):
+ return storage.cuda()
+ if os.path.isfile(path):
+ print("=> loading checkpoint '{}'".format(path))
+ checkpoint = torch.load(path, map_location=map_func)
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+ ckpt_keys = set(checkpoint['state_dict'].keys())
+ own_keys = set(model.state_dict().keys())
+ missing_keys = own_keys - ckpt_keys
+ # print(ckpt_keys)
+ # print(own_keys)
+ for k in missing_keys:
+ print('caution: missing keys from checkpoint {}: {}'.format(path, k))
+
+ last_iter = checkpoint['step']
+ if optimizer != None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ print("=> also loaded optimizer from checkpoint '{}' (iter {})"
+ .format(path, last_iter))
+ return last_iter
+ else:
+ print("=> no checkpoint found at '{}'".format(path))
+
+
diff --git a/models/cmp/utils/data_utils.py b/models/cmp/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0651fc5d9fefa638c76f86ebdb9db3139fecf991
--- /dev/null
+++ b/models/cmp/utils/data_utils.py
@@ -0,0 +1,280 @@
+from PIL import Image, ImageOps
+import scipy.ndimage as ndimage
+import cv2
+import random
+import numpy as np
+from scipy.ndimage.filters import maximum_filter
+from scipy import signal
+cv2.ocl.setUseOpenCL(False)
+
+def get_edge(data, blur=False):
+ if blur:
+ data = cv2.GaussianBlur(data, (3, 3), 1.)
+ sobel = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]).astype(np.float32)
+ ch_edges = []
+ for k in range(data.shape[2]):
+ edgex = signal.convolve2d(data[:,:,k], sobel, boundary='symm', mode='same')
+ edgey = signal.convolve2d(data[:,:,k], sobel.T, boundary='symm', mode='same')
+ ch_edges.append(np.sqrt(edgex**2 + edgey**2))
+ return sum(ch_edges)
+
+def get_max(score, bbox):
+ u = max(0, bbox[0])
+ d = min(score.shape[0], bbox[1])
+ l = max(0, bbox[2])
+ r = min(score.shape[1], bbox[3])
+ return score[u:d,l:r].max()
+
+def nms(score, ks):
+ assert ks % 2 == 1
+ ret_score = score.copy()
+ maxpool = maximum_filter(score, footprint=np.ones((ks, ks)))
+ ret_score[score < maxpool] = 0.
+ return ret_score
+
+def image_flow_crop(img1, img2, flow, crop_size, phase):
+ assert len(crop_size) == 2
+ pad_h = max(crop_size[0] - img1.height, 0)
+ pad_w = max(crop_size[1] - img1.width, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+ if pad_h > 0 or pad_w > 0:
+ flow_expand = np.zeros((img1.height + pad_h, img1.width + pad_w, 2), dtype=np.float32)
+ flow_expand[pad_h_half:pad_h_half+img1.height, pad_w_half:pad_w_half+img1.width, :] = flow
+ flow = flow_expand
+ border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)
+ img1 = ImageOps.expand(img1, border=border, fill=(0,0,0))
+ img2 = ImageOps.expand(img2, border=border, fill=(0,0,0))
+ if phase == 'train':
+ hoff = int(np.random.rand() * (img1.height - crop_size[0]))
+ woff = int(np.random.rand() * (img1.width - crop_size[1]))
+ else:
+ hoff = (img1.height - crop_size[0]) // 2
+ woff = (img1.width - crop_size[1]) // 2
+
+ img1 = img1.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))
+ img2 = img2.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))
+ flow = flow[hoff:hoff+crop_size[0], woff:woff+crop_size[1], :]
+ offset = (hoff, woff)
+ return img1, img2, flow, offset
+
+def image_crop(img, crop_size):
+ pad_h = max(crop_size[0] - img.height, 0)
+ pad_w = max(crop_size[1] - img.width, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+ if pad_h > 0 or pad_w > 0:
+ border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)
+ img = ImageOps.expand(img, border=border, fill=(0,0,0))
+ hoff = (img.height - crop_size[0]) // 2
+ woff = (img.width - crop_size[1]) // 2
+ return img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])), (pad_w_half, pad_h_half)
+
+def image_flow_resize(img1, img2, flow, short_size=None, long_size=None):
+ assert (short_size is None) ^ (long_size is None)
+ w, h = img1.width, img1.height
+ if short_size is not None:
+ if w < h:
+ neww = short_size
+ newh = int(short_size / float(w) * h)
+ else:
+ neww = int(short_size / float(h) * w)
+ newh = short_size
+ else:
+ if w < h:
+ neww = int(long_size / float(h) * w)
+ newh = long_size
+ else:
+ neww = long_size
+ newh = int(long_size / float(w) * h)
+ img1 = img1.resize((neww, newh), Image.BICUBIC)
+ img2 = img2.resize((neww, newh), Image.BICUBIC)
+ ratio = float(newh) / h
+ flow = cv2.resize(flow.copy(), (neww, newh), interpolation=cv2.INTER_LINEAR) * ratio
+ return img1, img2, flow, ratio
+
+def image_resize(img, short_size=None, long_size=None):
+ assert (short_size is None) ^ (long_size is None)
+ w, h = img.width, img.height
+ if short_size is not None:
+ if w < h:
+ neww = short_size
+ newh = int(short_size / float(w) * h)
+ else:
+ neww = int(short_size / float(h) * w)
+ newh = short_size
+ else:
+ if w < h:
+ neww = int(long_size / float(h) * w)
+ newh = long_size
+ else:
+ neww = long_size
+ newh = int(long_size / float(w) * h)
+ img = img.resize((neww, newh), Image.BICUBIC)
+ return img, [w, h]
+
+
+def image_pose_crop(img, posemap, crop_size, scale):
+ assert len(crop_size) == 2
+ assert crop_size[0] <= img.height
+ assert crop_size[1] <= img.width
+ hoff = (img.height - crop_size[0]) // 2
+ woff = (img.width - crop_size[1]) // 2
+ img = img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))
+ posemap = posemap[hoff//scale:hoff//scale+crop_size[0]//scale, woff//scale:woff//scale+crop_size[1]//scale,:]
+ return img, posemap
+
+def neighbor_elim(ph, pw, d):
+ valid = np.ones((len(ph))).astype(np.int)
+ h_dist = np.fabs(np.tile(ph[:,np.newaxis], [1,len(ph)]) - np.tile(ph.T[np.newaxis,:], [len(ph),1]))
+ w_dist = np.fabs(np.tile(pw[:,np.newaxis], [1,len(pw)]) - np.tile(pw.T[np.newaxis,:], [len(pw),1]))
+ idx1, idx2 = np.where((h_dist < d) & (w_dist < d))
+ for i,j in zip(idx1, idx2):
+ if valid[i] and valid[j] and i != j:
+ if np.random.rand() > 0.5:
+ valid[i] = 0
+ else:
+ valid[j] = 0
+ valid_idx = np.where(valid==1)
+ return ph[valid_idx], pw[valid_idx]
+
+def remove_border(mask):
+ mask[0,:] = 0
+ mask[:,0] = 0
+ mask[mask.shape[0]-1,:] = 0
+ mask[:,mask.shape[1]-1] = 0
+
+def flow_sampler(flow, strategy=['grid'], bg_ratio=1./6400, nms_ks=15, max_num_guide=-1, guidepoint=None):
+ assert bg_ratio >= 0 and bg_ratio <= 1, "sampling ratio must be in (0, 1]"
+ for s in strategy:
+ assert s in ['grid', 'uniform', 'gradnms', 'watershed', 'single', 'full', 'specified'], "No such strategy: {}".format(s)
+ h = flow.shape[0]
+ w = flow.shape[1]
+ ds = max(1, max(h, w) // 400) # reduce computation
+
+ if 'full' in strategy:
+ sparse = flow.copy()
+ mask = np.ones(flow.shape, dtype=np.int)
+ return sparse, mask
+
+ pts_h = []
+ pts_w = []
+ if 'grid' in strategy:
+ stride = int(np.sqrt(1./bg_ratio))
+ mesh_start_h = int((h - h // stride * stride) / 2)
+ mesh_start_w = int((w - w // stride * stride) / 2)
+ mesh = np.meshgrid(np.arange(mesh_start_h, h, stride), np.arange(mesh_start_w, w, stride))
+ pts_h.append(mesh[0].flat)
+ pts_w.append(mesh[1].flat)
+ if 'uniform' in strategy:
+ pts_h.append(np.random.randint(0, h, int(bg_ratio * h * w)))
+ pts_w.append(np.random.randint(0, w, int(bg_ratio * h * w)))
+ if "gradnms" in strategy:
+ ks = w // ds // 20
+ edge = get_edge(flow[::ds,::ds,:])
+ kernel = np.ones((ks, ks), dtype=np.float32) / (ks * ks)
+ subkernel = np.ones((ks//2, ks//2), dtype=np.float32) / (ks//2 * ks//2)
+ score = signal.convolve2d(edge, kernel, boundary='symm', mode='same')
+ subscore = signal.convolve2d(edge, subkernel, boundary='symm', mode='same')
+ score = score / score.max() - subscore / subscore.max()
+ nms_res = nms(score, nms_ks)
+ pth, ptw = np.where(nms_res > 0.1)
+ pts_h.append(pth * ds)
+ pts_w.append(ptw * ds)
+ if "watershed" in strategy:
+ edge = get_edge(flow[::ds,::ds,:])
+ edge /= max(edge.max(), 0.01)
+ edge = (edge > 0.1).astype(np.float32)
+ watershed = ndimage.distance_transform_edt(1-edge)
+ nms_res = nms(watershed, nms_ks)
+ remove_border(nms_res)
+ pth, ptw = np.where(nms_res > 0)
+ pth, ptw = neighbor_elim(pth, ptw, (nms_ks-1)/2)
+ pts_h.append(pth * ds)
+ pts_w.append(ptw * ds)
+ if "single" in strategy:
+ pth, ptw = np.where((flow[:,:,0] != 0) | (flow[:,:,1] != 0))
+ randidx = np.random.randint(len(pth))
+ pts_h.append(pth[randidx:randidx+1])
+ pts_w.append(ptw[randidx:randidx+1])
+ if 'specified' in strategy:
+ assert guidepoint is not None, "if using \"specified\", switch \"with_info\" on."
+ pts_h.append(guidepoint[:,1])
+ pts_w.append(guidepoint[:,0])
+
+ pts_h = np.concatenate(pts_h)
+ pts_w = np.concatenate(pts_w)
+
+ if max_num_guide == -1:
+ max_num_guide = np.inf
+
+ randsel = np.random.permutation(len(pts_h))[:len(pts_h)]
+ selidx = randsel[np.arange(min(max_num_guide, len(randsel)))]
+ pts_h = pts_h[selidx]
+ pts_w = pts_w[selidx]
+
+ sparse = np.zeros(flow.shape, dtype=flow.dtype)
+ mask = np.zeros(flow.shape, dtype=np.int)
+
+ sparse[:, :, 0][(pts_h, pts_w)] = flow[:, :, 0][(pts_h, pts_w)]
+ sparse[:, :, 1][(pts_h, pts_w)] = flow[:, :, 1][(pts_h, pts_w)]
+
+ mask[:,:,0][(pts_h, pts_w)] = 1
+ mask[:,:,1][(pts_h, pts_w)] = 1
+ return sparse, mask
+
+def image_flow_aug(img1, img2, flow, flip_horizon=True):
+ if flip_horizon:
+ if random.random() < 0.5:
+ img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
+ img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
+ flow = flow[:,::-1,:].copy()
+ flow[:,:,0] = -flow[:,:,0]
+ return img1, img2, flow
+
+def flow_aug(flow, reverse=True, scale=True, rotate=True):
+ if reverse:
+ if random.random() < 0.5:
+ flow = -flow
+ if scale:
+ rand_scale = random.uniform(0.5, 2.0)
+ flow = flow * rand_scale
+ if rotate and random.random() < 0.5:
+ lengh = np.sqrt(np.square(flow[:,:,0]) + np.square(flow[:,:,1]))
+ alpha = np.arctan(flow[:,:,1] / flow[:,:,0])
+ theta = random.uniform(0, np.pi*2)
+ flow[:,:,0] = lengh * np.cos(alpha + theta)
+ flow[:,:,1] = lengh * np.sin(alpha + theta)
+ return flow
+
+def draw_gaussian(img, pt, sigma, type='Gaussian'):
+ # Check that any part of the gaussian is in-bounds
+ ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
+ br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
+ if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
+ br[0] < 0 or br[1] < 0):
+ # If not, just return the image as is
+ return img
+
+ # Generate gaussian
+ size = 6 * sigma + 1
+ x = np.arange(0, size, 1, float)
+ y = x[:, np.newaxis]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized, we want the center value to equal 1
+ if type == 'Gaussian':
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ elif type == 'Cauchy':
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
+
+ img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+ return img
+
+
diff --git a/models/cmp/utils/distributed_utils.py b/models/cmp/utils/distributed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..97056fc313c198ea11ec96a6ad7575db5de6b302
--- /dev/null
+++ b/models/cmp/utils/distributed_utils.py
@@ -0,0 +1,229 @@
+import os
+import subprocess
+import numpy as np
+import multiprocessing as mp
+import math
+
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+from torch.nn import Module
+
+class DistModule(Module):
+ def __init__(self, module):
+ super(DistModule, self).__init__()
+ self.module = module
+ broadcast_params(self.module)
+ def forward(self, *inputs, **kwargs):
+ return self.module(*inputs, **kwargs)
+ def train(self, mode=True):
+ super(DistModule, self).train(mode)
+ self.module.train(mode)
+
+def average_gradients(model):
+ """ average gradients """
+ for param in model.parameters():
+ if param.requires_grad:
+ dist.all_reduce(param.grad.data)
+
+def broadcast_params(model):
+ """ broadcast model parameters """
+ for p in model.state_dict().values():
+ dist.broadcast(p, 0)
+
+def dist_init(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError('Invalid launcher type: {}'.format(launcher))
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+def _init_dist_mpi(backend, **kwargs):
+ raise NotImplementedError
+
+def _init_dist_slurm(backend, port=10086, **kwargs):
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ 'scontrol show hostname {} | head -n1'.format(node_list))
+ os.environ['MASTER_PORT'] = str(port)
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+def gather_tensors(input_array):
+ world_size = dist.get_world_size()
+ ## gather shapes first
+ myshape = input_array.shape
+ mycount = input_array.size
+ shape_tensor = torch.Tensor(np.array(myshape)).cuda()
+ all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)]
+ dist.all_gather(all_shape, shape_tensor)
+ ## compute largest shapes
+ all_shape = [x.cpu().numpy() for x in all_shape]
+ all_count = [int(x.prod()) for x in all_shape]
+ all_shape = [list(map(int, x)) for x in all_shape]
+ max_count = max(all_count)
+ ## padding tensors and gather them
+ output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)]
+ padded_input_array = np.zeros(max_count)
+ padded_input_array[:mycount] = input_array.reshape(-1)
+ input_tensor = torch.Tensor(padded_input_array).cuda()
+ dist.all_gather(output_tensors, input_tensor)
+ ## unpadding gathered tensors
+ padded_output = [x.cpu().numpy() for x in output_tensors]
+ output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)]
+ return output
+
+def gather_tensors_batch(input_array, part_size=10):
+ # gather
+ rank = dist.get_rank()
+ all_features = []
+ part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size
+ for i in range(part_num):
+ part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...]
+ assert part_feat.shape[0] > 0, "rank: {}, length of part features should > 0".format(rank)
+ print("rank: {}, gather part: {}/{}, length: {}".format(rank, i, part_num, len(part_feat)))
+ gather_part_feat = gather_tensors(part_feat)
+ all_features.append(gather_part_feat)
+ print("rank: {}, gather done.".format(rank))
+ all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0)
+ return all_features
+
+def reduce_tensors(tensor):
+ reduced_tensor = tensor.clone()
+ dist.all_reduce(reduced_tensor)
+ return reduced_tensor
+
+class DistributedSequentialSampler(Sampler):
+ def __init__(self, dataset, world_size=None, rank=None):
+ if world_size == None:
+ world_size = dist.get_world_size()
+ if rank == None:
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.world_size = world_size
+ self.rank = rank
+ assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size)
+ sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size))
+ self.beg = sub_num * self.rank
+ #self.end = min(self.beg+sub_num, len(self.dataset))
+ self.end = self.beg + sub_num
+ self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset)))
+
+ def __iter__(self):
+ indices = [self.padded_ind[i] for i in range(self.beg, self.end)]
+ return iter(indices)
+
+ def __len__(self):
+ return self.end - self.beg
+
+class GivenIterationSampler(Sampler):
+ def __init__(self, dataset, total_iter, batch_size, last_iter=-1):
+ self.dataset = dataset
+ self.total_iter = total_iter
+ self.batch_size = batch_size
+ self.last_iter = last_iter
+
+ self.total_size = self.total_iter * self.batch_size
+ self.indices = self.gen_new_list()
+ self.call = 0
+
+ def __iter__(self):
+ if self.call == 0:
+ self.call = 1
+ return iter(self.indices[(self.last_iter + 1) * self.batch_size:])
+ else:
+ raise RuntimeError("this sampler is not designed to be called more than once!!")
+
+ def gen_new_list(self):
+
+ # each process shuffle all list with same seed, and pick one piece according to rank
+ np.random.seed(0)
+
+ all_size = self.total_size
+ indices = np.arange(len(self.dataset))
+ indices = indices[:all_size]
+ num_repeat = (all_size-1) // indices.shape[0] + 1
+ indices = np.tile(indices, num_repeat)
+ indices = indices[:all_size]
+
+ np.random.shuffle(indices)
+
+ assert len(indices) == self.total_size
+
+ return indices
+
+ def __len__(self):
+ return self.total_size
+
+
+class DistributedGivenIterationSampler(Sampler):
+ def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1):
+ if world_size is None:
+ world_size = dist.get_world_size()
+ if rank is None:
+ rank = dist.get_rank()
+ assert rank < world_size
+ self.dataset = dataset
+ self.total_iter = total_iter
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.rank = rank
+ self.last_iter = last_iter
+
+ self.total_size = self.total_iter*self.batch_size
+
+ self.indices = self.gen_new_list()
+ self.call = 0
+
+ def __iter__(self):
+ if self.call == 0:
+ self.call = 1
+ return iter(self.indices[(self.last_iter+1)*self.batch_size:])
+ else:
+ raise RuntimeError("this sampler is not designed to be called more than once!!")
+
+ def gen_new_list(self):
+
+ # each process shuffle all list with same seed, and pick one piece according to rank
+ np.random.seed(0)
+
+ all_size = self.total_size * self.world_size
+ indices = np.arange(len(self.dataset))
+ indices = indices[:all_size]
+ num_repeat = (all_size-1) // indices.shape[0] + 1
+ indices = np.tile(indices, num_repeat)
+ indices = indices[:all_size]
+
+ np.random.shuffle(indices)
+ beg = self.total_size * self.rank
+ indices = indices[beg:beg+self.total_size]
+
+ assert len(indices) == self.total_size
+
+ return indices
+
+ def __len__(self):
+ # note here we do not take last iter into consideration, since __len__
+ # should only be used for displaying, the correct remaining size is
+ # handled by dataloader
+ #return self.total_size - (self.last_iter+1)*self.batch_size
+ return self.total_size
+
+
diff --git a/models/cmp/utils/flowlib.py b/models/cmp/utils/flowlib.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ab1a8bf3cbe05b55c50449319a55d4ae8d1ee
--- /dev/null
+++ b/models/cmp/utils/flowlib.py
@@ -0,0 +1,308 @@
+#!/usr/bin/python
+"""
+# ==============================
+# flowlib.py
+# library for optical flow processing
+# Author: Ruoteng Li
+# Date: 6th Aug 2016
+# ==============================
+"""
+#import png
+import numpy as np
+from PIL import Image
+import io
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+"""
+=============
+Flow Section
+=============
+"""
+
+def write_flow(flow, filename):
+ """
+ write optical flow in Middlebury .flo format
+ :param flow: optical flow map
+ :param filename: optical flow file path to be saved
+ :return: None
+ """
+ f = open(filename, 'wb')
+ magic = np.array([202021.25], dtype=np.float32)
+ (height, width) = flow.shape[0:2]
+ w = np.array([width], dtype=np.int32)
+ h = np.array([height], dtype=np.int32)
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ flow.tofile(f)
+ f.close()
+
+
+def save_flow_image(flow, image_file):
+ """
+ save flow visualization into image file
+ :param flow: optical flow data
+ :param flow_fil
+ :return: None
+ """
+ flow_img = flow_to_image(flow)
+ img_out = Image.fromarray(flow_img)
+ img_out.save(image_file)
+
+def segment_flow(flow):
+ h = flow.shape[0]
+ w = flow.shape[1]
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW))
+ idx2 = (abs(u) == SMALLFLOW)
+ class0 = (v == 0) & (u == 0)
+ u[idx2] = 0.00001
+ tan_value = v / u
+
+ class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0)
+ class2 = (tan_value >= 1) & (u >= 0) & (v >= 0)
+ class3 = (tan_value < -1) & (u <= 0) & (v >= 0)
+ class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0)
+ class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0)
+ class7 = (tan_value < -1) & (u >= 0) & (v <= 0)
+ class6 = (tan_value >= 1) & (u <= 0) & (v <= 0)
+ class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0)
+
+ seg = np.zeros((h, w))
+
+ seg[class1] = 1
+ seg[class2] = 2
+ seg[class3] = 3
+ seg[class4] = 4
+ seg[class5] = 5
+ seg[class6] = 6
+ seg[class7] = 7
+ seg[class8] = 8
+ seg[class0] = 0
+ seg[idx] = 0
+
+ return seg
+
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999.
+ maxv = -999.
+ minu = 999.
+ minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ maxu = max(maxu, np.max(u))
+ minu = min(minu, np.min(u))
+
+ maxv = max(maxv, np.max(v))
+ minv = min(minv, np.min(v))
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = max(5, np.max(rad))
+ #maxrad = max(-1, 99)
+
+ u = u/(maxrad + np.finfo(float).eps)
+ v = v/(maxrad + np.finfo(float).eps)
+
+ img = compute_color(u, v)
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+def disp_to_flowfile(disp, filename):
+ """
+ Read KITTI disparity file in png format
+ :param disp: disparity matrix
+ :param filename: the flow file name to save
+ :return: None
+ """
+ f = open(filename, 'wb')
+ magic = np.array([202021.25], dtype=np.float32)
+ (height, width) = disp.shape[0:2]
+ w = np.array([width], dtype=np.int32)
+ h = np.array([height], dtype=np.int32)
+ empty_map = np.zeros((height, width), dtype=np.float32)
+ data = np.dstack((disp, empty_map))
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ data.tofile(f)
+ f.close()
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u**2+v**2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a+1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols+1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel,1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0-1] / 255
+ col1 = tmp[k1-1] / 255
+ col = (1-f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1-rad[idx]*(1-col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
+
+ return img
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
+ colorwheel[col:col+YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col+GC, 1] = 255
+ colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
+ colorwheel[col:col+CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col+BM, 2] = 255
+ colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col+MR, 0] = 255
+
+ return colorwheel
+
+
+def read_flo_file(filename, memcached=False):
+ """
+ Read from Middlebury .flo file
+ :param flow_file: name of the flow file
+ :return: optical flow data in matrix
+ """
+ if memcached:
+ filename = io.BytesIO(filename)
+ f = open(filename, 'rb')
+ magic = np.fromfile(f, np.float32, count=1)[0]
+ data2d = None
+
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ else:
+ w = np.fromfile(f, np.int32, count=1)[0]
+ h = np.fromfile(f, np.int32, count=1)[0]
+ data2d = np.fromfile(f, np.float32, count=2 * w * h)
+ # reshape data into 3D array (columns, rows, channels)
+ data2d = np.resize(data2d, (h, w, 2))
+ f.close()
+ return data2d
+
+
+# fast resample layer
+def resample(img, sz):
+ """
+ img: flow map to be resampled
+ sz: new flow map size. Must be [height,weight]
+ """
+ original_image_size = img.shape
+ in_height = img.shape[0]
+ in_width = img.shape[1]
+ out_height = sz[0]
+ out_width = sz[1]
+ out_flow = np.zeros((out_height, out_width, 2))
+ # find scale
+ height_scale = float(in_height) / float(out_height)
+ width_scale = float(in_width) / float(out_width)
+
+ [x,y] = np.meshgrid(range(out_width), range(out_height))
+ xx = x * width_scale
+ yy = y * height_scale
+ x0 = np.floor(xx).astype(np.int32)
+ x1 = x0 + 1
+ y0 = np.floor(yy).astype(np.int32)
+ y1 = y0 + 1
+
+ x0 = np.clip(x0,0,in_width-1)
+ x1 = np.clip(x1,0,in_width-1)
+ y0 = np.clip(y0,0,in_height-1)
+ y1 = np.clip(y1,0,in_height-1)
+
+ Ia = img[y0,x0,:]
+ Ib = img[y1,x0,:]
+ Ic = img[y0,x1,:]
+ Id = img[y1,x1,:]
+
+ wa = (y1-yy) * (x1-xx)
+ wb = (yy-y0) * (x1-xx)
+ wc = (y1-yy) * (xx-x0)
+ wd = (yy-y0) * (xx-x0)
+ out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width
+ out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height
+
+ return out_flow
diff --git a/models/cmp/utils/scheduler.py b/models/cmp/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f34f6321b0b1f567c656e57ee0e32c6baafe9ce
--- /dev/null
+++ b/models/cmp/utils/scheduler.py
@@ -0,0 +1,102 @@
+import torch
+from bisect import bisect_right
+
+class _LRScheduler(object):
+ def __init__(self, optimizer, last_iter=-1):
+ if not isinstance(optimizer, torch.optim.Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+ if last_iter == -1:
+ for group in optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ else:
+ for i, group in enumerate(optimizer.param_groups):
+ if 'initial_lr' not in group:
+ raise KeyError("param 'initial_lr' is not specified "
+ "in param_groups[{}] when resuming an optimizer".format(i))
+ self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
+ self.last_iter = last_iter
+
+ def _get_new_lr(self):
+ raise NotImplementedError
+
+ def get_lr(self):
+ return list(map(lambda group: group['lr'], self.optimizer.param_groups))
+
+ def step(self, this_iter=None):
+ if this_iter is None:
+ this_iter = self.last_iter + 1
+ self.last_iter = this_iter
+ for param_group, lr in zip(self.optimizer.param_groups, self._get_new_lr()):
+ param_group['lr'] = lr
+
+class _WarmUpLRSchedulerOld(_LRScheduler):
+
+ def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):
+ self.base_lr = base_lr
+ self.warmup_steps = warmup_steps
+ if warmup_steps == 0:
+ self.warmup_lr = base_lr
+ else:
+ self.warmup_lr = warmup_lr
+ super(_WarmUpLRSchedulerOld, self).__init__(optimizer, last_iter)
+
+ def _get_warmup_lr(self):
+ if self.warmup_steps > 0 and self.last_iter < self.warmup_steps:
+ # first compute relative scale for self.base_lr, then multiply to base_lr
+ scale = ((self.last_iter/self.warmup_steps)*(self.warmup_lr - self.base_lr) + self.base_lr)/self.base_lr
+ #print('last_iter: {}, warmup_lr: {}, base_lr: {}, scale: {}'.format(self.last_iter, self.warmup_lr, self.base_lr, scale))
+ return [scale * base_lr for base_lr in self.base_lrs]
+ else:
+ return None
+
+class _WarmUpLRScheduler(_LRScheduler):
+
+ def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):
+ self.base_lr = base_lr
+ self.warmup_lr = warmup_lr
+ self.warmup_steps = warmup_steps
+ assert isinstance(warmup_lr, list)
+ assert isinstance(warmup_steps, list)
+ assert len(warmup_lr) == len(warmup_steps)
+ super(_WarmUpLRScheduler, self).__init__(optimizer, last_iter)
+
+ def _get_warmup_lr(self):
+ pos = bisect_right(self.warmup_steps, self.last_iter)
+ if pos >= len(self.warmup_steps):
+ return None
+ else:
+ if pos == 0:
+ curr_lr = self.base_lr + self.last_iter * (self.warmup_lr[pos] - self.base_lr) / self.warmup_steps[pos]
+ else:
+ curr_lr = self.warmup_lr[pos - 1] + (self.last_iter - self.warmup_steps[pos - 1]) * (self.warmup_lr[pos] - self.warmup_lr[pos - 1]) / (self.warmup_steps[pos] - self.warmup_steps[pos - 1])
+ scale = curr_lr / self.base_lr
+ return [scale * base_lr for base_lr in self.base_lrs]
+
+class StepLRScheduler(_WarmUpLRScheduler):
+ def __init__(self, optimizer, milestones, lr_mults, base_lr, warmup_lr, warmup_steps, last_iter=-1):
+ super(StepLRScheduler, self).__init__(optimizer, base_lr, warmup_lr, warmup_steps, last_iter)
+
+ assert len(milestones) == len(lr_mults), "{} vs {}".format(milestones, lr_mults)
+ for x in milestones:
+ assert isinstance(x, int)
+ if not list(milestones) == sorted(milestones):
+ raise ValueError('Milestones should be a list of'
+ ' increasing integers. Got {}', milestones)
+ self.milestones = milestones
+ self.lr_mults = [1.0]
+ for x in lr_mults:
+ self.lr_mults.append(self.lr_mults[-1]*x)
+
+ def _get_new_lr(self):
+ warmup_lrs = self._get_warmup_lr()
+ if warmup_lrs is not None:
+ return warmup_lrs
+
+ pos = bisect_right(self.milestones, self.last_iter)
+ if len(self.warmup_lr) == 0:
+ scale = self.lr_mults[pos]
+ else:
+ scale = self.warmup_lr[-1] * self.lr_mults[pos] / self.base_lr
+ return [base_lr * scale for base_lr in self.base_lrs]
diff --git a/models/cmp/utils/visualize_utils.py b/models/cmp/utils/visualize_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfb4796a980156e9a9e23f0cf86604ba24dfbc4e
--- /dev/null
+++ b/models/cmp/utils/visualize_utils.py
@@ -0,0 +1,109 @@
+import numpy as np
+
+import torch
+from . import flowlib
+
+class Fuser(object):
+ def __init__(self, nbins, fmax):
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.mesh = torch.arange(nbins).view(1,-1,1,1).float().cuda() * self.step - fmax + self.step / 2
+
+ def convert_flow(self, flow_prob):
+ flow_probx = torch.nn.functional.softmax(flow_prob[:, :self.nbins, :, :], dim=1)
+ flow_proby = torch.nn.functional.softmax(flow_prob[:, self.nbins:, :, :], dim=1)
+ flow_probx = flow_probx * self.mesh
+ flow_proby = flow_proby * self.mesh
+ flow = torch.cat([flow_probx.sum(dim=1, keepdim=True), flow_proby.sum(dim=1, keepdim=True)], dim=1)
+ return flow
+
+def visualize_tensor_old(image, mask, flow_pred, flow_target, warped, rgb_gen, image_target, image_mean, image_div):
+ together = [
+ draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.)),
+ flow_to_image(flow_pred.detach().cpu()),
+ flow_to_image(flow_target.detach().cpu())]
+ if warped is not None:
+ together.append(torch.clamp(unormalize(warped.detach().cpu(), mean=image_mean, div=image_div), 0, 255))
+ if rgb_gen is not None:
+ together.append(torch.clamp(unormalize(rgb_gen.detach().cpu(), mean=image_mean, div=image_div), 0, 255))
+ if image_target is not None:
+ together.append(torch.clamp(unormalize(image_target.cpu(), mean=image_mean, div=image_div), 0, 255))
+ together = torch.cat(together, dim=3)
+ return together
+
+def visualize_tensor(image, mask, flow_tensors, common_tensors, rgb_tensors, image_mean, image_div):
+ together = [
+ draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.))]
+ for ft in flow_tensors:
+ together.append(flow_to_image(ft.cpu()))
+ for ct in common_tensors:
+ together.append(torch.clamp(ct.cpu(), 0, 255))
+ for rt in rgb_tensors:
+ together.append(torch.clamp(unormalize(rt.cpu(), mean=image_mean, div=image_div), 0, 255))
+ together = torch.cat(together, dim=3)
+ return together
+
+
+def unormalize(tensor, mean, div):
+ for c, (m, d) in enumerate(zip(mean, div)):
+ tensor[:,c,:,:].mul_(d).add_(m)
+ return tensor
+
+
+def flow_to_image(flow):
+ flow = flow.numpy()
+ flow_img = np.array([flowlib.flow_to_image(fl.transpose((1,2,0))).transpose((2,0,1)) for fl in flow]).astype(np.float32)
+ return torch.from_numpy(flow_img)
+
+def shift_tensor(input, offh, offw):
+ new = torch.zeros(input.size())
+ h = input.size(2)
+ w = input.size(3)
+ new[:,:,max(0,offh):min(h,h+offh),max(0,offw):min(w,w+offw)] = input[:,:,max(0,-offh):min(h,h-offh),max(0,-offw):min(w,w-offw)]
+ return new
+
+def draw_block(mask, radius=5):
+ '''
+ input: tensor (NxCxHxW)
+ output: block_mask (Nx1xHxW)
+ '''
+ all_mask = []
+ mask = mask[:,0:1,:,:]
+ for offh in range(-radius, radius+1):
+ for offw in range(-radius, radius+1):
+ all_mask.append(shift_tensor(mask, offh, offw))
+ block_mask = sum(all_mask)
+ block_mask[block_mask > 0] = 1
+ return block_mask
+
+def expand_block(sparse, radius=5):
+ '''
+ input: sparse (NxCxHxW)
+ output: block_sparse (NxCxHxW)
+ '''
+ all_sparse = []
+ for offh in range(-radius, radius+1):
+ for offw in range(-radius, radius+1):
+ all_sparse.append(shift_tensor(sparse, offh, offw))
+ block_sparse = sum(all_sparse)
+ return block_sparse
+
+def draw_cross(tensor, mask, radius=5, thickness=2):
+ '''
+ input: tensor (NxCxHxW)
+ mask (NxXxHxW)
+ output: new_tensor (NxCxHxW)
+ '''
+ all_mask = []
+ mask = mask[:,0:1,:,:]
+ for off in range(-radius, radius+1):
+ for t in range(-thickness, thickness+1):
+ all_mask.append(shift_tensor(mask, off, t))
+ all_mask.append(shift_tensor(mask, t, off))
+ cross_mask = sum(all_mask)
+ new_tensor = tensor.clone()
+ new_tensor[:,0:1,:,:][cross_mask > 0] = 255.0
+ new_tensor[:,1:2,:,:][cross_mask > 0] = 0.0
+ new_tensor[:,2:3,:,:][cross_mask > 0] = 0.0
+ return new_tensor
diff --git a/models/controlnet_sdv.py b/models/controlnet_sdv.py
new file mode 100644
index 0000000000000000000000000000000000000000..d45f1597955446b5e8e6e92ac0346a94a56828f4
--- /dev/null
+++ b/models/controlnet_sdv.py
@@ -0,0 +1,782 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalControlnetMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_3d_blocks import (
+ get_down_block, get_up_block,UNetMidBlockSpatioTemporal,
+)
+from diffusers.models import UNetSpatioTemporalConditionModel
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbeddingSVD(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames
+ #combine batch and frames dimensions
+ batch_size, frames, channels, height, width = conditioning.size()
+ conditioning = conditioning.view(batch_size * frames, channels, height, width)
+
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ #split them apart again
+ #actually not needed
+ #new_channels, new_height, new_width = embedding.shape[1], embedding.shape[2], embedding.shape[3]
+ #embedding = embedding.view(batch_size, frames, new_channels, new_height, new_width)
+
+
+ return embedding
+
+
+class ControlNetSDVModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ conditioning_channels: int = 3,
+ conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+
+ print("layers per block is", layers_per_block)
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+ self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ # down
+ output_channel = block_out_channels[0]
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block[i]):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+
+
+
+ # out
+ #self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ #self.conv_act = nn.SiLU()
+
+ #self.conv_out = nn.Conv2d(
+ # block_out_channels[0],
+ # out_channels,
+ # kernel_size=3,
+ # padding=1,
+ #)
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ controlnet_cond: torch.FloatTensor = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ guess_mode: bool = False,
+ conditioning_scale: float = 1.0,
+
+
+ ) -> Union[ControlNetOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
+ tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ # print(t_emb.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ #controlnet cond
+ if controlnet_cond != None:
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ sample = sample + controlnet_cond
+
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNetSpatioTemporalConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+ print(unet.config)
+ controlnet = cls(
+ in_channels=unet.config.in_channels,
+ down_block_types=unet.config.down_block_types,
+ block_out_channels=unet.config.block_out_channels,
+ addition_time_embed_dim=unet.config.addition_time_embed_dim,
+ transformer_layers_per_block=unet.config.transformer_layers_per_block,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ num_frames=unet.config.num_frames,
+ sample_size=unet.config.sample_size, # Added based on the dict
+ layers_per_block=unet.config.layers_per_block,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ conditioning_channels = conditioning_channels,
+ conditioning_embedding_out_channels = conditioning_embedding_out_channels,
+ )
+ #controlnet rgb channel order ignored, set to not makea difference by default
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ # if controlnet.class_embedding:
+ # controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ # def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ # if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ # module.gradient_checkpointing = value
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/models/ldmk_ctrlnet.py b/models/ldmk_ctrlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..50239721612b1db928e9f59a45c82364cc40a967
--- /dev/null
+++ b/models/ldmk_ctrlnet.py
@@ -0,0 +1,575 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.utils import BaseOutput
+
+from models.controlnet_sdv import ControlNetSDVModel, zero_module
+from models.softsplat import softsplat
+import models.cmp.models as cmp_models
+import models.cmp.utils as cmp_utils
+from models.occlusion.hourglass import ForegroundMatting
+
+import yaml
+import os
+import torchvision.transforms as transforms
+
+
+class ArgObj(object):
+ def __init__(self):
+ pass
+
+
+class CMP_demo(nn.Module):
+ def __init__(self, configfn, load_iter):
+ super().__init__()
+ args = ArgObj()
+ with open(configfn) as f:
+ config = yaml.full_load(f)
+ for k, v in config.items():
+ setattr(args, k, v)
+ setattr(args, 'load_iter', load_iter)
+ setattr(args, 'exp_path', os.path.dirname(configfn))
+
+ self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False)
+ self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False)
+ self.model.switch_to('eval')
+
+ self.data_mean = args.data['data_mean']
+ self.data_div = args.data['data_div']
+
+ self.img_transform = transforms.Compose([
+ transforms.Normalize(self.data_mean, self.data_div)])
+
+ self.args = args
+ self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax'])
+ torch.cuda.synchronize()
+
+ def run(self, image, sparse, mask):
+ image = image * 2 - 1
+ cmp_output = self.model.model(image, torch.cat([sparse, mask], dim=1))
+ flow = self.fuser.convert_flow(cmp_output)
+ if flow.shape[2] != image.shape[2]:
+ flow = nn.functional.interpolate(
+ flow, size=image.shape[2:4],
+ mode="bilinear", align_corners=True)
+
+ return flow # [b, 2, h, w]
+
+ # tensor_dict = self.model.eval(ret_loss=False)
+ # flow = tensor_dict['flow_tensors'][0].cpu().numpy().squeeze().transpose(1,2,0)
+
+ # return flow
+
+
+
+class FlowControlNetConditioningEmbeddingSVD(nn.Module):
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+
+
+class FlowControlNetFirstFrameEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ c_in,
+ c_out,
+ is_downsample=False
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1)
+
+ def forward(self, feature):
+ '''
+ feature: [b, c, h, w]
+ '''
+
+ embedding = self.conv_in(feature)
+ embedding = F.silu(embedding)
+
+ return embedding
+
+
+
+class FlowControlNetFirstFrameEncoder(nn.Module):
+ def __init__(
+ self,
+ c_in=320,
+ channels=[320, 640, 1280],
+ downsamples=[True, True, True],
+ use_zeroconv=True
+ ):
+ super().__init__()
+
+ self.encoders = nn.ModuleList([])
+ # self.zeroconvs = nn.ModuleList([])
+
+ for channel, downsample in zip(channels, downsamples):
+ self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample))
+ # self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity())
+ c_in = channel
+
+ def forward(self, first_frame):
+ feature = first_frame
+ deep_features = []
+ # for encoder, zeroconv in zip(self.encoders, self.zeroconvs):
+ for encoder in self.encoders:
+ feature = encoder(feature)
+ # print(feature.shape)
+ # deep_features.append(zeroconv(feature))
+ deep_features.append(feature)
+ return deep_features
+
+
+
+@dataclass
+class FlowControlNetOutput(BaseOutput):
+ """
+ The output of [`FlowControlNetOutput`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+ controlnet_flow: torch.Tensor
+ occlusion_masks: torch.Tensor
+
+
+class FlowControlNet(ControlNetSDVModel):
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ conditioning_channels: int = 3,
+ conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.flow_encoder = FlowControlNetFirstFrameEncoder()
+
+ # time_embed_dim = block_out_channels[0] * 4
+ # blocks_time_embed_dim = time_embed_dim
+ self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ self.controlnet_ldmk_embedding = FlowControlNetConditioningEmbeddingSVD(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=(16, 32, 64, 128),
+ conditioning_channels=conditioning_channels,
+ )
+
+ self.zero_outs = nn.ModuleDict(
+ {
+ '8': zero_module(nn.Conv2d(320, 320, kernel_size=1)),
+ '16': zero_module(nn.Conv2d(320, 320, kernel_size=1)),
+ '32': zero_module(nn.Conv2d(640, 640, kernel_size=1)),
+ '64': zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
+ }
+ )
+
+ self.occlusions = nn.ModuleDict(
+ {
+ '8': ForegroundMatting(320),
+ '16': ForegroundMatting(320),
+ '32': ForegroundMatting(640),
+ '64': ForegroundMatting(1280),
+ }
+ )
+
+ # self.occlusions = nn.ModuleDict(
+ # {'8': nn.Sequential(
+ # nn.Conv2d(320+320, 128, 7, 1, 3),
+ # nn.SiLU(),
+ # nn.Conv2d(128, 64, 5, 1, 2),
+ # nn.SiLU(),
+ # nn.Conv2d(64, 1, 3, 1, 1),
+ # nn.Sigmoid()
+ # ),
+ # '16': nn.Sequential(
+ # nn.Conv2d(320+320, 128, 5, 1, 2),
+ # nn.SiLU(),
+ # nn.Conv2d(128, 64, 5, 1, 2),
+ # nn.SiLU(),
+ # nn.Conv2d(64, 1, 3, 1, 1),
+ # nn.Sigmoid()
+ # ),
+ # '32': nn.Sequential(
+ # nn.Conv2d(640+640, 128, 5, 1, 2),
+ # nn.SiLU(),
+ # nn.Conv2d(128, 64, 3, 1, 1),
+ # nn.SiLU(),
+ # nn.Conv2d(64, 1, 3, 1, 1),
+ # nn.Sigmoid()
+ # ),
+ # '64': nn.Sequential(
+ # nn.Conv2d(1280+1280, 128, 3, 1, 1),
+ # nn.SiLU(),
+ # nn.Conv2d(128, 64, 3, 1, 1),
+ # nn.SiLU(),
+ # nn.Conv2d(64, 1, 3, 1, 1),
+ # nn.Sigmoid()
+ # )}
+ # )
+
+ def get_warped_frames(self, first_frame, flows, scale):
+ '''
+ video_frame: [b, c, w, h]
+ flows: [b, t-1, c, w, h]
+ '''
+ dtype = first_frame.dtype
+ warped_frames = []
+ occlusion_masks = []
+ for i in range(flows.shape[1]):
+ warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h]
+
+ # print(first_frame.shape)
+ # print(warped_frame.shape)
+
+ # occlusion_mask = self.occlusions[str(scale)](torch.cat([first_frame, warped_frame], dim=1)) # [b, 1, w, h]
+ # warped_frame = warped_frame * occlusion_mask
+
+ warped_frame, occlusion_mask = self.occlusions[str(scale)](
+ first_frame, flows[:, i], warped_frame
+ )
+
+ # occlusion_mask = torch.ones_like(warped_frame[:, 0:1, :, :])
+
+ warped_frame = self.zero_outs[str(scale)](warped_frame)
+
+ warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h]
+ occlusion_masks.append(occlusion_mask.unsqueeze(1)) # [b, 1, 1, w, h]
+ warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h]
+ occlusion_masks = torch.cat(occlusion_masks, dim=1) # [b, t-1, 1, w, h]
+ return warped_frames, occlusion_masks
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w]
+ controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w]
+ landmarks: torch.FloatTensor = None, # [b, 14, 2, h, w]
+ # controlnet_mask: torch.FloatTensor = None, # [b, 13, 2, h, w]
+ # pixel_values_384: torch.FloatTensor = None,
+ # sparse_optical_flow_384: torch.FloatTensor = None,
+ # mask_384: torch.FloatTensor = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ guess_mode: bool = False,
+ conditioning_scale: float = 1.0,
+ ) -> Union[FlowControlNetOutput, Tuple]:
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample) # [b*l, 320, h//8, w//8]
+
+ # controlnet cond
+ if controlnet_cond != None:
+ # embed 成 64*64,和latent一个shape
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8]
+ # sample = sample + controlnet_cond
+
+ # ldb, ldl, ldc, ldh, ldw = landmarks.shape
+
+ landmarks = landmarks.flatten(0, 1)
+
+ # print(landmarks.shape)
+ # print(sample.shape)
+
+ if landmarks != None:
+ # embed 成 64*64,和latent一个shape
+ landmarks = self.controlnet_ldmk_embedding(landmarks) # [b, 320, h//8, w//8]
+
+ scale_landmarks = {landmarks.shape[-2]: landmarks}
+ for scale in [2, 4]:
+ scaled_ldmk = F.interpolate(landmarks, scale_factor=1/scale)
+ # print(scaled_ldmk.shape)
+ scale_landmarks[scaled_ldmk.shape[-2]] = scaled_ldmk
+
+
+ # assert False
+ controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4]
+
+ # print(controlnet_cond.shape)
+
+ '''
+ torch.Size([2, 320, 32, 32])
+ torch.Size([2, 320, 16, 16])
+ torch.Size([2, 640, 8, 8])
+ torch.Size([2, 1280, 4, 4])
+ '''
+
+ # for x in controlnet_cond_features:
+ # print(x.shape)
+
+ # assert False
+
+ scales = [8, 16, 32, 64]
+ scale_flows = {}
+ fb, fl, fc, fh, fw = controlnet_flow.shape
+ # print(controlnet_flow.shape)
+ for scale in scales:
+ scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale)
+ scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale
+ scale_flows[scale] = scaled_flow
+
+ # for k in scale_flows.keys():
+ # print(scale_flows[k].shape)
+
+ # assert False
+
+ warped_cond_features = []
+ occlusion_masks = []
+ for cond_feature in controlnet_cond_features:
+ cb, cc, ch, cw = cond_feature.shape
+ # print(cond_feature.shape)
+ warped_cond_feature, occlusion_mask = self.get_warped_frames(cond_feature, scale_flows[fh // ch], fh // ch)
+ warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w]
+ wb, wl, wc, wh, ww = warped_cond_feature.shape
+ # print(warped_cond_feature.shape)
+ warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww))
+ occlusion_masks.append(occlusion_mask)
+
+ # for x in warped_cond_features:
+ # print(x.shape)
+ # assert False
+
+ '''
+ torch.Size([28, 320, 32, 32])
+ torch.Size([28, 320, 16, 16])
+ torch.Size([28, 640, 8, 8])
+ torch.Size([28, 1280, 4, 4])
+ '''
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+
+ count = 0
+ length = len(warped_cond_features)
+
+ # print(sample.shape)
+ # print(warped_cond_features[count].shape)
+
+ # add the warped feature in the first scale
+ sample = sample + warped_cond_features[count] + scale_landmarks[sample.shape[-2]]
+ count += 1
+
+ down_block_res_samples = (sample,)
+
+ # print(sample.shape)
+
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # print(sample.shape)
+ # print(warped_cond_features[min(count, length - 1)].shape)
+ # print(sample.shape[-2])
+ # print(scale_landmarks[sample.shape[-2]].shape)
+
+ if sample.shape[1] == 320:
+ sample = sample + warped_cond_features[min(count, length - 1)] + scale_landmarks[sample.shape[-2]]
+ else:
+ sample = sample + warped_cond_features[min(count, length - 1)]
+
+ count += 1
+
+ down_block_res_samples += res_samples
+
+ # print(len(res_samples))
+ # for i in range(len(res_samples)):
+ # print(res_samples[i].shape)
+
+ # [28, 320, 32, 32]
+ # [28, 320, 32, 32]
+ # [28, 320, 16, 16]
+
+ # [28, 640, 16, 16]
+ # [28, 640, 16, 16]
+ # [28, 640, 8, 8]
+
+ # [28, 1280, 8, 8]
+ # [28, 1280, 8, 8]
+ # [28, 1280, 4, 4]
+
+ # [28, 1280, 4, 4]
+ # [28, 1280, 4, 4]
+
+ # print(sample.shape)
+ # print(warped_cond_features[-1].shape)
+
+ # add the warped feature in the last scale
+ sample = sample + warped_cond_features[-1]
+
+ # sample = sample + warped_cond_features[-1] + scale_landmarks[sample.shape[-2]]
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ ) # [b*l, 1280, h // 64, w // 64]
+
+ # print(sample.shape)
+
+ # assert False
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ # for sample in down_block_res_samples:
+ # print(torch.max(sample), torch.min(sample))
+ # print(torch.max(mid_block_res_sample), torch.min(mid_block_res_sample))
+ # assert False
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample, controlnet_flow, occlusion_masks)
+
+ return FlowControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, occlusion_masks=occlusion_masks
+ )
+
diff --git a/models/occlusion/hourglass.py b/models/occlusion/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0b57090ae30b037d66373fb2394b7fd2791ce99
--- /dev/null
+++ b/models/occlusion/hourglass.py
@@ -0,0 +1,298 @@
+from torch import nn
+from torch import nn
+import torch.nn.functional as F
+import torch
+# from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+
+# class ResBlock2d(nn.Module):
+# def __init__(self, in_features, kernel_size, padding):
+# super(ResBlock2d, self).__init__()
+# self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+# padding=padding)
+# self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+# padding=padding)
+# self.norm1 = BatchNorm2d(in_features)
+# self.norm2 = BatchNorm2d(in_features)
+# self.relu = nn.ReLU()
+# def forward(self, x):
+# out = self.norm1(x)
+# out = self.relu(out)
+# out = self.conv1(out)
+# out = self.norm2(out)
+# out = self.relu(out)
+# out = self.conv2(out)
+# out += x
+# return out
+
+class UpBlock2d(nn.Module):
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ # self.norm = BatchNorm2d(out_features)
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = x
+ # out = F.interpolate(x, scale_factor=2)
+ out = self.conv(out)
+ # out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+class DownBlock2d(nn.Module):
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ # self.norm = BatchNorm2d(out_features)
+ # self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = self.conv(x)
+ # out = self.norm(out)
+ out = self.relu(out)
+ # out = self.pool(out)
+ return out
+
+class SameBlock2d(nn.Module):
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+ kernel_size=kernel_size, padding=padding, groups=groups)
+ # self.norm = BatchNorm2d(out_features)
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = self.conv(x)
+ # out = self.norm(out)
+ out = self.relu(out)
+ return out
+
+class HourglassEncoder(nn.Module):
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(HourglassEncoder, self).__init__()
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ outs = outs[1:]
+ return outs
+
+class HourglassDecoder(nn.Module):
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(HourglassDecoder, self).__init__()
+ up_blocks = []
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
+ self.up_blocks = nn.ModuleList(up_blocks)
+ self.out_filters = block_expansion
+ def forward(self, x):
+ new_out = None
+ for up_block in self.up_blocks:
+ out = x.pop()
+ if new_out is not None:
+ out = torch.cat([out, new_out], dim=1)
+ new_out = up_block(out)
+
+ return new_out
+
+class Hourglass(nn.Module):
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = HourglassEncoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = HourglassDecoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_filters = self.decoder.out_filters
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
+
+# class AntiAliasInterpolation2d(nn.Module):
+# """
+# Band-limited downsampling, for better preservation of the input signal.
+# """
+# def __init__(self, channels, scale):
+# super(AntiAliasInterpolation2d, self).__init__()
+# sigma = (1 / scale - 1) / 2
+# kernel_size = 2 * round(sigma * 4) + 1
+# self.ka = kernel_size // 2
+# self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+# kernel_size = [kernel_size, kernel_size]
+# sigma = [sigma, sigma]
+# # The gaussian kernel is the product of the
+# # gaussian function of each dimension.
+# kernel = 1
+# meshgrids = torch.meshgrid(
+# [
+# torch.arange(size, dtype=torch.float32)
+# for size in kernel_size
+# ]
+# )
+# for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+# mean = (size - 1) / 2
+# kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+# # Make sure sum of values in gaussian kernel equals 1.
+# kernel = kernel / torch.sum(kernel)
+# # Reshape to depthwise convolutional weight
+# kernel = kernel.view(1, 1, *kernel.size())
+# kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+# self.register_buffer('weight', kernel)
+# self.groups = channels
+# self.scale = scale
+
+# def forward(self, input):
+# if self.scale == 1.0:
+# return input
+
+# out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+# out = F.conv2d(out, weight=self.weight, groups=self.groups)
+# out = F.interpolate(out, scale_factor=(self.scale, self.scale))
+
+# return out
+
+# class Encoder(nn.Module):
+# def __init__(self, num_channels, num_down_blocks=3, block_expansion=64, max_features=512,
+# ):
+# super(Encoder, self).__init__()
+# self.in_conv = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+# down_blocks = []
+# for i in range(num_down_blocks):
+# in_features = min(max_features, block_expansion * (2 ** i))
+# out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+# down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+# self.down_blocks = nn.Sequential(*down_blocks)
+# def forward(self, image):
+# out = self.in_conv(image)
+# out = self.down_blocks(out)
+# return out
+
+# class Bottleneck(nn.Module):
+# def __init__(self, num_bottleneck_blocks,num_down_blocks=3, block_expansion=64, max_features=512):
+# super(Bottleneck, self).__init__()
+# bottleneck = []
+# in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
+# for i in range(num_bottleneck_blocks):
+# bottleneck.append(ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
+# self.bottleneck = nn.Sequential(*bottleneck)
+# def forward(self, feature_map):
+# out = self.bottleneck(feature_map)
+# return out
+
+class Decoder(nn.Module):
+ def __init__(self,num_channels, num_down_blocks=3, block_expansion=64, max_features=512):
+ super(Decoder, self).__init__()
+ up_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
+ out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.up_blocks = nn.Sequential(*up_blocks)
+ self.out_conv = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
+ self.sigmoid = nn.Sigmoid()
+ def forward(self, feature_map):
+ out = self.up_blocks(feature_map)
+ out = self.out_conv(out)
+ out = self.sigmoid(out)
+ return out
+
+# def warp_image(image, motion_flow):
+# _, h_old, w_old, _ = motion_flow.shape
+# _, _, h, w = image.shape
+# if h_old != h or w_old != w:
+# motion_flow = motion_flow.permute(0, 3, 1, 2)
+# motion_flow = F.interpolate(motion_flow, size=(h, w), mode='bilinear')
+# motion_flow = motion_flow.permute(0, 2, 3, 1)
+# return F.grid_sample(image, motion_flow)
+
+# def make_coordinate_grid(spatial_size, type):
+# h, w = spatial_size
+# x = torch.arange(w).type(type)
+# y = torch.arange(h).type(type)
+# x = (2 * (x / (w - 1)) - 1)
+# y = (2 * (y / (h - 1)) - 1)
+# yy = y.view(-1, 1).repeat(1, w)
+# xx = x.view(1, -1).repeat(h, 1)
+# meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+# return meshed
+
+class ForegroundMatting(nn.Module):
+ def __init__(self, num_channels, num_blocks=3, block_expansion=64, max_features=512):
+ super(ForegroundMatting, self).__init__()
+ # self.down_sample_image = AntiAliasInterpolation2d(num_channels, scale_factor)
+ # self.down_sample_flow = AntiAliasInterpolation2d(2, scale_factor)
+ self.hourglass = Hourglass(
+ block_expansion=block_expansion,
+ in_features=num_channels * 2 + 2,
+ max_features=max_features,
+ num_blocks=num_blocks
+ )
+
+ # self.foreground_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
+
+ self.matting_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
+ self.matting = nn.Conv2d(self.hourglass.out_filters, num_channels, kernel_size=(7, 7), padding=(3, 3))
+
+ # self.scale_factor = scale_factor
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, reference_image, dense_flow, warped_image):
+ '''
+ source_image : b, c, h, w
+ dense_tensor: b, 2, h, w
+ warped_image: b, c, h, w
+ '''
+
+ # res_out = {}
+ # batch, _, h, w = reference_image.shape
+
+ # warped_image = warp_image(reference_image, dense_flow)#warp the image with dense flow
+ # res_out['warped_image'] = warped_image
+
+ hourglass_input = torch.cat([reference_image, dense_flow, warped_image], dim=1)
+ hourglass_out = self.hourglass(hourglass_input)
+
+ # foreground_mask = self.foreground_mask(hourglass_out) # compute foreground mask
+ # foreground_mask = self.sigmoid(foreground_mask).permute(0,2,3,1)
+ # res_out['foreground_mask'] = foreground_mask
+ # grid_flow = make_coordinate_grid((h, w), dense_flow.type())
+ # dense_flow_foreground = dense_flow * foreground_mask + (1-foreground_mask) * grid_flow.unsqueeze(0) ## revise the dense flow
+ # res_out['dense_flow_foreground'] = dense_flow_foreground
+ # res_out['dense_flow_foreground_vis'] = dense_flow * foreground_mask
+
+ matting_mask = self.matting_mask(hourglass_out) # compute matting mask
+ matting_mask = self.sigmoid(matting_mask)
+ # res_out['matting_mask'] = matting_mask
+
+ matting_image = self.matting(hourglass_out) # computing matting image
+ # res_out['matting_image'] = matting_image
+
+ out = warped_image * matting_mask + matting_image * (1 - matting_mask)
+
+ return out, matting_mask
+
+
+
+if __name__ == '__main__':
+
+ device = 'cuda'
+ b, c, h, w = 2, 1280, 40, 40
+
+ m = ForegroundMatting(c).to(device)
+
+ print(m)
+
+
+ reference_image = torch.randn(b, c, h, w).to(device)
+ dense_flow = torch.randn(b, 2, h, w).to(device)
+ warped_image = torch.randn(b, c, h, w).to(device)
+
+ o = m(reference_image, dense_flow, warped_image)
\ No newline at end of file
diff --git a/models/softsplat.py b/models/softsplat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f35ccc21604479940c2c86580c287e73f3dc327d
--- /dev/null
+++ b/models/softsplat.py
@@ -0,0 +1,529 @@
+#!/usr/bin/env python
+
+import collections
+import cupy
+import os
+import re
+import torch
+import typing
+
+
+##########################################################
+
+
+objCudacache = {}
+
+
+def cuda_int32(intIn:int):
+ return cupy.int32(intIn)
+# end
+
+
+def cuda_float32(fltIn:float):
+ return cupy.float32(fltIn)
+# end
+
+
+def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
+ if 'device' not in objCudacache:
+ objCudacache['device'] = torch.cuda.get_device_name()
+ # end
+
+ strKey = strFunction
+
+ for strVariable in objVariables:
+ objValue = objVariables[strVariable]
+
+ strKey += strVariable
+
+ if objValue is None:
+ continue
+
+ elif type(objValue) == int:
+ strKey += str(objValue)
+
+ elif type(objValue) == float:
+ strKey += str(objValue)
+
+ elif type(objValue) == bool:
+ strKey += str(objValue)
+
+ elif type(objValue) == str:
+ strKey += objValue
+
+ elif type(objValue) == torch.Tensor:
+ strKey += str(objValue.dtype)
+ strKey += str(objValue.shape)
+ strKey += str(objValue.stride())
+
+ elif True:
+ print(strVariable, type(objValue))
+ assert(False)
+
+ # end
+ # end
+
+ strKey += objCudacache['device']
+
+ if strKey not in objCudacache:
+ for strVariable in objVariables:
+ objValue = objVariables[strVariable]
+
+ if objValue is None:
+ continue
+
+ elif type(objValue) == int:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
+
+ elif type(objValue) == float:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
+
+ elif type(objValue) == bool:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
+
+ elif type(objValue) == str:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
+ strKernel = strKernel.replace('{{type}}', 'unsigned char')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
+ strKernel = strKernel.replace('{{type}}', 'half')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
+ strKernel = strKernel.replace('{{type}}', 'float')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
+ strKernel = strKernel.replace('{{type}}', 'double')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
+ strKernel = strKernel.replace('{{type}}', 'int')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
+ strKernel = strKernel.replace('{{type}}', 'long')
+
+ elif type(objValue) == torch.Tensor:
+ print(strVariable, objValue.dtype)
+ assert(False)
+
+ elif True:
+ print(strVariable, type(objValue))
+ assert(False)
+
+ # end
+ # end
+
+ while True:
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
+
+ if objMatch is None:
+ break
+ # end
+
+ intArg = int(objMatch.group(2))
+
+ strTensor = objMatch.group(4)
+ intSizes = objVariables[strTensor].size()
+
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
+ # end
+
+ while True:
+ objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
+
+ if objMatch is None:
+ break
+ # end
+
+ intStart = objMatch.span()[1]
+ intStop = objMatch.span()[1]
+ intParentheses = 1
+
+ while True:
+ intParentheses += 1 if strKernel[intStop] == '(' else 0
+ intParentheses -= 1 if strKernel[intStop] == ')' else 0
+
+ if intParentheses == 0:
+ break
+ # end
+
+ intStop += 1
+ # end
+
+ intArgs = int(objMatch.group(2))
+ strArgs = strKernel[intStart:intStop].split(',')
+
+ assert(intArgs == len(strArgs) - 1)
+
+ strTensor = strArgs[0]
+ intStrides = objVariables[strTensor].stride()
+
+ strIndex = []
+
+ for intArg in range(intArgs):
+ strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
+ # end
+
+ strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
+ # end
+
+ while True:
+ objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
+
+ if objMatch is None:
+ break
+ # end
+
+ intStart = objMatch.span()[1]
+ intStop = objMatch.span()[1]
+ intParentheses = 1
+
+ while True:
+ intParentheses += 1 if strKernel[intStop] == '(' else 0
+ intParentheses -= 1 if strKernel[intStop] == ')' else 0
+
+ if intParentheses == 0:
+ break
+ # end
+
+ intStop += 1
+ # end
+
+ intArgs = int(objMatch.group(2))
+ strArgs = strKernel[intStart:intStop].split(',')
+
+ assert(intArgs == len(strArgs) - 1)
+
+ strTensor = strArgs[0]
+ intStrides = objVariables[strTensor].stride()
+
+ strIndex = []
+
+ for intArg in range(intArgs):
+ strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
+ # end
+
+ strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
+ # end
+
+ objCudacache[strKey] = {
+ 'strFunction': strFunction,
+ 'strKernel': strKernel
+ }
+ # end
+
+ return strKey
+# end
+
+
+@cupy.memoize(for_each_device=True)
+def cuda_launch(strKey:str):
+ if 'CUDA_HOME' not in os.environ:
+ os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
+ # end
+
+ return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction'])
+# end
+
+
+##########################################################
+
+
+def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
+ assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
+
+ if strMode == 'sum': assert(tenMetric is None)
+ if strMode == 'avg': assert(tenMetric is None)
+ if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
+ if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
+
+ if strMode == 'avg':
+ tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
+
+ elif strMode.split('-')[0] == 'linear':
+ tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
+
+ elif strMode.split('-')[0] == 'soft':
+ tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
+
+ # end
+
+ tenOut = softsplat_func.apply(tenIn, tenFlow)
+
+ if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
+ tenNormalize = tenOut[:, -1:, :, :]
+
+ if len(strMode.split('-')) == 1:
+ tenNormalize = tenNormalize + 0.0000001
+
+ elif strMode.split('-')[1] == 'addeps':
+ tenNormalize = tenNormalize + 0.0000001
+
+ elif strMode.split('-')[1] == 'zeroeps':
+ tenNormalize[tenNormalize == 0.0] = 1.0
+
+ elif strMode.split('-')[1] == 'clipeps':
+ tenNormalize = tenNormalize.clip(0.0000001, None)
+
+ # end
+
+ tenOut = tenOut[:, :-1, :, :] / tenNormalize
+ # end
+
+ return tenOut
+# end
+
+
+class softsplat_func(torch.autograd.Function):
+ @staticmethod
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, tenIn, tenFlow):
+ tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
+
+ if tenIn.is_cuda == True:
+ cuda_launch(cuda_kernel('softsplat_out', '''
+ extern "C" __global__ void __launch_bounds__(512) softsplat_out(
+ const int n,
+ const {{type}}* __restrict__ tenIn,
+ const {{type}}* __restrict__ tenFlow,
+ {{type}}* __restrict__ tenOut
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
+ const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
+ const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
+ const int intX = ( intIndex ) % SIZE_3(tenOut);
+
+ assert(SIZE_1(tenFlow) == 2);
+
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
+
+ if (isfinite(fltX) == false) { return; }
+ if (isfinite(fltY) == false) { return; }
+
+ {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
+
+ int intNorthwestX = (int) (floor(fltX));
+ int intNorthwestY = (int) (floor(fltY));
+ int intNortheastX = intNorthwestX + 1;
+ int intNortheastY = intNorthwestY;
+ int intSouthwestX = intNorthwestX;
+ int intSouthwestY = intNorthwestY + 1;
+ int intSoutheastX = intNorthwestX + 1;
+ int intSoutheastY = intNorthwestY + 1;
+
+ {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
+ {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
+ {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
+ {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
+
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
+ }
+
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
+ }
+
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
+ }
+
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
+ }
+ } }
+ ''', {
+ 'tenIn': tenIn,
+ 'tenFlow': tenFlow,
+ 'tenOut': tenOut
+ }))(
+ grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
+ )
+
+ elif tenIn.is_cuda != True:
+ assert(False)
+
+ # end
+
+ self.save_for_backward(tenIn, tenFlow)
+
+ return tenOut
+ # end
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(self, tenOutgrad):
+ tenIn, tenFlow = self.saved_tensors
+
+ tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
+
+ tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
+ tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
+
+ if tenIngrad is not None:
+ cuda_launch(cuda_kernel('softsplat_ingrad', '''
+ extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
+ const int n,
+ const {{type}}* __restrict__ tenIn,
+ const {{type}}* __restrict__ tenFlow,
+ const {{type}}* __restrict__ tenOutgrad,
+ {{type}}* __restrict__ tenIngrad,
+ {{type}}* __restrict__ tenFlowgrad
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
+ const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
+ const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
+ const int intX = ( intIndex ) % SIZE_3(tenIngrad);
+
+ assert(SIZE_1(tenFlow) == 2);
+
+ {{type}} fltIngrad = 0.0f;
+
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
+
+ if (isfinite(fltX) == false) { return; }
+ if (isfinite(fltY) == false) { return; }
+
+ int intNorthwestX = (int) (floor(fltX));
+ int intNorthwestY = (int) (floor(fltY));
+ int intNortheastX = intNorthwestX + 1;
+ int intNortheastY = intNorthwestY;
+ int intSouthwestX = intNorthwestX;
+ int intSouthwestY = intNorthwestY + 1;
+ int intSoutheastX = intNorthwestX + 1;
+ int intSoutheastY = intNorthwestY + 1;
+
+ {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
+ {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
+ {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
+ {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
+
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
+ }
+
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
+ }
+
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
+ }
+
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
+ }
+
+ tenIngrad[intIndex] = fltIngrad;
+ } }
+ ''', {
+ 'tenIn': tenIn,
+ 'tenFlow': tenFlow,
+ 'tenOutgrad': tenOutgrad,
+ 'tenIngrad': tenIngrad,
+ 'tenFlowgrad': tenFlowgrad
+ }))(
+ grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
+ )
+ # end
+
+ if tenFlowgrad is not None:
+ cuda_launch(cuda_kernel('softsplat_flowgrad', '''
+ extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
+ const int n,
+ const {{type}}* __restrict__ tenIn,
+ const {{type}}* __restrict__ tenFlow,
+ const {{type}}* __restrict__ tenOutgrad,
+ {{type}}* __restrict__ tenIngrad,
+ {{type}}* __restrict__ tenFlowgrad
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
+ const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
+ const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
+ const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
+
+ assert(SIZE_1(tenFlow) == 2);
+
+ {{type}} fltFlowgrad = 0.0f;
+
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
+
+ if (isfinite(fltX) == false) { return; }
+ if (isfinite(fltY) == false) { return; }
+
+ int intNorthwestX = (int) (floor(fltX));
+ int intNorthwestY = (int) (floor(fltY));
+ int intNortheastX = intNorthwestX + 1;
+ int intNortheastY = intNorthwestY;
+ int intSouthwestX = intNorthwestX;
+ int intSouthwestY = intNorthwestY + 1;
+ int intSoutheastX = intNorthwestX + 1;
+ int intSoutheastY = intNorthwestY + 1;
+
+ {{type}} fltNorthwest = 0.0f;
+ {{type}} fltNortheast = 0.0f;
+ {{type}} fltSouthwest = 0.0f;
+ {{type}} fltSoutheast = 0.0f;
+
+ if (intC == 0) {
+ fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
+ fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
+ fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
+ fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
+
+ } else if (intC == 1) {
+ fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
+ fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
+ fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
+ fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
+
+ }
+
+ for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
+ {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
+
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
+ }
+
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
+ }
+
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
+ }
+
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
+ }
+ }
+
+ tenFlowgrad[intIndex] = fltFlowgrad;
+ } }
+ ''', {
+ 'tenIn': tenIn,
+ 'tenFlow': tenFlow,
+ 'tenOutgrad': tenOutgrad,
+ 'tenIngrad': tenIngrad,
+ 'tenFlowgrad': tenFlowgrad
+ }))(
+ grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
+ )
+ # end
+
+ return tenIngrad, tenFlowgrad
+ # end
+# end
diff --git a/models/traj_ctrlnet.py b/models/traj_ctrlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9e8305c6a2cf8b8bb85d9251ecd549b947175d1
--- /dev/null
+++ b/models/traj_ctrlnet.py
@@ -0,0 +1,515 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.utils import BaseOutput
+
+from models.controlnet_sdv import ControlNetSDVModel, zero_module
+# from unimatch.unimatch.geometry import flow_warp
+from models.softsplat import softsplat
+# from models.hourglass.dense_motion import DenseMotionNetwork
+import models.cmp.models as cmp_models
+import models.cmp.utils as cmp_utils
+
+import yaml
+import os
+import torchvision.transforms as transforms
+
+
+class ArgObj(object):
+ def __init__(self):
+ pass
+
+
+class CMP_demo(nn.Module):
+ def __init__(self, configfn, load_iter):
+ super().__init__()
+ args = ArgObj()
+ with open(configfn) as f:
+ config = yaml.full_load(f)
+ for k, v in config.items():
+ setattr(args, k, v)
+ setattr(args, 'load_iter', load_iter)
+ setattr(args, 'exp_path', os.path.dirname(configfn))
+
+ self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False)
+ self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False)
+ self.model.switch_to('eval')
+
+ self.data_mean = args.data['data_mean']
+ self.data_div = args.data['data_div']
+
+ self.img_transform = transforms.Compose([
+ transforms.Normalize(self.data_mean, self.data_div)])
+
+ self.args = args
+ self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax'])
+ torch.cuda.synchronize()
+
+ def run(self, image, sparse, mask):
+ dtype = image.dtype
+ image = image * 2 - 1
+ self.model.set_input(image.float(), torch.cat([sparse, mask], dim=1).float(), None)
+ cmp_output = self.model.model(self.model.image_input, self.model.sparse_input)
+ flow = self.fuser.convert_flow(cmp_output)
+ if flow.shape[2] != self.model.image_input.shape[2]:
+ flow = nn.functional.interpolate(
+ flow, size=self.model.image_input.shape[2:4],
+ mode="bilinear", align_corners=True)
+
+ return flow.to(dtype) # [b, 2, h, w]
+
+ # tensor_dict = self.model.eval(ret_loss=False)
+ # flow = tensor_dict['flow_tensors'][0].cpu().numpy().squeeze().transpose(1,2,0)
+
+ # return flow
+
+
+
+class FlowControlNetConditioningEmbeddingSVD(nn.Module):
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+
+
+class FlowControlNetFirstFrameEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ c_in,
+ c_out,
+ is_downsample=False
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1)
+
+ def forward(self, feature):
+ '''
+ feature: [b, c, h, w]
+ '''
+
+ embedding = self.conv_in(feature)
+ embedding = F.silu(embedding)
+
+ return embedding
+
+
+
+class FlowControlNetFirstFrameEncoder(nn.Module):
+ def __init__(
+ self,
+ c_in=320,
+ channels=[320, 640, 1280],
+ downsamples=[True, True, True],
+ use_zeroconv=True
+ ):
+ super().__init__()
+
+ self.encoders = nn.ModuleList([])
+ self.zeroconvs = nn.ModuleList([])
+
+ for channel, downsample in zip(channels, downsamples):
+ self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample))
+ self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity())
+ c_in = channel
+
+ def forward(self, first_frame):
+ feature = first_frame
+ deep_features = []
+ for encoder, zeroconv in zip(self.encoders, self.zeroconvs):
+ feature = encoder(feature)
+ # print(feature.shape)
+ deep_features.append(zeroconv(feature))
+ return deep_features
+
+
+@dataclass
+class FlowControlNetOutput(BaseOutput):
+ """
+ The output of [`FlowControlNetOutput`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+ controlnet_flow: torch.Tensor
+ cmp_output: torch.Tensor
+
+
+class FlowControlNet(ControlNetSDVModel):
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ conditioning_channels: int = 3,
+ conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.flow_encoder = FlowControlNetFirstFrameEncoder()
+
+ # time_embed_dim = block_out_channels[0] * 4
+ # blocks_time_embed_dim = time_embed_dim
+ self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ def get_warped_frames(self, first_frame, flows):
+ '''
+ video_frame: [b, c, w, h]
+ flows: [b, t-1, c, w, h]
+ '''
+ dtype = first_frame.dtype
+ warped_frames = []
+ for i in range(flows.shape[1]):
+ warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h]
+ warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h]
+ warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h]
+ return warped_frames
+
+ def get_cmp_flow(self, frames, sparse_optical_flow, mask):
+ '''
+ frames: [b, 13, 3, 384, 384] (0, 1) tensor
+ sparse_optical_flow: [b, 13, 2, 384, 384] (-384, 384) tensor
+ mask: [b, 13, 2, 384, 384] {0, 1} tensor
+ '''
+ b, t, c, h, w = frames.shape
+ assert h == 384 and w == 384
+ frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
+ sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
+ mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
+ cmp_flow, cmp_output = self.run(frames, sparse_optical_flow, mask) # [b*13, 2, 256, 256]
+ # cmp_flow = self.run(frames.float(), sparse_optical_flow.float(), mask.float()) # [b*13, 2, 256, 256]
+ cmp_flow = cmp_flow.reshape(b, t, 2, h, w)
+ return cmp_flow, cmp_output
+ # return cmp_flow.to(dtype=dtype)
+
+ def run(self, image, sparse, mask):
+ image = image * 2 - 1
+ cmp_output = self.cmp_model(image, torch.cat([sparse, mask], dim=1))
+ flow = self.fuser.convert_flow(cmp_output)
+ if flow.shape[2] != image.shape[2]:
+ flow = nn.functional.interpolate(
+ flow, size=image.shape[2:4],
+ mode="bilinear", align_corners=True)
+
+ return flow, cmp_output # [b, 2, h, w]
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w]
+ controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w]
+ # controlnet_mask: torch.FloatTensor = None, # [b, 13, 2, h, w]
+ # pixel_values_384: torch.FloatTensor = None,
+ # sparse_optical_flow_384: torch.FloatTensor = None,
+ # mask_384: torch.FloatTensor = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ guess_mode: bool = False,
+ conditioning_scale: float = 1.0,
+ ) -> Union[FlowControlNetOutput, Tuple]:
+
+
+ # print(sample.shape)
+ # print(controlnet_cond.shape)
+ # print(controlnet_flow.shape)
+
+ # assert False
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+
+
+ # hourglass_output = self.hourglass_forward(
+ # controlnet_cond, controlnet_sparse_flow, controlnet_mask, controlnet_init_flow) # [b, l, 3+2+2, h, w]
+
+ # controlnet_flow = controlnet_init_flow + hourglass_output
+
+ # 2. pre-process
+ sample = self.conv_in(sample) # [b*l, 320, h//8, w//8]
+
+ # print(controlnet_cond.shape)
+
+ # controlnet cond
+ if controlnet_cond != None:
+ # embed 成 64*64,和latent一个shape
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8]
+ # sample = sample + controlnet_cond
+
+ # print(controlnet_cond.shape)
+
+ # assert False
+ controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4]
+
+ # print(controlnet_cond.shape)
+
+ '''
+ torch.Size([2, 320, 32, 32])
+ torch.Size([2, 320, 16, 16])
+ torch.Size([2, 640, 8, 8])
+ torch.Size([2, 1280, 4, 4])
+ '''
+
+ # for x in controlnet_cond_features:
+ # print(x.shape)
+
+ # assert False
+
+ scales = [8, 16, 32, 64]
+ scale_flows = {}
+ fb, fl, fc, fh, fw = controlnet_flow.shape
+ # print(controlnet_flow.shape)
+ for scale in scales:
+ scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale)
+ scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale
+ scale_flows[scale] = scaled_flow
+
+ # for k in scale_flows.keys():
+ # print(scale_flows[k].shape)
+
+ # assert False
+
+ warped_cond_features = []
+ for cond_feature in controlnet_cond_features:
+ cb, cc, ch, cw = cond_feature.shape
+ # print(cond_feature.shape)
+ warped_cond_feature = self.get_warped_frames(cond_feature, scale_flows[fh // ch])
+ warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w]
+ wb, wl, wc, wh, ww = warped_cond_feature.shape
+ # print(warped_cond_feature.shape)
+ warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww))
+
+ # for x in warped_cond_features:
+ # print(x.shape)
+ # assert False
+
+ '''
+ torch.Size([28, 320, 32, 32])
+ torch.Size([28, 320, 16, 16])
+ torch.Size([28, 640, 8, 8])
+ torch.Size([28, 1280, 4, 4])
+ '''
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+
+ count = 0
+ length = len(warped_cond_features)
+
+ # print(sample.shape)
+ # print(warped_cond_features[0].shape)
+
+ # add the warped feature in the first scale
+ sample = sample + warped_cond_features[count]
+ count += 1
+
+ down_block_res_samples = (sample,)
+
+ # print(sample.shape)
+
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # print(sample.shape)
+ # print(warped_cond_features[min(count, length - 1)].shape)
+
+ sample = sample + warped_cond_features[min(count, length - 1)]
+ count += 1
+
+ down_block_res_samples += res_samples
+
+ # print(len(res_samples))
+ # for i in range(len(res_samples)):
+ # print(res_samples[i].shape)
+
+ # [28, 320, 32, 32]
+ # [28, 320, 32, 32]
+ # [28, 320, 16, 16]
+
+ # [28, 640, 16, 16]
+ # [28, 640, 16, 16]
+ # [28, 640, 8, 8]
+
+ # [28, 1280, 8, 8]
+ # [28, 1280, 8, 8]
+ # [28, 1280, 4, 4]
+
+ # [28, 1280, 4, 4]
+ # [28, 1280, 4, 4]
+
+ # print(sample.shape)
+ # print(warped_cond_features[-1].shape)
+
+ # add the warped feature in the last scale
+ sample = sample + warped_cond_features[-1]
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ ) # [b*l, 1280, h // 64, w // 64]
+
+ # print(sample.shape)
+
+ # assert False
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ # for sample in down_block_res_samples:
+ # print(sample.shape)
+ # print(mid_block_res_sample.shape)
+ # assert False
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample, controlnet_flow, None)
+
+ return FlowControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, cmp_output=None
+ )
+
diff --git a/models/unet_spatio_temporal_condition_controlnet.py b/models/unet_spatio_temporal_condition_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1361eeb83ab634ed298c05d3dbddda2b56376c8b
--- /dev/null
+++ b/models/unet_spatio_temporal_condition_controlnet.py
@@ -0,0 +1,504 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNetSpatioTemporalConditionOutput(BaseOutput):
+ """
+ The output of [`UNetSpatioTemporalConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNetSpatioTemporalConditionControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=1e-5,
+ resolution_idx=i,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ self.conv_act = nn.SiLU()
+
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0],
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ added_time_ids: torch.Tensor=None,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
+ tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ sample = sample + mid_block_additional_residual
+
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 7. Reshape back to original shape
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diff --git a/pipeline/pipeline.py b/pipeline/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4e30aab19574d9a94973e27315c7f8ef01cf687
--- /dev/null
+++ b/pipeline/pipeline.py
@@ -0,0 +1,660 @@
+import inspect
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKLTemporalDecoder
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from utils.scheduling_euler_discrete_karras_fix import EulerDiscreteScheduler
+
+from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
+from models.traj_ctrlnet import FlowControlNet as DragControlNet
+from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_add_time_ids(
+ noise_aug_strength,
+ dtype,
+ batch_size,
+ fps=4,
+ motion_bucket_id=128,
+ unet=None,
+ ):
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
+
+ passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids)
+ expected_add_embed_dim = unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ # add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
+
+
+ return add_time_ids
+
+
+def _append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ # Based on:
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ return outputs
+
+
+@dataclass
+class FlowControlNetPipelineOutput(BaseOutput):
+ r"""
+ Output class for zero-shot text-to-video pipeline.
+
+ Args:
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ """
+
+ frames: Union[List[PIL.Image.Image], np.ndarray]
+ controlnet_flow: torch.Tensor
+
+
+class FlowControlNetPipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "image_encoder->unet->vae"
+ _callback_tensor_inputs = ["latents"]
+ def __init__(
+ self,
+ vae: AutoencoderKLTemporalDecoder,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNetSpatioTemporalConditionControlNetModel,
+ drag_controlnet: DragControlNet,
+ face_controlnet: FaceControlNet,
+ scheduler: EulerDiscreteScheduler,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ drag_controlnet=drag_controlnet,
+ face_controlnet=face_controlnet,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.image_processor.pil_to_numpy(image)
+ image = self.image_processor.numpy_to_pt(image)
+
+ #image = image.unsqueeze(0)
+ image = _resize_with_antialiasing(image, (224, 224))
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ def _encode_vae_image(
+ self,
+ image: torch.Tensor,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ image = image.to(device=device)
+ image_latents = self.vae.encode(image).latent_dist.mode()
+
+ if do_classifier_free_guidance:
+ negative_image_latents = torch.zeros_like(image_latents)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_latents = torch.cat([negative_image_latents, image_latents])
+
+ # duplicate image_latents for each generation per prompt, using mps friendly method
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
+
+ return image_latents
+
+ def _get_add_time_ids(
+ self,
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ dtype,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
+
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
+
+ return add_time_ids
+
+ def decode_latents(self, latents, num_frames, decode_chunk_size=14):
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
+ latents = latents.flatten(0, 1)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
+
+ # decode decode_chunk_size frames at a time to avoid OOM
+ frames = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
+ decode_kwargs = {}
+ if accepts_num_frames:
+ # we only pass num_frames_in if it's expected
+ decode_kwargs["num_frames"] = num_frames_in
+
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
+ frames.append(frame)
+ frames = torch.cat(frames, dim=0)
+
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ frames = frames.float()
+ return frames
+
+ def check_inputs(self, image, height, width):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_frames,
+ num_channels_latents // 2,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ # print(shape)
+
+ # assert False
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, torch.FloatTensor],
+ controlnet_condition: torch.FloatTensor = None,
+
+ controlnet_flow: torch.FloatTensor = None,
+ landmarks: torch.FloatTensor = None,
+
+ drag_flow: torch.FloatTensor = None,
+ mask: torch.FloatTensor = None,
+
+ height: int = 576,
+ width: int = 1024,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 25,
+ min_guidance_scale: float = 1.0,
+ max_guidance_scale: float = 3.0,
+ fps: int = 7,
+ motion_bucket_id: int = 127,
+ noise_aug_strength: int = 0.02,
+ decode_chunk_size: Optional[int] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ return_dict: bool = True,
+ ctrl_scale_traj=1.0,
+ ctrl_scale_ldmk=1.0,
+ batch_size=1,
+ ):
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width)
+
+ # 2. Define call parameters
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = max_guidance_scale > 1.0
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
+
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
+ # is why it is reduced here.
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
+ fps = fps - 1
+
+ # 4. Encode input image using VAE
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
+ image = image + noise_aug_strength * noise
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float32)
+
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
+ image_latents = image_latents.to(image_embeddings.dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # Repeat the image latents for each frame so we can concatenate them with the noise
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
+ #image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
+
+ # 5. Get Added Time IDs
+ added_time_ids = self._get_add_time_ids(
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ image_embeddings.dtype,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ )
+ added_time_ids = added_time_ids.to(device)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+
+ #prepare controlnet condition
+ controlnet_condition = self.image_processor.preprocess(controlnet_condition, height=height, width=width)
+ # controlnet_condition = controlnet_condition.unsqueeze(0)
+ controlnet_condition = torch.cat([controlnet_condition] * 2) if do_classifier_free_guidance else controlnet_condition
+ controlnet_condition = controlnet_condition.to(device, latents.dtype)
+
+ controlnet_flow = torch.cat([controlnet_flow] * 2) if do_classifier_free_guidance else controlnet_flow
+ controlnet_flow = controlnet_flow.to(device, latents.dtype)
+
+ drag_flow = torch.cat([drag_flow] * 2) if do_classifier_free_guidance else drag_flow
+ drag_flow = drag_flow.to(device, latents.dtype)
+
+ # mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ mask = mask.to(device, latents.dtype)
+
+ landmarks = torch.cat([landmarks] * 2) if do_classifier_free_guidance else landmarks
+ landmarks = landmarks.to(device, latents.dtype)
+
+ # 7. Prepare guidance scale
+ # modified num_frames to window_size here !!!!!!!!!!!!!!
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
+ guidance_scale = guidance_scale.to(device, latents.dtype)
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
+
+ self._guidance_scale = guidance_scale
+
+ noise_aug_strength = 0.02 #"¯\_(ツ)_/¯
+ added_time_ids = _get_add_time_ids(
+ noise_aug_strength,
+ image_embeddings.dtype,
+ batch_size,
+ 6,
+ 128,
+ unet=self.unet,
+ )
+ added_time_ids = torch.cat([added_time_ids] * 2)
+ added_time_ids = added_time_ids.to(latents.device)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input_tmp = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input_tmp = self.scheduler.scale_model_input(latent_model_input_tmp, t)
+
+ # Concatenate image_latents over channels dimention
+ latent_model_input_tmp = torch.cat([latent_model_input_tmp, image_latents], dim=2)
+
+ down_res_face_tmp, mid_res_face_tmp, controlnet_flow, _ = self.face_controlnet(
+ latent_model_input_tmp,
+ t,
+ encoder_hidden_states=image_embeddings,
+ controlnet_cond=controlnet_condition,
+ controlnet_flow=controlnet_flow,
+ landmarks=landmarks,
+ added_time_ids=added_time_ids,
+ conditioning_scale=ctrl_scale_ldmk,
+ guess_mode=False,
+ return_dict=False,
+ )
+
+ down_res_drag_tmp, mid_res_drag_tmp, _, _ = self.drag_controlnet(
+ latent_model_input_tmp,
+ t,
+ encoder_hidden_states=image_embeddings,
+ controlnet_cond=controlnet_condition,
+ controlnet_flow=drag_flow,
+ added_time_ids=added_time_ids,
+ conditioning_scale=ctrl_scale_traj,
+ guess_mode=False,
+ return_dict=False,
+ )
+
+ down_block_res_samples_tmp = []
+ for down_face, down_drag in zip(down_res_face_tmp, down_res_drag_tmp):
+ _, _, h, w = down_face.shape
+ mask_tmp = F.interpolate(mask, (h, w), mode='nearest')
+ res = down_face * mask_tmp + down_drag * (1 - mask_tmp)
+ down_block_res_samples_tmp.append(res)
+
+ _, _, h, w = mid_res_face_tmp.shape
+ mask_tmp = F.interpolate(mask, (h, w), mode='nearest')
+ mid_block_res_sample_tmp = mid_res_face_tmp * mask_tmp + mid_res_drag_tmp * (1 - mask_tmp)
+
+ # predict the noise residual
+ noise_pred_tmp = self.unet(
+ latent_model_input_tmp,
+ t,
+ encoder_hidden_states=image_embeddings,
+ down_block_additional_residuals=down_block_res_samples_tmp,
+ mid_block_additional_residual=mid_block_res_sample_tmp,
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond_tmp, noise_pred_cond_tmp = noise_pred_tmp.chunk(2)
+ noise_pred_tmp = noise_pred_uncond_tmp + self.guidance_scale * (noise_pred_cond_tmp - noise_pred_uncond_tmp)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred_tmp, t, latents).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ frames = self.decode_latents(latents.to(self.vae.dtype), num_frames, decode_chunk_size)
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
+ else:
+ frames = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return frames, controlnet_flow
+
+ return FlowControlNetPipelineOutput(
+ frames=frames,
+ controlnet_flow=controlnet_flow
+ )
+
+
+# resizing utils
+# TODO: clean up later
+def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
+
+ if input.ndim == 3:
+ input = input.unsqueeze(0) # Add a batch dimension
+
+ h, w = input.shape[-2:]
+ factors = (h / size[0], w / size[1])
+
+ # First, we have to determine sigma
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+ sigmas = (
+ max((factors[0] - 1.0) / 2.0, 0.001),
+ max((factors[1] - 1.0) / 2.0, 0.001),
+ )
+
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+ # Make sure it is odd
+ if (ks[0] % 2) == 0:
+ ks = ks[0] + 1, ks[1]
+
+ if (ks[1] % 2) == 0:
+ ks = ks[0], ks[1] + 1
+
+ input = _gaussian_blur2d(input, ks, sigmas)
+
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
+ return output
+
+
+def _compute_padding(kernel_size):
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+def _filter2d(input, kernel):
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ padding_shape: list[int] = _compute_padding([height, width])
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ out = output.view(b, c, h, w)
+ return out
+
+
+def _gaussian(window_size: int, sigma):
+ if isinstance(sigma, float):
+ sigma = torch.tensor([[sigma]])
+
+ batch_size = sigma.shape[0]
+
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+ return gauss / gauss.sum(-1, keepdim=True)
+
+
+def _gaussian_blur2d(input, kernel_size, sigma):
+ if isinstance(sigma, tuple):
+ sigma = torch.tensor([sigma], dtype=input.dtype)
+ else:
+ sigma = sigma.to(dtype=input.dtype)
+
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
+ bs = sigma.shape[0]
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
+ out_x = _filter2d(input, kernel_x[..., None, :])
+ out = _filter2d(out_x, kernel_y[..., None])
+
+ return out
+
+
+def get_views(video_length, window_size=14, stride=7):
+ num_blocks_time = (video_length - window_size) // stride + 1
+ views = []
+ for i in range(num_blocks_time):
+ t_start = int(i * stride)
+ t_end = t_start + window_size
+ views.append((t_start,t_end))
+ return views
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bde8762942e5c906456b7ac7973574783e762967
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,21 @@
+diffusers==0.24.0
+gradio==4.5.0
+scikit-image
+torch==2.0.1
+torchvision==0.15.2
+einops==0.8.0
+accelerate==0.30.1
+transformers==4.41.1
+colorlog==6.8.2
+cupy-cuda117==10.6.0
+av==12.1.0
+gpustat==1.1.1
+trimesh==4.4.1
+facexlib==0.3.0
+omegaconf==2.3.0
+librosa==0.10.2.post1
+mediapipe==0.10.14
+kornia==0.7.2
+yacs==0.1.8
+gfpgan==1.3.8
+numpy==1.23.0
\ No newline at end of file
diff --git a/run_gradio_audio_driven.py b/run_gradio_audio_driven.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0d5f2b09db6597380e8b2c3a23e5640bae1e51f
--- /dev/null
+++ b/run_gradio_audio_driven.py
@@ -0,0 +1,1240 @@
+import gradio as gr
+import numpy as np
+import cv2
+import os
+from PIL import Image
+from scipy.interpolate import PchipInterpolator
+import torchvision
+import time
+from tqdm import tqdm
+import imageio
+
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from einops import repeat
+
+from pydub import AudioSegment
+
+from packaging import version
+
+from accelerate.utils import set_seed
+from transformers import CLIPVisionModelWithProjection
+
+from diffusers import AutoencoderKLTemporalDecoder
+from diffusers.utils.import_utils import is_xformers_available
+
+from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
+from pipeline.pipeline import FlowControlNetPipeline
+from models.traj_ctrlnet import FlowControlNet as DragControlNet, CMP_demo
+from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet
+
+from utils.flow_viz import flow_to_image
+from utils.utils import split_filename, image2arr, image2pil, ensure_dirname
+
+
+output_dir = "Output_audio_driven"
+
+
+ensure_dirname(output_dir)
+
+
+def draw_landmarks_cv2(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 0, 255), -1)
+ return image
+
+
+def sample_optical_flow(A, B, h, w):
+ b, l, k, _ = A.shape
+
+ sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device)
+ mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device)
+
+ x_coords = A[..., 0].long()
+ y_coords = A[..., 1].long()
+
+ x_coords = torch.clip(x_coords, 0, h - 1)
+ y_coords = torch.clip(y_coords, 0, w - 1)
+
+ b_idx = torch.arange(b)[:, None, None].repeat(1, l, k)
+ l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k)
+
+ sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B
+
+ mask[b_idx, l_idx, x_coords, y_coords] = 1
+
+ mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2)
+
+ return sparse_optical_flow, mask
+
+
+@torch.no_grad()
+def get_sparse_flow(landmarks, h, w, t):
+
+ landmarks = torch.flip(landmarks, dims=[3])
+
+ pose_flow = (landmarks - landmarks[:, 0:1].repeat(1, t, 1, 1))[:, 1:] # 前向光流
+ according_poses = landmarks[:, 0:1].repeat(1, t - 1, 1, 1)
+
+ pose_flow = torch.flip(pose_flow, dims=[3])
+
+ b, t, K, _ = pose_flow.shape
+
+ sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w)
+
+ return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3)
+
+
+
+def sample_inputs_face(first_frame, landmarks):
+
+ pc, ph, pw = first_frame.shape
+ landmarks = landmarks.unsqueeze(0)
+
+ pl = landmarks.shape[1]
+
+ sparse_optical_flow, mask = get_sparse_flow(landmarks, ph, pw, pl)
+
+ if ph != 384 or pw != 384:
+
+ first_frame_384 = F.interpolate(first_frame.unsqueeze(0), (384, 384)) # [3, 384, 384]
+
+ landmarks_384 = torch.zeros_like(landmarks)
+ landmarks_384[:, :, :, 0] = landmarks[:, :, :, 0] / pw * 384
+ landmarks_384[:, :, :, 1] = landmarks[:, :, :, 1] / ph * 384
+
+ sparse_optical_flow_384, mask_384 = get_sparse_flow(landmarks_384, 384, 384, pl)
+
+ else:
+ first_frame_384, landmarks_384 = first_frame, landmarks
+ sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask
+
+ controlnet_image = first_frame.unsqueeze(0)
+
+ return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384
+
+
+
+PARTS = [
+ ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)),
+ ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)),
+ ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)),
+ ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)),
+ ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)),
+ ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)),
+ ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)),
+ ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)),
+ ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)),
+ ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)),
+ ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)),
+ ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)),
+ ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)),
+ ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)),
+ ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)),
+]
+
+
+def draw_landmarks(keypoints, h, w):
+
+ image = np.zeros((h, w, 3))
+
+ for name, indices, color in PARTS:
+ indices = np.array(indices) - 1
+ current_part_keypoints = keypoints[indices]
+
+ for i in range(len(indices) - 1):
+ x1, y1 = current_part_keypoints[i]
+ x2, y2 = current_part_keypoints[i + 1]
+ cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2)
+
+ return image
+
+
+def divide_points_afterinterpolate(resized_all_points, motion_brush_mask):
+ k = resized_all_points.shape[0]
+ starts = resized_all_points[:, 0] # [K, 2]
+
+ in_masks = []
+ out_masks = []
+
+ for i in range(k):
+ x, y = int(starts[i][1]), int(starts[i][0])
+ if motion_brush_mask[x][y] == 255:
+ in_masks.append(resized_all_points[i])
+ else:
+ out_masks.append(resized_all_points[i])
+
+ in_masks = np.array(in_masks)
+ out_masks = np.array(out_masks)
+
+ return in_masks, out_masks
+
+
+def get_sparseflow_and_mask_forward(
+ resized_all_points,
+ n_steps, H, W,
+ is_backward_flow=False
+ ):
+
+ K = resized_all_points.shape[0]
+
+ starts = resized_all_points[:, 0]
+
+ interpolated_ends = resized_all_points[:, 1:]
+
+ s_flow = np.zeros((K, n_steps, H, W, 2))
+ mask = np.zeros((K, n_steps, H, W))
+
+ for k in range(K):
+ for i in range(n_steps):
+ start, end = starts[k], interpolated_ends[k][i]
+ flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1)
+ s_flow[k][i][int(start[1]), int(start[0])] = flow
+ mask[k][i][int(start[1]), int(start[0])] = 1
+
+ s_flow = np.sum(s_flow, axis=0)
+ mask = np.sum(mask, axis=0)
+
+ return s_flow, mask
+
+
+def init_models(pretrained_model_name_or_path, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False):
+
+ drag_ckpt = "./ckpts/mofa/traj_controlnet"
+ face_ckpt = "./ckpts/mofa/ldmk_controlnet"
+
+ print('start loading models...')
+
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16"
+ )
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
+ pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16")
+ unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="unet",
+ low_cpu_mem_usage=True,
+ variant="fp16",
+ )
+
+ drag_controlnet = DragControlNet.from_pretrained(drag_ckpt)
+ face_controlnet = FaceControlNet.from_pretrained(face_ckpt)
+
+ cmp = CMP_demo(
+ './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml',
+ 42000
+ ).to(device)
+ cmp.requires_grad_(False)
+
+ # Freeze vae and image_encoder
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+ drag_controlnet.requires_grad_(False)
+ face_controlnet.requires_grad_(False)
+
+ # Move image_encoder and vae to gpu and cast to weight_dtype
+ image_encoder.to(device, dtype=weight_dtype)
+ vae.to(device, dtype=weight_dtype)
+ unet.to(device, dtype=weight_dtype)
+ drag_controlnet.to(device, dtype=weight_dtype)
+ face_controlnet.to(device, dtype=weight_dtype)
+
+ if enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ print(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError(
+ "xformers is not available. Make sure it is installed correctly")
+
+ if allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ pipeline = FlowControlNetPipeline.from_pretrained(
+ pretrained_model_name_or_path,
+ unet=unet,
+ face_controlnet=face_controlnet,
+ drag_controlnet=drag_controlnet,
+ image_encoder=image_encoder,
+ vae=vae,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(device)
+
+ print('models loaded.')
+
+ return pipeline, cmp
+
+
+def interpolate_trajectory(points, n_points):
+ x = [point[0] for point in points]
+ y = [point[1] for point in points]
+
+ t = np.linspace(0, 1, len(points))
+
+ fx = PchipInterpolator(t, x)
+ fy = PchipInterpolator(t, y)
+
+ new_t = np.linspace(0, 1, n_points)
+
+ new_x = fx(new_t)
+ new_y = fy(new_t)
+ new_points = list(zip(new_x, new_y))
+
+ return new_points
+
+
+def visualize_drag_v2(background_image_path, splited_tracks, width, height):
+ trajectory_maps = []
+
+ background_image = Image.open(background_image_path).convert('RGBA')
+ background_image = background_image.resize((width, height))
+ w, h = background_image.size
+ transparent_background = np.array(background_image)
+ transparent_background[:, :, -1] = 128
+ transparent_background = Image.fromarray(transparent_background)
+
+ # Create a transparent layer with the same size as the background image
+ transparent_layer = np.zeros((h, w, 4))
+ for splited_track in splited_tracks:
+ if len(splited_track) > 1:
+ splited_track = interpolate_trajectory(splited_track, 16)
+ splited_track = splited_track[:16]
+ for i in range(len(splited_track)-1):
+ start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
+ end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(splited_track)-2:
+ cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
+ else:
+ cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+ trajectory_maps.append(trajectory_map)
+ return trajectory_maps, transparent_layer
+
+
+class Drag:
+ def __init__(self, device, height, width, model_length):
+ self.device = device
+
+ pretrained_model_name_or_path = "/apdcephfs/share_1290939/vg_zoo/huggingface/stable-video-diffusion-img2vid-xt-1-1"
+
+ self.device = 'cuda'
+ self.weight_dtype = torch.float16
+
+ self.pipeline, self.cmp = init_models(
+ pretrained_model_name_or_path,
+ weight_dtype=self.weight_dtype,
+ device=self.device,
+ )
+
+ self.height = height
+ self.width = width
+ self.model_length = model_length
+
+ def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None):
+
+ b, t, c, h, w = frames.shape
+ assert h == 384 and w == 384
+ frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
+ sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
+ mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
+
+ cmp_flow = []
+ for i in range(b*t):
+ tmp_flow = self.cmp.run(frames[i:i+1], sparse_optical_flow[i:i+1], mask[i:i+1]) # [1, 2, 256, 256]
+ cmp_flow.append(tmp_flow)
+ cmp_flow = torch.cat(cmp_flow, dim=0) # [b*13, 2, 256, 256]
+
+ if brush_mask is not None:
+ brush_mask = torch.from_numpy(brush_mask) / 255.
+ brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype)
+ brush_mask = brush_mask.unsqueeze(0).unsqueeze(0)
+ cmp_flow = cmp_flow * brush_mask
+
+ cmp_flow = cmp_flow.reshape(b, t, 2, h, w)
+
+ return cmp_flow
+
+
+ def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None):
+
+ fb, fl, fc, _, _ = pixel_values_384.shape
+
+ controlnet_flow = self.get_cmp_flow(
+ pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1),
+ sparse_optical_flow_384,
+ mask_384, motion_brush_mask
+ )
+
+ if self.height != 384 or self.width != 384:
+ scales = [self.height / 384, self.width / 384]
+ controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width)
+ controlnet_flow[:, :, 0] *= scales[1]
+ controlnet_flow[:, :, 1] *= scales[0]
+
+ return controlnet_flow
+
+ @torch.no_grad()
+ def forward_sample(self, save_root, first_frame_path, audio_path, hint_path, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask_384=None, ldmk_mask_mask_origin=None, ctrl_scale_traj=1., ctrl_scale_ldmk=1., ldmk_render='sadtalker'):
+
+ seed = 42
+
+ num_frames = self.model_length
+
+ set_seed(seed)
+
+ input_first_frame_384 = F.interpolate(input_first_frame, (384, 384))
+ input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0)
+ input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255))
+ height, width = input_first_frame.shape[-2:]
+
+ input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+ input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+
+ input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
+ input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
+
+ input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
+
+ if in_mask_flag:
+ flow_inmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384
+ )
+ else:
+ fb, fl = mask_384_inmask.shape[:2]
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ if out_mask_flag:
+ flow_outmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_outmask, mask_384_outmask
+ )
+ else:
+ fb, fl = mask_384_outmask.shape[:2]
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ inmask_no_zero = (flow_inmask != 0).all(dim=2)
+ inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
+
+ controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
+
+ ldmk_controlnet_flow, ldmk_pose_imgs, landmarks, num_frames = self.get_landmarks(save_root, first_frame_path, audio_path, input_first_frame[0], self.model_length, ldmk_render=ldmk_render)
+
+ ldmk_flow_len = ldmk_controlnet_flow.shape[1]
+ drag_flow_len = controlnet_flow.shape[1]
+ repeat_num = ldmk_flow_len // drag_flow_len + 1
+ drag_controlnet_flow = controlnet_flow.repeat(1, repeat_num, 1, 1, 1)
+ drag_controlnet_flow = drag_controlnet_flow[:, :ldmk_flow_len]
+
+ ldmk_mask_mask_origin = ldmk_mask_mask_origin.unsqueeze(0).unsqueeze(0) # [1, 1, h, w]
+
+ val_output = self.pipeline(
+ input_first_frame_pil,
+ input_first_frame_pil,
+
+ ldmk_controlnet_flow,
+ ldmk_pose_imgs,
+
+ drag_controlnet_flow,
+ ldmk_mask_mask_origin,
+
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ decode_chunk_size=8,
+ motion_bucket_id=127,
+ fps=7,
+ noise_aug_strength=0.02,
+ ctrl_scale_traj=ctrl_scale_traj,
+ ctrl_scale_ldmk=ctrl_scale_ldmk,
+ )
+
+ video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow
+
+ for i in range(num_frames):
+ img = video_frames[i]
+ video_frames[i] = np.array(img)
+
+ video_frames = np.array(video_frames)
+
+ outputs = self.save_video(ldmk_pose_imgs, first_frame_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow)
+
+ return outputs
+
+ def save_video(self, pose_imgs, image_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow, outputs=dict()):
+
+ pose_img_nps = (pose_imgs[0].permute(0, 2, 3, 1).cpu().numpy()*255).astype(np.uint8)
+
+ cv2_firstframe = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
+ cv2_hint = cv2.cvtColor(cv2.imread(hint_path), cv2.COLOR_BGR2RGB)
+
+ viz_landmarks = []
+ for k in tqdm(range(len(landmarks))):
+ im = draw_landmarks_cv2(video_frames[k].copy(), landmarks[k])
+ viz_landmarks.append(im)
+ viz_landmarks = np.stack(viz_landmarks)
+
+ viz_esti_flows = []
+ for i in range(estimated_flow.shape[1]):
+ temp_flow = estimated_flow[0][i].permute(1, 2, 0)
+ viz_esti_flows.append(flow_to_image(temp_flow))
+ viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows
+ viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c]
+
+ viz_drag_flows = []
+ for i in range(drag_controlnet_flow.shape[1]):
+ temp_flow = drag_controlnet_flow[0][i].permute(1, 2, 0)
+ viz_drag_flows.append(flow_to_image(temp_flow))
+ viz_drag_flows = [np.uint8(np.ones_like(viz_drag_flows[-1]) * 255)] + viz_drag_flows
+ viz_drag_flows = np.stack(viz_drag_flows) # [t-1, h, w, c]
+
+ out_nps = []
+ for plen in range(video_frames.shape[0]):
+ out_nps.append(video_frames[plen])
+ out_nps = np.stack(out_nps)
+
+ first_frames = np.stack([cv2_firstframe] * out_nps.shape[0])
+ hints = np.stack([cv2_hint] * out_nps.shape[0])
+
+ total_nps = np.concatenate([
+ first_frames, hints, viz_drag_flows, viz_esti_flows, pose_img_nps, viz_landmarks, out_nps
+ ], axis=2)
+
+ video_frames_tensor = torch.from_numpy(video_frames).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+
+ outputs['logits_imgs'] = video_frames_tensor
+ outputs['traj_flows'] = torch.from_numpy(viz_drag_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['ldmk_flows'] = torch.from_numpy(viz_esti_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['viz_ldmk'] = torch.from_numpy(pose_img_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['out_with_ldmk'] = torch.from_numpy(viz_landmarks).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['total'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+
+ return outputs
+
+ @torch.no_grad()
+ def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path):
+
+ original_width, original_height = self.width, self.height
+
+ flow_div = self.model_length
+
+ input_all_points = tracking_points.constructor_args['value']
+
+ if len(input_all_points) == 0 or len(input_all_points[-1]) == 1:
+ return np.uint8(np.ones((original_width, original_height, 3))*255)
+
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
+ resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points]
+
+ new_resized_all_points = []
+ new_resized_all_points_384 = []
+ for tnum in range(len(resized_all_points)):
+ new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div))
+ new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div))
+
+ resized_all_points = np.array(new_resized_all_points)
+ resized_all_points_384 = np.array(new_resized_all_points_384)
+
+ motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST)
+
+ resized_all_points_384_inmask, resized_all_points_384_outmask = \
+ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384)
+
+ in_mask_flag = False
+ out_mask_flag = False
+
+ if resized_all_points_384_inmask.shape[0] != 0:
+ in_mask_flag = True
+ input_drag_384_inmask, input_mask_384_inmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_inmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_inmask, input_mask_384_inmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ if resized_all_points_384_outmask.shape[0] != 0:
+ out_mask_flag = True
+ input_drag_384_outmask, input_mask_384_outmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_outmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_outmask, input_mask_384_outmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
+
+ first_frames_transform = transforms.Compose([
+ lambda x: Image.fromarray(x),
+ transforms.ToTensor(),
+ ])
+
+ input_first_frame = image2arr(first_frame_path)
+ input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device)
+
+ seed = 42
+ num_frames = flow_div
+
+ set_seed(seed)
+
+ input_first_frame_384 = F.interpolate(input_first_frame, (384, 384))
+ input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0)
+
+ input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+ input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+
+ input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
+ input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
+
+ input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
+
+ if in_mask_flag:
+ flow_inmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384
+ )
+ else:
+ fb, fl = mask_384_inmask.shape[:2]
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ if out_mask_flag:
+ flow_outmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_outmask, mask_384_outmask
+ )
+ else:
+ fb, fl = mask_384_outmask.shape[:2]
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ inmask_no_zero = (flow_inmask != 0).all(dim=2)
+ inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
+
+ controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
+
+ print(controlnet_flow.shape)
+
+ controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0)
+ viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c]
+
+ return viz_esti_flows
+
+ @torch.no_grad()
+ def get_cmp_flow_landmarks(self, frames, sparse_optical_flow, mask):
+
+ dtype = frames.dtype
+ b, t, c, h, w = sparse_optical_flow.shape
+ assert h == 384 and w == 384
+ frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
+ sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
+ mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
+
+ cmp_flow = []
+ for i in range(b*t):
+ tmp_flow = self.cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float()) # [b*13, 2, 256, 256]
+ cmp_flow.append(tmp_flow)
+ cmp_flow = torch.cat(cmp_flow, dim=0)
+ cmp_flow = cmp_flow.reshape(b, t, 2, h, w)
+
+ return cmp_flow.to(dtype=dtype)
+
+ def audio2landmark(self, audio_path, img_path, ldmk_result_dir, ldmk_render=0):
+
+ if ldmk_render == 'sadtalker':
+ return_code = os.system(
+ f'''
+ python sadtalker_audio2pose/inference.py \
+ --preprocess full \
+ --size 256 \
+ --driven_audio {audio_path} \
+ --source_image {img_path} \
+ --result_dir {ldmk_result_dir} \
+ --facerender pirender \
+ --verbose \
+ --face3dvis
+ ''')
+ assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report."
+ elif ldmk_render == 'aniportrait':
+ return_code = os.system(
+ f'''
+ python aniportrait/audio2ldmk.py \
+ --ref_image_path {img_path} \
+ --audio_path {audio_path} \
+ --save_dir {ldmk_result_dir} \
+ '''
+ )
+ assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report."
+ else:
+ assert False
+
+ return os.path.join(ldmk_result_dir, 'landmarks.npy')
+
+
+ def get_landmarks(self, save_root, first_frame_path, audio_path, first_frame, num_frames=25, ldmk_render='sadtalker'):
+
+ ldmk_dir = os.path.join(save_root, 'landmarks')
+ ldmknpy_dir = self.audio2landmark(audio_path, first_frame_path, ldmk_dir, ldmk_render)
+
+ landmarks = np.load(ldmknpy_dir)
+ landmarks = landmarks[:num_frames] # [25, 68, 2]
+ flow_len = landmarks.shape[0]
+
+ ldmk_clip = landmarks.copy()
+
+ assert ldmk_clip.ndim == 3
+
+ ldmk_clip[:, :, 0] = ldmk_clip[:, :, 0] / self.width * 320
+ ldmk_clip[:, :, 1] = ldmk_clip[:, :, 1] / self.height * 320
+
+ pose_imgs = []
+ for i in range(ldmk_clip.shape[0]):
+ pose_img = draw_landmarks(ldmk_clip[i], 320, 320)
+ pose_img = cv2.resize(pose_img, (self.width, self.height), cv2.INTER_NEAREST)
+ pose_imgs.append(pose_img)
+ pose_imgs = np.array(pose_imgs)
+ pose_imgs = torch.from_numpy(pose_imgs).permute(0, 3, 1, 2).float() / 255.
+ pose_imgs = pose_imgs.unsqueeze(0).to(self.weight_dtype).to(self.device)
+
+ landmarks = torch.from_numpy(landmarks).to(self.weight_dtype).to(self.device)
+
+ val_controlnet_image, val_sparse_optical_flow, \
+ val_mask, val_first_frame_384, \
+ val_sparse_optical_flow_384, val_mask_384 = sample_inputs_face(first_frame, landmarks)
+
+ fb, fl, fc, fh, fw = val_sparse_optical_flow.shape
+
+ val_controlnet_flow = self.get_cmp_flow_landmarks(
+ val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1),
+ val_sparse_optical_flow_384,
+ val_mask_384
+ )
+
+ if fh != 384 or fw != 384:
+ scales = [fh / 384, fw / 384]
+ val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw)
+ val_controlnet_flow[:, :, 0] *= scales[1]
+ val_controlnet_flow[:, :, 1] *= scales[0]
+
+ val_controlnet_image = val_controlnet_image.unsqueeze(0).repeat(1, fl, 1, 1, 1)
+
+ return val_controlnet_flow, pose_imgs, landmarks, flow_len
+
+
+ def run(self, first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render):
+
+
+ timestamp = str(time.time()).split('.')[0]
+ save_name = f"trajscale{ctrl_scale_traj}_ldmkscale{ctrl_scale_ldmk}_{ldmk_render}_ts{timestamp}"
+ save_root = os.path.join(os.path.dirname(audio_path), save_name)
+ os.makedirs(save_root, exist_ok=True)
+
+
+ original_width, original_height = self.width, self.height
+
+ flow_div = self.model_length
+
+ input_all_points = tracking_points.constructor_args['value']
+
+ # print(input_all_points)
+
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
+ resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points]
+
+ new_resized_all_points = []
+ new_resized_all_points_384 = []
+ for tnum in range(len(resized_all_points)):
+ new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div))
+ new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div))
+
+ resized_all_points = np.array(new_resized_all_points)
+ resized_all_points_384 = np.array(new_resized_all_points_384)
+
+ motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST)
+ # ldmk_mask_mask_384 = cv2.resize(ldmk_mask_mask, (384, 384), cv2.INTER_NEAREST)
+
+ # motion_brush_mask = torch.from_numpy(motion_brush_mask) / 255.
+ # motion_brush_mask = motion_brush_mask.to(self.device)
+
+ ldmk_mask_mask = torch.from_numpy(ldmk_mask_mask) / 255.
+ ldmk_mask_mask = ldmk_mask_mask.to(self.device)
+
+ if resized_all_points_384.shape[0] != 0:
+ resized_all_points_384_inmask, resized_all_points_384_outmask = \
+ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384)
+ else:
+ resized_all_points_384_inmask = np.array([])
+ resized_all_points_384_outmask = np.array([])
+
+ in_mask_flag = False
+ out_mask_flag = False
+
+ if resized_all_points_384_inmask.shape[0] != 0:
+ in_mask_flag = True
+ input_drag_384_inmask, input_mask_384_inmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_inmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_inmask, input_mask_384_inmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ if resized_all_points_384_outmask.shape[0] != 0:
+ out_mask_flag = True
+ input_drag_384_outmask, input_mask_384_outmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_outmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_outmask, input_mask_384_outmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2]
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w]
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2]
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w]
+
+ dir, base, ext = split_filename(first_frame_path)
+ id = base.split('_')[0]
+
+ image_pil = image2pil(first_frame_path)
+ image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGBA')
+
+ visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height)
+
+ motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA')
+ visualized_drag = visualized_drag[0].convert('RGBA')
+ ldmk_mask_viz_pil = Image.fromarray(ldmk_mask_viz.astype(np.uint8)).convert('RGBA')
+
+ drag_input = Image.alpha_composite(image_pil, visualized_drag)
+ motionbrush_ldmkmask = Image.alpha_composite(motion_brush_viz_pil, ldmk_mask_viz_pil)
+
+ visualized_drag_brush_ldmk_mask = Image.alpha_composite(drag_input, motionbrush_ldmkmask)
+
+ first_frames_transform = transforms.Compose([
+ lambda x: Image.fromarray(x),
+ transforms.ToTensor(),
+ ])
+
+ hint_path = os.path.join(save_root, f'hint.png')
+ visualized_drag_brush_ldmk_mask.save(hint_path)
+
+ first_frames = image2arr(first_frame_path)
+ first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=1).to(self.device)
+
+ outputs = self.forward_sample(
+ save_root,
+ first_frame_path,
+ audio_path,
+ hint_path,
+ input_drag_384_inmask.to(self.device),
+ input_drag_384_outmask.to(self.device),
+ first_frames.to(self.device),
+ input_mask_384_inmask.to(self.device),
+ input_mask_384_outmask.to(self.device),
+ in_mask_flag,
+ out_mask_flag,
+ motion_brush_mask_384, ldmk_mask_mask,
+ ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render=ldmk_render)
+
+ traj_flow_tensor = outputs['traj_flows'][0] # [25, 3, h, w]
+ ldmk_flow_tensor = outputs['ldmk_flows'][0] # [25, 3, h, w]
+ viz_ldmk_tensor = outputs['viz_ldmk'][0] # [25, 3, h, w]
+ out_with_ldmk_tensor = outputs['out_with_ldmk'][0] # [25, 3, h, w]
+ output_tensor = outputs['logits_imgs'][0] # [25, 3, h, w]
+ total_tensor = outputs['total'][0] # [25, 3, h, w]
+
+ traj_flows_path = os.path.join(save_root, f'traj_flow.gif')
+ ldmk_flows_path = os.path.join(save_root, f'ldmk_flow.gif')
+ viz_ldmk_path = os.path.join(save_root, f'viz_ldmk.gif')
+ out_with_ldmk_path = os.path.join(save_root, f'output_w_ldmk.gif')
+ outputs_path = os.path.join(save_root, f'output.gif')
+ total_path = os.path.join(save_root, f'total.gif')
+
+ traj_flows_path_mp4 = os.path.join(save_root, f'traj_flow.mp4')
+ ldmk_flows_path_mp4 = os.path.join(save_root, f'ldmk_flow.mp4')
+ viz_ldmk_path_mp4 = os.path.join(save_root, f'viz_ldmk.mp4')
+ out_with_ldmk_path_mp4 = os.path.join(save_root, f'output_w_ldmk.mp4')
+ outputs_path_mp4 = os.path.join(save_root, f'output.mp4')
+ total_path_mp4 = os.path.join(save_root, f'total.mp4')
+
+ # print(output_tensor.shape)
+
+ traj_flow_np = traj_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ ldmk_flow_np = ldmk_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ viz_ldmk_np = viz_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ out_with_ldmk_np = out_with_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ output_np = output_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ total_np = total_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+
+ torchvision.io.write_video(
+ traj_flows_path_mp4,
+ traj_flow_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ ldmk_flows_path_mp4,
+ ldmk_flow_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ viz_ldmk_path_mp4,
+ viz_ldmk_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ out_with_ldmk_path_mp4,
+ out_with_ldmk_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ outputs_path_mp4,
+ output_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+
+ imageio.mimsave(traj_flows_path, np.uint8(traj_flow_np), fps=20, loop=0)
+ imageio.mimsave(ldmk_flows_path, np.uint8(ldmk_flow_np), fps=20, loop=0)
+ imageio.mimsave(viz_ldmk_path, np.uint8(viz_ldmk_np), fps=20, loop=0)
+ imageio.mimsave(out_with_ldmk_path, np.uint8(out_with_ldmk_np), fps=20, loop=0)
+ imageio.mimsave(outputs_path, np.uint8(output_np), fps=20, loop=0)
+
+ torchvision.io.write_video(total_path_mp4, total_np, fps=20, video_codec='h264', options={'crf': '10'})
+ imageio.mimsave(total_path, np.uint8(total_np), fps=20, loop=0)
+
+ return hint_path, traj_flows_path, ldmk_flows_path, viz_ldmk_path, outputs_path, traj_flows_path_mp4, ldmk_flows_path_mp4, viz_ldmk_path_mp4, outputs_path_mp4
+
+
+with gr.Blocks() as demo:
+ gr.Markdown("""MOFA-Video
""")
+
+ gr.Markdown("""Official Gradio Demo for MOFA-Video: Controllable Image Animation via Generative Motion Field Adaptions in Frozen Image-to-Video Diffusion Model.
""")
+
+ gr.Markdown(
+ """
+ 1. Use the "Upload Image" button to upload an image. Avoid dragging the image directly into the window.
+ 2. Proceed to trajectory control:
+ 2.1. Click "Add Trajectory" first, then select points on the "Add Trajectory Here" image. The first click sets the starting point. Click multiple points to create a non-linear trajectory. To add a new trajectory, click "Add Trajectory" again and select points on the image.
+ 2.2. After adding each trajectory, an optical flow image will be displayed automatically in "Temporary Trajectory Flow Visualization". Use it as a reference to adjust the trajectory for desired effects (e.g., area, intensity).
+ 2.3. To delete the latest trajectory, click "Delete Last Trajectory."
+ 2.4. To use the motion brush for restraining the control area of the trajectory, click to add masks on the "Add Motion Brush Here" image. The motion brush restricts the optical flow area derived from the trajectory whose starting point is within the motion brush. The displayed optical flow image will change correspondingly. Adjust the motion brush radius using the "Motion Brush Radius" slider.
+ 2.5. Choose the Control scale for trajectory using the "Control Scale for Trajectory" slider. This determines the control intensity of trajectory. Setting it to 0 means no control (pure generation result of SVD itself), while setting it to 1 results in the strongest control (which will not lead to good results in most cases because of twisting artifacts). A preset value of 0.6 is recommended for most cases.
+ 3. Proceed to landmark control from audio:
+ 3.1. Use the "Upload Audio" button to upload an audio (currently support .wav and .mp3 extensions).
+ 3.2. Click to add masks on the "Add Landmark Mask Here" image. This mask restricts the optical flow area derived from the landmarks, which should usually covers the area of the person's head parts, and, if desired, body parts for more natural body movement instead of being stationary. Adjust the landmark brush radius using the "Landmark Brush Radius" slider.
+ 3.3. Choose the Control scale for landmarks using the "Control Scale for Landmark" slider. This determines the control intensity of landmarks. Different from trajectory controls, a preset value of 1 is recommended for most cases.
+ 3.4. Choose the landmark renderer to generate landmark sequences from the input audio. The landmark generation codes are based on either SadTalker or AniPortrait. We empirically find that SadTalker provides landmarks that follow the audio more precisely in the lips part, while Aniportrait provides more significant lips movement. Note that while pure landmark-based control of MOFA-Video supports long video generation via the periodic sampling strategy, current version of hybrid control only supports short video generation (25 frames), which means that the first 25 frames of the generated landmark sequences are used to obtain the result.
+ 4. Click the "Run" button to animate the image according to the trajectory and the landmark.
+ """
+ )
+
+ target_size = 512 # NOTICE: changing to lower resolution may impair the performance of the model.
+ DragNUWA_net = Drag("cuda:0", target_size, target_size, 25)
+ first_frame_path = gr.State()
+ audio_path = gr.State()
+ tracking_points = gr.State([])
+ motion_brush_points = gr.State([])
+ motion_brush_mask = gr.State()
+ motion_brush_viz = gr.State()
+ ldmk_mask_mask = gr.State()
+ ldmk_mask_viz = gr.State()
+
+ def preprocess_image(image):
+
+ image_pil = image2pil(image.name)
+ raw_w, raw_h = image_pil.size
+
+ max_edge = min(raw_w, raw_h)
+ resize_ratio = target_size / max_edge
+
+ image_pil = image_pil.resize((round(raw_w * resize_ratio), round(raw_h * resize_ratio)), Image.BILINEAR)
+
+ new_w, new_h = image_pil.size
+ crop_w = new_w - (new_w % 64)
+ crop_h = new_h - (new_h % 64)
+
+ image_pil = transforms.CenterCrop((crop_h, crop_w))(image_pil.convert('RGB'))
+
+ DragNUWA_net.width = crop_w
+ DragNUWA_net.height = crop_h
+
+ id = str(time.time()).split('.')[0]
+ os.makedirs(os.path.join(output_dir, str(id)), exist_ok=True)
+
+ first_frame_path = os.path.join(output_dir, str(id), f"input.png")
+ image_pil.save(first_frame_path)
+
+ return first_frame_path, first_frame_path, first_frame_path, first_frame_path, gr.State([]), gr.State([]), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4))
+
+ def convert_audio_to_wav(input_audio_file, output_wav_file):
+
+ extension = os.path.splitext(os.path.basename(input_audio_file))[-1]
+
+ if extension.lower() == ".mp3":
+ audio = AudioSegment.from_mp3(input_audio_file)
+ elif extension.lower() == ".wav":
+ audio = AudioSegment.from_wav(input_audio_file)
+ elif extension.lower() == ".ogg":
+ audio = AudioSegment.from_ogg(input_audio_file)
+ elif extension.lower() == ".flac":
+ audio = AudioSegment.from_file(input_audio_file, "flac")
+ else:
+ raise ValueError(f"Not supported extension: {extension}")
+
+ audio.export(output_wav_file, format="wav")
+
+ def save_audio(audio, first_frame_path):
+
+ assert first_frame_path is not None, "First upload image, then audio!"
+
+ img_basedir = os.path.dirname(first_frame_path)
+
+ id = str(time.time()).split('.')[0]
+
+ audio_path = os.path.join(img_basedir, f'audio_{str(id)}', 'audio.wav')
+ os.makedirs(os.path.dirname(audio_path), exist_ok=True)
+
+ # os.system(f'cp -r {audio.name} {audio_path}')
+
+ convert_audio_to_wav(audio.name, audio_path)
+
+ return audio_path, audio_path
+
+ def add_drag(tracking_points):
+ if len(tracking_points.constructor_args['value']) != 0 and tracking_points.constructor_args['value'][-1] == []:
+ return tracking_points
+ tracking_points.constructor_args['value'].append([])
+ return tracking_points
+
+ def delete_last_drag(tracking_points, first_frame_path, motion_brush_mask):
+
+ if len(tracking_points.constructor_args['value']) > 0:
+ tracking_points.constructor_args['value'].pop()
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+ transparent_layer = np.zeros((h, w, 4))
+ for track in tracking_points.constructor_args['value']:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = track[i]
+ end_point = track[i+1]
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ else:
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return tracking_points, trajectory_map, viz_flow
+
+ def add_motion_brushes(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, tracking_points, evt: gr.SelectData):
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+
+ motion_points = motion_brush_points.constructor_args['value']
+ motion_points.append(evt.index)
+
+ x, y = evt.index
+
+ cv2.circle(motion_brush_mask, (x, y), radius, 255, -1)
+ cv2.circle(transparent_layer, (x, y), radius, (128, 0, 128, 127), -1)
+
+ transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8))
+ motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return motion_brush_mask, transparent_layer, motion_map, viz_flow
+
+
+ def add_ldmk_mask(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, evt: gr.SelectData):
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+
+ motion_points = motion_brush_points.constructor_args['value']
+ motion_points.append(evt.index)
+
+ x, y = evt.index
+
+ cv2.circle(motion_brush_mask, (x, y), radius, 255, -1)
+ cv2.circle(transparent_layer, (x, y), radius, (0, 0, 255, 127), -1)
+
+ transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8))
+ motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil)
+
+ return motion_brush_mask, transparent_layer, motion_map
+
+
+
+ def add_tracking_points(tracking_points, first_frame_path, motion_brush_mask, evt: gr.SelectData): # SelectData is a subclass of EventData
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
+
+ if len(tracking_points.constructor_args['value']) == 0:
+ tracking_points.constructor_args['value'].append([])
+
+ tracking_points.constructor_args['value'][-1].append(evt.index)
+
+ print(tracking_points.constructor_args['value'])
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+ transparent_layer = np.zeros((h, w, 4))
+ for track in tracking_points.constructor_args['value']:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = track[i]
+ end_point = track[i+1]
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ else:
+ cv2.circle(transparent_layer, tuple(track[0]), 3, (255, 0, 0, 255), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return tracking_points, trajectory_map, viz_flow
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
+ audio_upload_button = gr.UploadButton(label="Upload Audio", file_types=["audio"])
+ input_audio = gr.Audio(label="Audio")
+ with gr.Column(scale=3):
+ add_drag_button = gr.Button(value="Add Trajectory")
+ delete_last_drag_button = gr.Button(value="Delete Last Trajectory")
+ run_button = gr.Button(value="Run")
+ with gr.Column(scale=3):
+ motion_brush_radius = gr.Slider(label='Motion Brush Radius',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=10)
+ ldmk_mask_radius = gr.Slider(label='Landmark Brush Radius',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=10)
+ with gr.Column(scale=3):
+ ctrl_scale_traj = gr.Slider(label='Control Scale for Trajectory',
+ minimum=0,
+ maximum=1.,
+ step=0.01,
+ value=0.6)
+ ctrl_scale_ldmk = gr.Slider(label='Control Scale for Landmark',
+ minimum=0,
+ maximum=1.,
+ step=0.01,
+ value=1.)
+ ldmk_render = gr.Radio(label='Landmark Renderer',
+ choices=['sadtalker', 'aniportrait'],
+ value='aniportrait')
+
+ with gr.Column(scale=4):
+ input_image = gr.Image(label="Add Trajectory Here",
+ interactive=True)
+ with gr.Column(scale=4):
+ motion_brush_image = gr.Image(label="Add Motion Brush Here",
+ interactive=True)
+ with gr.Column(scale=4):
+ ldmk_mask_image = gr.Image(label="Add Landmark Mask Here",
+ interactive=True)
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_flow = gr.Image(label="Temporary Trajectory Flow Visualization")
+ with gr.Column(scale=6):
+ hint_image = gr.Image(label="Final Hint Image")
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ traj_flows_gif = gr.Image(label="Trajectory Flow GIF")
+ with gr.Column(scale=6):
+ ldmk_flows_gif = gr.Image(label="Landmark Flow GIF")
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_ldmk_gif = gr.Image(label="Landmark Visualization GIF")
+ with gr.Column(scale=6):
+ outputs_gif = gr.Image(label="Output GIF")
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ traj_flows_mp4 = gr.Video(label="Trajectory Flow MP4")
+ with gr.Column(scale=6):
+ ldmk_flows_mp4 = gr.Video(label="Landmark Flow MP4")
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_ldmk_mp4 = gr.Video(label="Landmark Visualization MP4")
+ with gr.Column(scale=6):
+ outputs_mp4 = gr.Video(label="Output MP4")
+
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, motion_brush_image, ldmk_mask_image, first_frame_path, tracking_points, motion_brush_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz])
+
+ audio_upload_button.upload(save_audio, [audio_upload_button, first_frame_path], [input_audio, audio_path])
+
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
+
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow])
+
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow])
+
+ motion_brush_image.select(add_motion_brushes, [motion_brush_points, motion_brush_mask, motion_brush_viz, first_frame_path, motion_brush_radius, tracking_points], [motion_brush_mask, motion_brush_viz, motion_brush_image, viz_flow])
+
+ ldmk_mask_image.select(add_ldmk_mask, [motion_brush_points, ldmk_mask_mask, ldmk_mask_viz, first_frame_path, ldmk_mask_radius], [ldmk_mask_mask, ldmk_mask_viz, ldmk_mask_image])
+
+ run_button.click(DragNUWA_net.run, [first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render], [hint_image, traj_flows_gif, ldmk_flows_gif, viz_ldmk_gif, outputs_gif, traj_flows_mp4, ldmk_flows_mp4, viz_ldmk_mp4, outputs_mp4])
+
+ # demo.launch(server_name="0.0.0.0", debug=True, server_port=80)
+ demo.launch(server_name="127.0.0.1", debug=True, server_port=9080)
diff --git a/run_gradio_video_driven.py b/run_gradio_video_driven.py
new file mode 100644
index 0000000000000000000000000000000000000000..db708c711ebb68d7da9afcaf93a70496a842dff0
--- /dev/null
+++ b/run_gradio_video_driven.py
@@ -0,0 +1,1234 @@
+import gradio as gr
+import numpy as np
+import cv2
+import os
+from PIL import Image
+from scipy.interpolate import PchipInterpolator
+import torchvision
+import time
+from tqdm import tqdm
+import imageio
+
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from einops import repeat
+
+from pydub import AudioSegment
+
+from packaging import version
+
+from accelerate.utils import set_seed
+from transformers import CLIPVisionModelWithProjection
+
+from diffusers import AutoencoderKLTemporalDecoder
+from diffusers.utils.import_utils import is_xformers_available
+
+from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
+from pipeline.pipeline import FlowControlNetPipeline
+from models.traj_ctrlnet import FlowControlNet as DragControlNet, CMP_demo
+from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet
+
+from utils.flow_viz import flow_to_image
+from utils.utils import split_filename, image2arr, image2pil, ensure_dirname
+
+
+output_dir = "Output_video_driven"
+
+
+ensure_dirname(output_dir)
+
+
+def draw_landmarks_cv2(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 0, 255), -1)
+ return image
+
+
+def sample_optical_flow(A, B, h, w):
+ b, l, k, _ = A.shape
+
+ sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device)
+ mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device)
+
+ x_coords = A[..., 0].long()
+ y_coords = A[..., 1].long()
+
+ x_coords = torch.clip(x_coords, 0, h - 1)
+ y_coords = torch.clip(y_coords, 0, w - 1)
+
+ b_idx = torch.arange(b)[:, None, None].repeat(1, l, k)
+ l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k)
+
+ sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B
+
+ mask[b_idx, l_idx, x_coords, y_coords] = 1
+
+ mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2)
+
+ return sparse_optical_flow, mask
+
+
+@torch.no_grad()
+def get_sparse_flow(landmarks, h, w, t):
+
+ landmarks = torch.flip(landmarks, dims=[3])
+
+ pose_flow = (landmarks - landmarks[:, 0:1].repeat(1, t, 1, 1))[:, 1:] # 前向光流
+ according_poses = landmarks[:, 0:1].repeat(1, t - 1, 1, 1)
+
+ pose_flow = torch.flip(pose_flow, dims=[3])
+
+ b, t, K, _ = pose_flow.shape
+
+ sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w)
+
+ return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3)
+
+
+
+def sample_inputs_face(first_frame, landmarks):
+
+ pc, ph, pw = first_frame.shape
+ landmarks = landmarks.unsqueeze(0)
+
+ pl = landmarks.shape[1]
+
+ sparse_optical_flow, mask = get_sparse_flow(landmarks, ph, pw, pl)
+
+ if ph != 384 or pw != 384:
+
+ first_frame_384 = F.interpolate(first_frame.unsqueeze(0), (384, 384)) # [3, 384, 384]
+
+ landmarks_384 = torch.zeros_like(landmarks)
+ landmarks_384[:, :, :, 0] = landmarks[:, :, :, 0] / pw * 384
+ landmarks_384[:, :, :, 1] = landmarks[:, :, :, 1] / ph * 384
+
+ sparse_optical_flow_384, mask_384 = get_sparse_flow(landmarks_384, 384, 384, pl)
+
+ else:
+ first_frame_384, landmarks_384 = first_frame, landmarks
+ sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask
+
+ controlnet_image = first_frame.unsqueeze(0)
+
+ return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384
+
+
+
+PARTS = [
+ ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)),
+ ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)),
+ ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)),
+ ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)),
+ ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)),
+ ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)),
+ ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)),
+ ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)),
+ ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)),
+ ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)),
+ ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)),
+ ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)),
+ ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)),
+ ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)),
+ ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)),
+]
+
+
+def draw_landmarks(keypoints, h, w):
+
+ image = np.zeros((h, w, 3))
+
+ for name, indices, color in PARTS:
+ indices = np.array(indices) - 1
+ current_part_keypoints = keypoints[indices]
+
+ for i in range(len(indices) - 1):
+ x1, y1 = current_part_keypoints[i]
+ x2, y2 = current_part_keypoints[i + 1]
+ cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2)
+
+ return image
+
+
+def divide_points_afterinterpolate(resized_all_points, motion_brush_mask):
+ k = resized_all_points.shape[0]
+ starts = resized_all_points[:, 0] # [K, 2]
+
+ in_masks = []
+ out_masks = []
+
+ for i in range(k):
+ x, y = int(starts[i][1]), int(starts[i][0])
+ if motion_brush_mask[x][y] == 255:
+ in_masks.append(resized_all_points[i])
+ else:
+ out_masks.append(resized_all_points[i])
+
+ in_masks = np.array(in_masks)
+ out_masks = np.array(out_masks)
+
+ return in_masks, out_masks
+
+
+def get_sparseflow_and_mask_forward(
+ resized_all_points,
+ n_steps, H, W,
+ is_backward_flow=False
+ ):
+
+ K = resized_all_points.shape[0]
+
+ starts = resized_all_points[:, 0]
+
+ interpolated_ends = resized_all_points[:, 1:]
+
+ s_flow = np.zeros((K, n_steps, H, W, 2))
+ mask = np.zeros((K, n_steps, H, W))
+
+ for k in range(K):
+ for i in range(n_steps):
+ start, end = starts[k], interpolated_ends[k][i]
+ flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1)
+ s_flow[k][i][int(start[1]), int(start[0])] = flow
+ mask[k][i][int(start[1]), int(start[0])] = 1
+
+ s_flow = np.sum(s_flow, axis=0)
+ mask = np.sum(mask, axis=0)
+
+ return s_flow, mask
+
+
+def init_models(pretrained_model_name_or_path, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False):
+
+ drag_ckpt = "./ckpts/mofa/traj_controlnet"
+ face_ckpt = "./ckpts/mofa/ldmk_controlnet"
+
+ print('start loading models...')
+
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16"
+ )
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
+ pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16")
+ unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="unet",
+ low_cpu_mem_usage=True,
+ variant="fp16",
+ )
+
+ drag_controlnet = DragControlNet.from_pretrained(drag_ckpt)
+ face_controlnet = FaceControlNet.from_pretrained(face_ckpt)
+
+ cmp = CMP_demo(
+ './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml',
+ 42000
+ ).to(device)
+ cmp.requires_grad_(False)
+
+ # Freeze vae and image_encoder
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+ drag_controlnet.requires_grad_(False)
+ face_controlnet.requires_grad_(False)
+
+ # Move image_encoder and vae to gpu and cast to weight_dtype
+ image_encoder.to(device, dtype=weight_dtype)
+ vae.to(device, dtype=weight_dtype)
+ unet.to(device, dtype=weight_dtype)
+ drag_controlnet.to(device, dtype=weight_dtype)
+ face_controlnet.to(device, dtype=weight_dtype)
+
+ if enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ print(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError(
+ "xformers is not available. Make sure it is installed correctly")
+
+ if allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ pipeline = FlowControlNetPipeline.from_pretrained(
+ pretrained_model_name_or_path,
+ unet=unet,
+ face_controlnet=face_controlnet,
+ drag_controlnet=drag_controlnet,
+ image_encoder=image_encoder,
+ vae=vae,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(device)
+
+ print('models loaded.')
+
+ return pipeline, cmp
+
+
+def interpolate_trajectory(points, n_points):
+ x = [point[0] for point in points]
+ y = [point[1] for point in points]
+
+ t = np.linspace(0, 1, len(points))
+
+ fx = PchipInterpolator(t, x)
+ fy = PchipInterpolator(t, y)
+
+ new_t = np.linspace(0, 1, n_points)
+
+ new_x = fx(new_t)
+ new_y = fy(new_t)
+ new_points = list(zip(new_x, new_y))
+
+ return new_points
+
+
+def visualize_drag_v2(background_image_path, splited_tracks, width, height):
+ trajectory_maps = []
+
+ background_image = Image.open(background_image_path).convert('RGBA')
+ background_image = background_image.resize((width, height))
+ w, h = background_image.size
+ transparent_background = np.array(background_image)
+ transparent_background[:, :, -1] = 128
+ transparent_background = Image.fromarray(transparent_background)
+
+ # Create a transparent layer with the same size as the background image
+ transparent_layer = np.zeros((h, w, 4))
+ for splited_track in splited_tracks:
+ if len(splited_track) > 1:
+ splited_track = interpolate_trajectory(splited_track, 16)
+ splited_track = splited_track[:16]
+ for i in range(len(splited_track)-1):
+ start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
+ end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(splited_track)-2:
+ cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
+ else:
+ cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+ trajectory_maps.append(trajectory_map)
+ return trajectory_maps, transparent_layer
+
+
+class Drag:
+ def __init__(self, device, height, width, model_length):
+ self.device = device
+
+ pretrained_model_name_or_path = "./ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1"
+
+ self.device = 'cuda'
+ self.weight_dtype = torch.float16
+
+ self.pipeline, self.cmp = init_models(
+ pretrained_model_name_or_path,
+ weight_dtype=self.weight_dtype,
+ device=self.device,
+ )
+
+ self.height = height
+ self.width = width
+ self.model_length = model_length
+
+ def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None):
+
+ b, t, c, h, w = frames.shape
+ assert h == 384 and w == 384
+ frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
+ sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
+ mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
+
+ cmp_flow = []
+ for i in range(b*t):
+ tmp_flow = self.cmp.run(frames[i:i+1], sparse_optical_flow[i:i+1], mask[i:i+1]) # [1, 2, 256, 256]
+ cmp_flow.append(tmp_flow)
+ cmp_flow = torch.cat(cmp_flow, dim=0) # [b*13, 2, 256, 256]
+
+ if brush_mask is not None:
+ brush_mask = torch.from_numpy(brush_mask) / 255.
+ brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype)
+ brush_mask = brush_mask.unsqueeze(0).unsqueeze(0)
+ cmp_flow = cmp_flow * brush_mask
+
+ cmp_flow = cmp_flow.reshape(b, t, 2, h, w)
+
+ return cmp_flow
+
+
+ def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None):
+
+ fb, fl, fc, _, _ = pixel_values_384.shape
+
+ controlnet_flow = self.get_cmp_flow(
+ pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1),
+ sparse_optical_flow_384,
+ mask_384, motion_brush_mask
+ )
+
+ if self.height != 384 or self.width != 384:
+ scales = [self.height / 384, self.width / 384]
+ controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width)
+ controlnet_flow[:, :, 0] *= scales[1]
+ controlnet_flow[:, :, 1] *= scales[0]
+
+ return controlnet_flow
+
+ @torch.no_grad()
+ def forward_sample(self, save_root, first_frame_path, driven_video_path, hint_path, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask_384=None, ldmk_mask_mask_origin=None, ctrl_scale_traj=1., ctrl_scale_ldmk=1., ldmk_render='sadtalker'):
+
+ seed = 42
+
+ num_frames = self.model_length
+
+ set_seed(seed)
+
+ input_first_frame_384 = F.interpolate(input_first_frame, (384, 384))
+ input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0)
+ input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255))
+ height, width = input_first_frame.shape[-2:]
+
+ input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+ input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+
+ input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
+ input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
+
+ input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
+
+ if in_mask_flag:
+ flow_inmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384
+ )
+ else:
+ fb, fl = mask_384_inmask.shape[:2]
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ if out_mask_flag:
+ flow_outmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_outmask, mask_384_outmask
+ )
+ else:
+ fb, fl = mask_384_outmask.shape[:2]
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ inmask_no_zero = (flow_inmask != 0).all(dim=2)
+ inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
+
+ controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
+
+ ldmk_controlnet_flow, ldmk_pose_imgs, landmarks, num_frames = self.get_landmarks(save_root, first_frame_path, driven_video_path, input_first_frame[0], self.model_length, ldmk_render=ldmk_render)
+
+ ldmk_flow_len = ldmk_controlnet_flow.shape[1]
+ drag_flow_len = controlnet_flow.shape[1]
+ repeat_num = ldmk_flow_len // drag_flow_len + 1
+ drag_controlnet_flow = controlnet_flow.repeat(1, repeat_num, 1, 1, 1)
+ drag_controlnet_flow = drag_controlnet_flow[:, :ldmk_flow_len]
+
+ ldmk_mask_mask_origin = ldmk_mask_mask_origin.unsqueeze(0).unsqueeze(0) # [1, 1, h, w]
+
+ val_output = self.pipeline(
+ input_first_frame_pil,
+ input_first_frame_pil,
+
+ ldmk_controlnet_flow,
+ ldmk_pose_imgs,
+
+ drag_controlnet_flow,
+ ldmk_mask_mask_origin,
+
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ decode_chunk_size=8,
+ motion_bucket_id=127,
+ fps=7,
+ noise_aug_strength=0.02,
+ ctrl_scale_traj=ctrl_scale_traj,
+ ctrl_scale_ldmk=ctrl_scale_ldmk,
+ )
+
+ video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow
+
+ for i in range(num_frames):
+ img = video_frames[i]
+ video_frames[i] = np.array(img)
+
+ video_frames = np.array(video_frames)
+
+ outputs = self.save_video(ldmk_pose_imgs, first_frame_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow)
+
+ return outputs
+
+ def save_video(self, pose_imgs, image_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow, outputs=dict()):
+
+ pose_img_nps = (pose_imgs[0].permute(0, 2, 3, 1).cpu().numpy()*255).astype(np.uint8)
+
+ cv2_firstframe = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
+ cv2_hint = cv2.cvtColor(cv2.imread(hint_path), cv2.COLOR_BGR2RGB)
+
+ viz_landmarks = []
+ for k in tqdm(range(len(landmarks))):
+ im = draw_landmarks_cv2(video_frames[k].copy(), landmarks[k])
+ viz_landmarks.append(im)
+ viz_landmarks = np.stack(viz_landmarks)
+
+ viz_esti_flows = []
+ for i in range(estimated_flow.shape[1]):
+ temp_flow = estimated_flow[0][i].permute(1, 2, 0)
+ viz_esti_flows.append(flow_to_image(temp_flow))
+ viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows
+ viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c]
+
+ viz_drag_flows = []
+ for i in range(drag_controlnet_flow.shape[1]):
+ temp_flow = drag_controlnet_flow[0][i].permute(1, 2, 0)
+ viz_drag_flows.append(flow_to_image(temp_flow))
+ viz_drag_flows = [np.uint8(np.ones_like(viz_drag_flows[-1]) * 255)] + viz_drag_flows
+ viz_drag_flows = np.stack(viz_drag_flows) # [t-1, h, w, c]
+
+ out_nps = []
+ for plen in range(video_frames.shape[0]):
+ out_nps.append(video_frames[plen])
+ out_nps = np.stack(out_nps)
+
+ first_frames = np.stack([cv2_firstframe] * out_nps.shape[0])
+ hints = np.stack([cv2_hint] * out_nps.shape[0])
+
+ total_nps = np.concatenate([
+ first_frames, hints, viz_drag_flows, viz_esti_flows, pose_img_nps, viz_landmarks, out_nps
+ ], axis=2)
+
+ video_frames_tensor = torch.from_numpy(video_frames).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+
+ outputs['logits_imgs'] = video_frames_tensor
+ outputs['traj_flows'] = torch.from_numpy(viz_drag_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['ldmk_flows'] = torch.from_numpy(viz_esti_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['viz_ldmk'] = torch.from_numpy(pose_img_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['out_with_ldmk'] = torch.from_numpy(viz_landmarks).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+ outputs['total'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+
+ return outputs
+
+ @torch.no_grad()
+ def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path):
+
+ original_width, original_height = self.width, self.height
+
+ flow_div = self.model_length
+
+ input_all_points = tracking_points.constructor_args['value']
+
+ if len(input_all_points) == 0 or len(input_all_points[-1]) == 1:
+ return np.uint8(np.ones((original_width, original_height, 3))*255)
+
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
+ resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points]
+
+ new_resized_all_points = []
+ new_resized_all_points_384 = []
+ for tnum in range(len(resized_all_points)):
+ new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div))
+ new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div))
+
+ resized_all_points = np.array(new_resized_all_points)
+ resized_all_points_384 = np.array(new_resized_all_points_384)
+
+ motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST)
+
+ resized_all_points_384_inmask, resized_all_points_384_outmask = \
+ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384)
+
+ in_mask_flag = False
+ out_mask_flag = False
+
+ if resized_all_points_384_inmask.shape[0] != 0:
+ in_mask_flag = True
+ input_drag_384_inmask, input_mask_384_inmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_inmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_inmask, input_mask_384_inmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ if resized_all_points_384_outmask.shape[0] != 0:
+ out_mask_flag = True
+ input_drag_384_outmask, input_mask_384_outmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_outmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_outmask, input_mask_384_outmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
+
+ first_frames_transform = transforms.Compose([
+ lambda x: Image.fromarray(x),
+ transforms.ToTensor(),
+ ])
+
+ input_first_frame = image2arr(first_frame_path)
+ input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device)
+
+ seed = 42
+ num_frames = flow_div
+
+ set_seed(seed)
+
+ input_first_frame_384 = F.interpolate(input_first_frame, (384, 384))
+ input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0)
+
+ input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+ input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+
+ input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
+ input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
+
+ input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
+
+ if in_mask_flag:
+ flow_inmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384
+ )
+ else:
+ fb, fl = mask_384_inmask.shape[:2]
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ if out_mask_flag:
+ flow_outmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_outmask, mask_384_outmask
+ )
+ else:
+ fb, fl = mask_384_outmask.shape[:2]
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ inmask_no_zero = (flow_inmask != 0).all(dim=2)
+ inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
+
+ controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
+
+ print(controlnet_flow.shape)
+
+ controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0)
+ viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c]
+
+ return viz_esti_flows
+
+ @torch.no_grad()
+ def get_cmp_flow_landmarks(self, frames, sparse_optical_flow, mask):
+
+ dtype = frames.dtype
+ b, t, c, h, w = sparse_optical_flow.shape
+ assert h == 384 and w == 384
+ frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
+ sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
+ mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
+
+ cmp_flow = []
+ for i in range(b*t):
+ tmp_flow = self.cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float()) # [b*13, 2, 256, 256]
+ cmp_flow.append(tmp_flow)
+ cmp_flow = torch.cat(cmp_flow, dim=0)
+ cmp_flow = cmp_flow.reshape(b, t, 2, h, w)
+
+ return cmp_flow.to(dtype=dtype)
+
+ def video2landmark(self, driven_video_path, img_path, ldmk_result_dir, ldmk_render=0):
+
+ if ldmk_render == 'sadtalker':
+ return_code = os.system(
+ f'''
+ python sadtalker_video2pose/inference.py \
+ --preprocess full \
+ --size 256 \
+ --ref_pose {driven_video_path} \
+ --source_image {img_path} \
+ --result_dir {ldmk_result_dir} \
+ --facerender pirender \
+ --verbose \
+ --face3dvis
+ ''')
+ assert return_code == 0, "Errors in generating landmarks! Maybe Sadtalker can not detect the landmark from source video. Please trace back up for detailed error report."
+ else:
+ assert False
+
+ return os.path.join(ldmk_result_dir, 'landmarks.npy')
+
+
+ def get_landmarks(self, save_root, first_frame_path, driven_video_path, first_frame, num_frames=25, ldmk_render='sadtalker'):
+
+ ldmk_dir = os.path.join(save_root, 'landmarks')
+ ldmknpy_dir = self.video2landmark(driven_video_path, first_frame_path, ldmk_dir, ldmk_render)
+
+ landmarks = np.load(ldmknpy_dir)
+ landmarks = landmarks[:num_frames] # [25, 68, 2]
+ flow_len = landmarks.shape[0]
+
+ ldmk_clip = landmarks.copy()
+
+ assert ldmk_clip.ndim == 3
+
+ ldmk_clip[:, :, 0] = ldmk_clip[:, :, 0] / self.width * 320
+ ldmk_clip[:, :, 1] = ldmk_clip[:, :, 1] / self.height * 320
+
+ pose_imgs = []
+ for i in range(ldmk_clip.shape[0]):
+ pose_img = draw_landmarks(ldmk_clip[i], 320, 320)
+ pose_img = cv2.resize(pose_img, (self.width, self.height), cv2.INTER_NEAREST)
+ pose_imgs.append(pose_img)
+ pose_imgs = np.array(pose_imgs)
+ pose_imgs = torch.from_numpy(pose_imgs).permute(0, 3, 1, 2).float() / 255.
+ pose_imgs = pose_imgs.unsqueeze(0).to(self.weight_dtype).to(self.device)
+
+ landmarks = torch.from_numpy(landmarks).to(self.weight_dtype).to(self.device)
+
+ val_controlnet_image, val_sparse_optical_flow, \
+ val_mask, val_first_frame_384, \
+ val_sparse_optical_flow_384, val_mask_384 = sample_inputs_face(first_frame, landmarks)
+
+ fb, fl, fc, fh, fw = val_sparse_optical_flow.shape
+
+ val_controlnet_flow = self.get_cmp_flow_landmarks(
+ val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1),
+ val_sparse_optical_flow_384,
+ val_mask_384
+ )
+
+ if fh != 384 or fw != 384:
+ scales = [fh / 384, fw / 384]
+ val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw)
+ val_controlnet_flow[:, :, 0] *= scales[1]
+ val_controlnet_flow[:, :, 1] *= scales[0]
+
+ val_controlnet_image = val_controlnet_image.unsqueeze(0).repeat(1, fl, 1, 1, 1)
+
+ return val_controlnet_flow, pose_imgs, landmarks, flow_len
+
+
+ def run(self, first_frame_path, driven_video_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render):
+
+
+ timestamp = str(time.time()).split('.')[0]
+ save_name = f"trajscale{ctrl_scale_traj}_ldmkscale{ctrl_scale_ldmk}_{ldmk_render}_ts{timestamp}"
+ save_root = os.path.join(os.path.dirname(driven_video_path), save_name)
+ os.makedirs(save_root, exist_ok=True)
+
+
+ original_width, original_height = self.width, self.height
+
+ flow_div = self.model_length
+
+ input_all_points = tracking_points.constructor_args['value']
+
+ # print(input_all_points)
+
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
+ resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points]
+
+ new_resized_all_points = []
+ new_resized_all_points_384 = []
+ for tnum in range(len(resized_all_points)):
+ new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div))
+ new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div))
+
+ resized_all_points = np.array(new_resized_all_points)
+ resized_all_points_384 = np.array(new_resized_all_points_384)
+
+ motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST)
+ # ldmk_mask_mask_384 = cv2.resize(ldmk_mask_mask, (384, 384), cv2.INTER_NEAREST)
+
+ # motion_brush_mask = torch.from_numpy(motion_brush_mask) / 255.
+ # motion_brush_mask = motion_brush_mask.to(self.device)
+
+ ldmk_mask_mask = torch.from_numpy(ldmk_mask_mask) / 255.
+ ldmk_mask_mask = ldmk_mask_mask.to(self.device)
+
+ if resized_all_points_384.shape[0] != 0:
+ resized_all_points_384_inmask, resized_all_points_384_outmask = \
+ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384)
+ else:
+ resized_all_points_384_inmask = np.array([])
+ resized_all_points_384_outmask = np.array([])
+
+ in_mask_flag = False
+ out_mask_flag = False
+
+ if resized_all_points_384_inmask.shape[0] != 0:
+ in_mask_flag = True
+ input_drag_384_inmask, input_mask_384_inmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_inmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_inmask, input_mask_384_inmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ if resized_all_points_384_outmask.shape[0] != 0:
+ out_mask_flag = True
+ input_drag_384_outmask, input_mask_384_outmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_outmask,
+ flow_div - 1, 384, 384
+ )
+ else:
+ input_drag_384_outmask, input_mask_384_outmask = \
+ np.zeros((flow_div - 1, 384, 384, 2)), \
+ np.zeros((flow_div - 1, 384, 384))
+
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2]
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w]
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2]
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w]
+
+ dir, base, ext = split_filename(first_frame_path)
+ id = base.split('_')[0]
+
+ image_pil = image2pil(first_frame_path)
+ image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGBA')
+
+ visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height)
+
+ motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA')
+ visualized_drag = visualized_drag[0].convert('RGBA')
+ ldmk_mask_viz_pil = Image.fromarray(ldmk_mask_viz.astype(np.uint8)).convert('RGBA')
+
+ drag_input = Image.alpha_composite(image_pil, visualized_drag)
+ motionbrush_ldmkmask = Image.alpha_composite(motion_brush_viz_pil, ldmk_mask_viz_pil)
+
+ visualized_drag_brush_ldmk_mask = Image.alpha_composite(drag_input, motionbrush_ldmkmask)
+
+ first_frames_transform = transforms.Compose([
+ lambda x: Image.fromarray(x),
+ transforms.ToTensor(),
+ ])
+
+ hint_path = os.path.join(save_root, f'hint.png')
+ visualized_drag_brush_ldmk_mask.save(hint_path)
+
+ first_frames = image2arr(first_frame_path)
+ first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=1).to(self.device)
+
+ outputs = self.forward_sample(
+ save_root,
+ first_frame_path,
+ driven_video_path,
+ hint_path,
+ input_drag_384_inmask.to(self.device),
+ input_drag_384_outmask.to(self.device),
+ first_frames.to(self.device),
+ input_mask_384_inmask.to(self.device),
+ input_mask_384_outmask.to(self.device),
+ in_mask_flag,
+ out_mask_flag,
+ motion_brush_mask_384, ldmk_mask_mask,
+ ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render=ldmk_render)
+
+ traj_flow_tensor = outputs['traj_flows'][0] # [25, 3, h, w]
+ ldmk_flow_tensor = outputs['ldmk_flows'][0] # [25, 3, h, w]
+ viz_ldmk_tensor = outputs['viz_ldmk'][0] # [25, 3, h, w]
+ out_with_ldmk_tensor = outputs['out_with_ldmk'][0] # [25, 3, h, w]
+ output_tensor = outputs['logits_imgs'][0] # [25, 3, h, w]
+ total_tensor = outputs['total'][0] # [25, 3, h, w]
+
+ traj_flows_path = os.path.join(save_root, f'traj_flow.gif')
+ ldmk_flows_path = os.path.join(save_root, f'ldmk_flow.gif')
+ viz_ldmk_path = os.path.join(save_root, f'viz_ldmk.gif')
+ out_with_ldmk_path = os.path.join(save_root, f'output_w_ldmk.gif')
+ outputs_path = os.path.join(save_root, f'output.gif')
+ total_path = os.path.join(save_root, f'total.gif')
+
+ traj_flows_path_mp4 = os.path.join(save_root, f'traj_flow.mp4')
+ ldmk_flows_path_mp4 = os.path.join(save_root, f'ldmk_flow.mp4')
+ viz_ldmk_path_mp4 = os.path.join(save_root, f'viz_ldmk.mp4')
+ out_with_ldmk_path_mp4 = os.path.join(save_root, f'output_w_ldmk.mp4')
+ outputs_path_mp4 = os.path.join(save_root, f'output.mp4')
+ total_path_mp4 = os.path.join(save_root, f'total.mp4')
+
+ # print(output_tensor.shape)
+
+ traj_flow_np = traj_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ ldmk_flow_np = ldmk_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ viz_ldmk_np = viz_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ out_with_ldmk_np = out_with_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ output_np = output_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+ total_np = total_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy()
+
+ torchvision.io.write_video(
+ traj_flows_path_mp4,
+ traj_flow_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ ldmk_flows_path_mp4,
+ ldmk_flow_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ viz_ldmk_path_mp4,
+ viz_ldmk_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ out_with_ldmk_path_mp4,
+ out_with_ldmk_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+ torchvision.io.write_video(
+ outputs_path_mp4,
+ output_np,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+
+ imageio.mimsave(traj_flows_path, np.uint8(traj_flow_np), fps=20, loop=0)
+ imageio.mimsave(ldmk_flows_path, np.uint8(ldmk_flow_np), fps=20, loop=0)
+ imageio.mimsave(viz_ldmk_path, np.uint8(viz_ldmk_np), fps=20, loop=0)
+ imageio.mimsave(out_with_ldmk_path, np.uint8(out_with_ldmk_np), fps=20, loop=0)
+ imageio.mimsave(outputs_path, np.uint8(output_np), fps=20, loop=0)
+
+ torchvision.io.write_video(total_path_mp4, total_np, fps=20, video_codec='h264', options={'crf': '10'})
+ imageio.mimsave(total_path, np.uint8(total_np), fps=20, loop=0)
+
+ return hint_path, traj_flows_path, ldmk_flows_path, viz_ldmk_path, outputs_path, traj_flows_path_mp4, ldmk_flows_path_mp4, viz_ldmk_path_mp4, outputs_path_mp4
+
+
+with gr.Blocks() as demo:
+ gr.Markdown("""MOFA-Video
""")
+
+ gr.Markdown("""Official Gradio Demo for MOFA-Video: Controllable Image Animation via Generative Motion Field Adaptions in Frozen Image-to-Video Diffusion Model.
""")
+
+ gr.Markdown(
+ """
+ 1. Use the "Upload Image" button to upload an image. Avoid dragging the image directly into the window.
+ 2. Proceed to trajectory control:
+ 2.1. Click "Add Trajectory" first, then select points on the "Add Trajectory Here" image. The first click sets the starting point. Click multiple points to create a non-linear trajectory. To add a new trajectory, click "Add Trajectory" again and select points on the image.
+ 2.2. After adding each trajectory, an optical flow image will be displayed automatically in "Temporary Trajectory Flow Visualization". Use it as a reference to adjust the trajectory for desired effects (e.g., area, intensity).
+ 2.3. To delete the latest trajectory, click "Delete Last Trajectory."
+ 2.4. To use the motion brush for restraining the control area of the trajectory, click to add masks on the "Add Motion Brush Here" image. The motion brush restricts the optical flow area derived from the trajectory whose starting point is within the motion brush. The displayed optical flow image will change correspondingly. Adjust the motion brush radius using the "Motion Brush Radius" slider.
+ 2.5. Choose the Control scale for trajectory using the "Control Scale for Trajectory" slider. This determines the control intensity of trajectory. Setting it to 0 means no control (pure generation result of SVD itself), while setting it to 1 results in the strongest control (which will not lead to good results in most cases because of twisting artifacts). A preset value of 0.6 is recommended for most cases.
+ 3. Proceed to landmark control from driven video:
+ 3.1. Use the "Upload Driven Video" button to upload an driven video (We have tested .mp4 extensions, and other formats compatible with `cv2.VideoCapture` may also be uploaded without causing errors.).
+ 3.2. Click to add masks on the "Add Landmark Mask Here" image. This mask restricts the optical flow area derived from the landmarks, which should usually covers the area of the person's head parts, and, if desired, body parts for more natural body movement instead of being stationary. Adjust the landmark brush radius using the "Landmark Brush Radius" slider.
+ 3.3. Choose the Control scale for landmarks using the "Control Scale for Landmark" slider. This determines the control intensity of landmarks. Different from trajectory controls, a preset value of 1 is recommended for most cases.
+ 3.4. For video-driven landmark generation, our codes are modified based on SadTalker. Note that while pure landmark-based control of MOFA-Video supports long video generation via the periodic sampling strategy, current version of hybrid control only supports short video generation (25 frames), which means that the first 25 frames of the generated landmark sequences are used to obtain the result.
+ 4. Click the "Run" button to animate the image according to the trajectory and the landmark.
+ """
+ )
+
+ target_size = 512 # NOTICE: changing to lower resolution may impair the performance of the model.
+ DragNUWA_net = Drag("cuda:0", target_size, target_size, 25)
+ first_frame_path = gr.State()
+ driven_video_path = gr.State()
+ tracking_points = gr.State([])
+ motion_brush_points = gr.State([])
+ motion_brush_mask = gr.State()
+ motion_brush_viz = gr.State()
+ ldmk_mask_mask = gr.State()
+ ldmk_mask_viz = gr.State()
+
+ def preprocess_image(image):
+
+ image_pil = image2pil(image.name)
+ raw_w, raw_h = image_pil.size
+
+ max_edge = min(raw_w, raw_h)
+ resize_ratio = target_size / max_edge
+
+ image_pil = image_pil.resize((round(raw_w * resize_ratio), round(raw_h * resize_ratio)), Image.BILINEAR)
+
+ new_w, new_h = image_pil.size
+ crop_w = new_w - (new_w % 64)
+ crop_h = new_h - (new_h % 64)
+
+ image_pil = transforms.CenterCrop((crop_h, crop_w))(image_pil.convert('RGB'))
+
+ DragNUWA_net.width = crop_w
+ DragNUWA_net.height = crop_h
+
+ id = str(time.time()).split('.')[0]
+ os.makedirs(os.path.join(output_dir, str(id)), exist_ok=True)
+
+ first_frame_path = os.path.join(output_dir, str(id), f"input.png")
+ image_pil.save(first_frame_path)
+
+ return first_frame_path, first_frame_path, first_frame_path, first_frame_path, gr.State([]), gr.State([]), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4))
+
+ def video_to_numpy_array(video_path):
+ video = cv2.VideoCapture(video_path)
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
+ frames = []
+ for i in range(num_frames):
+ ret, frame = video.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ frames.append(frame)
+ else:
+ break
+ video.release()
+ frames = np.stack(frames, axis=0)
+ return frames
+
+ def convert_video_to_mp4(input_audio_file, output_wav_file):
+ video_np = np.uint8(video_to_numpy_array(input_audio_file))
+ torchvision.io.write_video(
+ output_wav_file,
+ video_np,
+ fps=25, video_codec='h264', options={'crf': '10'}
+ )
+
+ def save_driven_video(driven_video, first_frame_path):
+
+ assert first_frame_path is not None, "Please first upload image, then upload audio."
+
+ img_basedir = os.path.dirname(first_frame_path)
+
+ id = str(time.time()).split('.')[0]
+
+ driven_video_path = os.path.join(img_basedir, f'driven_video_{str(id)}', 'driven_video.mp4')
+ os.makedirs(os.path.dirname(driven_video_path), exist_ok=True)
+
+ convert_video_to_mp4(driven_video.name, driven_video_path)
+
+ return driven_video_path, driven_video_path
+
+ def add_drag(tracking_points):
+ if len(tracking_points.constructor_args['value']) != 0 and tracking_points.constructor_args['value'][-1] == []:
+ return tracking_points
+ tracking_points.constructor_args['value'].append([])
+ return tracking_points
+
+ def delete_last_drag(tracking_points, first_frame_path, motion_brush_mask):
+
+ if len(tracking_points.constructor_args['value']) > 0:
+ tracking_points.constructor_args['value'].pop()
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+ transparent_layer = np.zeros((h, w, 4))
+ for track in tracking_points.constructor_args['value']:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = track[i]
+ end_point = track[i+1]
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ else:
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return tracking_points, trajectory_map, viz_flow
+
+ def add_motion_brushes(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, tracking_points, evt: gr.SelectData):
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+
+ motion_points = motion_brush_points.constructor_args['value']
+ motion_points.append(evt.index)
+
+ x, y = evt.index
+
+ cv2.circle(motion_brush_mask, (x, y), radius, 255, -1)
+ cv2.circle(transparent_layer, (x, y), radius, (128, 0, 128, 127), -1)
+
+ transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8))
+ motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return motion_brush_mask, transparent_layer, motion_map, viz_flow
+
+
+ def add_ldmk_mask(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, evt: gr.SelectData):
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+
+ motion_points = motion_brush_points.constructor_args['value']
+ motion_points.append(evt.index)
+
+ x, y = evt.index
+
+ cv2.circle(motion_brush_mask, (x, y), radius, 255, -1)
+ cv2.circle(transparent_layer, (x, y), radius, (0, 0, 255, 127), -1)
+
+ transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8))
+ motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil)
+
+ return motion_brush_mask, transparent_layer, motion_map
+
+
+
+ def add_tracking_points(tracking_points, first_frame_path, motion_brush_mask, evt: gr.SelectData): # SelectData is a subclass of EventData
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
+
+ if len(tracking_points.constructor_args['value']) == 0:
+ tracking_points.constructor_args['value'].append([])
+
+ tracking_points.constructor_args['value'][-1].append(evt.index)
+
+ print(tracking_points.constructor_args['value'])
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+ transparent_layer = np.zeros((h, w, 4))
+ for track in tracking_points.constructor_args['value']:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = track[i]
+ end_point = track[i+1]
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ else:
+ cv2.circle(transparent_layer, tuple(track[0]), 3, (255, 0, 0, 255), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return tracking_points, trajectory_map, viz_flow
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
+ video_upload_button = gr.UploadButton(label="Upload Driven Video", file_types=["video"])
+ driven_video = gr.Video(label="Driven Video")
+ with gr.Column(scale=3):
+ add_drag_button = gr.Button(value="Add Trajectory")
+ delete_last_drag_button = gr.Button(value="Delete Last Trajectory")
+ run_button = gr.Button(value="Run")
+ with gr.Column(scale=3):
+ motion_brush_radius = gr.Slider(label='Motion Brush Radius',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=10)
+ ldmk_mask_radius = gr.Slider(label='Landmark Brush Radius',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=10)
+ with gr.Column(scale=3):
+ ctrl_scale_traj = gr.Slider(label='Control Scale for Trajectory',
+ minimum=0,
+ maximum=1.,
+ step=0.01,
+ value=0.6)
+ ctrl_scale_ldmk = gr.Slider(label='Control Scale for Landmark',
+ minimum=0,
+ maximum=1.,
+ step=0.01,
+ value=1.)
+ ldmk_render = gr.Radio(label='Landmark Renderer',
+ choices=['sadtalker'],
+ value='sadtalker')
+
+ with gr.Column(scale=4):
+ input_image = gr.Image(label="Add Trajectory Here",
+ interactive=True)
+ with gr.Column(scale=4):
+ motion_brush_image = gr.Image(label="Add Motion Brush Here",
+ interactive=True)
+ with gr.Column(scale=4):
+ ldmk_mask_image = gr.Image(label="Add Landmark Mask Here",
+ interactive=True)
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_flow = gr.Image(label="Temporary Trajectory Flow Visualization")
+ with gr.Column(scale=6):
+ hint_image = gr.Image(label="Final Hint Image")
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ traj_flows_gif = gr.Image(label="Trajectory Flow GIF")
+ with gr.Column(scale=6):
+ ldmk_flows_gif = gr.Image(label="Landmark Flow GIF")
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_ldmk_gif = gr.Image(label="Landmark Visualization GIF")
+ with gr.Column(scale=6):
+ outputs_gif = gr.Image(label="Output GIF")
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ traj_flows_mp4 = gr.Video(label="Trajectory Flow MP4")
+ with gr.Column(scale=6):
+ ldmk_flows_mp4 = gr.Video(label="Landmark Flow MP4")
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_ldmk_mp4 = gr.Video(label="Landmark Visualization MP4")
+ with gr.Column(scale=6):
+ outputs_mp4 = gr.Video(label="Output MP4")
+
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, motion_brush_image, ldmk_mask_image, first_frame_path, tracking_points, motion_brush_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz])
+
+ video_upload_button.upload(save_driven_video, [video_upload_button, first_frame_path], [driven_video, driven_video_path])
+
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
+
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow])
+
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow])
+
+ motion_brush_image.select(add_motion_brushes, [motion_brush_points, motion_brush_mask, motion_brush_viz, first_frame_path, motion_brush_radius, tracking_points], [motion_brush_mask, motion_brush_viz, motion_brush_image, viz_flow])
+
+ ldmk_mask_image.select(add_ldmk_mask, [motion_brush_points, ldmk_mask_mask, ldmk_mask_viz, first_frame_path, ldmk_mask_radius], [ldmk_mask_mask, ldmk_mask_viz, ldmk_mask_image])
+
+ run_button.click(DragNUWA_net.run, [first_frame_path, driven_video_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render], [hint_image, traj_flows_gif, ldmk_flows_gif, viz_ldmk_gif, outputs_gif, traj_flows_mp4, ldmk_flows_mp4, viz_ldmk_mp4, outputs_mp4])
+
+ # demo.launch(server_name="0.0.0.0", debug=True, server_port=80)
+ demo.launch(server_name="127.0.0.1", debug=True, server_port=9080)
diff --git a/sadtalker_audio2pose/.DS_Store b/sadtalker_audio2pose/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..27cf1b75b1e9bcb455443bb0050ad3478243664f
Binary files /dev/null and b/sadtalker_audio2pose/.DS_Store differ
diff --git a/sadtalker_audio2pose/inference.py b/sadtalker_audio2pose/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..6198f61aa619cd7ee7a0ce973541a2c78747011b
--- /dev/null
+++ b/sadtalker_audio2pose/inference.py
@@ -0,0 +1,188 @@
+from glob import glob
+import shutil
+import torch
+from time import strftime
+import os, sys, time
+from argparse import ArgumentParser
+import platform
+
+from src.utils.preprocess import CropAndExtract
+from src.test_audio2coeff import Audio2Coeff
+from src.facerender.animate import AnimateFromCoeff
+from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
+from src.generate_batch import get_data
+from src.generate_facerender_batch import get_facerender_data
+from src.utils.init_path import init_path
+
+import random
+import numpy as np
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+def main(args):
+ #torch.backends.cudnn.enabled = False
+
+ set_seed(42)
+
+ # args.facerender = 'pirender'
+
+
+
+ pic_path = args.source_image
+ audio_path = args.driven_audio
+ save_dir = args.result_dir
+ os.makedirs(save_dir, exist_ok=True)
+ pose_style = args.pose_style
+ device = args.device
+ batch_size = args.batch_size
+ input_yaw_list = args.input_yaw
+ input_pitch_list = args.input_pitch
+ input_roll_list = args.input_roll
+ ref_eyeblink = args.ref_eyeblink
+ ref_pose = args.ref_pose
+
+ # print(args.still)
+ # assert False
+
+ current_root_path = os.path.split(sys.argv[0])[0]
+
+ sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
+
+ #init model
+ preprocess_model = CropAndExtract(sadtalker_paths, device)
+
+ audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
+
+ if args.facerender == 'facevid2vid':
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
+ elif args.facerender == 'pirender':
+ animate_from_coeff = AnimateFromCoeff_PIRender(sadtalker_paths, device)
+ else:
+ raise(RuntimeError('Unknown model: {}'.format(args.facerender)))
+
+ #crop image and extract 3dmm from image
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+ os.makedirs(first_frame_dir, exist_ok=True)
+ print('3DMM Extraction for source image')
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
+ source_image_flag=True, pic_size=args.size)
+ if first_coeff_path is None:
+ print("Can't get the coeffs of the input")
+ return
+
+ if ref_eyeblink is not None:
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
+ ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing eye blinking')
+ ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
+ else:
+ ref_eyeblink_coeff_path=None
+
+ if ref_pose is not None:
+ if ref_pose == ref_eyeblink:
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
+ else:
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
+ ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing pose')
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
+ else:
+ ref_pose_coeff_path=None
+
+ #audio2ceoff
+ batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
+ coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+
+ # print(ref_pose_coeff_path.shape)
+ # print(coeff_path.shape)
+
+ # assert False
+
+ # 3dface render
+ if args.face3dvis:
+ from src.face3d.visualize import gen_composed_video
+ gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, \
+ os.path.join(save_dir, '3dface.mp4'), os.path.join(save_dir, 'landmarks.mp4'), crop_info, extended_crop= True if 'ext' in args.preprocess else False )
+ return
+
+ #coeff2video
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
+ batch_size, input_yaw_list, input_pitch_list, input_roll_list,
+ expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size, facemodel=args.facerender)
+
+ result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
+ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
+
+ shutil.move(result, save_dir+'.mp4')
+ print('The generated video is named:', save_dir+'.mp4')
+
+ # result = animate_from_coeff.generate_flow(data, args.result_dir, pic_path, crop_info, \
+ # enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
+
+ # if not args.verbose:
+ # shutil.rmtree(save_dir)
+
+
+if __name__ == '__main__':
+
+ parser = ArgumentParser()
+ parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
+ parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
+ parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
+ parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
+ parser.add_argument("--checkpoint_dir", default='./ckpts/sad_talker', help="path to output")
+ parser.add_argument("--result_dir", default='./results', help="path to output")
+ parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
+ parser.add_argument("--batch_size", type=int, default=1, help="the batch size of facerender")
+ parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
+ parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
+ parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
+ parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
+ parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
+ parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
+ parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
+ parser.add_argument("--cpu", dest="cpu", action="store_true")
+ parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
+ parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
+ parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
+ parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
+ parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
+ parser.add_argument("--facerender", default='facevid2vid', choices=['pirender', 'facevid2vid'] )
+
+
+ # net structure and parameters
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
+ parser.add_argument('--init_path', type=str, default=None, help='Useless')
+ parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
+ parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talker/BFM_Fitting/')
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+ # default renderer parameters
+ parser.add_argument('--focal', type=float, default=1015.)
+ parser.add_argument('--center', type=float, default=112.)
+ parser.add_argument('--camera_d', type=float, default=10.)
+ parser.add_argument('--z_near', type=float, default=5.)
+ parser.add_argument('--z_far', type=float, default=15.)
+
+ args = parser.parse_args()
+
+ if torch.cuda.is_available() and not args.cpu:
+ args.device = "cuda"
+ elif platform.system() == 'Darwin' and args.facerender == 'pirender': # macos
+ args.device = "mps"
+ else:
+ args.device = "cpu"
+
+ main(args)
+
diff --git a/sadtalker_audio2pose/src/.DS_Store b/sadtalker_audio2pose/src/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..e3eba8349f8b3d6836329847e5a266205475acf2
Binary files /dev/null and b/sadtalker_audio2pose/src/.DS_Store differ
diff --git a/sadtalker_audio2pose/src/audio2exp_models/audio2exp.py b/sadtalker_audio2pose/src/audio2exp_models/audio2exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1062ab6684df01e0b3c48b6b577cc8df0503c91
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2exp_models/audio2exp.py
@@ -0,0 +1,41 @@
+from tqdm import tqdm
+import torch
+from torch import nn
+
+
+class Audio2Exp(nn.Module):
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
+ super(Audio2Exp, self).__init__()
+ self.cfg = cfg
+ self.device = device
+ self.netG = netG.to(device)
+
+ def test(self, batch):
+
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
+ bs = mel_input.shape[0]
+ T = mel_input.shape[1]
+
+ exp_coeff_pred = []
+
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
+
+ current_mel_input = mel_input[:,i:i+10]
+
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
+ ref = batch['ref'][:, :, :64][:, i:i+10]
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
+
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
+
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
+
+ exp_coeff_pred += [curr_exp_coeff_pred]
+
+ # BS x T x 64
+ results_dict = {
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
+ }
+ return results_dict
+
+
diff --git a/sadtalker_audio2pose/src/audio2exp_models/networks.py b/sadtalker_audio2pose/src/audio2exp_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd77a2f48d7c00ce85fe2eefe3a3e820730fbb74
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2exp_models/networks.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+ self.use_act = use_act
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+
+ if self.use_act:
+ return self.act(out)
+ else:
+ return out
+
+class SimpleWrapperV2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
+ )
+
+ #### load the pre-trained audio_encoder
+ #self.audio_encoder = self.audio_encoder.to(device)
+ '''
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
+ state_dict = self.audio_encoder.state_dict()
+
+ for k,v in wav2lip_state_dict.items():
+ if 'audio_encoder' in k:
+ print('init:', k)
+ state_dict[k.replace('module.audio_encoder.', '')] = v
+ self.audio_encoder.load_state_dict(state_dict)
+ '''
+
+ self.mapping1 = nn.Linear(512+64+1, 64)
+ #self.mapping2 = nn.Linear(30, 64)
+ #nn.init.constant_(self.mapping1.weight, 0.)
+ nn.init.constant_(self.mapping1.bias, 0.)
+
+ def forward(self, x, ref, ratio):
+ x = self.audio_encoder(x).view(x.size(0), -1)
+ ref_reshape = ref.reshape(x.size(0), -1)
+ ratio = ratio.reshape(x.size(0), -1)
+
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
+ return out
diff --git a/sadtalker_audio2pose/src/audio2pose_models/audio2pose.py b/sadtalker_audio2pose/src/audio2pose_models/audio2pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..53883adc508037294ba664d05d34e5459f1879f8
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2pose_models/audio2pose.py
@@ -0,0 +1,94 @@
+import torch
+from torch import nn
+from src.audio2pose_models.cvae import CVAE
+from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
+from src.audio2pose_models.audio_encoder import AudioEncoder
+
+class Audio2Pose(nn.Module):
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
+ super().__init__()
+ self.cfg = cfg
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
+ self.device = device
+
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
+ self.audio_encoder.eval()
+ for param in self.audio_encoder.parameters():
+ param.requires_grad = False
+
+ self.netG = CVAE(cfg)
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
+
+
+ def forward(self, x):
+
+ batch = {}
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
+
+ # forward
+ audio_emb_list = []
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
+ batch['audio_emb'] = audio_emb
+ batch = self.netG(batch)
+
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
+
+ batch['pose_pred'] = pose_pred
+ batch['pose_gt'] = pose_gt
+
+ return batch
+
+ def test(self, x):
+
+ batch = {}
+ ref = x['ref'] #bs 1 70
+ batch['ref'] = x['ref'][:,0,-6:]
+ batch['class'] = x['class']
+ bs = ref.shape[0]
+
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
+ num_frames = x['num_frames']
+ num_frames = int(num_frames) - 1
+
+ #
+ div = num_frames//self.seq_len
+ re = num_frames%self.seq_len
+ audio_emb_list = []
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
+ device=batch['ref'].device)]
+
+ for i in range(div):
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
+ batch['z'] = z
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
+ batch['audio_emb'] = audio_emb
+ batch = self.netG.test(batch)
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
+
+ if re != 0:
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
+ batch['z'] = z
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
+ if audio_emb.shape[1] != self.seq_len:
+ pad_dim = self.seq_len-audio_emb.shape[1]
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
+ batch['audio_emb'] = audio_emb
+ batch = self.netG.test(batch)
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
+
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
+ batch['pose_motion_pred'] = pose_motion_pred
+
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
+
+ batch['pose_pred'] = pose_pred
+ return batch
diff --git a/sadtalker_audio2pose/src/audio2pose_models/audio_encoder.py b/sadtalker_audio2pose/src/audio2pose_models/audio_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0c165afbc25910cb66828d8676973fe727cb3a3
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2pose_models/audio_encoder.py
@@ -0,0 +1,64 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+class AudioEncoder(nn.Module):
+ def __init__(self, wav2lip_checkpoint, device):
+ super(AudioEncoder, self).__init__()
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
+ # state_dict = self.audio_encoder.state_dict()
+
+ # for k,v in wav2lip_state_dict.items():
+ # if 'audio_encoder' in k:
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
+ # self.audio_encoder.load_state_dict(state_dict)
+
+
+ def forward(self, audio_sequences):
+ # audio_sequences = (B, T, 1, 80, 16)
+ B = audio_sequences.size(0)
+
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+ dim = audio_embedding.shape[1]
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
+
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
diff --git a/sadtalker_audio2pose/src/audio2pose_models/cvae.py b/sadtalker_audio2pose/src/audio2pose_models/cvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..407b78894cde564dd3f2819772a84e8bb1de251d
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2pose_models/cvae.py
@@ -0,0 +1,149 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from src.audio2pose_models.res_unet import ResUnet
+
+def class2onehot(idx, class_num):
+
+ assert torch.max(idx).item() < class_num
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
+ onehot.scatter_(1, idx, 1)
+ return onehot
+
+class CVAE(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
+ num_classes = cfg.DATASET.NUM_CLASSES
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
+
+ self.latent_size = latent_size
+
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len)
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len)
+ def reparameterize(self, mu, logvar):
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return mu + eps * std
+
+ def forward(self, batch):
+ batch = self.encoder(batch)
+ mu = batch['mu']
+ logvar = batch['logvar']
+ z = self.reparameterize(mu, logvar)
+ batch['z'] = z
+ return self.decoder(batch)
+
+ def test(self, batch):
+ '''
+ class_id = batch['class']
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
+ batch['z'] = z
+ '''
+ return self.decoder(batch)
+
+class ENCODER(nn.Module):
+ def __init__(self, layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len):
+ super().__init__()
+
+ self.resunet = ResUnet()
+ self.num_classes = num_classes
+ self.seq_len = seq_len
+
+ self.MLP = nn.Sequential()
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
+ self.MLP.add_module(
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+ def forward(self, batch):
+ class_id = batch['class']
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
+ ref = batch['ref'] #bs 6
+ bs = pose_motion_gt.shape[0]
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
+
+ #pose encode
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
+
+ #audio mapping
+ print(audio_in.shape)
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
+ audio_out = audio_out.reshape(bs, -1)
+
+ class_bias = self.classbias[class_id] #bs latent_size
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
+ x_out = self.MLP(x_in)
+
+ mu = self.linear_means(x_out)
+ logvar = self.linear_means(x_out) #bs latent_size
+
+ batch.update({'mu':mu, 'logvar':logvar})
+ return batch
+
+class DECODER(nn.Module):
+ def __init__(self, layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len):
+ super().__init__()
+
+ self.resunet = ResUnet()
+ self.num_classes = num_classes
+ self.seq_len = seq_len
+
+ self.MLP = nn.Sequential()
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
+ self.MLP.add_module(
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+ if i+1 < len(layer_sizes):
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+ else:
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
+
+ self.pose_linear = nn.Linear(6, 6)
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+ def forward(self, batch):
+
+ z = batch['z'] #bs latent_size
+ bs = z.shape[0]
+ class_id = batch['class']
+ ref = batch['ref'] #bs 6
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
+ #print('audio_in: ', audio_in[:, :, :10])
+
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
+ #print('audio_out: ', audio_out[:, :, :10])
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
+ class_bias = self.classbias[class_id] #bs latent_size
+
+ z = z + class_bias
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
+ x_out = x_out.reshape((bs, self.seq_len, -1))
+
+ #print('x_out: ', x_out)
+
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
+
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
+
+ batch.update({'pose_motion_pred':pose_motion_pred})
+ return batch
diff --git a/sadtalker_audio2pose/src/audio2pose_models/discriminator.py b/sadtalker_audio2pose/src/audio2pose_models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8ed6e36708d4a70227ff90109f56c6f73a17d2
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2pose_models/discriminator.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class ConvNormRelu(nn.Module):
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
+ super().__init__()
+ if kernel_size is None:
+ if downsample:
+ kernel_size, stride, padding = 4, 2, 1
+ else:
+ kernel_size, stride, padding = 3, 1, 1
+
+ if conv_type == '2d':
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=False,
+ )
+ if norm == 'BN':
+ self.norm = nn.BatchNorm2d(out_channels)
+ elif norm == 'IN':
+ self.norm = nn.InstanceNorm2d(out_channels)
+ else:
+ raise NotImplementedError
+ elif conv_type == '1d':
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=False,
+ )
+ if norm == 'BN':
+ self.norm = nn.BatchNorm1d(out_channels)
+ elif norm == 'IN':
+ self.norm = nn.InstanceNorm1d(out_channels)
+ else:
+ raise NotImplementedError
+ nn.init.kaiming_normal_(self.conv.weight)
+
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ if isinstance(self.norm, nn.InstanceNorm1d):
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
+ else:
+ x = self.norm(x)
+ x = self.act(x)
+ return x
+
+
+class PoseSequenceDiscriminator(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
+
+ self.seq = nn.Sequential(
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
+ )
+
+ def forward(self, x):
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
+ x = self.seq(x)
+ x = x.squeeze(1)
+ return x
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/audio2pose_models/networks.py b/sadtalker_audio2pose/src/audio2pose_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9212b49836d9221895993d1d490a476707599922
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2pose_models/networks.py
@@ -0,0 +1,140 @@
+import torch.nn as nn
+import torch
+
+
+class ResidualConv(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding):
+ super(ResidualConv, self).__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.BatchNorm2d(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
+ ),
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
+ nn.BatchNorm2d(output_dim),
+ )
+
+ def forward(self, x):
+
+ return self.conv_block(x) + self.conv_skip(x)
+
+
+class Upsample(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel, stride):
+ super(Upsample, self).__init__()
+
+ self.upsample = nn.ConvTranspose2d(
+ input_dim, output_dim, kernel_size=kernel, stride=stride
+ )
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+class Squeeze_Excite_Block(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(Squeeze_Excite_Block, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
+
+
+class ASPP(nn.Module):
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
+ super(ASPP, self).__init__()
+
+ self.aspp_block1 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+ self.aspp_block2 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+ self.aspp_block3 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
+ self._init_weights()
+
+ def forward(self, x):
+ x1 = self.aspp_block1(x)
+ x2 = self.aspp_block2(x)
+ x3 = self.aspp_block3(x)
+ out = torch.cat([x1, x2, x3], dim=1)
+ return self.output(out)
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+class Upsample_(nn.Module):
+ def __init__(self, scale=2):
+ super(Upsample_, self).__init__()
+
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+class AttentionBlock(nn.Module):
+ def __init__(self, input_encoder, input_decoder, output_dim):
+ super(AttentionBlock, self).__init__()
+
+ self.conv_encoder = nn.Sequential(
+ nn.BatchNorm2d(input_encoder),
+ nn.ReLU(),
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
+ nn.MaxPool2d(2, 2),
+ )
+
+ self.conv_decoder = nn.Sequential(
+ nn.BatchNorm2d(input_decoder),
+ nn.ReLU(),
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
+ )
+
+ self.conv_attn = nn.Sequential(
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, 1, 1),
+ )
+
+ def forward(self, x1, x2):
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
+ out = self.conv_attn(out)
+ return out * x2
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/audio2pose_models/res_unet.py b/sadtalker_audio2pose/src/audio2pose_models/res_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..280404c2a2804038705f792dd800ddf707b75cf8
--- /dev/null
+++ b/sadtalker_audio2pose/src/audio2pose_models/res_unet.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from src.audio2pose_models.networks import ResidualConv, Upsample
+
+
+class ResUnet(nn.Module):
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
+ super(ResUnet, self).__init__()
+
+ self.input_layer = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
+ nn.BatchNorm2d(filters[0]),
+ nn.ReLU(),
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
+ )
+ self.input_skip = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
+ )
+
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
+
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
+
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
+
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
+
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
+
+ self.output_layer = nn.Sequential(
+ nn.Conv2d(filters[0], 1, 1, 1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ # Encode
+ x1 = self.input_layer(x) + self.input_skip(x)
+ x2 = self.residual_conv_1(x1)
+ x3 = self.residual_conv_2(x2)
+ # Bridge
+ x4 = self.bridge(x3)
+
+ # Decode
+ x4 = self.upsample_1(x4)
+ x5 = torch.cat([x4, x3], dim=1)
+
+ x6 = self.up_residual_conv1(x5)
+
+ x6 = self.upsample_2(x6)
+ x7 = torch.cat([x6, x2], dim=1)
+
+ x8 = self.up_residual_conv2(x7)
+
+ x8 = self.upsample_3(x8)
+ x9 = torch.cat([x8, x1], dim=1)
+
+ x10 = self.up_residual_conv3(x9)
+
+ output = self.output_layer(x10)
+
+ return output
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/config/auido2exp.yaml b/sadtalker_audio2pose/src/config/auido2exp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e0e8fbba267158d26a147c8cb2ec5acdd73f432
--- /dev/null
+++ b/sadtalker_audio2pose/src/config/auido2exp.yaml
@@ -0,0 +1,58 @@
+DATASET:
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
+ TRAIN_BATCH_SIZE: 32
+ EVAL_BATCH_SIZE: 32
+ EXP: True
+ EXP_DIM: 64
+ FRAME_LEN: 32
+ COEFF_LEN: 73
+ NUM_CLASSES: 46
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+ DEBUG: True
+ NUM_REPEATS: 2
+ T: 40
+
+
+MODEL:
+ FRAMEWORK: V2
+ AUDIOENCODER:
+ LEAKY_RELU: True
+ NORM: 'IN'
+ DISCRIMINATOR:
+ LEAKY_RELU: False
+ INPUT_CHANNELS: 6
+ CVAE:
+ AUDIO_EMB_IN_SIZE: 512
+ AUDIO_EMB_OUT_SIZE: 128
+ SEQ_LEN: 32
+ LATENT_SIZE: 256
+ ENCODER_LAYER_SIZES: [192, 1024]
+ DECODER_LAYER_SIZES: [1024, 192]
+
+
+TRAIN:
+ MAX_EPOCH: 300
+ GENERATOR:
+ LR: 2.0e-5
+ DISCRIMINATOR:
+ LR: 1.0e-5
+ LOSS:
+ W_FEAT: 0
+ W_COEFF_EXP: 2
+ W_LM: 1.0e-2
+ W_LM_MOUTH: 0
+ W_REG: 0
+ W_SYNC: 0
+ W_COLOR: 0
+ W_EXPRESSION: 0
+ W_LIPREADING: 0.01
+ W_LIPREADING_VV: 0
+ W_EYE_BLINK: 4
+
+TAG:
+ NAME: small_dataset
+
+
diff --git a/sadtalker_audio2pose/src/config/auido2pose.yaml b/sadtalker_audio2pose/src/config/auido2pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7702414b11581ff99aef7a3187f0d0d1388ae3f3
--- /dev/null
+++ b/sadtalker_audio2pose/src/config/auido2pose.yaml
@@ -0,0 +1,49 @@
+DATASET:
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
+ TRAIN_BATCH_SIZE: 64
+ EVAL_BATCH_SIZE: 1
+ EXP: True
+ EXP_DIM: 64
+ FRAME_LEN: 32
+ COEFF_LEN: 73
+ NUM_CLASSES: 46
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+ DEBUG: True
+
+
+MODEL:
+ AUDIOENCODER:
+ LEAKY_RELU: True
+ NORM: 'IN'
+ DISCRIMINATOR:
+ LEAKY_RELU: False
+ INPUT_CHANNELS: 6
+ CVAE:
+ AUDIO_EMB_IN_SIZE: 512
+ AUDIO_EMB_OUT_SIZE: 6
+ SEQ_LEN: 32
+ LATENT_SIZE: 64
+ ENCODER_LAYER_SIZES: [192, 128]
+ DECODER_LAYER_SIZES: [128, 192]
+
+
+TRAIN:
+ MAX_EPOCH: 150
+ GENERATOR:
+ LR: 1.0e-4
+ DISCRIMINATOR:
+ LR: 1.0e-4
+ LOSS:
+ LAMBDA_REG: 1
+ LAMBDA_LANDMARKS: 0
+ LAMBDA_VERTICES: 0
+ LAMBDA_GAN_MOTION: 0.7
+ LAMBDA_GAN_COEFF: 0
+ LAMBDA_KL: 1
+
+TAG:
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
+
+
diff --git a/sadtalker_audio2pose/src/config/facerender.yaml b/sadtalker_audio2pose/src/config/facerender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd1e1ddfe265698e49dac4a6e103cba0aac4f3ce
--- /dev/null
+++ b/sadtalker_audio2pose/src/config/facerender.yaml
@@ -0,0 +1,45 @@
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+ mapping_params:
+ coeff_nc: 70
+ descriptor_nc: 1024
+ layer: 3
+ num_kp: 15
+ num_bins: 66
+
diff --git a/sadtalker_audio2pose/src/config/facerender_pirender.yaml b/sadtalker_audio2pose/src/config/facerender_pirender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f893b5d0a22f0546642c2d2bdafda88740c81138
--- /dev/null
+++ b/sadtalker_audio2pose/src/config/facerender_pirender.yaml
@@ -0,0 +1,83 @@
+# How often do you want to log the training stats.
+# network_list:
+# gen: gen_optimizer
+# dis: dis_optimizer
+
+distributed: False
+image_to_tensorboard: True
+snapshot_save_iter: 40000
+snapshot_save_epoch: 20
+snapshot_save_start_iter: 20000
+snapshot_save_start_epoch: 10
+image_save_iter: 1000
+max_epoch: 200
+logging_iter: 100
+results_dir: ./eval_results
+
+gen_optimizer:
+ type: adam
+ lr: 0.0001
+ adam_beta1: 0.5
+ adam_beta2: 0.999
+ lr_policy:
+ iteration_mode: True
+ type: step
+ step_size: 300000
+ gamma: 0.2
+
+trainer:
+ type: trainers.face_trainer::FaceTrainer
+ pretrain_warp_iteration: 200000
+ loss_weight:
+ weight_perceptual_warp: 2.5
+ weight_perceptual_final: 4
+ vgg_param_warp:
+ network: vgg19
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
+ use_style_loss: False
+ num_scales: 4
+ vgg_param_final:
+ network: vgg19
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
+ use_style_loss: True
+ num_scales: 4
+ style_to_perceptual: 250
+ init:
+ type: 'normal'
+ gain: 0.02
+gen:
+ type: generators.face_model::FaceGenerator
+ param:
+ mapping_net:
+ coeff_nc: 73
+ descriptor_nc: 256
+ layer: 3
+ warpping_net:
+ encoder_layer: 5
+ decoder_layer: 3
+ base_nc: 32
+ editing_net:
+ layer: 3
+ num_res_blocks: 2
+ base_nc: 64
+ common:
+ image_nc: 3
+ descriptor_nc: 256
+ max_nc: 256
+ use_spect: False
+
+
+# Data options.
+data:
+ type: data.vox_dataset::VoxDataset
+ path: ./dataset/vox_lmdb
+ resolution: 256
+ semantic_radius: 13
+ train:
+ batch_size: 5
+ distributed: True
+ val:
+ batch_size: 8
+ distributed: True
+
+
diff --git a/sadtalker_audio2pose/src/config/facerender_still.yaml b/sadtalker_audio2pose/src/config/facerender_still.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6b84181763caf7184a0769e53a7e419e2e3f604
--- /dev/null
+++ b/sadtalker_audio2pose/src/config/facerender_still.yaml
@@ -0,0 +1,45 @@
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+ mapping_params:
+ coeff_nc: 73
+ descriptor_nc: 1024
+ layer: 3
+ num_kp: 15
+ num_bins: 66
+
diff --git a/sadtalker_audio2pose/src/config/similarity_Lm3D_all.mat b/sadtalker_audio2pose/src/config/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..9f5b0bd4ecffb926128a29cb1bbf9d9081c3d4e7
--- /dev/null
+++ b/sadtalker_audio2pose/src/config/similarity_Lm3D_all.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a
+size 994
diff --git a/sadtalker_audio2pose/src/face3d/data/__init__.py b/sadtalker_audio2pose/src/face3d/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be2378c5877af8e749db18d8a67a382f3eb0912b
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/data/__init__.py
@@ -0,0 +1,116 @@
+"""This package includes all the modules related to data loading and preprocessing
+
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
+ You need to implement four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point from data loader.
+ -- : (optionally) add dataset-specific options and set default options.
+
+Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
+See our template dataset class 'template_dataset.py' for more details.
+"""
+import numpy as np
+import importlib
+import torch.utils.data
+from face3d.data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+ """Import the module "data/[dataset_name]_dataset.py".
+
+ In the file, the class called DatasetNameDataset() will
+ be instantiated. It has to be a subclass of BaseDataset,
+ and it is case-insensitive.
+ """
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ dataset = None
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower() \
+ and issubclass(cls, BaseDataset):
+ dataset = cls
+
+ if dataset is None:
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ """Return the static method of the dataset class."""
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt, rank=0):
+ """Create a dataset given the option.
+
+ This function wraps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from data import create_dataset
+ >>> dataset = create_dataset(opt)
+ """
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
+ dataset = data_loader.load_data()
+ return dataset
+
+class CustomDatasetDataLoader():
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+ def __init__(self, opt, rank=0):
+ """Initialize this class
+
+ Step 1: create a dataset instance given the name [dataset_mode]
+ Step 2: create a multi-threaded data loader.
+ """
+ self.opt = opt
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
+ self.dataset = dataset_class(opt)
+ self.sampler = None
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
+ if opt.use_ddp and opt.isTrain:
+ world_size = opt.world_size
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
+ self.dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=not opt.serial_batches
+ )
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ sampler=self.sampler,
+ num_workers=int(opt.num_threads / world_size),
+ batch_size=int(opt.batch_size / world_size),
+ drop_last=True)
+ else:
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ shuffle=(not opt.serial_batches) and opt.isTrain,
+ num_workers=int(opt.num_threads),
+ drop_last=True
+ )
+
+ def set_epoch(self, epoch):
+ self.dataset.current_epoch = epoch
+ if self.sampler is not None:
+ self.sampler.set_epoch(epoch)
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/sadtalker_audio2pose/src/face3d/data/base_dataset.py b/sadtalker_audio2pose/src/face3d/data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a7ea5024206e6e58c2f404ac6a1bf0987f5fd4
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/data/base_dataset.py
@@ -0,0 +1,125 @@
+"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
+
+It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
+"""
+import random
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+from abc import ABC, abstractmethod
+
+
+class BaseDataset(data.Dataset, ABC):
+ """This class is an abstract base class (ABC) for datasets.
+
+ To create a subclass, you need to implement the following four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point.
+ -- : (optionally) add dataset-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the class; save the options in the class
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ self.opt = opt
+ # self.root = opt.dataroot
+ self.current_epoch = 0
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return 0
+
+ @abstractmethod
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index - - a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
+ """
+ pass
+
+
+def get_transform(grayscale=False):
+ transform_list = []
+ if grayscale:
+ transform_list.append(transforms.Grayscale(1))
+ transform_list += [transforms.ToTensor()]
+ return transforms.Compose(transform_list)
+
+def get_affine_mat(opt, size):
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
+ w, h = size
+
+ if 'shift' in opt.preprocess:
+ shift_pixs = int(opt.shift_pixs)
+ shift_x = random.randint(-shift_pixs, shift_pixs)
+ shift_y = random.randint(-shift_pixs, shift_pixs)
+ if 'scale' in opt.preprocess:
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
+ if 'rot' in opt.preprocess:
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
+ rot_rad = -rot_angle * np.pi/180
+ if 'flip' in opt.preprocess:
+ flip = random.random() > 0.5
+
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
+
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
+ affine_inv = np.linalg.inv(affine)
+ return affine, affine_inv, flip
+
+def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
+
+def apply_lm_affine(landmark, affine, flip, size):
+ _, h = size
+ lm = landmark.copy()
+ lm[:, 1] = h - 1 - lm[:, 1]
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
+ lm = lm @ np.transpose(affine)
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
+ lm = lm[:, :2]
+ lm[:, 1] = h - 1 - lm[:, 1]
+ if flip:
+ lm_ = lm.copy()
+ lm_[:17] = lm[16::-1]
+ lm_[17:22] = lm[26:21:-1]
+ lm_[22:27] = lm[21:16:-1]
+ lm_[31:36] = lm[35:30:-1]
+ lm_[36:40] = lm[45:41:-1]
+ lm_[40:42] = lm[47:45:-1]
+ lm_[42:46] = lm[39:35:-1]
+ lm_[46:48] = lm[41:39:-1]
+ lm_[48:55] = lm[54:47:-1]
+ lm_[55:60] = lm[59:54:-1]
+ lm_[60:65] = lm[64:59:-1]
+ lm_[65:68] = lm[67:64:-1]
+ lm = lm_
+ return lm
diff --git a/sadtalker_audio2pose/src/face3d/data/flist_dataset.py b/sadtalker_audio2pose/src/face3d/data/flist_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b49caa8020f8e9aedb73a839b7112320cad68a
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/data/flist_dataset.py
@@ -0,0 +1,125 @@
+"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
+"""
+
+import os.path
+from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+import numpy as np
+import json
+import torch
+from scipy.io import loadmat, savemat
+import pickle
+from util.preprocess import align_img, estimate_norm
+from util.load_mats import load_lm3d
+
+
+def default_flist_reader(flist):
+ """
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
+ """
+ imlist = []
+ with open(flist, 'r') as rf:
+ for line in rf.readlines():
+ impath = line.strip()
+ imlist.append(impath)
+
+ return imlist
+
+def jason_flist_reader(flist):
+ with open(flist, 'r') as fp:
+ info = json.load(fp)
+ return info
+
+def parse_label(label):
+ return torch.tensor(np.array(label).astype(np.float32))
+
+
+class FlistDataset(BaseDataset):
+ """
+ It requires one directories to host training images '/path/to/data/train'
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
+
+ msk_names = default_flist_reader(opt.flist)
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
+
+ self.size = len(self.msk_paths)
+ self.opt = opt
+
+ self.name = 'train' if opt.isTrain else 'val'
+ if '_' in opt.flist:
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
+
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ img (tensor) -- an image in the input domain
+ msk (tensor) -- its corresponding attention mask
+ lm (tensor) -- its corresponding 3d landmarks
+ im_paths (str) -- image paths
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
+ """
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
+ img_path = msk_path.replace('mask/', '')
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
+
+ raw_img = Image.open(img_path).convert('RGB')
+ raw_msk = Image.open(msk_path).convert('RGB')
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
+
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
+
+ aug_flag = self.opt.use_aug and self.opt.isTrain
+ if aug_flag:
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
+
+ _, H = img.size
+ M = estimate_norm(lm, H)
+ transform = get_transform()
+ img_tensor = transform(img)
+ msk_tensor = transform(msk)[:1, ...]
+ lm_tensor = parse_label(lm)
+ M_tensor = parse_label(M)
+
+
+ return {'imgs': img_tensor,
+ 'lms': lm_tensor,
+ 'msks': msk_tensor,
+ 'M': M_tensor,
+ 'im_paths': img_path,
+ 'aug_flag': aug_flag,
+ 'dataset': self.name}
+
+ def _augmentation(self, img, lm, opt, msk=None):
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
+ img = apply_img_affine(img, affine_inv)
+ lm = apply_lm_affine(lm, affine, flip, img.size)
+ if msk is not None:
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
+ return img, lm, msk
+
+
+
+
+ def __len__(self):
+ """Return the total number of images in the dataset.
+ """
+ return self.size
diff --git a/sadtalker_audio2pose/src/face3d/data/image_folder.py b/sadtalker_audio2pose/src/face3d/data/image_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ef069029b0db1fc40b9b5f9a6f52a48c1cd162
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/data/image_folder.py
@@ -0,0 +1,66 @@
+"""A modified image folder class
+
+We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
+so that this class can load images from both current directory and its subdirectories.
+"""
+import numpy as np
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+ '.tif', '.TIF', '.tiff', '.TIFF',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir, max_dataset_size=float("inf")):
+ images = []
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images[:min(max_dataset_size, len(images))]
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/sadtalker_audio2pose/src/face3d/data/template_dataset.py b/sadtalker_audio2pose/src/face3d/data/template_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..693b6b09085ad424e53f26e0938b61eea30ed644
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/data/template_dataset.py
@@ -0,0 +1,75 @@
+"""Dataset class template
+
+This module provides a template for users to implement custom datasets.
+You can specify '--dataset_mode template' to use this dataset.
+The class name should be consistent with both the filename and its dataset_mode option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+You need to implement the following functions:
+ -- : Add dataset-specific options and rewrite default values for existing options.
+ -- <__init__>: Initialize this dataset class.
+ -- <__getitem__>: Return a data point and its metadata information.
+ -- <__len__>: Return the number of images.
+"""
+from data.base_dataset import BaseDataset, get_transform
+# from data.image_folder import make_dataset
+# from PIL import Image
+
+
+class TemplateDataset(BaseDataset):
+ """A template dataset class for you to implement custom datasets."""
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ A few things can be done here.
+ - save the options (have been done in BaseDataset)
+ - get image paths and meta information of the dataset.
+ - define the image transformation.
+ """
+ # save the option and dataset root
+ BaseDataset.__init__(self, opt)
+ # get the image paths of your dataset;
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
+ # define the default transform function. You can use ; You can also define your custom transform function
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index -- a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
+
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
+ Step 4: return a data point as a dictionary.
+ """
+ path = 'temp' # needs to be a string
+ data_A = None # needs to be a tensor
+ data_B = None # needs to be a tensor
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
+
+ def __len__(self):
+ """Return the total number of images."""
+ return len(self.image_paths)
diff --git a/sadtalker_audio2pose/src/face3d/extract_kp_videos.py b/sadtalker_audio2pose/src/face3d/extract_kp_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..68dd79badafd406113ee85cde83492b6c7c66a9b
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/extract_kp_videos.py
@@ -0,0 +1,108 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import face_alignment
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from itertools import cycle
+
+from torch.multiprocessing import Pool, Process, set_start_method
+
+class KeypointExtractor():
+ def __init__(self, device):
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
+ device=device)
+
+ def extract_keypoint(self, images, name=None, info=True):
+ if isinstance(images, list):
+ keypoints = []
+ if info:
+ i_range = tqdm(images,desc='landmark Det:')
+ else:
+ i_range = images
+
+ for image in i_range:
+ current_kp = self.extract_keypoint(image)
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+
+def read_video(filename):
+ frames = []
+ cap = cv2.VideoCapture(filename)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ else:
+ break
+ cap.release()
+ return frames
+
+def run(data):
+ filename, opt, device = data
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ kp_extractor = KeypointExtractor()
+ images = read_video(filename)
+ name = filename.split('/')[-2:]
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+ kp_extractor.extract_keypoint(
+ images,
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
+ )
+
+if __name__ == '__main__':
+ set_start_method('spawn')
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+ parser.add_argument('--device_ids', type=str, default='0,1')
+ parser.add_argument('--workers', type=int, default=4)
+
+ opt = parser.parse_args()
+ filenames = list()
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+ extensions = VIDEO_EXTENSIONS
+
+ for ext in extensions:
+ os.listdir(f'{opt.input_dir}')
+ print(f'{opt.input_dir}/*.{ext}')
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+ print('Total number of videos:', len(filenames))
+ pool = Pool(opt.workers)
+ args_list = cycle([opt])
+ device_ids = opt.device_ids.split(",")
+ device_ids = cycle(device_ids)
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+ None
diff --git a/sadtalker_audio2pose/src/face3d/extract_kp_videos_safe.py b/sadtalker_audio2pose/src/face3d/extract_kp_videos_safe.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe5a01151d3e3722b4a6e3e041fd4f352eee9e8
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/extract_kp_videos_safe.py
@@ -0,0 +1,146 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+from torch.multiprocessing import Pool, Process, set_start_method
+
+from facexlib.alignment import landmark_98_to_68
+from facexlib.detection import init_detection_model
+
+from facexlib.utils import load_file_from_url
+from facexlib.alignment.awing_arch import FAN
+
+def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
+ if model_name == 'awing_fan':
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
+
+
+class KeypointExtractor():
+ def __init__(self, device='cuda'):
+
+ ### gfpgan/weights
+ root_path = 'ckpts/gfpgan'
+
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
+
+ def extract_keypoint(self, images, name=None, info=True):
+ if isinstance(images, list):
+ keypoints = []
+ if info:
+ i_range = tqdm(images,desc='landmark Det:')
+ else:
+ i_range = images
+
+ for image in i_range:
+ current_kp = self.extract_keypoint(image)
+ # current_kp = self.detector.get_landmarks(np.array(image))
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ with torch.no_grad():
+ # face detection -> face alignment.
+ img = np.array(images)
+ bboxes = self.det_net.detect_faces(images, 0.97)
+
+ bboxes = bboxes[0]
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
+
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ keypoints[:,0] += int(bboxes[0])
+ keypoints[:,1] += int(bboxes[1])
+
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+
+def read_video(filename):
+ frames = []
+ cap = cv2.VideoCapture(filename)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ else:
+ break
+ cap.release()
+ return frames
+
+def run(data):
+ filename, opt, device = data
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ kp_extractor = KeypointExtractor()
+ images = read_video(filename)
+ name = filename.split('/')[-2:]
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+ kp_extractor.extract_keypoint(
+ images,
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
+ )
+
+if __name__ == '__main__':
+ set_start_method('spawn')
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+ parser.add_argument('--device_ids', type=str, default='0,1')
+ parser.add_argument('--workers', type=int, default=4)
+
+ opt = parser.parse_args()
+ filenames = list()
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+ extensions = VIDEO_EXTENSIONS
+
+ for ext in extensions:
+ os.listdir(f'{opt.input_dir}')
+ print(f'{opt.input_dir}/*.{ext}')
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+ print('Total number of videos:', len(filenames))
+ pool = Pool(opt.workers)
+ args_list = cycle([opt])
+ device_ids = opt.device_ids.split(",")
+ device_ids = cycle(device_ids)
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+ None
diff --git a/sadtalker_audio2pose/src/face3d/models/__init__.py b/sadtalker_audio2pose/src/face3d/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef6b5e399254bd42850f3385878f35d4acf90852
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from src.face3d.models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "face3d.models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/README.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc7f1d45f2f5e4b752c42dc81d3e2879c1459c6e
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/README.md
@@ -0,0 +1,164 @@
+# Distributed Arcface Training in Pytorch
+
+This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
+identity on a single server.
+
+## Requirements
+
+- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
+- `pip install -r requirements.txt`.
+- Download the dataset
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
+ .
+
+## How to Training
+
+To train a model, run `train.py` with the path to the configs:
+
+### 1. Single node, 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+```
+
+### 2. Multiple nodes, each node 8 GPUs:
+
+Node 0:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+Node 1:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+### 3.Training resnet2060 with 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
+```
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.
+- All models can be found in here.
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
+- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
+As the result, we can evaluate the FAIR performance for different algorithms.
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
+
+For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
+Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
+There are totally 13,928 positive pairs and 96,983,824 negative pairs.
+
+| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
+| :---: | :--- | :--- | :--- |:--- |:--- |
+| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
+| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
+| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
+| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
+| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
+| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
+| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
+| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
+| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
+| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
+
+### Performance on IJB-C and Verification Datasets
+
+| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
+| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
+| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
+| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
+| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
+| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
+| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
+| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
+| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
+| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
+| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
+
+[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
+
+
+## [Speed Benchmark](docs/speed_benchmark.md)
+
+**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
+classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
+accuracy with several times faster training performance and smaller GPU memory.
+Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
+sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
+sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
+we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
+training and mixed precision training.
+
+![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
+
+More details see
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|1400000 | **1672** | 3043 | 4738 |
+|5500000 | **-** | **1389** | 3975 |
+|8000000 | **-** | **-** | 3565 |
+|16000000 | **-** | **-** | 2679 |
+|29000000 | **-** | **-** | **1855** |
+
+### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|1400000 | 32252 | 11178 | 6056 |
+|5500000 | **-** | 32188 | 9854 |
+|8000000 | **-** | **-** | 12310 |
+|16000000 | **-** | **-** | 19950 |
+|29000000 | **-** | **-** | 32324 |
+
+## Evaluation ICCV2021-MFR and IJB-C
+
+More details see [eval.md](docs/eval.md) in docs.
+
+## Test
+
+We tested many versions of PyTorch. Please create an issue if you are having trouble.
+
+- [x] torch 1.6.0
+- [x] torch 1.7.1
+- [x] torch 1.8.0
+- [x] torch 1.9.0
+
+## Citation
+
+```
+@inproceedings{deng2019arcface,
+ title={Arcface: Additive angular margin loss for deep face recognition},
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={4690--4699},
+ year={2019}
+}
+@inproceedings{an2020partical_fc,
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
+ Zhang, Debing and Fu Ying},
+ booktitle={Arxiv 2010.05222},
+ year={2020}
+}
+```
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5650187b4fdea84c5a23e0445440901690ab682a
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/__init__.py
@@ -0,0 +1,25 @@
+from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+ # resnet
+ if name == "r18":
+ return iresnet18(False, **kwargs)
+ elif name == "r34":
+ return iresnet34(False, **kwargs)
+ elif name == "r50":
+ return iresnet50(False, **kwargs)
+ elif name == "r100":
+ return iresnet100(False, **kwargs)
+ elif name == "r200":
+ return iresnet200(False, **kwargs)
+ elif name == "r2060":
+ from .iresnet2060 import iresnet2060
+ return iresnet2060(False, **kwargs)
+ elif name == "mbf":
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf(fp16=fp16, num_features=num_features)
+ else:
+ raise ValueError()
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29f5f2bfbd444273717c4bc8aa20ba7edd08f80
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,187 @@
+import torch
+from torch import nn
+
+__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
+ progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
+ progress, **kwargs)
+
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bb4335716b653bd5924e20d616d825ef48339f
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ['iresnet2060']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def checkpoint(self, func, num_seg, x):
+ if self.training:
+ return checkpoint_sequential(func, num_seg, x)
+ else:
+ return func(x)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.checkpoint(self.layer2, 20, x)
+ x = self.checkpoint(self.layer3, 100, x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02c6c1e4fa6a6ddf09f5b01dec96971427cb110
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,130 @@
+'''
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+'''
+
+import torch.nn as nn
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
+import torch
+
+
+class Flatten(Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(ConvBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+ BatchNorm2d(num_features=out_c),
+ PReLU(num_parameters=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class LinearBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(LinearBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
+ BatchNorm2d(num_features=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DepthWise(Module):
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+ super(DepthWise, self).__init__()
+ self.residual = residual
+ self.layers = nn.Sequential(
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
+ )
+
+ def forward(self, x):
+ short_cut = None
+ if self.residual:
+ short_cut = x
+ x = self.layers(x)
+ if self.residual:
+ output = short_cut + x
+ else:
+ output = x
+ return output
+
+
+class Residual(Module):
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+ super(Residual, self).__init__()
+ modules = []
+ for _ in range(num_block):
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+ self.layers = Sequential(*modules)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GDC(Module):
+ def __init__(self, embedding_size):
+ super(GDC, self).__init__()
+ self.layers = nn.Sequential(
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+ Flatten(),
+ Linear(512, embedding_size, bias=False),
+ BatchNorm1d(embedding_size))
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MobileFaceNet(Module):
+ def __init__(self, fp16=False, num_features=512):
+ super(MobileFaceNet, self).__init__()
+ scale = 2
+ self.fp16 = fp16
+ self.layers = nn.Sequential(
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ )
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+ self.features = GDC(num_features)
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.layers(x)
+ x = self.conv_sep(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def get_mbf(fp16, num_features):
+ return MobileFaceNet(fp16, num_features)
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bee7cb4236e8b842a1bd1e8c26de7a11df0bf43
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf7df5f04e2509e5dcc14adebbb9302a18f03f2b
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/base.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f98c62fed44afde276dcbacecd9da0a8f474963c
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/base.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = "ms1mv3_arcface_r50"
+
+config.dataset = "ms1m-retinaface-t1"
+config.embedding_size = 512
+config.sample_rate = 1
+config.fp16 = False
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+if config.dataset == "emore":
+ config.rec = "/train_tmp/faces_emore"
+ config.num_classes = 85742
+ config.num_image = 5822653
+ config.num_epoch = 16
+ config.warmup_epoch = -1
+ config.decay_epoch = [8, 14, ]
+ config.val_targets = ["lfw", ]
+
+elif config.dataset == "ms1m-retinaface-t1":
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
+ config.num_classes = 93431
+ config.num_image = 5179510
+ config.num_epoch = 25
+ config.warmup_epoch = -1
+ config.decay_epoch = [11, 17, 22]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "glint360k":
+ config.rec = "/train_tmp/glint360k"
+ config.num_classes = 360232
+ config.num_image = 17091657
+ config.num_epoch = 20
+ config.warmup_epoch = -1
+ config.decay_epoch = [8, 12, 15, 18]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "webface":
+ config.rec = "/train_tmp/faces_webface_112x112"
+ config.num_classes = 10572
+ config.num_image = "forget"
+ config.num_epoch = 34
+ config.warmup_epoch = -1
+ config.decay_epoch = [20, 28, 32]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ee5e8d96249d57196df43418f6fda4ab339877
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f8ef745c0efb9d5ea67409edc8c904def8a9d9
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..473b59a954fffcaddca132fb6e0f32cbe70c70f4
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c22ff0c82cc98bbbe81c9a1c26c9b3fc186105
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ecbfda06730e3842e7b347db366e82f0714912f
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c87a99867db55c7f689574c331c14cda23ea96
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 20, 25]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aeb851b05ea22e01da87b3d387812f0253989f8
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..8693e67080dac7e7b84da08a62df326c7b12d465
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r2060"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 64
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bff483db179045c0e3acc8e2975477182b0756
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..de81ffdd84edd6fcea7fcb4d3594db031b9e4e26
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/speed.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/speed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c172f9d44d39b534f2253630471e91cf78e6fba7
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/speed.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 100 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/dataset.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bead250243237c650fa3138f6aa172d4f98535f
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/dataset.py
@@ -0,0 +1,124 @@
+import numbers
+import os
+import queue as Queue
+import threading
+
+import mxnet as mx
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, local_rank, max_prefetch=6):
+ super(BackgroundGenerator, self).__init__()
+ self.queue = Queue.Queue(max_prefetch)
+ self.generator = generator
+ self.local_rank = local_rank
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ torch.cuda.set_device(self.local_rank)
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class DataLoaderX(DataLoader):
+
+ def __init__(self, local_rank, **kwargs):
+ super(DataLoaderX, self).__init__(**kwargs)
+ self.stream = torch.cuda.Stream(local_rank)
+ self.local_rank = local_rank
+
+ def __iter__(self):
+ self.iter = super(DataLoaderX, self).__iter__()
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
+ self.preload()
+ return self
+
+ def preload(self):
+ self.batch = next(self.iter, None)
+ if self.batch is None:
+ return None
+ with torch.cuda.stream(self.stream):
+ for k in range(len(self.batch)):
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+ def __next__(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is None:
+ raise StopIteration
+ self.preload()
+ return batch
+
+
+class MXFaceDataset(Dataset):
+ def __init__(self, root_dir, local_rank):
+ super(MXFaceDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [transforms.ToPILImage(),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+ self.root_dir = root_dir
+ self.local_rank = local_rank
+ path_imgrec = os.path.join(root_dir, 'train.rec')
+ path_imgidx = os.path.join(root_dir, 'train.idx')
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+ s = self.imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ if header.flag > 0:
+ self.header0 = (int(header.label[0]), int(header.label[1]))
+ self.imgidx = np.array(range(1, int(header.label[0])))
+ else:
+ self.imgidx = np.array(list(self.imgrec.keys))
+
+ def __getitem__(self, index):
+ idx = self.imgidx[index]
+ s = self.imgrec.read_idx(idx)
+ header, img = mx.recordio.unpack(s)
+ label = header.label
+ if not isinstance(label, numbers.Number):
+ label = label[0]
+ label = torch.tensor(label, dtype=torch.long)
+ sample = mx.image.imdecode(img).asnumpy()
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample, label
+
+ def __len__(self):
+ return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+ def __init__(self, local_rank):
+ super(SyntheticDataset, self).__init__()
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).squeeze(0).float()
+ img = ((img / 255) - 0.5) / 0.5
+ self.img = img
+ self.label = 1
+
+ def __getitem__(self, index):
+ return self.img, self.label
+
+ def __len__(self):
+ return 1000000
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/eval.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..4d29c855fc6e4245ed264216c1f96ab2efc57248
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/eval.md
@@ -0,0 +1,31 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/install.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1b770a0d93dac1f160185b5bbf4da2f414f21f6
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/install.md
@@ -0,0 +1,51 @@
+## v1.8.0
+### Linux and Windows
+```shell
+# CUDA 11.0
+pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
+
+# CPU only
+pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+```
+
+
+## v1.7.1
+### Linux and Windows
+```shell
+# CUDA 11.0
+pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
+
+# CUDA 10.1
+pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+
+## v1.6.0
+
+### Linux and Windows
+```shell
+# CUDA 10.2
+pip install torch==1.6.0 torchvision==0.7.0
+
+# CUDA 10.1
+pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+```
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/modelzoo.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..d54904587df4e13784dc68d5709b4d7d97490890
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance.
+The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
+[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
+[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
+[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
+[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
+[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
+[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
+[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
+[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
+[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
+[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
+[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
+[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
+[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|250000 | 4047 | 4521 | 4976 |
+|500000 | 3087 | 4013 | 4900 |
+|1000000 | 2090 | 3449 | 4803 |
+|1400000 | 1672 | 3043 | 4738 |
+|2000000 | - | 2593 | 4626 |
+|4000000 | - | 1748 | 4208 |
+|5500000 | - | 1389 | 3975 |
+|8000000 | - | - | 3565 |
+|16000000 | - | - | 2679 |
+|29000000 | - | - | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|250000 | 9940 | 5826 | 5004 |
+|500000 | 14220 | 7114 | 5202 |
+|1000000 | 23708 | 9966 | 5620 |
+|1400000 | 32252 | 11178 | 6056 |
+|2000000 | - | 13978 | 6472 |
+|4000000 | - | 23238 | 8284 |
+|5500000 | - | 32188 | 9854 |
+|8000000 | - | - | 12310 |
+|16000000 | - | - | 19950 |
+|29000000 | - | - | 32324 |
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/verification.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b1f5618184effae64895847af1a65d43d2e4418
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/verification.py
@@ -0,0 +1,407 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset
+"""
+
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+ def __init__(self, n_splits=2, shuffle=False):
+ self.n_splits = n_splits
+ if self.n_splits > 1:
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+ def split(self, indices):
+ if self.n_splits > 1:
+ return self.k_fold.split(indices)
+ else:
+ return [(indices, indices)]
+
+
+def calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ nrof_folds=10,
+ pca=0):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
+ accuracy = np.zeros((nrof_folds))
+ indices = np.arange(nrof_pairs)
+
+ if pca == 0:
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+ if pca > 0:
+ print('doing pca on', fold_idx)
+ embed1_train = embeddings1[train_set]
+ embed2_train = embeddings2[train_set]
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+ pca_model = PCA(n_components=pca)
+ pca_model.fit(_embed_train)
+ embed1 = pca_model.transform(embeddings1)
+ embed2 = pca_model.transform(embeddings2)
+ embed1 = sklearn.preprocessing.normalize(embed1)
+ embed2 = sklearn.preprocessing.normalize(embed2)
+ diff = np.subtract(embed1, embed2)
+ dist = np.sum(np.square(diff), 1)
+
+ # Find the best threshold for the fold
+ acc_train = np.zeros((nrof_thresholds))
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, _, acc_train[threshold_idx] = calculate_accuracy(
+ threshold, dist[train_set], actual_issame[train_set])
+ best_threshold_index = np.argmax(acc_train)
+ for threshold_idx, threshold in enumerate(thresholds):
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+ threshold, dist[test_set],
+ actual_issame[test_set])
+ _, _, accuracy[fold_idx] = calculate_accuracy(
+ thresholds[best_threshold_index], dist[test_set],
+ actual_issame[test_set])
+
+ tpr = np.mean(tprs, 0)
+ fpr = np.mean(fprs, 0)
+ return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ tn = np.sum(
+ np.logical_and(np.logical_not(predict_issame),
+ np.logical_not(actual_issame)))
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+ acc = float(tp + tn) / dist.size
+ return tpr, fpr, acc
+
+
+def calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ far_target,
+ nrof_folds=10):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ val = np.zeros(nrof_folds)
+ far = np.zeros(nrof_folds)
+
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+ indices = np.arange(nrof_pairs)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+ # Find the threshold that gives FAR = far_target
+ far_train = np.zeros(nrof_thresholds)
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, far_train[threshold_idx] = calculate_val_far(
+ threshold, dist[train_set], actual_issame[train_set])
+ if np.max(far_train) >= far_target:
+ f = interpolate.interp1d(far_train, thresholds, kind='slinear')
+ threshold = f(far_target)
+ else:
+ threshold = 0.0
+
+ val[fold_idx], far[fold_idx] = calculate_val_far(
+ threshold, dist[test_set], actual_issame[test_set])
+
+ val_mean = np.mean(val)
+ far_mean = np.mean(far)
+ val_std = np.std(val)
+ return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+ false_accept = np.sum(
+ np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ n_same = np.sum(actual_issame)
+ n_diff = np.sum(np.logical_not(actual_issame))
+ # print(true_accept, false_accept)
+ # print(n_same, n_diff)
+ val = float(true_accept) / float(n_same)
+ far = float(false_accept) / float(n_diff)
+ return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+ # Calculate evaluation metrics
+ thresholds = np.arange(0, 4, 0.01)
+ embeddings1 = embeddings[0::2]
+ embeddings2 = embeddings[1::2]
+ tpr, fpr, accuracy = calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ nrof_folds=nrof_folds,
+ pca=pca)
+ thresholds = np.arange(0, 4, 0.001)
+ val, val_std, far = calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ 1e-3,
+ nrof_folds=nrof_folds)
+ return tpr, fpr, accuracy, val, val_std, far
+
+@torch.no_grad()
+def load_bin(path, image_size):
+ try:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f) # py2
+ except UnicodeDecodeError as e:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
+ data_list = []
+ for flip in [0, 1]:
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+ data_list.append(data)
+ for idx in range(len(issame_list) * 2):
+ _bin = bins[idx]
+ img = mx.image.imdecode(_bin)
+ if img.shape[1] != image_size[0]:
+ img = mx.image.resize_short(img, image_size[0])
+ img = nd.transpose(img, axes=(2, 0, 1))
+ for flip in [0, 1]:
+ if flip == 1:
+ img = mx.ndarray.flip(data=img, axis=2)
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+ if idx % 1000 == 0:
+ print('loading bin', idx)
+ print(data_list[0].shape)
+ return data_list, issame_list
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+ print('testing verification..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+ _data = data[bb - batch_size: bb]
+ time0 = datetime.datetime.now()
+ img = ((_data / 255) - 0.5) / 0.5
+ net_out: torch.Tensor = backbone(img)
+ _embeddings = net_out.detach().cpu().numpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+
+ _xnorm = 0.0
+ _xnorm_cnt = 0
+ for embed in embeddings_list:
+ for i in range(embed.shape[0]):
+ _em = embed[i]
+ _norm = np.linalg.norm(_em)
+ _xnorm += _norm
+ _xnorm_cnt += 1
+ _xnorm /= _xnorm_cnt
+
+ acc1 = 0.0
+ std1 = 0.0
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ print(embeddings.shape)
+ print('infer time', time_consumed)
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set,
+ backbone,
+ batch_size,
+ name='',
+ data_extra=None,
+ label_shape=None):
+ print('dump verification embedding..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+ time0 = datetime.datetime.now()
+ if data_extra is None:
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
+ else:
+ db = mx.io.DataBatch(data=(_data, _data_extra),
+ label=(_label,))
+ model.forward(db, is_train=False)
+ net_out = model.get_outputs()
+ _embeddings = net_out[0].asnumpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ actual_issame = np.asarray(issame_list)
+ outname = os.path.join('temp.bin')
+ with open(outname, 'wb') as f:
+ pickle.dump((embeddings, issame_list),
+ f,
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+# parser = argparse.ArgumentParser(description='do verification')
+# # general
+# parser.add_argument('--data-dir', default='', help='')
+# parser.add_argument('--model',
+# default='../model/softmax,50',
+# help='path to load model.')
+# parser.add_argument('--target',
+# default='lfw,cfp_ff,cfp_fp,agedb_30',
+# help='test targets.')
+# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+# parser.add_argument('--batch-size', default=32, type=int, help='')
+# parser.add_argument('--max', default='', type=str, help='')
+# parser.add_argument('--mode', default=0, type=int, help='')
+# parser.add_argument('--nfolds', default=10, type=int, help='')
+# args = parser.parse_args()
+# image_size = [112, 112]
+# print('image_size', image_size)
+# ctx = mx.gpu(args.gpu)
+# nets = []
+# vec = args.model.split(',')
+# prefix = args.model.split(',')[0]
+# epochs = []
+# if len(vec) == 1:
+# pdir = os.path.dirname(prefix)
+# for fname in os.listdir(pdir):
+# if not fname.endswith('.params'):
+# continue
+# _file = os.path.join(pdir, fname)
+# if _file.startswith(prefix):
+# epoch = int(fname.split('.')[0].split('-')[1])
+# epochs.append(epoch)
+# epochs = sorted(epochs, reverse=True)
+# if len(args.max) > 0:
+# _max = [int(x) for x in args.max.split(',')]
+# assert len(_max) == 2
+# if len(epochs) > _max[1]:
+# epochs = epochs[_max[0]:_max[1]]
+#
+# else:
+# epochs = [int(x) for x in vec[1].split('|')]
+# print('model number', len(epochs))
+# time0 = datetime.datetime.now()
+# for epoch in epochs:
+# print('loading', prefix, epoch)
+# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+# all_layers = sym.get_internals()
+# sym = all_layers['fc1_output']
+# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+# image_size[1]))])
+# model.set_params(arg_params, aux_params)
+# nets.append(model)
+# time_now = datetime.datetime.now()
+# diff = time_now - time0
+# print('model loading time', diff.total_seconds())
+#
+# ver_list = []
+# ver_name_list = []
+# for name in args.target.split(','):
+# path = os.path.join(args.data_dir, name + ".bin")
+# if os.path.exists(path):
+# print('loading.. ', name)
+# data_set = load_bin(path, image_size)
+# ver_list.append(data_set)
+# ver_name_list.append(name)
+#
+# if args.mode == 0:
+# for i in range(len(ver_list)):
+# results = []
+# for model in nets:
+# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+# ver_list[i], model, args.batch_size, args.nfolds)
+# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+# results.append(acc2)
+# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+# elif args.mode == 1:
+# raise ValueError
+# else:
+# model = nets[0]
+# dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval_ijbc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..64844c4723a88b4b160d2fee9a7b626b987981d9
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval_ijbc.py
@@ -0,0 +1,483 @@
+# coding: utf-8
+
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description='do ijb test')
+# general
+parser.add_argument('--model-prefix', default='', help='path to load model.')
+parser.add_argument('--image-path', default='', type=str, help='')
+parser.add_argument('--result-dir', default='.', type=str, help='')
+parser.add_argument('--batch-size', default=128, type=int, help='')
+parser.add_argument('--network', default='iresnet50', type=str, help='')
+parser.add_argument('--job', default='insightface', type=str, help='job name')
+parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True # if Ture, TestMode(N1)
+use_detector_score = True # if Ture, TestMode(D1)
+use_flip_test = True # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+ def __init__(self, prefix, data_shape, batch_size=1):
+ image_size = (112, 112)
+ self.image_size = image_size
+ weight = torch.load(prefix)
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+ resnet.load_state_dict(weight)
+ model = torch.nn.DataParallel(resnet)
+ self.model = model
+ self.model.eval()
+ src = np.array([
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]], dtype=np.float32)
+ src[:, 0] += 8.0
+ self.src = src
+ self.batch_size = batch_size
+ self.data_shape = data_shape
+
+ def get(self, rimg, landmark):
+
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+ assert landmark.shape[1] == 2
+ if landmark.shape[0] == 68:
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
+ landmark5[2] = landmark[30]
+ landmark5[3] = landmark[48]
+ landmark5[4] = landmark[54]
+ else:
+ landmark5 = landmark
+ tform = trans.SimilarityTransform()
+ tform.estimate(landmark5, self.src)
+ M = tform.params[0:2, :]
+ img = cv2.warpAffine(rimg,
+ M, (self.image_size[1], self.image_size[0]),
+ borderValue=0.0)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_flip = np.fliplr(img)
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
+ img_flip = np.transpose(img_flip, (2, 0, 1))
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+ input_blob[0] = img
+ input_blob[1] = img_flip
+ return input_blob
+
+ @torch.no_grad()
+ def forward_db(self, batch_data):
+ imgs = torch.Tensor(batch_data).cuda()
+ imgs.div_(255).sub_(0.5).div_(0.5)
+ feat = self.model(imgs)
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+ return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+ twoList = [[] for i in range(n)]
+ for i, e in enumerate(listTemp):
+ twoList[i % n].append(e)
+ return twoList
+
+
+def read_template_media_list(path):
+ # ijb_meta = np.loadtxt(path, dtype=str)
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+ # pairs = np.loadtxt(path, dtype=str)
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ # print(pairs.shape)
+ # print(pairs[:, 0].astype(np.int))
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+ batch_size = args.batch_size
+ data_shape = (3, 112, 112)
+
+ files = files_list
+ print('files:', len(files))
+ rare_size = len(files) % batch_size
+ faceness_scores = []
+ batch = 0
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, batch_size)
+ for img_index, each_line in enumerate(files[:len(files) - rare_size]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+ if (img_index + 1) % batch_size == 0:
+ print('batch', batch)
+ img_feats[batch * batch_size:batch * batch_size +
+ batch_size][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, rare_size)
+ for img_index, each_line in enumerate(files[len(files) - rare_size:]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+ batch_data[2 * img_index][:] = input_blob[0]
+ batch_data[2 * img_index + 1][:] = input_blob[1]
+ if (img_index + 1) % rare_size == 0:
+ print('batch', batch)
+ img_feats[len(files) -
+ rare_size:][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+ return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ # ==========================================================
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+ # 2. compute media feature.
+ # 3. compute template feature.
+ # ==========================================================
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+ for count_template, uqt in enumerate(unique_templates):
+
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias,
+ return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
+ ]
+ media_norm_feats = np.array(media_norm_feats)
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+ # print(template_norm_feats.shape)
+ return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ # ==========================================================
+ # Compute set-to-set Similarity Score.
+ # ==========================================================
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def read_score(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == 'IJBC' or target == 'IJBB'
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id, mid --> media id
+# format:
+# image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_face_tid_mid.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id, label : 1/0
+# format:
+# tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+# img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = '%s/loose_crop' % image_path
+img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list,
+ model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
+ img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+ # concat --- F1
+ # img_input_feats = img_feats
+ # add --- F2
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] //
+ 2] + img_feats[:, img_feats.shape[1] // 2:]
+else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+if use_norm_score:
+ img_input_feats = img_input_feats
+else:
+ # normalise features to remove norm information
+ img_input_feats = img_input_feats / np.sqrt(
+ np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+ img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+ methods.append(Path(file).stem)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
+print(tpr_fpr_table)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/inference.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..1929d4abb640d040398dda57b491b9bd96deac9d
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/inference.py
@@ -0,0 +1,35 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+ if img is None:
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+ else:
+ img = cv2.imread(img)
+ img = cv2.resize(img, (112, 112))
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+ img.div_(255).sub_(0.5).div_(0.5)
+ net = get_model(name, fp16=False)
+ net.load_state_dict(torch.load(weight))
+ net.eval()
+ feat = net(img).numpy()
+ print(feat)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('--network', type=str, default='r50', help='backbone network')
+ parser.add_argument('--weight', type=str, default='')
+ parser.add_argument('--img', type=str, default=None)
+ args = parser.parse_args()
+ inference(args.weight, args.network, args.img)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/losses.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bfdd8c6b7f6b0d465928f19c554e62340e5ad7b
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/losses.py
@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+
+def get_loss(name):
+ if name == "cosface":
+ return CosFace()
+ elif name == "arcface":
+ return ArcFace()
+ else:
+ raise ValueError()
+
+
+class CosFace(nn.Module):
+ def __init__(self, s=64.0, m=0.40):
+ super(CosFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, cosine, label):
+ index = torch.where(label != -1)[0]
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+ m_hot.scatter_(1, label[index, None], self.m)
+ cosine[index] -= m_hot
+ ret = cosine * self.s
+ return ret
+
+
+class ArcFace(nn.Module):
+ def __init__(self, s=64.0, m=0.5):
+ super(ArcFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, cosine: torch.Tensor, label):
+ index = torch.where(label != -1)[0]
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+ m_hot.scatter_(1, label[index, None], self.m)
+ cosine.acos_()
+ cosine[index] += m_hot
+ cosine.cos_().mul_(self.s)
+ return cosine
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_helper.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a01a46621dc0ea695bd903de5d1e212d424c860
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_helper.py
@@ -0,0 +1,250 @@
+from __future__ import division
+import datetime
+import os
+import os.path as osp
+import glob
+import numpy as np
+import cv2
+import sys
+import onnxruntime
+import onnx
+import argparse
+from onnx import numpy_helper
+from insightface.data import get_image
+
+class ArcFaceORT:
+ def __init__(self, model_path, cpu=False):
+ self.model_path = model_path
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+ self.providers = ['CPUExecutionProvider'] if cpu else None
+
+ #input_size is (w,h), return error message, return None if success
+ def check(self, track='cfat', test_img = None):
+ #default is cfat
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=15
+ if track.startswith('ms1m'):
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=10
+ elif track.startswith('glint'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=20
+ elif track.startswith('cfat'):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ elif track.startswith('unconstrained'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=30
+ else:
+ return "track not found"
+
+ if not os.path.exists(self.model_path):
+ return "model_path not exists"
+ if not os.path.isdir(self.model_path):
+ return "model_path should be directory"
+ onnx_files = []
+ for _file in os.listdir(self.model_path):
+ if _file.endswith('.onnx'):
+ onnx_files.append(osp.join(self.model_path, _file))
+ if len(onnx_files)==0:
+ return "do not have onnx files"
+ self.model_file = sorted(onnx_files)[-1]
+ print('use onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('input-shape:', input_shape)
+ if len(input_shape)!=4:
+ return "length of input_shape should be 4"
+ if not isinstance(input_shape[0], str):
+ #return "input_shape[0] should be str to support batch-inference"
+ print('reset input-shape[0] to None')
+ model = onnx.load(self.model_file)
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
+ onnx.save(model, new_model_file)
+ self.model_file = new_model_file
+ print('use new onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('new-input-shape:', input_shape)
+
+ self.image_size = tuple(input_shape[2:4][::-1])
+ #print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ outputs = session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ #print(o.name, o.shape)
+ if len(output_names)!=1:
+ return "number of output nodes should be 1"
+ self.session = session
+ self.input_name = input_name
+ self.output_names = output_names
+ #print(self.output_names)
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ if len(graph.node)<8:
+ return "too small onnx graph"
+
+ input_size = (112,112)
+ self.crop = None
+ if track=='cfat':
+ crop_file = osp.join(self.model_path, 'crop.txt')
+ if osp.exists(crop_file):
+ lines = open(crop_file,'r').readlines()
+ if len(lines)!=6:
+ return "crop.txt should contain 6 lines"
+ lines = [int(x) for x in lines]
+ self.crop = lines[:4]
+ input_size = tuple(lines[4:6])
+ if input_size!=self.image_size:
+ return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
+
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
+ if self.model_size_mb > max_model_size_mb:
+ return "max model size exceed, given %.3f-MB"%self.model_size_mb
+
+ input_mean = None
+ input_std = None
+ if track=='cfat':
+ pn_file = osp.join(self.model_path, 'pixel_norm.txt')
+ if osp.exists(pn_file):
+ lines = open(pn_file,'r').readlines()
+ if len(lines)!=2:
+ return "pixel_norm.txt should contain 2 lines"
+ input_mean = float(lines[0])
+ input_std = float(lines[1])
+ if input_mean is not None or input_std is not None:
+ if input_mean is None or input_std is None:
+ return "please set input_mean and input_std simultaneously"
+ else:
+ find_sub = False
+ find_mul = False
+ for nid, node in enumerate(graph.node[:8]):
+ print(nid, node.name)
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
+ find_sub = True
+ if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
+ find_mul = True
+ if find_sub and find_mul:
+ print("find sub and mul")
+ #mxnet arcface model
+ input_mean = 0.0
+ input_std = 1.0
+ else:
+ input_mean = 127.5
+ input_std = 127.5
+ self.input_mean = input_mean
+ self.input_std = input_std
+ for initn in graph.initializer:
+ weight_array = numpy_helper.to_array(initn)
+ dt = weight_array.dtype
+ if dt.itemsize<4:
+ return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
+ if test_img is None:
+ test_img = get_image('Tom_Hanks_54745')
+ test_img = cv2.resize(test_img, self.image_size)
+ else:
+ test_img = cv2.resize(test_img, self.image_size)
+ feat, cost = self.benchmark(test_img)
+ batch_result = self.check_batch(test_img)
+ batch_result_sum = float(np.sum(batch_result))
+ if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
+ print(batch_result)
+ print(batch_result_sum)
+ return "batch result output contains NaN!"
+
+ if len(feat.shape) < 2:
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+ if feat.shape[1] > max_feat_dim:
+ return "max feat dim exceed, given %d"%feat.shape[1]
+ self.feat_dim = feat.shape[1]
+ cost_ms = cost*1000
+ if cost_ms>max_time_cost:
+ return "max time cost exceed, given %.4f"%cost_ms
+ self.cost_ms = cost_ms
+ print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
+ return None
+
+ def check_batch(self, img):
+ if not isinstance(img, list):
+ imgs = [img, ] * 32
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+ nimg = cv2.resize(nimg, self.image_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
+ mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+
+ def meta_info(self):
+ return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
+
+
+ def forward(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.image_size
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ return net_out
+
+ def benchmark(self, img):
+ input_size = self.image_size
+ if self.crop is not None:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ img = nimg
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ costs = []
+ for _ in range(50):
+ ta = datetime.datetime.now()
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ tb = datetime.datetime.now()
+ cost = (tb-ta).total_seconds()
+ costs.append(cost)
+ costs = sorted(costs)
+ cost = costs[5]
+ return net_out, cost
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='')
+ # general
+ parser.add_argument('workdir', help='submitted work dir', type=str)
+ parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
+ args = parser.parse_args()
+ handler = ArcFaceORT(args.workdir)
+ err = handler.check(args.track)
+ print('err:', err)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_ijbc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa96b96745e23d4d6642d99f71456c10af5e4e4e
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,267 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+
+from onnx_helper import ArcFaceORT
+
+SRC = np.array(
+ [
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]]
+ , dtype=np.float32)
+SRC[:, 0] += 8.0
+
+
+class AlignedDataSet(mx.gluon.data.Dataset):
+ def __init__(self, root, lines, align=True):
+ self.lines = lines
+ self.root = root
+ self.align = align
+
+ def __len__(self):
+ return len(self.lines)
+
+ def __getitem__(self, idx):
+ each_line = self.lines[idx]
+ name_lmk_score = each_line.strip().split(' ')
+ name = os.path.join(self.root, name_lmk_score[0])
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+ st = skimage.transform.SimilarityTransform()
+ st.estimate(landmark5, SRC)
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+ img_1 = np.expand_dims(img, 0)
+ img_2 = np.expand_dims(np.fliplr(img), 0)
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+ output = np.transpose(output, (0, 3, 1, 2))
+ output = mx.nd.array(output)
+ return output
+
+
+def extract(model_root, dataset):
+ model = ArcFaceORT(model_path=model_root)
+ model.check()
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+ def batchify_fn(data):
+ return mx.nd.concat(*data, dim=0)
+
+ data_loader = mx.gluon.data.DataLoader(
+ dataset, 128, last_batch='keep', num_workers=4,
+ thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
+ num_iter = 0
+ for batch in data_loader:
+ batch = batch.asnumpy()
+ batch = (batch - model.input_mean) / model.input_std
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
+ feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
+ num_iter += 1
+ if num_iter % 50 == 0:
+ print(num_iter)
+ return feat_mat
+
+
+def read_template_media_list(path):
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+def image2template_feature(img_feats=None,
+ templates=None,
+ medias=None):
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+ for count_template, uqt in enumerate(unique_templates):
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
+ media_norm_feats = np.array(media_norm_feats)
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ template_norm_feats = normalize(template_feats)
+ return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),))
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000
+ sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def main(args):
+ use_norm_score = True # if Ture, TestMode(N1)
+ use_detector_score = True # if Ture, TestMode(D1)
+ use_flip_test = True # if Ture, TestMode(F1)
+ assert args.target == 'IJBC' or args.target == 'IJBB'
+
+ start = timeit.default_timer()
+ templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % args.image_path,
+ '%s_template_pair_label.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ img_path = '%s/loose_crop' % args.image_path
+ img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
+ img_list = open(img_list_path)
+ files = img_list.readlines()
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+ img_feats = extract(args.model_root, dataset)
+
+ faceness_scores = []
+ for each_line in files:
+ name_lmk_score = each_line.split()
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
+ start = timeit.default_timer()
+
+ if use_flip_test:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
+ else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+ if use_norm_score:
+ img_input_feats = img_input_feats
+ else:
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+ if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+ else:
+ img_input_feats = img_input_feats
+
+ template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ score = verification(template_norm_feats, unique_templates, p1, p2)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
+ np.save(score_save_file, score)
+ files = [score_save_file]
+ methods = []
+ scores = []
+ for file in files:
+ methods.append(os.path.basename(file))
+ scores.append(np.load(file))
+ methods = np.array(methods)
+ scores = dict(zip(methods, scores))
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+ tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
+ for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr)
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+ print(tpr_fpr_table)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='do ijb test')
+ # general
+ parser.add_argument('--model-root', default='', help='path to load model.')
+ parser.add_argument('--image-path', default='', type=str, help='')
+ parser.add_argument('--result-dir', default='.', type=str, help='')
+ parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+ main(parser.parse_args())
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/partial_fc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0286dd437319c920ecb61f4eb3a32333dcf49eb
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/partial_fc.py
@@ -0,0 +1,222 @@
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from torch.nn import Module
+from torch.nn.functional import normalize, linear
+from torch.nn.parameter import Parameter
+
+
+class PartialFC(Module):
+ """
+ Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
+ Partial FC: Training 10 Million Identities on a Single Machine
+ See the original paper:
+ https://arxiv.org/abs/2010.05222
+ """
+
+ @torch.no_grad()
+ def __init__(self, rank, local_rank, world_size, batch_size, resume,
+ margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
+ """
+ rank: int
+ Unique process(GPU) ID from 0 to world_size - 1.
+ local_rank: int
+ Unique process(GPU) ID within the server from 0 to 7.
+ world_size: int
+ Number of GPU.
+ batch_size: int
+ Batch size on current rank(GPU).
+ resume: bool
+ Select whether to restore the weight of softmax.
+ margin_softmax: callable
+ A function of margin softmax, eg: cosface, arcface.
+ num_classes: int
+ The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
+ required.
+ sample_rate: float
+ The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
+ can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
+ embedding_size: int
+ The feature dimension, default is 512.
+ prefix: str
+ Path for save checkpoint, default is './'.
+ """
+ super(PartialFC, self).__init__()
+ #
+ self.num_classes: int = num_classes
+ self.rank: int = rank
+ self.local_rank: int = local_rank
+ self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
+ self.world_size: int = world_size
+ self.batch_size: int = batch_size
+ self.margin_softmax: callable = margin_softmax
+ self.sample_rate: float = sample_rate
+ self.embedding_size: int = embedding_size
+ self.prefix: str = prefix
+ self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
+ self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+
+ self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
+ self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
+
+ if resume:
+ try:
+ self.weight: torch.Tensor = torch.load(self.weight_name)
+ self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
+ if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
+ raise IndexError
+ logging.info("softmax weight resume successfully!")
+ logging.info("softmax weight mom resume successfully!")
+ except (FileNotFoundError, KeyError, IndexError):
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+ logging.info("softmax weight init!")
+ logging.info("softmax weight mom init!")
+ else:
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+ logging.info("softmax weight init successfully!")
+ logging.info("softmax weight mom init successfully!")
+ self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
+
+ self.index = None
+ if int(self.sample_rate) == 1:
+ self.update = lambda: 0
+ self.sub_weight = Parameter(self.weight)
+ self.sub_weight_mom = self.weight_mom
+ else:
+ self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
+
+ def save_params(self):
+ """ Save softmax weight for each rank on prefix
+ """
+ torch.save(self.weight.data, self.weight_name)
+ torch.save(self.weight_mom, self.weight_mom_name)
+
+ @torch.no_grad()
+ def sample(self, total_label):
+ """
+ Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
+ `num_sample`.
+
+ total_label: tensor
+ Label after all gather, which cross all GPUs.
+ """
+ index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
+ total_label[~index_positive] = -1
+ total_label[index_positive] -= self.class_start
+ if int(self.sample_rate) != 1:
+ positive = torch.unique(total_label[index_positive], sorted=True)
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local], device=self.device)
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1]
+ index = index.sort()[0]
+ else:
+ index = positive
+ self.index = index
+ total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
+ self.sub_weight = Parameter(self.weight[index])
+ self.sub_weight_mom = self.weight_mom[index]
+
+ def forward(self, total_features, norm_weight):
+ """ Partial fc forward, `logits = X * sample(W)`
+ """
+ torch.cuda.current_stream().wait_stream(self.stream)
+ logits = linear(total_features, norm_weight)
+ return logits
+
+ @torch.no_grad()
+ def update(self):
+ """ Set updated weight and weight_mom to memory bank.
+ """
+ self.weight_mom[self.index] = self.sub_weight_mom
+ self.weight[self.index] = self.sub_weight
+
+ def prepare(self, label, optimizer):
+ """
+ get sampled class centers for cal softmax.
+
+ label: tensor
+ Label tensor on each rank.
+ optimizer: opt
+ Optimizer for partial fc, which need to get weight mom.
+ """
+ with torch.cuda.stream(self.stream):
+ total_label = torch.zeros(
+ size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
+ dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
+ self.sample(total_label)
+ optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
+ optimizer.param_groups[-1]['params'][0] = self.sub_weight
+ optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
+ norm_weight = normalize(self.sub_weight)
+ return total_label, norm_weight
+
+ def forward_backward(self, label, features, optimizer):
+ """
+ Partial fc forward and backward with model parallel
+
+ label: tensor
+ Label tensor on each rank(GPU)
+ features: tensor
+ Features tensor on each rank(GPU)
+ optimizer: optimizer
+ Optimizer for partial fc
+
+ Returns:
+ --------
+ x_grad: tensor
+ The gradient of features.
+ loss_v: tensor
+ Loss value for cross entropy.
+ """
+ total_label, norm_weight = self.prepare(label, optimizer)
+ total_features = torch.zeros(
+ size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
+ dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
+ total_features.requires_grad = True
+
+ logits = self.forward(total_features, norm_weight)
+ logits = self.margin_softmax(logits, total_label)
+
+ with torch.no_grad():
+ max_fc = torch.max(logits, dim=1, keepdim=True)[0]
+ dist.all_reduce(max_fc, dist.ReduceOp.MAX)
+
+ # calculate exp(logits) and all-reduce
+ logits_exp = torch.exp(logits - max_fc)
+ logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
+ dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
+
+ # calculate prob
+ logits_exp.div_(logits_sum_exp)
+
+ # get one-hot
+ grad = logits_exp
+ index = torch.where(total_label != -1)[0]
+ one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
+ one_hot.scatter_(1, total_label[index, None], 1)
+
+ # calculate loss
+ loss = torch.zeros(grad.size()[0], 1, device=grad.device)
+ loss[index] = grad[index].gather(1, total_label[index, None])
+ dist.all_reduce(loss, dist.ReduceOp.SUM)
+ loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ # calculate grad
+ grad[index] -= one_hot
+ grad.div_(self.batch_size * self.world_size)
+
+ logits.backward(grad)
+ if total_features.grad is not None:
+ total_features.grad.detach_()
+ x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
+ # feature gradient all-reduce
+ dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
+ x_grad = x_grad * self.world_size
+ # backward backbone
+ return x_grad, loss_v
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/requirement.txt b/sadtalker_audio2pose/src/face3d/models/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..99aef673e30b99cbe56ce82a564c1df9df24ba21
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/requirement.txt
@@ -0,0 +1,5 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/run.sh b/sadtalker_audio2pose/src/face3d/models/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..67b25fd63ef3921733d81d5be844aacc5a5c84ed
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/run.sh
@@ -0,0 +1,2 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/torch2onnx.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..458660df7cc7f9a567aaf492c45f232e776a9ef0
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/torch2onnx.py
@@ -0,0 +1,59 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+ assert isinstance(net, torch.nn.Module)
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = img.astype(np.float)
+ img = (img / 255. - 0.5) / 0.5 # torch style norm
+ img = img.transpose((2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+
+ weight = torch.load(path_module)
+ net.load_state_dict(weight)
+ net.eval()
+ torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
+ model = onnx.load(output)
+ graph = model.graph
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ if simplify:
+ from onnxsim import simplify
+ model, check = simplify(model)
+ assert check, "Simplified ONNX model could not be validated"
+ onnx.save(model, output)
+
+
+if __name__ == '__main__':
+ import os
+ import argparse
+ from backbones import get_model
+
+ parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
+ parser.add_argument('input', type=str, help='input backbone.pth file or path')
+ parser.add_argument('--output', type=str, default=None, help='output onnx path')
+ parser.add_argument('--network', type=str, default=None, help='backbone network')
+ parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
+ args = parser.parse_args()
+ input_file = args.input
+ if os.path.isdir(input_file):
+ input_file = os.path.join(input_file, "backbone.pth")
+ assert os.path.exists(input_file)
+ model_name = os.path.basename(os.path.dirname(input_file)).lower()
+ params = model_name.split("_")
+ if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+ if args.network is None:
+ args.network = params[2]
+ assert args.network is not None
+ print(args)
+ backbone_onnx = get_model(args.network, dropout=0)
+
+ output_path = args.output
+ if output_path is None:
+ output_path = os.path.join(os.path.dirname(__file__), 'onnx')
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ assert os.path.isdir(output_path)
+ output_file = os.path.join(output_path, "%s.onnx" % model_name)
+ convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/train.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5491de9af8fc7a2f3d0648c53b89584864f20e
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/train.py
@@ -0,0 +1,141 @@
+import argparse
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torch.utils.data.distributed
+from torch.nn.utils import clip_grad_norm_
+
+import losses
+from backbones import get_model
+from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
+from partial_fc import PartialFC
+from utils.utils_amp import MaxClipGradScaler
+from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
+from utils.utils_config import get_config
+from utils.utils_logging import AverageMeter, init_logging
+
+
+def main(args):
+ cfg = get_config(args.config)
+ try:
+ world_size = int(os.environ['WORLD_SIZE'])
+ rank = int(os.environ['RANK'])
+ dist.init_process_group('nccl')
+ except KeyError:
+ world_size = 1
+ rank = 0
+ dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
+
+ local_rank = args.local_rank
+ torch.cuda.set_device(local_rank)
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ if cfg.rec == "synthetic":
+ train_set = SyntheticDataset(local_rank=local_rank)
+ else:
+ train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
+
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
+ train_loader = DataLoaderX(
+ local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
+ sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
+
+ if cfg.resume:
+ try:
+ backbone_pth = os.path.join(cfg.output, "backbone.pth")
+ backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
+ if rank == 0:
+ logging.info("backbone resume successfully!")
+ except (FileNotFoundError, KeyError, IndexError, RuntimeError):
+ if rank == 0:
+ logging.info("resume fail, backbone init successfully!")
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank])
+ backbone.train()
+ margin_softmax = losses.get_loss(cfg.loss)
+ module_partial_fc = PartialFC(
+ rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
+ batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
+ sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
+
+ opt_backbone = torch.optim.SGD(
+ params=[{'params': backbone.parameters()}],
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
+ momentum=0.9, weight_decay=cfg.weight_decay)
+ opt_pfc = torch.optim.SGD(
+ params=[{'params': module_partial_fc.parameters()}],
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
+ momentum=0.9, weight_decay=cfg.weight_decay)
+
+ num_image = len(train_set)
+ total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
+ cfg.total_step = num_image // total_batch_size * cfg.num_epoch
+
+ def lr_step_func(current_step):
+ cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
+ if current_step < cfg.warmup_step:
+ return current_step / cfg.warmup_step
+ else:
+ return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
+
+ scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
+ optimizer=opt_backbone, lr_lambda=lr_step_func)
+ scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
+ optimizer=opt_pfc, lr_lambda=lr_step_func)
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ val_target = cfg.val_targets
+ callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
+ callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
+ callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
+
+ loss = AverageMeter()
+ start_epoch = 0
+ global_step = 0
+ grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
+ for epoch in range(start_epoch, cfg.num_epoch):
+ train_sampler.set_epoch(epoch)
+ for step, (img, label) in enumerate(train_loader):
+ global_step += 1
+ features = F.normalize(backbone(img))
+ x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
+ if cfg.fp16:
+ features.backward(grad_amp.scale(x_grad))
+ grad_amp.unscale_(opt_backbone)
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+ grad_amp.step(opt_backbone)
+ grad_amp.update()
+ else:
+ features.backward(x_grad)
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+ opt_backbone.step()
+
+ opt_pfc.step()
+ module_partial_fc.update()
+ opt_backbone.zero_grad()
+ opt_pfc.zero_grad()
+ loss.update(loss_v, 1)
+ callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
+ callback_verification(global_step, backbone)
+ scheduler_backbone.step()
+ scheduler_pfc.step()
+ callback_checkpoint(global_step, backbone, module_partial_fc)
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('config', type=str, help='py config file')
+ parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
+ main(parser.parse_args())
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/plot.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fce6cc0ae526d5aebc8e7a1550300ceae3a2034
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/plot.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+
+import os
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import roc_curve, auc
+
+image_path = "/data/anxiang/IJB_release/IJBC"
+files = [
+ "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
+]
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % 'ijbc'))
+
+methods = []
+scores = []
+for file in files:
+ methods.append(file.split('/')[-2])
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_amp.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_amp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d5bcbb540ff8b04535e71c0057e124338df5bd
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_amp.py
@@ -0,0 +1,88 @@
+from typing import Dict, List
+
+import torch
+
+if torch.__version__ < '1.9':
+ Iterable = torch._six.container_abcs.Iterable
+else:
+ import collections
+
+ Iterable = collections.abc.Iterable
+from torch.cuda.amp import GradScaler
+
+
+class _MultiDeviceReplicator(object):
+ """
+ Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
+ """
+
+ def __init__(self, master_tensor: torch.Tensor) -> None:
+ assert master_tensor.is_cuda
+ self.master = master_tensor
+ self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+ def get(self, device) -> torch.Tensor:
+ retval = self._per_device_tensors.get(device, None)
+ if retval is None:
+ retval = self.master.to(device=device, non_blocking=True, copy=True)
+ self._per_device_tensors[device] = retval
+ return retval
+
+
+class MaxClipGradScaler(GradScaler):
+ def __init__(self, init_scale, max_scale: float, growth_interval=100):
+ GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
+ self.max_scale = max_scale
+
+ def scale_clip(self):
+ if self.get_scale() == self.max_scale:
+ self.set_growth_factor(1)
+ elif self.get_scale() < self.max_scale:
+ self.set_growth_factor(2)
+ elif self.get_scale() > self.max_scale:
+ self._scale.fill_(self.max_scale)
+ self.set_growth_factor(1)
+
+ def scale(self, outputs):
+ """
+ Multiplies ('scales') a tensor or list of tensors by the scale factor.
+
+ Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
+ unmodified.
+
+ Arguments:
+ outputs (Tensor or iterable of Tensors): Outputs to scale.
+ """
+ if not self._enabled:
+ return outputs
+ self.scale_clip()
+ # Short-circuit for the common case.
+ if isinstance(outputs, torch.Tensor):
+ assert outputs.is_cuda
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(outputs.device)
+ assert self._scale is not None
+ return outputs * self._scale.to(device=outputs.device, non_blocking=True)
+
+ # Invoke the more complex machinery only if we're treating multiple outputs.
+ stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
+
+ def apply_scale(val):
+ if isinstance(val, torch.Tensor):
+ assert val.is_cuda
+ if len(stash) == 0:
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(val.device)
+ assert self._scale is not None
+ stash.append(_MultiDeviceReplicator(self._scale))
+ return val * stash[0].get(val.device)
+ elif isinstance(val, Iterable):
+ iterable = map(apply_scale, val)
+ if isinstance(val, list) or isinstance(val, tuple):
+ return type(val)(iterable)
+ else:
+ return iterable
+ else:
+ raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+ return apply_scale(outputs)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..748923b36358bd118efa0532a6f512b6ca96ff34
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,117 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+from eval import verification
+from utils.utils_logging import AverageMeter
+
+
+class CallBackVerification(object):
+ def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
+ self.frequent: int = frequent
+ self.rank: int = rank
+ self.highest_acc: float = 0.0
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+ self.ver_list: List[object] = []
+ self.ver_name_list: List[str] = []
+ if self.rank is 0:
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
+ results = []
+ for i in range(len(self.ver_list)):
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
+ self.ver_list[i], backbone, 10, 10)
+ logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
+ logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
+ if acc2 > self.highest_acc_list[i]:
+ self.highest_acc_list[i] = acc2
+ logging.info(
+ '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
+ results.append(acc2)
+
+ def init_dataset(self, val_targets, data_dir, image_size):
+ for name in val_targets:
+ path = os.path.join(data_dir, name + ".bin")
+ if os.path.exists(path):
+ data_set = verification.load_bin(path, image_size)
+ self.ver_list.append(data_set)
+ self.ver_name_list.append(name)
+
+ def __call__(self, num_update, backbone: torch.nn.Module):
+ if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
+ backbone.eval()
+ self.ver_test(backbone, num_update)
+ backbone.train()
+
+
+class CallBackLogging(object):
+ def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
+ self.frequent: int = frequent
+ self.rank: int = rank
+ self.time_start = time.time()
+ self.total_step: int = total_step
+ self.batch_size: int = batch_size
+ self.world_size: int = world_size
+ self.writer = writer
+
+ self.init = False
+ self.tic = 0
+
+ def __call__(self,
+ global_step: int,
+ loss: AverageMeter,
+ epoch: int,
+ fp16: bool,
+ learning_rate: float,
+ grad_scaler: torch.cuda.amp.GradScaler):
+ if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+ if self.init:
+ try:
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+ speed_total = speed * self.world_size
+ except ZeroDivisionError:
+ speed_total = float('inf')
+
+ time_now = (time.time() - self.time_start) / 3600
+ time_total = time_now / ((global_step + 1) / self.total_step)
+ time_for_end = time_total - time_now
+ if self.writer is not None:
+ self.writer.add_scalar('time_for_end', time_for_end, global_step)
+ self.writer.add_scalar('learning_rate', learning_rate, global_step)
+ self.writer.add_scalar('loss', loss.avg, global_step)
+ if fp16:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
+ "Fp16 Grad Scale: %2.f Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step,
+ grad_scaler.get_scale(), time_for_end
+ )
+ else:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
+ "Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
+ )
+ logging.info(msg)
+ loss.reset()
+ self.tic = time.time()
+ else:
+ self.init = True
+ self.tic = time.time()
+
+
+class CallBackModelCheckpoint(object):
+ def __init__(self, rank, output="./"):
+ self.rank: int = rank
+ self.output: str = output
+
+ def __call__(self, global_step, backbone, partial_fc, ):
+ if global_step > 100 and self.rank == 0:
+ path_module = os.path.join(self.output, "backbone.pth")
+ torch.save(backbone.module.state_dict(), path_module)
+ logging.info("Pytorch Model Saved in '{}'".format(path_module))
+
+ if global_step > 100 and partial_fc is not None:
+ partial_fc.save_params()
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_config.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60a1e5a2e860ce5511a2d3863c8b57a4df292d7
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+ assert config_file.startswith('configs/'), 'config file setting must start with configs/'
+ temp_config_name = osp.basename(config_file)
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ config = importlib.import_module("configs.base")
+ cfg = config.config
+ config = importlib.import_module("configs.%s" % temp_module_name)
+ job_cfg = config.config
+ cfg.update(job_cfg)
+ if cfg.output is None:
+ cfg.output = osp.join('work_dirs', temp_module_name)
+ return cfg
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_logging.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b43b851c9e06230abd94c73a1f64cfa1b6f3ac
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,41 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value
+ """
+
+ def __init__(self):
+ self.val = None
+ self.avg = None
+ self.sum = None
+ self.count = None
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+ if rank == 0:
+ log_root = logging.getLogger()
+ log_root.setLevel(logging.INFO)
+ formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+ handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+ handler_stream = logging.StreamHandler(sys.stdout)
+ handler_file.setFormatter(formatter)
+ handler_stream.setFormatter(formatter)
+ log_root.addHandler(handler_file)
+ log_root.addHandler(handler_stream)
+ log_root.info('rank_id: %d' % rank)
diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_os.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_os.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_audio2pose/src/face3d/models/base_model.py b/sadtalker_audio2pose/src/face3d/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b975223f6148febfe32d20d63980583c97b61eb3
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/base_model.py
@@ -0,0 +1,316 @@
+"""This script defines the base network model for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.isTrain = False
+ self.device = torch.device('cpu')
+ self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.parallel_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def dict_grad_hook_factory(add_func=lambda x: x):
+ saved_dict = dict()
+
+ def hook_gen(name):
+ def grad_hook(grad):
+ saved_vals = add_func(grad)
+ saved_dict[name] = saved_vals
+ return grad_hook
+ return hook_gen, saved_dict
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+
+ if not self.isTrain or opt.continue_train:
+ load_suffix = opt.epoch
+ self.load_networks(load_suffix)
+
+
+ # self.print_networks(opt.verbose)
+
+ def parallelize(self, convert_sync_batchnorm=True):
+ if not self.opt.use_ddp:
+ for name in self.parallel_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+ else:
+ for name in self.model_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ if convert_sync_batchnorm:
+ module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
+ setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
+ device_ids=[self.device.index],
+ find_unused_parameters=True, broadcast_buffers=True))
+
+ # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
+ for name in self.parallel_names:
+ if isinstance(name, str) and name not in self.model_names:
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+
+ # put state_dict of optimizer to gpu device
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ for optim in self.optimizers:
+ for state in optim.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(self.device)
+
+ def data_dependent_initialize(self, data):
+ pass
+
+ def train(self):
+ """Make models train mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.train()
+
+ def eval(self):
+ """Make models eval mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self, name='A'):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths if name =='A' else self.image_paths_B
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)[:, :3, ...]
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if not os.path.isdir(self.save_dir):
+ os.makedirs(self.save_dir)
+
+ save_filename = 'epoch_%s.pth' % (epoch)
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ save_dict = {}
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net,
+ torch.nn.parallel.DistributedDataParallel):
+ net = net.module
+ save_dict[name] = net.state_dict()
+
+
+ for i, optim in enumerate(self.optimizers):
+ save_dict['opt_%02d'%i] = optim.state_dict()
+
+ for i, sched in enumerate(self.schedulers):
+ save_dict['sched_%02d'%i] = sched.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+ else:
+ load_dir = self.save_dir
+ load_filename = 'epoch_%s.pth' % (epoch)
+ load_path = os.path.join(load_dir, load_filename)
+ state_dict = torch.load(load_path, map_location=self.device)
+ print('loading the model from %s' % load_path)
+
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ net.load_state_dict(state_dict[name])
+
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ print('loading the optim from %s' % load_path)
+ for i, optim in enumerate(self.optimizers):
+ optim.load_state_dict(state_dict['opt_%02d'%i])
+
+ try:
+ print('loading the sched from %s' % load_path)
+ for i, sched in enumerate(self.schedulers):
+ sched.load_state_dict(state_dict['sched_%02d'%i])
+ except:
+ print('Failed to load schedulers, set schedulers according to epoch count manually')
+ for i, sched in enumerate(self.schedulers):
+ sched.last_epoch = self.opt.epoch_count - 1
+
+
+
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ return {}
diff --git a/sadtalker_audio2pose/src/face3d/models/bfm.py b/sadtalker_audio2pose/src/face3d/models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cecaf589befac790cf9c124737ba01e27bc29e6
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/bfm.py
@@ -0,0 +1,331 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+from src.face3d.util.load_mats import transferBFM09
+import os
+
+def perspective_projection(focal, center):
+ # return p.T (N, 3) @ (3, 3)
+ return np.array([
+ focal, 0, center,
+ 0, focal, center,
+ 0, 0, 1
+ ]).reshape([3, 3]).astype(np.float32).transpose()
+
+class SH:
+ def __init__(self):
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
+
+
+
+class ParametricFaceModel:
+ def __init__(self,
+ bfm_folder='./BFM',
+ recenter=True,
+ camera_distance=10.,
+ init_lit=np.array([
+ 0.8, 0, 0, 0, 0, 0, 0, 0, 0
+ ]),
+ focal=1015.,
+ center=112.,
+ is_train=True,
+ default_name='BFM_model_front.mat'):
+
+ if not os.path.isfile(os.path.join(bfm_folder, default_name)):
+ transferBFM09(bfm_folder)
+
+ model = loadmat(os.path.join(bfm_folder, default_name))
+ # mean face shape. [3*N,1]
+ self.mean_shape = model['meanshape'].astype(np.float32)
+ # identity basis. [3*N,80]
+ self.id_base = model['idBase'].astype(np.float32)
+ # expression basis. [3*N,64]
+ self.exp_base = model['exBase'].astype(np.float32)
+ # mean face texture. [3*N,1] (0-255)
+ self.mean_tex = model['meantex'].astype(np.float32)
+ # texture basis. [3*N,80]
+ self.tex_base = model['texBase'].astype(np.float32)
+ # face indices for each vertex that lies in. starts from 0. [N,8]
+ self.point_buf = model['point_buf'].astype(np.int64) - 1
+ # vertex indices for each face. starts from 0. [F,3]
+ self.face_buf = model['tri'].astype(np.int64) - 1
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
+ self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
+
+ if is_train:
+ # vertex indices for small face region to compute photometric error. starts from 0.
+ self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
+ # vertex indices for each face from small face region. starts from 0. [f,3]
+ self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
+ # vertex indices for pre-defined skin region to compute reflectance loss
+ self.skin_mask = np.squeeze(model['skinmask'])
+
+ if recenter:
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+
+ self.persc_proj = perspective_projection(focal, center)
+ self.device = 'cpu'
+ self.camera_distance = camera_distance
+ self.SH = SH()
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+
+
+ def to(self, device):
+ self.device = device
+ for key, value in self.__dict__.items():
+ if type(value).__module__ == np.__name__:
+ setattr(self, key, torch.tensor(value).to(device))
+
+
+ def compute_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+
+ def compute_texture(self, tex_coeff, normalize=True):
+ """
+ Return:
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+ Parameters:
+ tex_coeff -- torch.tensor, size (B, 80)
+ """
+ batch_size = tex_coeff.shape[0]
+ face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
+ if normalize:
+ face_texture = face_texture / 255.
+ return face_texture.reshape([batch_size, -1, 3])
+
+
+ def compute_norm(self, face_shape):
+ """
+ Return:
+ vertex_norm -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+
+ v1 = face_shape[:, self.face_buf[:, 0]]
+ v2 = face_shape[:, self.face_buf[:, 1]]
+ v3 = face_shape[:, self.face_buf[:, 2]]
+ e1 = v1 - v2
+ e2 = v2 - v3
+ face_norm = torch.cross(e1, e2, dim=-1)
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+ return vertex_norm
+
+
+ def compute_color(self, face_texture, face_norm, gamma):
+ """
+ Return:
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+ Parameters:
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
+ gamma -- torch.tensor, size (B, 27), SH coeffs
+ """
+ batch_size = gamma.shape[0]
+ v_num = face_texture.shape[1]
+ a, c = self.SH.a, self.SH.c
+ gamma = gamma.reshape([batch_size, 3, 9])
+ gamma = gamma + self.init_lit
+ gamma = gamma.permute(0, 2, 1)
+ Y = torch.cat([
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+ -a[1] * c[1] * face_norm[..., 1:2],
+ a[1] * c[1] * face_norm[..., 2:],
+ -a[1] * c[1] * face_norm[..., :1],
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
+ ], dim=-1)
+ r = Y @ gamma[..., :1]
+ g = Y @ gamma[..., 1:2]
+ b = Y @ gamma[..., 2:]
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
+ return face_color
+
+
+ def compute_rotation(self, angles):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ angles -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = angles.shape[0]
+ ones = torch.ones([batch_size, 1]).to(self.device)
+ zeros = torch.zeros([batch_size, 1]).to(self.device)
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+
+ def to_camera(self, face_shape):
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
+ return face_shape
+
+ def to_image(self, face_shape):
+ """
+ Return:
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+ # to image_plane
+ face_proj = face_shape @ self.persc_proj
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+ return face_proj
+
+
+ def transform(self, face_shape, rot, trans):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ rot -- torch.tensor, size (B, 3, 3)
+ trans -- torch.tensor, size (B, 3)
+ """
+ return face_shape @ rot + trans.unsqueeze(1)
+
+
+ def get_landmarks(self, face_proj):
+ """
+ Return:
+ face_lms -- torch.tensor, size (B, 68, 2)
+
+ Parameters:
+ face_proj -- torch.tensor, size (B, N, 2)
+ """
+ return face_proj[:, self.keypoints]
+
+ def split_coeff(self, coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+ def compute_for_render(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ rotation = self.compute_rotation(coef_dict['angle'])
+
+
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+ def compute_for_render_woRotation(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ #rotation = self.compute_rotation(coef_dict['angle'])
+
+
+ #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm # @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+
+if __name__ == '__main__':
+ transferBFM09()
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/face3d/models/facerecon_model.py b/sadtalker_audio2pose/src/face3d/models/facerecon_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8a701f4771fc337aa9b456310f4af4a6f86a69
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/facerecon_model.py
@@ -0,0 +1,220 @@
+"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+from src.face3d.models.base_model import BaseModel
+from src.face3d.models import networks
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
+from src.face3d.util import util
+from src.face3d.util.nvdiffrast import MeshRenderer
+# from src.face3d.util.preprocess import estimate_norm_torch
+
+import trimesh
+from scipy.io import savemat
+
+class FaceReconModel(BaseModel):
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=False):
+ """ Configures options specific for CUT model
+ """
+ # net structure and parameters
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
+ parser.add_argument('--init_path', type=str, default='./ckpts/sad_talker/init_model/resnet50-0676ba61.pth')
+ parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
+ parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talker/BFM_Fitting/')
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+ # renderer parameters
+ parser.add_argument('--focal', type=float, default=1015.)
+ parser.add_argument('--center', type=float, default=112.)
+ parser.add_argument('--camera_d', type=float, default=10.)
+ parser.add_argument('--z_near', type=float, default=5.)
+ parser.add_argument('--z_far', type=float, default=15.)
+
+ if is_train:
+ # training parameters
+ parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
+ parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
+ parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
+ parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
+
+
+ # augmentation parameters
+ parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
+ parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
+ parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
+
+ # loss weights
+ parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
+ parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
+ parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
+ parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
+ parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
+ parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
+ parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
+ parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
+ parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
+
+ opt, _ = parser.parse_known_args()
+ parser.set_defaults(
+ focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
+ )
+ if is_train:
+ parser.set_defaults(
+ use_crop_face=True, use_predef_M=False
+ )
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+
+ self.visual_names = ['output_vis']
+ self.model_names = ['net_recon']
+ self.parallel_names = self.model_names + ['renderer']
+
+ self.facemodel = ParametricFaceModel(
+ bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
+ is_train=self.isTrain, default_name=opt.bfm_model
+ )
+
+ fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
+ self.renderer = MeshRenderer(
+ rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
+ )
+
+ if self.isTrain:
+ self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
+
+ self.net_recog = networks.define_net_recog(
+ net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
+ )
+ # loss func name: (compute_%s_loss) % loss_name
+ self.compute_feat_loss = perceptual_loss
+ self.comupte_color_loss = photo_loss
+ self.compute_lm_loss = landmark_loss
+ self.compute_reg_loss = reg_loss
+ self.compute_reflc_loss = reflectance_loss
+
+ self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
+ self.optimizers = [self.optimizer]
+ self.parallel_names += ['net_recog']
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ self.input_img = input['imgs'].to(self.device)
+ self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
+ self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
+ self.trans_m = input['M'].to(self.device) if 'M' in input else None
+ self.image_paths = input['im_paths'] if 'im_paths' in input else None
+
+ def forward(self, output_coeff, device):
+ self.facemodel.to(device)
+ self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
+ self.facemodel.compute_for_render(output_coeff)
+ self.pred_mask, _, self.pred_face = self.renderer(
+ self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
+
+ self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
+
+
+ def compute_losses(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+
+ assert self.net_recog.training == False
+ trans_m = self.trans_m
+ if not self.opt.use_predef_M:
+ trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
+
+ pred_feat = self.net_recog(self.pred_face, trans_m)
+ gt_feat = self.net_recog(self.input_img, self.trans_m)
+ self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
+
+ face_mask = self.pred_mask
+ if self.opt.use_crop_face:
+ face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
+
+ face_mask = face_mask.detach()
+ self.loss_color = self.opt.w_color * self.comupte_color_loss(
+ self.pred_face, self.input_img, self.atten_mask * face_mask)
+
+ loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
+ self.loss_reg = self.opt.w_reg * loss_reg
+ self.loss_gamma = self.opt.w_gamma * loss_gamma
+
+ self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
+
+ self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
+
+ self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+ + self.loss_lm + self.loss_reflc
+
+
+ def optimize_parameters(self, isTrain=True):
+ self.forward()
+ self.compute_losses()
+ """Update network weights; it will be called in every training iteration."""
+ if isTrain:
+ self.optimizer.zero_grad()
+ self.loss_all.backward()
+ self.optimizer.step()
+
+ def compute_visuals(self):
+ with torch.no_grad():
+ input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
+ output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
+
+ if self.gt_lm is not None:
+ gt_lm_numpy = self.gt_lm.cpu().numpy()
+ pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
+
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw, output_vis_numpy), axis=-2)
+ else:
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw), axis=-2)
+
+ self.output_vis = torch.tensor(
+ output_vis_numpy / 255., dtype=torch.float32
+ ).permute(0, 3, 1, 2).to(self.device)
+
+ def save_mesh(self, name):
+
+ recon_shape = self.pred_vertex # get reconstructed shape
+ recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
+ recon_shape = recon_shape.cpu().numpy()[0]
+ recon_color = self.pred_color
+ recon_color = recon_color.cpu().numpy()[0]
+ tri = self.facemodel.face_buf.cpu().numpy()
+ mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
+ mesh.export(name)
+
+ def save_coeff(self,name):
+
+ pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
+ pred_lm = self.pred_lm.cpu().numpy()
+ pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
+ pred_coeffs['lm68'] = pred_lm
+ savemat(name,pred_coeffs)
+
+
+
diff --git a/sadtalker_audio2pose/src/face3d/models/losses.py b/sadtalker_audio2pose/src/face3d/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d9da84f28d54e772bebd2385ae5a7fedd10f7d
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/losses.py
@@ -0,0 +1,113 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from kornia.geometry import warp_affine
+import torch.nn.functional as F
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+### perceptual level loss
+class PerceptualLoss(nn.Module):
+ def __init__(self, recog_net, input_size=112):
+ super(PerceptualLoss, self).__init__()
+ self.recog_net = recog_net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+ def forward(imageA, imageB, M):
+ """
+ 1 - cosine distance
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
+ imageB --same as imageA
+ """
+
+ imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
+ imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
+
+ # freeze bn
+ self.recog_net.eval()
+
+ id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
+ id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+def perceptual_loss(id_featureA, id_featureB):
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+### image level loss
+def photo_loss(imageA, imageB, mask, eps=1e-6):
+ """
+ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
+ imageB --same as imageA
+ """
+ loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
+ loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
+ return loss
+
+def landmark_loss(predict_lm, gt_lm, weight=None):
+ """
+ weighted mse loss
+ Parameters:
+ predict_lm --torch.tensor (B, 68, 2)
+ gt_lm --torch.tensor (B, 68, 2)
+ weight --numpy.array (1, 68)
+ """
+ if not weight:
+ weight = np.ones([68])
+ weight[28:31] = 20
+ weight[-8:] = 20
+ weight = np.expand_dims(weight, 0)
+ weight = torch.tensor(weight).to(predict_lm.device)
+ loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
+ loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
+ return loss
+
+
+### regulization
+def reg_loss(coeffs_dict, opt=None):
+ """
+ l2 norm without the sqrt, from yu's implementation (mse)
+ tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
+ Parameters:
+ coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
+
+ """
+ # coefficient regularization to ensure plausible 3d faces
+ if opt:
+ w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
+ else:
+ w_id, w_exp, w_tex = 1, 1, 1, 1
+ creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
+ w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
+ w_tex * torch.sum(coeffs_dict['tex'] ** 2)
+ creg_loss = creg_loss / coeffs_dict['id'].shape[0]
+
+ # gamma regularization to ensure a nearly-monochromatic light
+ gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
+ gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
+ gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
+
+ return creg_loss, gamma_loss
+
+def reflectance_loss(texture, mask):
+ """
+ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
+ Parameters:
+ texture --torch.tensor, (B, N, 3)
+ mask --torch.tensor, (N), 1 or 0
+
+ """
+ mask = mask.reshape([1, mask.shape[0], 1])
+ texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
+ loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
+ return loss
+
diff --git a/sadtalker_audio2pose/src/face3d/models/networks.py b/sadtalker_audio2pose/src/face3d/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e69eba1ade2e6431e7e7fd526ea68b8f63e7152
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/networks.py
@@ -0,0 +1,521 @@
+"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import torch
+from torch import Tensor
+import torch.nn as nn
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+from .arcface_torch.backbones import get_model
+from kornia.geometry import warp_affine
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+def filter_state_dict(state_dict, remove_name='fc'):
+ new_state_dict = {}
+ for key in state_dict:
+ if remove_name in key:
+ continue
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def define_net_recon(net_recon, use_last_fc=False, init_path=None):
+ return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
+
+def define_net_recog(net_recog, pretrained_path=None):
+ net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
+ net.eval()
+ return net
+
+class ReconNetWrapper(nn.Module):
+ fc_dim=257
+ def __init__(self, net_recon, use_last_fc=False, init_path=None):
+ super(ReconNetWrapper, self).__init__()
+ self.use_last_fc = use_last_fc
+ if net_recon not in func_dict:
+ return NotImplementedError('network [%s] is not implemented', net_recon)
+ func, last_dim = func_dict[net_recon]
+ backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
+ if init_path and os.path.isfile(init_path):
+ state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
+ backbone.load_state_dict(state_dict)
+ print("loading init net_recon %s from %s" %(net_recon, init_path))
+ self.backbone = backbone
+ if not use_last_fc:
+ self.final_layers = nn.ModuleList([
+ conv1x1(last_dim, 80, bias=True), # id layer
+ conv1x1(last_dim, 64, bias=True), # exp layer
+ conv1x1(last_dim, 80, bias=True), # tex layer
+ conv1x1(last_dim, 3, bias=True), # angle layer
+ conv1x1(last_dim, 27, bias=True), # gamma layer
+ conv1x1(last_dim, 2, bias=True), # tx, ty
+ conv1x1(last_dim, 1, bias=True) # tz
+ ])
+ for m in self.final_layers:
+ nn.init.constant_(m.weight, 0.)
+ nn.init.constant_(m.bias, 0.)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ if not self.use_last_fc:
+ output = []
+ for layer in self.final_layers:
+ output.append(layer(x))
+ x = torch.flatten(torch.cat(output, dim=1), 1)
+ return x
+
+
+class RecogNetWrapper(nn.Module):
+ def __init__(self, net_recog, pretrained_path=None, input_size=112):
+ super(RecogNetWrapper, self).__init__()
+ net = get_model(name=net_recog, fp16=False)
+ if pretrained_path:
+ state_dict = torch.load(pretrained_path, map_location='cpu')
+ net.load_state_dict(state_dict)
+ print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
+ for param in net.parameters():
+ param.requires_grad = False
+ self.net = net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+
+ def forward(self, image, M):
+ image = self.preprocess(resize_n_crop(image, M, self.input_size))
+ id_feature = F.normalize(self.net(image), dim=-1, p=2)
+ return id_feature
+
+
+# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ use_last_fc: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.use_last_fc = use_last_fc
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ if self.use_last_fc:
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ if self.use_last_fc:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+func_dict = {
+ 'resnet18': (resnet18, 512),
+ 'resnet50': (resnet50, 2048)
+}
diff --git a/sadtalker_audio2pose/src/face3d/models/template_model.py b/sadtalker_audio2pose/src/face3d/models/template_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..75860272a06312bfa4de382729dce5136a480a7f
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/models/template_model.py
@@ -0,0 +1,100 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+ min_ ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+ : Add model-specific options and rewrite default values for existing options.
+ <__init__>: Initialize this model class.
+ : Unpack input data and perform data pre-processing.
+ : Run forward pass. This will be called by both and .
+ : Update network weights; it will be called in every training iteration.
+"""
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class TemplateModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new model-specific options and rewrite default values for existing options.
+
+ Parameters:
+ parser -- the option parser
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+ if is_train:
+ parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+ self.loss_names = ['loss_G']
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+ self.visual_names = ['data_A', 'data_B', 'output']
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+ self.model_names = ['G']
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+ if self.isTrain: # only defined during training time
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+ self.criterionLoss = torch.nn.L1Loss()
+ # define and initialize optimizers. You can define one optimizer for each network.
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = [self.optimizer]
+
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
+ self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
+ self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
+ self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
+
+ def forward(self):
+ """Run forward pass. This will be called by both functions and ."""
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
+
+ def backward(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # caculate the intermediate results if necessary; here self.output has been computed during function
+ # calculate loss given the input and intermediate results
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
+
+ def optimize_parameters(self):
+ """Update network weights; it will be called in every training iteration."""
+ self.forward() # first call forward to calculate intermediate results
+ self.optimizer.zero_grad() # clear network G's existing gradients
+ self.backward() # calculate gradients for network G
+ self.optimizer.step() # update gradients for network G
diff --git a/sadtalker_audio2pose/src/face3d/options/__init__.py b/sadtalker_audio2pose/src/face3d/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06559aa558cf178b946c4523b28b098d1dfad606
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/sadtalker_audio2pose/src/face3d/options/base_options.py b/sadtalker_audio2pose/src/face3d/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..616a2e63f57e033a0a37e01a9b41babf93f6c3dd
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/options/base_options.py
@@ -0,0 +1,169 @@
+"""This script contains base options for Deep3DFaceRecon_pytorch
+"""
+
+import argparse
+import os
+from util import util
+import numpy as np
+import torch
+import face3d.models as models
+import face3d.data as data
+
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self, cmd_line=None):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+ self.cmd_line = None
+ if cmd_line is not None:
+ self.cmd_line = cmd_line.split()
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--checkpoints_dir', type=str, default='./ckpts/sad_talker', help='models are saved here')
+ parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
+ parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
+ parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
+ parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
+ parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
+ parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
+ parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
+
+ # model parameters
+ parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
+
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args()
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line)
+
+ # set cuda visible devices
+ os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
+
+ # modify dataset-related parser options
+ if opt.dataset_mode:
+ dataset_name = opt.dataset_mode
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ if self.cmd_line is None:
+ return parser.parse_args()
+ else:
+ return parser.parse_args(self.cmd_line)
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ try:
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ except PermissionError as error:
+ print("permission error {}".format(error))
+ pass
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ gpu_ids.append(id)
+ opt.world_size = len(gpu_ids)
+ # if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(gpu_ids[0])
+ if opt.world_size == 1:
+ opt.use_ddp = False
+
+ if opt.phase != 'test':
+ # set continue_train automatically
+ if opt.pretrained_name is None:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ else:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
+ if os.path.isdir(model_dir):
+ model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
+ if os.path.isdir(model_dir) and len(model_pths) != 0:
+ opt.continue_train= True
+
+ # update the latest epoch count
+ if opt.continue_train:
+ if opt.epoch == 'latest':
+ epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
+ if len(epoch_counts) != 0:
+ opt.epoch_count = max(epoch_counts) + 1
+ else:
+ opt.epoch_count = int(opt.epoch) + 1
+
+
+ self.print_options(opt)
+ self.opt = opt
+ return self.opt
diff --git a/sadtalker_audio2pose/src/face3d/options/inference_options.py b/sadtalker_audio2pose/src/face3d/options/inference_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b9466776e120e0fe3d164217df5071c2114cef
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/options/inference_options.py
@@ -0,0 +1,23 @@
+from face3d.options.base_options import BaseOptions
+
+
+class InferenceOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
+ parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
+ parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
+ parser.add_argument('--inference_batch_size', type=int, default=8)
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/sadtalker_audio2pose/src/face3d/options/test_options.py b/sadtalker_audio2pose/src/face3d/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81c0c6eee0549e6fa8762dc4fc4b8573b887fe4
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/options/test_options.py
@@ -0,0 +1,21 @@
+"""This script contains the test options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/sadtalker_audio2pose/src/face3d/options/train_options.py b/sadtalker_audio2pose/src/face3d/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1100b0e35cc8ef563f41f6b8219510edbef53233
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/options/train_options.py
@@ -0,0 +1,53 @@
+"""This script contains the training options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+from util import util
+
+class TrainOptions(BaseOptions):
+ """This class includes training options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # dataset parameters
+ # for train
+ parser.add_argument('--data_root', type=str, default='./', help='dataset root')
+ parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
+ parser.add_argument('--batch_size', type=int, default=32)
+ parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
+ parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
+
+ # for val
+ parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
+ parser.add_argument('--batch_size_val', type=int, default=32)
+
+
+ # visualization parameters
+ parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+
+ # network saving and loading parameters
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+ parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
+ parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
+
+ # training parameters
+ parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
+ parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
+ parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
+ parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
+
+ self.isTrain = True
+ return parser
diff --git a/sadtalker_audio2pose/src/face3d/util/BBRegressorParam_r.mat b/sadtalker_audio2pose/src/face3d/util/BBRegressorParam_r.mat
new file mode 100644
index 0000000000000000000000000000000000000000..a0da99af145c400a5216d9f6fb251d9412565921
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/BBRegressorParam_r.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a5a07b8ce75a39d96b918dc0fc6e110a72e090da16f5f056a0ef7bfbc3f4560
+size 22019
diff --git a/sadtalker_audio2pose/src/face3d/util/__init__.py b/sadtalker_audio2pose/src/face3d/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c67833cc634a2ca310b883ae253b08687665f40
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/__init__.py
@@ -0,0 +1,3 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from src.face3d.util import *
+
diff --git a/sadtalker_audio2pose/src/face3d/util/detect_lm68.py b/sadtalker_audio2pose/src/face3d/util/detect_lm68.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2cfd22b342de5c872ff07fc1c2a9920c2985b7
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/detect_lm68.py
@@ -0,0 +1,106 @@
+import os
+import cv2
+import numpy as np
+from scipy.io import loadmat
+import tensorflow as tf
+from util.preprocess import align_for_lm
+from shutil import move
+
+mean_face = np.loadtxt('util/test_mean_face.txt')
+mean_face = mean_face.reshape([68, 2])
+
+def save_label(labels, save_path):
+ np.savetxt(save_path, labels)
+
+def draw_landmarks(img, landmark, save_name):
+ landmark = landmark
+ lm_img = np.zeros([img.shape[0], img.shape[1], 3])
+ lm_img[:] = img.astype(np.float32)
+ landmark = np.round(landmark).astype(np.int32)
+
+ for i in range(len(landmark)):
+ for j in range(-1, 1):
+ for k in range(-1, 1):
+ if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
+ img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
+ landmark[i, 0]+k > 0 and \
+ landmark[i, 0]+k < img.shape[1]:
+ lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
+ :] = np.array([0, 0, 255])
+ lm_img = lm_img.astype(np.uint8)
+
+ cv2.imwrite(save_name, lm_img)
+
+
+def load_data(img_name, txt_name):
+ return cv2.imread(img_name), np.loadtxt(txt_name)
+
+# create tensorflow graph for landmark detector
+def load_lm_graph(graph_filename):
+ with tf.gfile.GFile(graph_filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+
+ with tf.Graph().as_default() as graph:
+ tf.import_graph_def(graph_def, name='net')
+ img_224 = graph.get_tensor_by_name('net/input_imgs:0')
+ output_lm = graph.get_tensor_by_name('net/lm:0')
+ lm_sess = tf.Session(graph=graph)
+
+ return lm_sess,img_224,output_lm
+
+# landmark detection
+def detect_68p(img_path,sess,input_op,output_op):
+ print('detecting landmarks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ vis_path = os.path.join(img_path, 'vis')
+ remove_path = os.path.join(img_path, 'remove')
+ save_path = os.path.join(img_path, 'landmarks')
+ if not os.path.isdir(vis_path):
+ os.makedirs(vis_path)
+ if not os.path.isdir(remove_path):
+ os.makedirs(remove_path)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
+ full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
+
+ # if an image does not have detected 5 facial landmarks, remove it from the training list
+ if not os.path.isfile(full_txt_name):
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # load data
+ img, five_points = load_data(full_image_name, full_txt_name)
+ input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
+
+ # if the alignment fails, remove corresponding image from the training list
+ if scale == 0:
+ move(full_txt_name, os.path.join(
+ remove_path, txt_name))
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # detect landmarks
+ input_img = np.reshape(
+ input_img, [1, 224, 224, 3]).astype(np.float32)
+ landmark = sess.run(
+ output_op, feed_dict={input_op: input_img})
+
+ # transform back to original image coordinate
+ landmark = landmark.reshape([68, 2]) + mean_face
+ landmark[:, 1] = 223 - landmark[:, 1]
+ landmark = landmark / scale
+ landmark[:, 0] = landmark[:, 0] + bbox[0]
+ landmark[:, 1] = landmark[:, 1] + bbox[1]
+ landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
+
+ if i % 100 == 0:
+ draw_landmarks(img, landmark, os.path.join(vis_path, name))
+ save_label(landmark, os.path.join(save_path, txt_name))
diff --git a/sadtalker_audio2pose/src/face3d/util/generate_list.py b/sadtalker_audio2pose/src/face3d/util/generate_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe93fcc5c61fbc79f4cd004a8d1bdd10ece16eb
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/generate_list.py
@@ -0,0 +1,34 @@
+"""This script is to generate training list files for Deep3DFaceRecon_pytorch
+"""
+
+import os
+
+# save path to training data
+def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
+ save_path = os.path.join(save_folder, mode)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+ with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in lms_list])
+
+ with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in imgs_list])
+
+ with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in msks_list])
+
+# check if the path is valid
+def check_list(rlms_list, rimgs_list, rmsks_list):
+ lms_list, imgs_list, msks_list = [], [], []
+ for i in range(len(rlms_list)):
+ flag = 'false'
+ lm_path = rlms_list[i]
+ im_path = rimgs_list[i]
+ msk_path = rmsks_list[i]
+ if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
+ flag = 'true'
+ lms_list.append(rlms_list[i])
+ imgs_list.append(rimgs_list[i])
+ msks_list.append(rmsks_list[i])
+ print(i, rlms_list[i], flag)
+ return lms_list, imgs_list, msks_list
diff --git a/sadtalker_audio2pose/src/face3d/util/html.py b/sadtalker_audio2pose/src/face3d/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c4e6a66ba5a34e30cee3beb13e21465c72ef38
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__': # we show an example usage here.
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/sadtalker_audio2pose/src/face3d/util/load_mats.py b/sadtalker_audio2pose/src/face3d/util/load_mats.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ea0a7877e80035883138415c102910d896bb61
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/load_mats.py
@@ -0,0 +1,120 @@
+"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat, savemat
+from array import array
+import os.path as osp
+
+# load expression basis
+def LoadExpBasis(bfm_folder='BFM'):
+ n_vertex = 53215
+ Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
+ exp_dim = array('i')
+ exp_dim.fromfile(Expbin, 1)
+ expMU = array('f')
+ expPC = array('f')
+ expMU.fromfile(Expbin, 3*n_vertex)
+ expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
+ Expbin.close()
+
+ expPC = np.array(expPC)
+ expPC = np.reshape(expPC, [exp_dim[0], -1])
+ expPC = np.transpose(expPC)
+
+ expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
+
+ return expPC, expEV
+
+
+# transfer original BFM09 to our face model
+def transferBFM09(bfm_folder='BFM'):
+ print('Transfer BFM09 to BFM_model_front......')
+ original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
+ shapePC = original_BFM['shapePC'] # shape basis
+ shapeEV = original_BFM['shapeEV'] # corresponding eigen value
+ shapeMU = original_BFM['shapeMU'] # mean face
+ texPC = original_BFM['texPC'] # texture basis
+ texEV = original_BFM['texEV'] # eigen value
+ texMU = original_BFM['texMU'] # mean texture
+
+ expPC, expEV = LoadExpBasis(bfm_folder)
+
+ # transfer BFM09 to our face model
+
+ idBase = shapePC*np.reshape(shapeEV, [-1, 199])
+ idBase = idBase/1e5 # unify the scale to decimeter
+ idBase = idBase[:, :80] # use only first 80 basis
+
+ exBase = expPC*np.reshape(expEV, [-1, 79])
+ exBase = exBase/1e5 # unify the scale to decimeter
+ exBase = exBase[:, :64] # use only first 64 basis
+
+ texBase = texPC*np.reshape(texEV, [-1, 199])
+ texBase = texBase[:, :80] # use only first 80 basis
+
+ # our face model is cropped along face landmarks and contains only 35709 vertex.
+ # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
+ # thus we select corresponding vertex to get our face model.
+
+ index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
+ index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
+
+ index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
+ index_shape = index_shape['trimIndex'].astype(
+ np.int32) - 1 # starts from 0 (to 53490)
+ index_shape = index_shape[index_exp]
+
+ idBase = np.reshape(idBase, [-1, 3, 80])
+ idBase = idBase[index_shape, :, :]
+ idBase = np.reshape(idBase, [-1, 80])
+
+ texBase = np.reshape(texBase, [-1, 3, 80])
+ texBase = texBase[index_shape, :, :]
+ texBase = np.reshape(texBase, [-1, 80])
+
+ exBase = np.reshape(exBase, [-1, 3, 64])
+ exBase = exBase[index_exp, :, :]
+ exBase = np.reshape(exBase, [-1, 64])
+
+ meanshape = np.reshape(shapeMU, [-1, 3])/1e5
+ meanshape = meanshape[index_shape, :]
+ meanshape = np.reshape(meanshape, [1, -1])
+
+ meantex = np.reshape(texMU, [-1, 3])
+ meantex = meantex[index_shape, :]
+ meantex = np.reshape(meantex, [1, -1])
+
+ # other info contains triangles, region used for computing photometric loss,
+ # region used for skin texture regularization, and 68 landmarks index etc.
+ other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
+ frontmask2_idx = other_info['frontmask2_idx']
+ skinmask = other_info['skinmask']
+ keypoints = other_info['keypoints']
+ point_buf = other_info['point_buf']
+ tri = other_info['tri']
+ tri_mask2 = other_info['tri_mask2']
+
+ # save our face model
+ savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
+ 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
+
+
+# load landmarks for standard face, which is used for image preprocessing
+def load_lm3d(bfm_folder):
+
+ Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
+ Lm3D = Lm3D['lm']
+
+ # calculate 5 facial landmarks using 68 landmarks
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
+ Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
+ Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
+
+ return Lm3D
+
+
+if __name__ == '__main__':
+ transferBFM09()
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/face3d/util/nvdiffrast.py b/sadtalker_audio2pose/src/face3d/util/nvdiffrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b345db30085de501b6718ad5b49bb5f9144dd29
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/nvdiffrast.py
@@ -0,0 +1,126 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+ Attention, antialiasing step is missing in current version.
+"""
+import pytorch3d.ops
+import torch
+import torch.nn.functional as F
+import kornia
+from kornia.geometry.camera import pixel2cam
+import numpy as np
+from typing import List
+from scipy.io import loadmat
+from torch import nn
+
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import (
+ look_at_view_transform,
+ FoVPerspectiveCameras,
+ DirectionalLights,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesUV,
+)
+
+# def ndc_projection(x=0.1, n=1.0, f=50.0):
+# return np.array([[n/x, 0, 0, 0],
+# [ 0, n/-x, 0, 0],
+# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+# [ 0, 0, -1, 0]]).astype(np.float32)
+
+class MeshRenderer(nn.Module):
+ def __init__(self,
+ rasterize_fov,
+ znear=0.1,
+ zfar=10,
+ rasterize_size=224):
+ super(MeshRenderer, self).__init__()
+
+ # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+ # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+ # torch.diag(torch.tensor([1., -1, -1, 1])))
+ self.rasterize_size = rasterize_size
+ self.fov = rasterize_fov
+ self.znear = znear
+ self.zfar = zfar
+
+ self.rasterizer = None
+
+ def forward(self, vertex, tri, feat=None):
+ """
+ Return:
+ mask -- torch.tensor, size (B, 1, H, W)
+ depth -- torch.tensor, size (B, 1, H, W)
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+ Parameters:
+ vertex -- torch.tensor, size (B, N, 3)
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+ feat(optional) -- torch.tensor, size (B, N ,C), features
+ """
+ device = vertex.device
+ rsize = int(self.rasterize_size)
+ # ndc_proj = self.ndc_proj.to(device)
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+ if vertex.shape[-1] == 3:
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+ vertex[..., 0] = -vertex[..., 0]
+
+
+ # vertex_ndc = vertex @ ndc_proj.t()
+ if self.rasterizer is None:
+ self.rasterizer = MeshRasterizer()
+ print("create rasterizer on device cuda:%d"%device.index)
+
+ # ranges = None
+ # if isinstance(tri, List) or len(tri.shape) == 3:
+ # vum = vertex_ndc.shape[1]
+ # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+ # fstartidx = torch.cumsum(fnum, dim=0) - fnum
+ # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+ # for i in range(tri.shape[0]):
+ # tri[i] = tri[i] + i*vum
+ # vertex_ndc = torch.cat(vertex_ndc, dim=0)
+ # tri = torch.cat(tri, dim=0)
+
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+ tri = tri.type(torch.int32).contiguous()
+
+ # rasterize
+ cameras = FoVPerspectiveCameras(
+ device=device,
+ fov=self.fov,
+ znear=self.znear,
+ zfar=self.zfar,
+ )
+
+ raster_settings = RasterizationSettings(
+ image_size=rsize
+ )
+
+ # print(vertex.shape, tri.shape)
+ mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))
+
+ fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
+ rast_out = fragments.pix_to_face.squeeze(-1)
+ depth = fragments.zbuf
+
+ # render depth
+ depth = depth.permute(0, 3, 1, 2)
+ mask = (rast_out > 0).float().unsqueeze(1)
+ depth = mask * depth
+
+
+ image = None
+ if feat is not None:
+ attributes = feat.reshape(-1,3)[mesh.faces_packed()]
+ image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
+ fragments.bary_coords,
+ attributes)
+ # print(image.shape)
+ image = image.squeeze(-2).permute(0, 3, 1, 2)
+ image = mask * image
+
+ return mask, depth, image
+
diff --git a/sadtalker_audio2pose/src/face3d/util/preprocess.py b/sadtalker_audio2pose/src/face3d/util/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b36443fe4c84c1ad6366897a8e7d4e8b63b2b6
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/preprocess.py
@@ -0,0 +1,134 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from scipy.io import loadmat
+from PIL import Image
+import cv2
+import os
+from skimage import transform as trans
+import torch
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+ npts = xp.shape[1]
+
+ A = np.zeros([2*npts, 8])
+
+ A[0:2*npts-1:2, 0:3] = x.transpose()
+ A[0:2*npts-1:2, 3] = 1
+
+ A[1:2*npts:2, 4:7] = x.transpose()
+ A[1:2*npts:2, 7] = 1
+
+ b = np.reshape(xp.transpose(), [2*npts, 1])
+
+ k, _, _, _ = np.linalg.lstsq(A, b)
+
+ R1 = k[0:3]
+ R2 = k[4:7]
+ sTx = k[3]
+ sTy = k[7]
+ s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
+ t = np.stack([sTx, sTy], axis=0)
+
+ return t, s
+
+# # resize and crop images for face reconstruction
+# def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+# w0, h0 = img.size
+# w = (w0*s).astype(np.int32)
+# h = (h0*s).astype(np.int32)
+# left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+# right = left + target_size
+# up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+# below = up + target_size
+
+# img = img.resize((w, h), resample=Image.BICUBIC)
+# img = img.crop((left, up, right, below))
+
+# if mask is not None:
+# mask = mask.resize((w, h), resample=Image.BICUBIC)
+# mask = mask.crop((left, up, right, below))
+
+# lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+# t[1] + h0/2], axis=1)*s
+# lm = lm - np.reshape(
+# np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+# return img, lm, mask
+
+
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+ w0, h0 = img.size
+ w = (w0*s).astype(np.int32)
+ h = (h0*s).astype(np.int32)
+ left = np.round(w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+ right = left + target_size
+ up = np.round(h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+ below = up + target_size
+
+ img = img.resize((w, h), resample=Image.BICUBIC)
+ img = img.crop((left, up, right, below))
+ # import pdb; pdb.set_trace()
+ if mask is not None:
+ mask = mask.resize((w, h), resample=Image.BICUBIC)
+ mask = mask.crop((left, up, right, below))
+
+ lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+ t[1] + h0/2], axis=1)*s
+ lm = lm - np.reshape(
+ np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+ # orig_left, orig_up, orig_crop_size = (left,up,target_size)/s
+
+ return img, lm, mask, left, up, target_size
+
+# utils for face reconstruction
+def extract_5p(lm):
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
+ lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
+ lm5p = lm5p[[1, 2, 0, 3, 4], :]
+ return lm5p
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
+ """
+ Return:
+ transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
+ img_new --PIL.Image (target_size, target_size, 3)
+ lm_new --numpy.array (68, 2), y direction is opposite to v direction
+ mask_new --PIL.Image (target_size, target_size)
+
+ Parameters:
+ img --PIL.Image (raw_H, raw_W, 3)
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ lm3D --numpy.array (5, 3)
+ mask --PIL.Image (raw_H, raw_W, 3)
+ """
+
+ w0, h0 = img.size
+ if lm.shape[0] != 5:
+ lm5p = extract_5p(lm)
+ else:
+ lm5p = lm
+
+ # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+ t, s = POS(lm5p.transpose(), lm3D.transpose())
+ s = rescale_factor/s
+
+ # processing the image
+
+ # processing the image
+ img_new, lm_new, mask_new, orig_left, orig_up, orig_crop_size = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ trans_params = np.array([w0, h0, s, t[0], t[1], orig_left, orig_up, orig_crop_size])
+ # img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ # trans_params = np.array([w0, h0, s, t[0], t[1]])
+
+ return trans_params, img_new, lm_new, mask_new
diff --git a/sadtalker_audio2pose/src/face3d/util/skin_mask.py b/sadtalker_audio2pose/src/face3d/util/skin_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed764759038f77b35d45448b344d4347498ca427
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/skin_mask.py
@@ -0,0 +1,125 @@
+"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
+"""
+
+import math
+import numpy as np
+import os
+import cv2
+
+class GMM:
+ def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
+ self.dim = dim # feature dimension
+ self.num = num # number of Gaussian components
+ self.w = w # weights of Gaussian components (a list of scalars)
+ self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
+ self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
+ self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
+ self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
+
+ self.factor = [0]*num
+ for i in range(self.num):
+ self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
+
+ def likelihood(self, data):
+ assert(data.shape[1] == self.dim)
+ N = data.shape[0]
+ lh = np.zeros(N)
+
+ for i in range(self.num):
+ data_ = data - self.mu[i]
+
+ tmp = np.matmul(data_,self.cov_inv[i]) * data_
+ tmp = np.sum(tmp,axis=1)
+ power = -0.5 * tmp
+
+ p = np.array([math.exp(power[j]) for j in range(N)])
+ p = p/self.factor[i]
+ lh += p*self.w[i]
+
+ return lh
+
+
+def _rgb2ycbcr(rgb):
+ m = np.array([[65.481, 128.553, 24.966],
+ [-37.797, -74.203, 112],
+ [112, -93.786, -18.214]])
+ shape = rgb.shape
+ rgb = rgb.reshape((shape[0] * shape[1], 3))
+ ycbcr = np.dot(rgb, m.transpose() / 255.)
+ ycbcr[:, 0] += 16.
+ ycbcr[:, 1:] += 128.
+ return ycbcr.reshape(shape)
+
+
+def _bgr2ycbcr(bgr):
+ rgb = bgr[..., ::-1]
+ return _rgb2ycbcr(rgb)
+
+
+gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
+gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
+ np.array([150.19858, 105.18467, 155.51428]),
+ np.array([183.92976, 107.62468, 152.71820]),
+ np.array([114.90524, 113.59782, 151.38217])]
+gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
+gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
+ np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
+ np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
+ np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
+
+gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
+
+gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
+gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
+ np.array([110.91392, 125.52969, 130.19237]),
+ np.array([129.75864, 129.96107, 126.96808]),
+ np.array([112.29587, 128.85121, 129.05431])]
+gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
+gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
+ np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
+ np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
+ np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
+
+gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
+
+prior_skin = 0.8
+prior_nonskin = 1 - prior_skin
+
+
+# calculate skin attention mask
+def skinmask(imbgr):
+ im = _bgr2ycbcr(imbgr)
+
+ data = im.reshape((-1,3))
+
+ lh_skin = gmm_skin.likelihood(data)
+ lh_nonskin = gmm_nonskin.likelihood(data)
+
+ tmp1 = prior_skin * lh_skin
+ tmp2 = prior_nonskin * lh_nonskin
+ post_skin = tmp1 / (tmp1+tmp2) # posterior probability
+
+ post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
+
+ post_skin = np.round(post_skin*255)
+ post_skin = post_skin.astype(np.uint8)
+ post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
+
+ return post_skin
+
+
+def get_skin_mask(img_path):
+ print('generating skin masks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ save_path = os.path.join(img_path, 'mask')
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ img = cv2.imread(full_image_name).astype(np.float32)
+ skin_img = skinmask(img)
+ cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
diff --git a/sadtalker_audio2pose/src/face3d/util/test_mean_face.txt b/sadtalker_audio2pose/src/face3d/util/test_mean_face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1637648acf5a61cbc71b317c845414bb16d0150c
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/test_mean_face.txt
@@ -0,0 +1,136 @@
+-5.228591537475585938e+01
+2.078247070312500000e-01
+-5.064269638061523438e+01
+-1.315765380859375000e+01
+-4.952939224243164062e+01
+-2.592591094970703125e+01
+-4.793047332763671875e+01
+-3.832135772705078125e+01
+-4.512159729003906250e+01
+-5.059623336791992188e+01
+-3.917720794677734375e+01
+-6.043736648559570312e+01
+-2.929953765869140625e+01
+-6.861183166503906250e+01
+-1.719801330566406250e+01
+-7.572736358642578125e+01
+-1.961936950683593750e+00
+-7.862001037597656250e+01
+1.467941284179687500e+01
+-7.607844543457031250e+01
+2.744073486328125000e+01
+-6.915261840820312500e+01
+3.855677795410156250e+01
+-5.950350570678710938e+01
+4.478240966796875000e+01
+-4.867547225952148438e+01
+4.714337158203125000e+01
+-3.800830078125000000e+01
+4.940315246582031250e+01
+-2.496297454833984375e+01
+5.117234802246093750e+01
+-1.241538238525390625e+01
+5.190507507324218750e+01
+8.244247436523437500e-01
+-4.150688934326171875e+01
+2.386329650878906250e+01
+-3.570307159423828125e+01
+3.017010498046875000e+01
+-2.790358734130859375e+01
+3.212951660156250000e+01
+-1.941773223876953125e+01
+3.156523132324218750e+01
+-1.138106536865234375e+01
+2.841992187500000000e+01
+5.993263244628906250e+00
+2.895182800292968750e+01
+1.343590545654296875e+01
+3.189880371093750000e+01
+2.203153991699218750e+01
+3.302221679687500000e+01
+2.992478942871093750e+01
+3.099150085449218750e+01
+3.628388977050781250e+01
+2.765748596191406250e+01
+-1.933914184570312500e+00
+1.405374145507812500e+01
+-2.153038024902343750e+00
+5.772636413574218750e+00
+-2.270050048828125000e+00
+-2.121643066406250000e+00
+-2.218330383300781250e+00
+-1.068978118896484375e+01
+-1.187252044677734375e+01
+-1.997912597656250000e+01
+-6.879402160644531250e+00
+-2.143579864501953125e+01
+-1.227821350097656250e+00
+-2.193494415283203125e+01
+4.623237609863281250e+00
+-2.152721405029296875e+01
+9.721397399902343750e+00
+-1.953671264648437500e+01
+-3.648714447021484375e+01
+9.811126708984375000e+00
+-3.130242919921875000e+01
+1.422447967529296875e+01
+-2.212834930419921875e+01
+1.493019866943359375e+01
+-1.500880432128906250e+01
+1.073588562011718750e+01
+-2.095037078857421875e+01
+9.054298400878906250e+00
+-3.050099182128906250e+01
+8.704177856445312500e+00
+1.173237609863281250e+01
+1.054329681396484375e+01
+1.856353759765625000e+01
+1.535009765625000000e+01
+2.893331909179687500e+01
+1.451992797851562500e+01
+3.452944946289062500e+01
+1.065280151367187500e+01
+2.875990295410156250e+01
+8.654792785644531250e+00
+1.942100524902343750e+01
+9.422447204589843750e+00
+-2.204488372802734375e+01
+-3.983994293212890625e+01
+-1.324458312988281250e+01
+-3.467377471923828125e+01
+-6.749649047851562500e+00
+-3.092894744873046875e+01
+-9.183349609375000000e-01
+-3.196458435058593750e+01
+4.220649719238281250e+00
+-3.090406036376953125e+01
+1.089889526367187500e+01
+-3.497008514404296875e+01
+1.874589538574218750e+01
+-4.065438079833984375e+01
+1.124106597900390625e+01
+-4.438417816162109375e+01
+5.181709289550781250e+00
+-4.649170684814453125e+01
+-1.158607482910156250e+00
+-4.680406951904296875e+01
+-7.918922424316406250e+00
+-4.671575164794921875e+01
+-1.452505493164062500e+01
+-4.416526031494140625e+01
+-2.005007171630859375e+01
+-3.997841644287109375e+01
+-1.054919433593750000e+01
+-3.849683380126953125e+01
+-1.051826477050781250e+00
+-3.794863128662109375e+01
+6.412681579589843750e+00
+-3.804645538330078125e+01
+1.627674865722656250e+01
+-4.039697265625000000e+01
+6.373878479003906250e+00
+-4.087213897705078125e+01
+-8.551712036132812500e-01
+-4.157129669189453125e+01
+-1.014953613281250000e+01
+-4.128469085693359375e+01
diff --git a/sadtalker_audio2pose/src/face3d/util/util.py b/sadtalker_audio2pose/src/face3d/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c7517ee66c8830a73fa86ab5e5c3513f11d869
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/util.py
@@ -0,0 +1,208 @@
+"""This script contains basic utilities for Deep3DFaceRecon_pytorch
+"""
+from __future__ import print_function
+import numpy as np
+import torch
+from PIL import Image
+import os
+import importlib
+import argparse
+from argparse import Namespace
+import torchvision
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def copyconf(default_opt, **kwargs):
+ conf = Namespace(**vars(default_opt))
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+ return conf
+
+def genvalconf(train_opt, **kwargs):
+ conf = Namespace(**vars(train_opt))
+ attr_dict = train_opt.__dict__
+ for key, value in attr_dict.items():
+ if 'val' in key and key.split('_')[0] in attr_dict:
+ setattr(conf, key.split('_')[0], value)
+
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+
+ return conf
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace('_', '').lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
+
+ return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array, range(0, 1)
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
+ if image_numpy.shape[0] == 1: # grayscale to RGB
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+
+ image_pil = Image.fromarray(image_numpy)
+ h, w, _ = image_numpy.shape
+
+ if aspect_ratio is None:
+ pass
+ elif aspect_ratio > 1.0:
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ elif aspect_ratio < 1.0:
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i, :1]
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+ one_np = one_np[:, :, 0]
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+ resized_t = torch.from_numpy(np.array(one_image)).long()
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.BICUBIC):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i:i + 1]
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+def draw_landmarks(img, landmark, color='r', step=2):
+ """
+ Return:
+ img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
+
+
+ Parameters:
+ img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
+ landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
+ color -- str, 'r' or 'b' (red or blue)
+ """
+ if color =='r':
+ c = np.array([255., 0, 0])
+ else:
+ c = np.array([0, 0, 255.])
+
+ _, H, W, _ = img.shape
+ img, landmark = img.copy(), landmark.copy()
+ landmark[..., 1] = H - 1 - landmark[..., 1]
+ landmark = np.round(landmark).astype(np.int32)
+ for i in range(landmark.shape[1]):
+ x, y = landmark[:, i, 0], landmark[:, i, 1]
+ for j in range(-step, step):
+ for k in range(-step, step):
+ u = np.clip(x + j, 0, W - 1)
+ v = np.clip(y + k, 0, H - 1)
+ for m in range(landmark.shape[0]):
+ img[m, v[m], u[m]] = c
+ return img
diff --git a/sadtalker_audio2pose/src/face3d/util/visualizer.py b/sadtalker_audio2pose/src/face3d/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a8b755e054a4a34d003962a723ef189726a7a0
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/util/visualizer.py
@@ -0,0 +1,227 @@
+"""This script defines the visualizer for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+from torch.utils.tensorboard import SummaryWriter
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = '%s/%s.png' % (label, name)
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ self.use_html = opt.isTrain and not opt.no_html
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.saved = False
+ if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+
+ def display_current_results(self, visuals, total_iters, epoch, save_result):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ for label, image in visuals.items():
+ self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
+
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ def plot_current_losses(self, total_iters, losses):
+ # G_loss_collection = {}
+ # D_loss_collection = {}
+ # for name, value in losses.items():
+ # if 'G' in name or 'NCE' in name or 'idt' in name:
+ # G_loss_collection[name] = value
+ # else:
+ # D_loss_collection[name] = value
+ # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
+ # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
+ for name, value in losses.items():
+ self.writer.add_scalar(name, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
+
+
+class MyVisualizer:
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the optio
+ self.name = opt.name
+ self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
+
+ if opt.phase != 'test':
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+
+ def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
+ add_image=True):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ dataset (str) - - 'train' or 'val' or 'test'
+ """
+ # if (not add_image) and (not save_results): return
+
+ for label, image in visuals.items():
+ for i in range(image.shape[0]):
+ image_numpy = util.tensor2im(image[i])
+ if add_image:
+ self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
+ image_numpy, total_iters, dataformats='HWC')
+
+ if save_results:
+ save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ if name is not None:
+ img_path = os.path.join(save_path, '%s.png' % name)
+ else:
+ img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
+ util.save_image(image_numpy, img_path)
+
+
+ def plot_current_losses(self, total_iters, losses, dataset='train'):
+ for name, value in losses.items():
+ self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
+ dataset, epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
diff --git a/sadtalker_audio2pose/src/face3d/visualize.py b/sadtalker_audio2pose/src/face3d/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb8791ec30fb8f748aefc82cf4385444754825a4
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/visualize.py
@@ -0,0 +1,133 @@
+# check the sync of 3dmm feature and the audio
+import shutil
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm
+
+
+def draw_landmarks(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1)
+ cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return image
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False):
+
+ coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+ info = scio.loadmat(first_frame_coeff)['trans_params'][0]
+ print(info)
+
+ coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+ # print(coeff_pred.shape)
+ # print(coeff_pred[1:, 64:].shape)
+
+ if args.still:
+ coeff_pred[1:, 64:] = np.stack([coeff_pred[0, 64:]]*coeff_pred[1:, 64:].shape[0])
+
+ # assert False
+
+ coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+ coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+ coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
+ coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_video_path = '/tmp/face3dtmp.mp4'
+ facemodel = FaceReconModel(args)
+ im0 = cv2.imread(args.source_image)
+
+ video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+ # since we resize the video, we first need to resize the landmark to the cropped size resolution
+ # then, we need to add it back to the original video
+ x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256
+
+ W, H = im0.shape[0], im0.shape[1]
+
+ _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7])
+ orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)]
+
+ landmark_scale = np.array([[x_scale, y_scale]])
+ landmark_shift = np.array([[orig_left, orig_up]])
+ landmark_shift2 = np.array([[ox1, oy1]])
+
+
+ landmarks = []
+
+ for k in tqdm(range(coeff_first.shape[0]), '1st:'):
+ cur_coeff_full = torch.tensor(coeff_first, device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ print(orig_up, orig_left, orig_crop_size, s)
+
+ for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+ cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ rendered_img = facemodel.pred_face
+ rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+ out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+ video.write(np.uint8(out_img[:,:,::-1]))
+
+ video.release()
+
+ # visualize landmarks
+ video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1]))
+
+ for k in tqdm(range(len(landmarks)), 'face3d vis:'):
+ # im = draw_landmarks(im0.copy(), landmarks[k])
+ im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k])
+ video.write(im)
+ video.release()
+
+ shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png'))
+
+ np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks)
+
+ command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+ subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/sadtalker_audio2pose/src/face3d/visualize_old.py b/sadtalker_audio2pose/src/face3d/visualize_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4a37b388320344fd96b4778b60679440fe584c3
--- /dev/null
+++ b/sadtalker_audio2pose/src/face3d/visualize_old.py
@@ -0,0 +1,110 @@
+# check the sync of 3dmm feature and the audio
+import shutil
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm
+
+
+def draw_landmarks(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1)
+ cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return image
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False):
+
+ coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+ info = scio.loadmat(first_frame_coeff)['trans_params'][0]
+ print(info)
+
+ coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+ coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+ coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+ coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
+ coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_video_path = '/tmp/face3dtmp.mp4'
+ facemodel = FaceReconModel(args)
+ im0 = cv2.imread(args.source_image)
+
+ video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+ # since we resize the video, we first need to resize the landmark to the cropped size resolution
+ # then, we need to add it back to the original video
+ x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256
+
+ W, H = im0.shape[0], im0.shape[1]
+
+ _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7])
+ orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)]
+
+ landmark_scale = np.array([[x_scale, y_scale]])
+ landmark_shift = np.array([[orig_left, orig_up]])
+ landmark_shift2 = np.array([[ox1, oy1]])
+
+ landmarks = []
+
+ print(orig_up, orig_left, orig_crop_size, s)
+
+ for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+ cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ rendered_img = facemodel.pred_face
+ rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+ out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+ video.write(np.uint8(out_img[:,:,::-1]))
+
+ video.release()
+
+ # visualize landmarks
+ video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1]))
+
+ for k in tqdm(range(len(landmarks)), 'face3d vis:'):
+ # im = draw_landmarks(im0.copy(), landmarks[k])
+ im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k])
+ video.write(im)
+ video.release()
+
+ shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png'))
+
+ np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks)
+
+ command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+ subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/sadtalker_audio2pose/src/facerender/animate.py b/sadtalker_audio2pose/src/facerender/animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..45fcb45edb4169166b851a066c8aaf08063ed1c6
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/animate.py
@@ -0,0 +1,261 @@
+import os
+import cv2
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+from src.facerender.modules.make_animation import make_animation
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+class AnimateFromCoeff():
+
+ def __init__(self, sadtalker_path, device):
+
+ with open(sadtalker_path['facerender_yaml']) as f:
+ config = yaml.safe_load(f)
+
+ generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+ kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+ he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+ mapping = MappingNet(**config['model_params']['mapping_params'])
+
+ generator.to(device)
+ kp_extractor.to(device)
+ he_estimator.to(device)
+ mapping.to(device)
+ for param in generator.parameters():
+ param.requires_grad = False
+ for param in kp_extractor.parameters():
+ param.requires_grad = False
+ for param in he_estimator.parameters():
+ param.requires_grad = False
+ for param in mapping.parameters():
+ param.requires_grad = False
+
+ if sadtalker_path is not None:
+ if 'checkpoint' in sadtalker_path: # use safe tensor
+ self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
+ else:
+ self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+ else:
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+ if sadtalker_path['mappingnet_checkpoint'] is not None:
+ self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
+ else:
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.he_estimator = he_estimator
+ self.mapping = mapping
+
+ self.kp_extractor.eval()
+ self.generator.eval()
+ self.he_estimator.eval()
+ self.mapping.eval()
+
+ self.device = device
+
+ def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+ def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if mapping is not None:
+ mapping.load_state_dict(checkpoint['mapping'])
+ if discriminator is not None:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ if optimizer_mapping is not None:
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
+ if optimizer_discriminator is not None:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+
+ return checkpoint['epoch']
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ if 'yaw_c_seq' in x:
+ yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
+ yaw_c_seq = x['yaw_c_seq'].to(self.device)
+ else:
+ yaw_c_seq = None
+ if 'pitch_c_seq' in x:
+ pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
+ pitch_c_seq = x['pitch_c_seq'].to(self.device)
+ else:
+ pitch_c_seq = None
+ if 'roll_c_seq' in x:
+ roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
+ roll_c_seq = x['roll_c_seq'].to(self.device)
+ else:
+ roll_c_seq = None
+
+ frame_num = x['frame_num']
+
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
+ self.generator, self.kp_extractor, self.he_estimator, self.mapping,
+ yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
+
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+ predictions_video = predictions_video[:frame_num]
+
+ video = []
+ for idx in range(predictions_video.shape[0]):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ # print(path)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+
+
+ # os.remove(enhanced_path)
+
+ # os.remove(path)
+ # os.remove(new_audio_path)
+
+ return return_path
+
diff --git a/sadtalker_audio2pose/src/facerender/modules/dense_motion.py b/sadtalker_audio2pose/src/facerender/modules/dense_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c30417870e79bc005ea47a8f383c3aa406df563
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/dense_motion.py
@@ -0,0 +1,121 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+
+class DenseMotionNetwork(nn.Module):
+ """
+ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
+ """
+
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
+ estimate_occlusion_map=False):
+ super(DenseMotionNetwork, self).__init__()
+ # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
+
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
+
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
+ self.norm = BatchNorm3d(compress, affine=True)
+
+ if estimate_occlusion_map:
+ # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
+ else:
+ self.occlusion = None
+
+ self.num_kp = num_kp
+
+
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape
+ identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3)
+ coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
+
+ # if 'jacobian' in kp_driving:
+ if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
+ jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
+ jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
+ jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
+ coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
+ coordinate_grid = coordinate_grid.squeeze(-1)
+
+
+ driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
+
+ #adding background feature
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3
+
+ # sparse_motions = driving_to_source
+
+ return sparse_motions
+
+ def create_deformed_feature(self, feature, sparse_motions):
+ bs, _, d, h, w = feature.shape
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!!
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
+ return sparse_deformed
+
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
+ spatial_size = feature.shape[3:]
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
+ heatmap = gaussian_driving - gaussian_source
+
+ # adding background feature
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
+ heatmap = torch.cat([zeros, heatmap], dim=1)
+ heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+ return heatmap
+
+ def forward(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape
+
+ feature = self.compress(feature)
+ feature = self.norm(feature)
+ feature = F.relu(feature)
+
+ out_dict = dict()
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion)
+
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
+
+ input_ = torch.cat([heatmap, deformed_feature], dim=2)
+ input_ = input_.view(bs, -1, d, h, w)
+
+ # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w)
+
+ prediction = self.hourglass(input_)
+
+
+ mask = self.mask(prediction)
+ mask = F.softmax(mask, dim=1)
+ out_dict['mask'] = mask
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+
+ zeros_mask = torch.zeros_like(mask)
+ mask = torch.where(mask < 1e-3, zeros_mask, mask)
+
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w)
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
+
+ out_dict['deformation'] = deformation
+
+ if self.occlusion:
+ bs, c, d, h, w = prediction.shape
+ prediction = prediction.view(bs, -1, h, w)
+ occlusion_map = torch.sigmoid(self.occlusion(prediction))
+ out_dict['occlusion_map'] = occlusion_map
+
+ return out_dict
diff --git a/sadtalker_audio2pose/src/facerender/modules/discriminator.py b/sadtalker_audio2pose/src/facerender/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc0a2b460d2175a958d7b230b7e5233d7d7c7f92
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/discriminator.py
@@ -0,0 +1,90 @@
+from torch import nn
+import torch.nn.functional as F
+from facerender.modules.util import kp2gaussian
+import torch
+
+
+class DownBlock2d(nn.Module):
+ """
+ Simple block for processing video (encoder).
+ """
+
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
+
+ if sn:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ if norm:
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
+ else:
+ self.norm = None
+ self.pool = pool
+
+ def forward(self, x):
+ out = x
+ out = self.conv(out)
+ if self.norm:
+ out = self.norm(out)
+ out = F.leaky_relu(out, 0.2)
+ if self.pool:
+ out = F.avg_pool2d(out, (2, 2))
+ return out
+
+
+class Discriminator(nn.Module):
+ """
+ Discriminator similar to Pix2Pix
+ """
+
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
+ sn=False, **kwargs):
+ super(Discriminator, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(
+ DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
+ if sn:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ def forward(self, x):
+ feature_maps = []
+ out = x
+
+ for down_block in self.down_blocks:
+ feature_maps.append(down_block(out))
+ out = feature_maps[-1]
+ prediction_map = self.conv(out)
+
+ return feature_maps, prediction_map
+
+
+class MultiScaleDiscriminator(nn.Module):
+ """
+ Multi-scale (scale) discriminator
+ """
+
+ def __init__(self, scales=(), **kwargs):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.scales = scales
+ discs = {}
+ for scale in scales:
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
+ self.discs = nn.ModuleDict(discs)
+
+ def forward(self, x):
+ out_dict = {}
+ for scale, disc in self.discs.items():
+ scale = str(scale).replace('-', '.')
+ key = 'prediction_' + scale
+ feature_maps, prediction_map = disc(x[key])
+ out_dict['feature_maps_' + scale] = feature_maps
+ out_dict['prediction_map_' + scale] = prediction_map
+ return out_dict
diff --git a/sadtalker_audio2pose/src/facerender/modules/generator.py b/sadtalker_audio2pose/src/facerender/modules/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b94dde7a37c5ddf0f74dd0317a5db3507ab0729
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/generator.py
@@ -0,0 +1,255 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
+from src.facerender.modules.dense_motion import DenseMotionNetwork
+
+
+class OcclusionAwareGenerator(nn.Module):
+ """
+ Generator follows NVIDIA architecture.
+ """
+
+ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+ num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+ super(OcclusionAwareGenerator, self).__init__()
+
+ if dense_motion_params is not None:
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params)
+ else:
+ self.dense_motion_network = None
+
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+ out_features = block_expansion * (2 ** (num_down_blocks))
+ self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+ self.resblocks_2d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
+
+ up_blocks = []
+ for i in range(num_down_blocks):
+ in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
+ out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
+ self.estimate_occlusion_map = estimate_occlusion_map
+ self.image_channel = image_channel
+
+ def deform_input(self, inp, deformation):
+ _, d_old, h_old, w_old, _ = deformation.shape
+ _, _, d, h, w = inp.shape
+ if d_old != d or h_old != h or w_old != w:
+ deformation = deformation.permute(0, 4, 1, 2, 3)
+ deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+ deformation = deformation.permute(0, 2, 3, 4, 1)
+ return F.grid_sample(inp, deformation)
+
+ def forward(self, source_image, kp_driving, kp_source):
+ # Encoding (downsampling) part
+ out = self.first(source_image)
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape
+ # print(out.shape)
+ feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
+ feature_3d = self.resblocks_3d(feature_3d)
+
+ # Transforming feature representation according to deformation and occlusion
+ output_dict = {}
+ if self.dense_motion_network is not None:
+ dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+ kp_source=kp_source)
+ output_dict['mask'] = dense_motion['mask']
+
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+ else:
+ occlusion_map = None
+ deformation = dense_motion['deformation']
+ out = self.deform_input(feature_3d, deformation)
+
+ bs, c, d, h, w = out.shape
+ out = out.view(bs, c*d, h, w)
+ out = self.third(out)
+ out = self.fourth(out)
+
+ if occlusion_map is not None:
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ out = out * occlusion_map
+
+ # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image
+
+ # Decoding part
+ out = self.resblocks_2d(out)
+ for i in range(len(self.up_blocks)):
+ out = self.up_blocks[i](out)
+ out = self.final(out)
+ out = F.sigmoid(out)
+
+ output_dict["prediction"] = out
+
+ return output_dict
+
+
+class SPADEDecoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ ic = 256
+ oc = 64
+ norm_G = 'spadespectralinstance'
+ label_nc = 256
+
+ self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
+ self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
+ self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
+ self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
+ self.up = nn.Upsample(scale_factor=2)
+
+ def forward(self, feature):
+ seg = feature
+ x = self.fc(feature)
+ x = self.G_middle_0(x, seg)
+ x = self.G_middle_1(x, seg)
+ x = self.G_middle_2(x, seg)
+ x = self.G_middle_3(x, seg)
+ x = self.G_middle_4(x, seg)
+ x = self.G_middle_5(x, seg)
+ x = self.up(x)
+ x = self.up_0(x, seg) # 256, 128, 128
+ x = self.up(x)
+ x = self.up_1(x, seg) # 64, 256, 256
+
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
+ # x = torch.tanh(x)
+ x = F.sigmoid(x)
+
+ return x
+
+
+class OcclusionAwareSPADEGenerator(nn.Module):
+
+ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+ num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+ super(OcclusionAwareSPADEGenerator, self).__init__()
+
+ if dense_motion_params is not None:
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params)
+ else:
+ self.dense_motion_network = None
+
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+ out_features = block_expansion * (2 ** (num_down_blocks))
+ self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+ self.estimate_occlusion_map = estimate_occlusion_map
+ self.image_channel = image_channel
+
+ self.decoder = SPADEDecoder()
+
+ def deform_input(self, inp, deformation):
+ _, d_old, h_old, w_old, _ = deformation.shape
+ _, _, d, h, w = inp.shape
+ if d_old != d or h_old != h or w_old != w:
+ deformation = deformation.permute(0, 4, 1, 2, 3)
+ deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+ deformation = deformation.permute(0, 2, 3, 4, 1)
+ return F.grid_sample(inp, deformation)
+
+ def forward(self, source_image, kp_driving, kp_source):
+ # Encoding (downsampling) part
+ out = self.first(source_image)
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape
+ # print(out.shape)
+ feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
+ feature_3d = self.resblocks_3d(feature_3d)
+
+ # Transforming feature representation according to deformation and occlusion
+ output_dict = {}
+ if self.dense_motion_network is not None:
+ dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+ kp_source=kp_source)
+ output_dict['mask'] = dense_motion['mask']
+
+ # import pdb; pdb.set_trace()
+
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+ else:
+ occlusion_map = None
+ deformation = dense_motion['deformation']
+ out = self.deform_input(feature_3d, deformation)
+
+ bs, c, d, h, w = out.shape
+ out = out.view(bs, c*d, h, w)
+ out = self.third(out)
+ out = self.fourth(out)
+
+ # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)
+
+ if occlusion_map is not None:
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ out = out * occlusion_map
+
+ # Decoding part
+ out = self.decoder(out)
+
+ output_dict["prediction"] = out
+
+ return output_dict
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/modules/keypoint_detector.py b/sadtalker_audio2pose/src/facerender/modules/keypoint_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..e56800c7b1e94bb3cbf97200cd3f059ce9d29cf3
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/keypoint_detector.py
@@ -0,0 +1,179 @@
+from torch import nn
+import torch
+import torch.nn.functional as F
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
+
+
+class KPDetector(nn.Module):
+ """
+ Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
+ """
+
+ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
+ num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
+ super(KPDetector, self).__init__()
+
+ self.predictor = KPHourglass(block_expansion, in_features=image_channel,
+ max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
+
+ # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
+ self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
+
+ if estimate_jacobian:
+ self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
+ # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
+ self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
+ '''
+ initial as:
+ [[1 0 0]
+ [0 1 0]
+ [0 0 1]]
+ '''
+ self.jacobian.weight.data.zero_()
+ self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
+ else:
+ self.jacobian = None
+
+ self.temperature = temperature
+ self.scale_factor = scale_factor
+ if self.scale_factor != 1:
+ self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
+
+ def gaussian2kp(self, heatmap):
+ """
+ Extract the mean from a heatmap
+ """
+ shape = heatmap.shape
+ heatmap = heatmap.unsqueeze(-1)
+ grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
+ value = (heatmap * grid).sum(dim=(2, 3, 4))
+ kp = {'value': value}
+
+ return kp
+
+ def forward(self, x):
+ if self.scale_factor != 1:
+ x = self.down(x)
+
+ feature_map = self.predictor(x)
+ prediction = self.kp(feature_map)
+
+ final_shape = prediction.shape
+ heatmap = prediction.view(final_shape[0], final_shape[1], -1)
+ heatmap = F.softmax(heatmap / self.temperature, dim=2)
+ heatmap = heatmap.view(*final_shape)
+
+ out = self.gaussian2kp(heatmap)
+
+ if self.jacobian is not None:
+ jacobian_map = self.jacobian(feature_map)
+ jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
+ final_shape[3], final_shape[4])
+ heatmap = heatmap.unsqueeze(2)
+
+ jacobian = heatmap * jacobian_map
+ jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
+ jacobian = jacobian.sum(dim=-1)
+ jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
+ out['jacobian'] = jacobian
+
+ return out
+
+
+class HEEstimator(nn.Module):
+ """
+ Estimating head pose and expression.
+ """
+
+ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
+ super(HEEstimator, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
+ self.norm1 = BatchNorm2d(block_expansion, affine=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
+ self.norm2 = BatchNorm2d(256, affine=True)
+
+ self.block1 = nn.Sequential()
+ for i in range(3):
+ self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
+
+ self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
+ self.norm3 = BatchNorm2d(512, affine=True)
+ self.block2 = ResBottleneck(in_features=512, stride=2)
+
+ self.block3 = nn.Sequential()
+ for i in range(3):
+ self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
+
+ self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
+ self.norm4 = BatchNorm2d(1024, affine=True)
+ self.block4 = ResBottleneck(in_features=1024, stride=2)
+
+ self.block5 = nn.Sequential()
+ for i in range(5):
+ self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
+
+ self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
+ self.norm5 = BatchNorm2d(2048, affine=True)
+ self.block6 = ResBottleneck(in_features=2048, stride=2)
+
+ self.block7 = nn.Sequential()
+ for i in range(2):
+ self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
+
+ self.fc_roll = nn.Linear(2048, num_bins)
+ self.fc_pitch = nn.Linear(2048, num_bins)
+ self.fc_yaw = nn.Linear(2048, num_bins)
+
+ self.fc_t = nn.Linear(2048, 3)
+
+ self.fc_exp = nn.Linear(2048, 3*num_kp)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = F.relu(out)
+ out = self.maxpool(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+
+ out = self.block1(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+ out = F.relu(out)
+ out = self.block2(out)
+
+ out = self.block3(out)
+
+ out = self.conv4(out)
+ out = self.norm4(out)
+ out = F.relu(out)
+ out = self.block4(out)
+
+ out = self.block5(out)
+
+ out = self.conv5(out)
+ out = self.norm5(out)
+ out = F.relu(out)
+ out = self.block6(out)
+
+ out = self.block7(out)
+
+ out = F.adaptive_avg_pool2d(out, 1)
+ out = out.view(out.shape[0], -1)
+
+ yaw = self.fc_roll(out)
+ pitch = self.fc_pitch(out)
+ roll = self.fc_yaw(out)
+ t = self.fc_t(out)
+ exp = self.fc_exp(out)
+
+ return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
+
diff --git a/sadtalker_audio2pose/src/facerender/modules/make_animation.py b/sadtalker_audio2pose/src/facerender/modules/make_animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c8c53dcc04da8354d05c98c2bc0d88bf067fb2
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/make_animation.py
@@ -0,0 +1,170 @@
+from scipy.spatial import ConvexHull
+import torch
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm
+
+def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
+ use_relative_movement=False, use_relative_jacobian=False):
+ if adapt_movement_scale:
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
+ else:
+ adapt_movement_scale = 1
+
+ kp_new = {k: v for k, v in kp_driving.items()}
+
+ if use_relative_movement:
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
+ kp_value_diff *= adapt_movement_scale
+ kp_new['value'] = kp_value_diff + kp_source['value']
+
+ if use_relative_jacobian:
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
+
+ return kp_new
+
+def headpose_pred_to_degree(pred):
+ device = pred.device
+ idx_tensor = [idx for idx in range(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
+ pred = F.softmax(pred)
+ degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+ return degree
+
+def get_rotation_matrix(yaw, pitch, roll):
+ yaw = yaw / 180 * 3.14
+ pitch = pitch / 180 * 3.14
+ roll = roll / 180 * 3.14
+
+ roll = roll.unsqueeze(1)
+ pitch = pitch.unsqueeze(1)
+ yaw = yaw.unsqueeze(1)
+
+ pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
+ torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
+ torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
+ pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+ yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
+ torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
+ -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
+ yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+ roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
+ torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
+ torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
+ roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+ rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
+
+ return rot_mat
+
+def keypoint_transformation(kp_canonical, he, wo_exp=False):
+ kp = kp_canonical['value'] # (bs, k, 3)
+ yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']
+ yaw = headpose_pred_to_degree(yaw)
+ pitch = headpose_pred_to_degree(pitch)
+ roll = headpose_pred_to_degree(roll)
+
+ if 'yaw_in' in he:
+ yaw = he['yaw_in']
+ if 'pitch_in' in he:
+ pitch = he['pitch_in']
+ if 'roll_in' in he:
+ roll = he['roll_in']
+
+ rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ t, exp = he['t'], he['exp']
+ if wo_exp:
+ exp = exp*0
+
+ # keypoint rotation
+ kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+ # keypoint translation
+ t[:, 0] = t[:, 0]*0
+ t[:, 2] = t[:, 2]*0
+ t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.view(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return {'value': kp_transformed}
+
+
+
+def make_animation(source_image, source_semantics, target_semantics,
+ generator, kp_detector, he_estimator, mapping,
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
+ use_exp=True, use_half=False):
+ with torch.no_grad():
+ predictions = []
+
+ kp_canonical = kp_detector(source_image)
+ he_source = mapping(source_semantics)
+ kp_source = keypoint_transformation(kp_canonical, he_source)
+
+ for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
+ # still check the dimension
+ # print(target_semantics.shape, source_semantics.shape)
+ target_semantics_frame = target_semantics[:, frame_idx]
+ he_driving = mapping(target_semantics_frame)
+ if yaw_c_seq is not None:
+ he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
+ if pitch_c_seq is not None:
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
+ if roll_c_seq is not None:
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
+
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
+
+ kp_norm = kp_driving
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
+ '''
+ source_image_new = out['prediction'].squeeze(1)
+ kp_canonical_new = kp_detector(source_image_new)
+ he_source_new = he_estimator(source_image_new)
+ kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
+ kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
+ out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
+ '''
+ predictions.append(out['prediction'])
+ predictions_ts = torch.stack(predictions, dim=1)
+ return predictions_ts
+
+class AnimateModel(torch.nn.Module):
+ """
+ Merge all generator related updates into single model for better multi-gpu usage
+ """
+
+ def __init__(self, generator, kp_extractor, mapping):
+ super(AnimateModel, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.mapping = mapping
+
+ self.kp_extractor.eval()
+ self.generator.eval()
+ self.mapping.eval()
+
+ def forward(self, x):
+
+ source_image = x['source_image']
+ source_semantics = x['source_semantics']
+ target_semantics = x['target_semantics']
+ yaw_c_seq = x['yaw_c_seq']
+ pitch_c_seq = x['pitch_c_seq']
+ roll_c_seq = x['roll_c_seq']
+
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
+ self.generator, self.kp_extractor,
+ self.mapping, use_exp = True,
+ yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)
+
+ return predictions_video
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/modules/mapping.py b/sadtalker_audio2pose/src/facerender/modules/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac98dd9e177b949f71f8f47029b66d67ece05b4
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/mapping.py
@@ -0,0 +1,47 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ self.fc_roll = nn.Linear(descriptor_nc, num_bins)
+ self.fc_pitch = nn.Linear(descriptor_nc, num_bins)
+ self.fc_yaw = nn.Linear(descriptor_nc, num_bins)
+ self.fc_t = nn.Linear(descriptor_nc, 3)
+ self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ out = out.view(out.shape[0], -1)
+ #print('out:', out.shape)
+
+ yaw = self.fc_yaw(out)
+ pitch = self.fc_pitch(out)
+ roll = self.fc_roll(out)
+ t = self.fc_t(out)
+ exp = self.fc_exp(out)
+
+ return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/modules/util.py b/sadtalker_audio2pose/src/facerender/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bfb1f26427b491f032ca9952db41cdeb793d70
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/modules/util.py
@@ -0,0 +1,564 @@
+from torch import nn
+
+import torch.nn.functional as F
+import torch
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+import torch.nn.utils.spectral_norm as spectral_norm
+
+
+def kp2gaussian(kp, spatial_size, kp_variance):
+ """
+ Transform a keypoint into gaussian like representation
+ """
+ mean = kp['value']
+
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
+ number_of_leading_dimensions = len(mean.shape) - 1
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
+ coordinate_grid = coordinate_grid.view(*shape)
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
+ coordinate_grid = coordinate_grid.repeat(*repeats)
+
+ # Preprocess kp shape
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
+ mean = mean.view(*shape)
+
+ mean_sub = (coordinate_grid - mean)
+
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
+
+ return out
+
+def make_coordinate_grid_2d(spatial_size, type):
+ """
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
+ """
+ h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+
+ return meshed
+
+
+def make_coordinate_grid(spatial_size, type):
+ d, h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+ z = torch.arange(d).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+ z = (2 * (z / (d - 1)) - 1)
+
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
+
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
+
+ return meshed
+
+
+class ResBottleneck(nn.Module):
+ def __init__(self, in_features, stride):
+ super(ResBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)
+ self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)
+ self.norm1 = BatchNorm2d(in_features//4, affine=True)
+ self.norm2 = BatchNorm2d(in_features//4, affine=True)
+ self.norm3 = BatchNorm2d(in_features, affine=True)
+
+ self.stride = stride
+ if self.stride != 1:
+ self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)
+ self.norm4 = BatchNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv3(out)
+ out = self.norm3(out)
+ if self.stride != 1:
+ x = self.skip(x)
+ x = self.norm4(x)
+ out += x
+ out = F.relu(out)
+ return out
+
+
+class ResBlock2d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm2d(in_features, affine=True)
+ self.norm2 = BatchNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class ResBlock3d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock3d, self).__init__()
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm3d(in_features, affine=True)
+ self.norm2 = BatchNorm3d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class UpBlock2d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock2d, self).__init__()
+
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+
+ def forward(self, x):
+ out = F.interpolate(x, scale_factor=2)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+class UpBlock3d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock3d, self).__init__()
+
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm3d(out_features, affine=True)
+
+ def forward(self, x):
+ # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class DownBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class DownBlock3d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock3d, self).__init__()
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups, stride=(1, 2, 2))
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm3d(out_features, affine=True)
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class SameBlock2d(nn.Module):
+ """
+ Simple block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+ kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+ if lrelu:
+ self.ac = nn.LeakyReLU()
+ else:
+ self.ac = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = self.ac(out)
+ return out
+
+
+class Encoder(nn.Module):
+ """
+ Hourglass Encoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Encoder, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ return outs
+
+
+class Decoder(nn.Module):
+ """
+ Hourglass Decoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Decoder, self).__init__()
+
+ up_blocks = []
+
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+ # self.out_filters = block_expansion
+ self.out_filters = block_expansion + in_features
+
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
+ self.norm = BatchNorm3d(self.out_filters, affine=True)
+
+ def forward(self, x):
+ out = x.pop()
+ # for up_block in self.up_blocks[:-1]:
+ for up_block in self.up_blocks:
+ out = up_block(out)
+ skip = x.pop()
+ out = torch.cat([out, skip], dim=1)
+ # out = self.up_blocks[-1](out)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class Hourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_filters = self.decoder.out_filters
+
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
+
+
+class KPHourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):
+ super(KPHourglass, self).__init__()
+
+ self.down_blocks = nn.Sequential()
+ for i in range(num_blocks):
+ self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+
+ in_filters = min(max_features, block_expansion * (2 ** num_blocks))
+ self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)
+
+ self.up_blocks = nn.Sequential()
+ for i in range(num_blocks):
+ in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))
+ out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
+ self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.reshape_depth = reshape_depth
+ self.out_filters = out_filters
+
+ def forward(self, x):
+ out = self.down_blocks(x)
+ out = self.conv(out)
+ bs, c, h, w = out.shape
+ out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)
+ out = self.up_blocks(out)
+
+ return out
+
+
+
+class AntiAliasInterpolation2d(nn.Module):
+ """
+ Band-limited downsampling, for better preservation of the input signal.
+ """
+ def __init__(self, channels, scale):
+ super(AntiAliasInterpolation2d, self).__init__()
+ sigma = (1 / scale - 1) / 2
+ kernel_size = 2 * round(sigma * 4) + 1
+ self.ka = kernel_size // 2
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+ kernel_size = [kernel_size, kernel_size]
+ sigma = [sigma, sigma]
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+ self.scale = scale
+ inv_scale = 1 / scale
+ self.int_inv_scale = int(inv_scale)
+
+ def forward(self, input):
+ if self.scale == 1.0:
+ return input
+
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
+
+ return out
+
+
+class SPADE(nn.Module):
+ def __init__(self, norm_nc, label_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
+ nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+
+ def forward(self, x, segmap):
+ normalized = self.param_free_norm(x)
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+ self.use_se = use_se
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+ # apply spectral norm if specified
+ if 'spectral' in norm_G:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+ # define normalization layers
+ self.norm_0 = SPADE(fin, label_nc)
+ self.norm_1 = SPADE(fmiddle, label_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(fin, label_nc)
+
+ def forward(self, x, seg1):
+ x_s = self.shortcut(x, seg1)
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg1):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg1))
+ else:
+ x_s = x
+ return x_s
+
+ def actvn(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+class audio2image(nn.Module):
+ def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params):
+ super().__init__()
+ # Attributes
+ self.generator = generator
+ self.kp_extractor = kp_extractor
+ self.he_estimator_video = he_estimator_video
+ self.he_estimator_audio = he_estimator_audio
+ self.train_params = train_params
+
+ def headpose_pred_to_degree(self, pred):
+ device = pred.device
+ idx_tensor = [idx for idx in range(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).to(device)
+ pred = F.softmax(pred)
+ degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+
+ return degree
+
+ def get_rotation_matrix(self, yaw, pitch, roll):
+ yaw = yaw / 180 * 3.14
+ pitch = pitch / 180 * 3.14
+ roll = roll / 180 * 3.14
+
+ roll = roll.unsqueeze(1)
+ pitch = pitch.unsqueeze(1)
+ yaw = yaw.unsqueeze(1)
+
+ roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll),
+ torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
+ torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
+ roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+ pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch),
+ torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
+ -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
+ pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+ yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),
+ torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
+ torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
+ yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+ rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
+
+ return rot_mat
+
+ def keypoint_transformation(self, kp_canonical, he):
+ kp = kp_canonical['value'] # (bs, k, 3)
+ yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
+ t, exp = he['t'], he['exp']
+
+ yaw = self.headpose_pred_to_degree(yaw)
+ pitch = self.headpose_pred_to_degree(pitch)
+ roll = self.headpose_pred_to_degree(roll)
+
+ rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ # keypoint rotation
+ kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+
+
+ # keypoint translation
+ t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.view(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return {'value': kp_transformed}
+
+ def forward(self, source_image, target_audio):
+ pose_source = self.he_estimator_video(source_image)
+ pose_generated = self.he_estimator_audio(target_audio)
+ kp_canonical = self.kp_extractor(source_image)
+ kp_source = self.keypoint_transformation(kp_canonical, pose_source)
+ kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated)
+ generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated)
+ return generated
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/pirender/base_function.py b/sadtalker_audio2pose/src/facerender/pirender/base_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..650fb7de1b95fc34e4b7c17b2526c1f450a577a0
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/pirender/base_function.py
@@ -0,0 +1,368 @@
+import sys
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
+
+
+class LayerNorm2d(nn.Module):
+ def __init__(self, n_out, affine=True):
+ super(LayerNorm2d, self).__init__()
+ self.n_out = n_out
+ self.affine = affine
+
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
+
+ def forward(self, x):
+ normalized_shape = x.size()[1:]
+ if self.affine:
+ return F.layer_norm(x, normalized_shape, \
+ self.weight.expand(normalized_shape),
+ self.bias.expand(normalized_shape))
+
+ else:
+ return F.layer_norm(x, normalized_shape)
+
+class ADAINHourglass(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
+ super(ADAINHourglass, self).__init__()
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
+ self.output_nc = self.decoder.output_nc
+
+ def forward(self, x, z):
+ return self.decoder(self.encoder(x, z), z)
+
+
+
+class ADAINEncoder(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoder, self).__init__()
+ self.layers = layers
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
+ for i in range(layers):
+ in_channels = min(ngf * (2**i), img_f)
+ out_channels = min(ngf *(2**(i+1)), img_f)
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
+ setattr(self, 'encoder' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = self.input_layer(x)
+ out_list = [out]
+ for i in range(self.layers):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out, z)
+ out_list.append(out)
+ return out_list
+
+class ADAINDecoder(nn.Module):
+ """docstring for ADAINDecoder"""
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
+
+ super(ADAINDecoder, self).__init__()
+ self.encoder_layers = encoder_layers
+ self.decoder_layers = decoder_layers
+ self.skip_connect = skip_connect
+ use_transpose = True
+
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
+ in_channels = min(ngf * (2**(i+1)), img_f)
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
+ out_channels = min(ngf * (2**i), img_f)
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
+ setattr(self, 'decoder' + str(i), model)
+
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
+
+ def forward(self, x, z):
+ out = x.pop() if self.skip_connect else x
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
+ model = getattr(self, 'decoder' + str(i))
+ out = model(out, z)
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
+ return out
+
+class ADAINEncoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoderBlock, self).__init__()
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
+
+
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(output_nc, feature_nc)
+ self.actvn = nonlinearity
+
+ def forward(self, x, z):
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
+ return x
+
+class ADAINDecoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINDecoderBlock, self).__init__()
+ # Attributes
+ self.actvn = nonlinearity
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
+
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
+ if use_transpose:
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
+ else:
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
+
+ # create conv layers
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
+ if use_transpose:
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
+ else:
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ # define normalization layers
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
+ self.norm_s = ADAIN(input_nc, feature_nc)
+
+ def forward(self, x, z):
+ x_s = self.shortcut(x, z)
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, z):
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
+ return x_s
+
+
+def spectral_norm(module, use_spect=True):
+ """use spectral normal layer to stable the training process"""
+ if use_spect:
+ return SpectralNorm(module)
+ else:
+ return module
+
+
+class ADAIN(nn.Module):
+ def __init__(self, norm_nc, feature_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+
+ nhidden = 128
+ use_bias=True
+
+ self.mlp_shared = nn.Sequential(
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
+ nn.ReLU()
+ )
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
+
+ def forward(self, x, feature):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on feature
+ feature = feature.view(feature.size(0), -1)
+ actv = self.mlp_shared(feature)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
+ beta = beta.view(*beta.size()[:2], 1,1)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class FineEncoder(nn.Module):
+ """docstring for Encoder"""
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineEncoder, self).__init__()
+ self.layers = layers
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
+ for i in range(layers):
+ in_channels = min(ngf*(2**i), img_f)
+ out_channels = min(ngf*(2**(i+1)), img_f)
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'down' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x):
+ x = self.first(x)
+ out=[x]
+ for i in range(self.layers):
+ model = getattr(self, 'down'+str(i))
+ x = model(x)
+ out.append(x)
+ return out
+
+class FineDecoder(nn.Module):
+ """docstring for FineDecoder"""
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineDecoder, self).__init__()
+ self.layers = layers
+ for i in range(layers)[::-1]:
+ in_channels = min(ngf*(2**(i+1)), img_f)
+ out_channels = min(ngf*(2**i), img_f)
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
+
+ setattr(self, 'up' + str(i), up)
+ setattr(self, 'res' + str(i), res)
+ setattr(self, 'jump' + str(i), jump)
+
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
+
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = x.pop()
+ for i in range(self.layers)[::-1]:
+ res_model = getattr(self, 'res' + str(i))
+ up_model = getattr(self, 'up' + str(i))
+ jump_model = getattr(self, 'jump' + str(i))
+ out = res_model(out, z)
+ out = up_model(out)
+ out = jump_model(x.pop()) + out
+ out_image = self.final(out)
+ return out_image
+
+class FirstBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FirstBlock2d, self).__init__()
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class DownBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(DownBlock2d, self).__init__()
+
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity, pool)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class UpBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(UpBlock2d, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(F.interpolate(x, scale_factor=2))
+ return out
+
+class FineADAINResBlocks(nn.Module):
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlocks, self).__init__()
+ self.num_block = num_block
+ for i in range(num_block):
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'res'+str(i), model)
+
+ def forward(self, x, z):
+ for i in range(self.num_block):
+ model = getattr(self, 'res'+str(i))
+ x = model(x, z)
+ return x
+
+class Jump(nn.Module):
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(Jump, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class FineADAINResBlock2d(nn.Module):
+ """
+ Define an Residual block for different types
+ """
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.norm1 = ADAIN(input_nc, feature_nc)
+ self.norm2 = ADAIN(input_nc, feature_nc)
+
+ self.actvn = nonlinearity
+
+
+ def forward(self, x, z):
+ dx = self.actvn(self.norm1(self.conv1(x), z))
+ dx = self.norm2(self.conv2(x), z)
+ out = dx + x
+ return out
+
+class FinalBlock2d(nn.Module):
+ """
+ Define the output layer
+ """
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
+ super(FinalBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if tanh_or_sigmoid == 'sigmoid':
+ out_nonlinearity = nn.Sigmoid()
+ else:
+ out_nonlinearity = nn.Tanh()
+
+ self.model = nn.Sequential(conv, out_nonlinearity)
+ def forward(self, x):
+ out = self.model(x)
+ return out
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/pirender/config.py b/sadtalker_audio2pose/src/facerender/pirender/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..29dc2d1b9008dbf2dc3c0a307212471621bae8da
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/pirender/config.py
@@ -0,0 +1,211 @@
+import collections
+import functools
+import os
+import re
+
+import yaml
+
+class AttrDict(dict):
+ """Dict as attribute trick."""
+
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+ for key, value in self.__dict__.items():
+ if isinstance(value, dict):
+ self.__dict__[key] = AttrDict(value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ self.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ self.__dict__[key] = value
+
+ def yaml(self):
+ """Convert object to yaml dict and return."""
+ yaml_dict = {}
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ yaml_dict[key] = value.yaml()
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ new_l = []
+ for item in value:
+ new_l.append(item.yaml())
+ yaml_dict[key] = new_l
+ else:
+ yaml_dict[key] = value
+ else:
+ yaml_dict[key] = value
+ return yaml_dict
+
+ def __repr__(self):
+ """Print all variables."""
+ ret_str = []
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ ret_str.append('{}:'.format(key))
+ child_ret_str = value.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ ret_str.append('{}:'.format(key))
+ for item in value:
+ # Treat as AttrDict above.
+ child_ret_str = item.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ return '\n'.join(ret_str)
+
+
+class Config(AttrDict):
+ r"""Configuration class. This should include every human specifiable
+ hyperparameter values for your training."""
+
+ def __init__(self, filename=None, args=None, verbose=False, is_train=True):
+ super(Config, self).__init__()
+ # Set default parameters.
+ # Logging.
+
+ large_number = 1000000000
+ self.snapshot_save_iter = large_number
+ self.snapshot_save_epoch = large_number
+ self.snapshot_save_start_iter = 0
+ self.snapshot_save_start_epoch = 0
+ self.image_save_iter = large_number
+ self.eval_epoch = large_number
+ self.start_eval_epoch = large_number
+ self.eval_epoch = large_number
+ self.max_epoch = large_number
+ self.max_iter = large_number
+ self.logging_iter = 100
+ self.image_to_tensorboard=False
+ self.which_iter = 0 # args.which_iter
+ self.resume = False
+
+ self.checkpoints_dir = '/Users/shadowcun/Downloads/'
+ self.name = 'face'
+ self.phase = 'train' if is_train else 'test'
+
+ # Networks.
+ self.gen = AttrDict(type='generators.dummy')
+ self.dis = AttrDict(type='discriminators.dummy')
+
+ # Optimizers.
+ self.gen_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ self.dis_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ # Data.
+ self.data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0)
+ self.test_data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0,
+ test=AttrDict(is_lmdb=False,
+ roots='',
+ batch_size=1))
+ self.trainer = AttrDict(
+ model_average=False,
+ model_average_beta=0.9999,
+ model_average_start_iteration=1000,
+ model_average_batch_norm_estimation_iteration=30,
+ model_average_remove_sn=True,
+ image_to_tensorboard=False,
+ hparam_to_tensorboard=False,
+ distributed_data_parallel='pytorch',
+ delay_allreduce=True,
+ gan_relativistic=False,
+ gen_step=1,
+ dis_step=1)
+
+ # # Cudnn.
+ self.cudnn = AttrDict(deterministic=False,
+ benchmark=True)
+
+ # Others.
+ self.pretrained_weight = ''
+ self.inference_args = AttrDict()
+
+
+ # Update with given configurations.
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
+ loader = yaml.SafeLoader
+ loader.add_implicit_resolver(
+ u'tag:yaml.org,2002:float',
+ re.compile(u'''^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$''', re.X),
+ list(u'-+0123456789.'))
+ try:
+ with open(filename, 'r') as f:
+ cfg_dict = yaml.load(f, Loader=loader)
+ except EnvironmentError:
+ print('Please check the file with name of "%s"', filename)
+ recursive_update(self, cfg_dict)
+
+ # Put common opts in both gen and dis.
+ if 'common' in cfg_dict:
+ self.common = AttrDict(**cfg_dict['common'])
+ self.gen.common = self.common
+ self.dis.common = self.common
+
+
+ if verbose:
+ print(' config '.center(80, '-'))
+ print(self.__repr__())
+ print(''.center(80, '-'))
+
+
+def rsetattr(obj, attr, val):
+ """Recursively find object and set value"""
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+ """Recursively find object and return value"""
+
+ def _getattr(obj, attr):
+ r"""Get attribute."""
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def recursive_update(d, u):
+ """Recursively update AttrDict d with AttrDict u"""
+ for key, value in u.items():
+ if isinstance(value, collections.abc.Mapping):
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ d.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ d.__dict__[key] = value
+ else:
+ d.__dict__[key] = value
+ return d
diff --git a/sadtalker_audio2pose/src/facerender/pirender/face_model.py b/sadtalker_audio2pose/src/facerender/pirender/face_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f83e2fc5d8c66cf9bd2e2c5549773e11e0f8a44
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/pirender/face_model.py
@@ -0,0 +1,178 @@
+import functools
+import torch
+import torch.nn as nn
+from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
+
+def convert_flow_to_deformation(flow):
+ r"""convert flow fields to deformations.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ deformation (tensor): The deformation used for warpping
+ """
+ b,c,h,w = flow.shape
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
+ grid = make_coordinate_grid(flow)
+ deformation = grid + flow_norm.permute(0,2,3,1)
+ return deformation
+
+def make_coordinate_grid(flow):
+ r"""obtain coordinate grid with the same size as the flow filed.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ grid (tensor): The grid with the same size as the input flow
+ """
+ b,c,h,w = flow.shape
+
+ x = torch.arange(w).to(flow)
+ y = torch.arange(h).to(flow)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ meshed = meshed.expand(b, -1, -1, -1)
+ return meshed
+
+
+def warp_image(source_image, deformation):
+ r"""warp the input image according to the deformation
+
+ Args:
+ source_image (tensor): source images to be warpped
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
+ Returns:
+ output (tensor): the warpped images
+ """
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = source_image.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
+ deformation = deformation.permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(source_image, deformation)
+
+
+class FaceGenerator(nn.Module):
+ def __init__(
+ self,
+ mapping_net,
+ warpping_net,
+ editing_net,
+ common
+ ):
+ super(FaceGenerator, self).__init__()
+ self.mapping_net = MappingNet(**mapping_net)
+ self.warpping_net = WarpingNet(**warpping_net, **common)
+ self.editing_net = EditingNet(**editing_net, **common)
+
+ def forward(
+ self,
+ input_image,
+ driving_source,
+ stage=None
+ ):
+ if stage == 'warp':
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ else:
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
+ return output
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ return out
+
+class WarpingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ base_nc,
+ max_nc,
+ encoder_layer,
+ decoder_layer,
+ use_spect
+ ):
+ super( WarpingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
+
+ self.descriptor_nc = descriptor_nc
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
+ max_nc, encoder_layer, decoder_layer, **kwargs)
+
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
+ nonlinearity,
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
+
+ self.pool = nn.AdaptiveAvgPool2d(1)
+
+ def forward(self, input_image, descriptor):
+ final_output={}
+ output = self.hourglass(input_image, descriptor)
+ final_output['flow_field'] = self.flow_out(output)
+
+ deformation = convert_flow_to_deformation(final_output['flow_field'])
+ final_output['warp_image'] = warp_image(input_image, deformation)
+ return final_output
+
+
+class EditingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ layer,
+ base_nc,
+ max_nc,
+ num_res_blocks,
+ use_spect):
+ super(EditingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
+ self.descriptor_nc = descriptor_nc
+
+ # encoder part
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
+
+ def forward(self, input_image, warp_image, descriptor):
+ x = torch.cat([input_image, warp_image], 1)
+ x = self.encoder(x)
+ gen_image = self.decoder(x, descriptor)
+ return gen_image
diff --git a/sadtalker_audio2pose/src/facerender/pirender_animate.py b/sadtalker_audio2pose/src/facerender/pirender_animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d4ccf0918f09dcfa422a85694bd17bf42d11ff
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/pirender_animate.py
@@ -0,0 +1,266 @@
+import os
+import uuid
+import cv2
+from tqdm import tqdm
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+from src.facerender.pirender.config import Config
+from src.facerender.pirender.face_model import FaceGenerator
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+from src.utils.flow_util import vis_flow
+from scipy.io import savemat,loadmat
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+expession = loadmat('expression.mat')
+control_dict = {}
+for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']:
+ control_dict[item] = torch.tensor(expession[item])[0]
+
+class AnimateFromCoeff_PIRender():
+
+ def __init__(self, sadtalker_path, device):
+
+ opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False)
+ opt.device = device
+ self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device)
+ checkpoint_path = sadtalker_path['pirender_checkpoint']
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
+ self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False)
+ print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path))
+ self.net_G = self.net_G_ema.eval()
+ self.device = device
+
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+
+ num = 16
+
+ # import pdb; pdb.set_trace()
+ # target_semantics_
+ current = target_semantics[0, 0, :64, 0]
+ for control_k in range(len(control_dict.keys())):
+ listx = list(control_dict.keys())
+ control_v = control_dict[listx[control_k]]
+ for i in range(num):
+ expression = (control_v-current)*i/(num-1)+current
+ target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None]
+
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
+
+ def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+
+
+ num = 16
+
+ current = target_semantics[0, 0, :64, 0]
+ for control_k in range(len(control_dict.keys())):
+ listx = list(control_dict.keys())
+ control_v = control_dict[listx[control_k]]
+ for i in range(num):
+ expression = (control_v-current)*i/(num-1)+current
+ target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None]
+
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+
+ results = np.stack(video, axis=0)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ # original_size = crop_info[0]
+ # if original_size:
+ # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+ # results = np.stack(result, axis=0)
+
+ x_name = os.path.basename(pic_path)
+ save_name = os.path.join(video_save_dir, x_name + '.flo')
+ save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4')
+
+ flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False)
+
+ flow_viz = []
+ for kk in range(flow_full.shape[0]):
+ tmp = vis_flow(flow_full[kk])
+ flow_viz.append(tmp)
+ flow_viz = np.stack(flow_viz)
+
+ torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'})
+
+ return save_name_flow_vis
+
+
+def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ # full images, we only use it as reference for zero init image.
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ # template = np.zeros((frame_h, frame_w, 2)) # full flows
+ out_tmp = []
+ for crop_frame in tqdm(flows, 'seamlessClone:'):
+ p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4)
+
+ gen_img = np.zeros((frame_h, frame_w, 2))
+ # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE)
+ gen_img[oy1:oy2,ox1:ox2] = p
+ out_tmp.append(gen_img)
+
+ np.save(save_name, np.stack(out_tmp))
+ return np.stack(out_tmp)
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/pirender_animate_control.py b/sadtalker_audio2pose/src/facerender/pirender_animate_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c357f35577816c8d6731627afd505c6dd8efdca
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/pirender_animate_control.py
@@ -0,0 +1,251 @@
+import os
+import uuid
+import cv2
+from tqdm import tqdm
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+from src.facerender.pirender.config import Config
+from src.facerender.pirender.face_model import FaceGenerator
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+from src.utils.flow_util import vis_flow
+
+from scipy.io import savemat,loadmat
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+expession = loadmat('expression.mat')
+control_dict = {}
+for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']:
+ control_dict[item] = torch.tensor(expession[item])[0]
+
+class AnimateFromCoeff_PIRender():
+
+ def __init__(self, sadtalker_path, device):
+
+ opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False)
+ opt.device = device
+ self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device)
+ checkpoint_path = sadtalker_path['pirender_checkpoint']
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
+ self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False)
+ print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path))
+ self.net_G = self.net_G_ema.eval()
+ self.device = device
+
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ num = 10
+
+ # target_semantics_
+ current = target_semantics['target_semantics_list'][0, :64, 0]
+ for control in control_dict:
+ for i in range(num):
+ expression = (control_dict[control]-current)*i/(num-1)+current
+ target_semantics['target_semantics_list'][:, :64, :] = expression[None, :, None]
+
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
+
+ def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+
+ results = np.stack(video, axis=0)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ # original_size = crop_info[0]
+ # if original_size:
+ # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+ # results = np.stack(result, axis=0)
+
+ x_name = os.path.basename(pic_path)
+ save_name = os.path.join(video_save_dir, x_name + '.flo')
+ save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4')
+
+ flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False)
+
+ flow_viz = []
+ for kk in range(flow_full.shape[0]):
+ tmp = vis_flow(flow_full[kk])
+ flow_viz.append(tmp)
+ flow_viz = np.stack(flow_viz)
+
+ torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'})
+
+ return save_name_flow_vis
+
+
+def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ # full images, we only use it as reference for zero init image.
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ # template = np.zeros((frame_h, frame_w, 2)) # full flows
+ out_tmp = []
+ for crop_frame in tqdm(flows, 'seamlessClone:'):
+ p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4)
+
+ gen_img = np.zeros((frame_h, frame_w, 2))
+ # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE)
+ gen_img[oy1:oy2,ox1:ox2] = p
+ out_tmp.append(gen_img)
+
+ np.save(save_name, np.stack(out_tmp))
+ return np.stack(out_tmp)
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/__init__.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48871cdcdc882c903501ecc6d70fcb1b50bd7e9f
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/sync_batchnorm/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# File : __init__.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/batchnorm.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4cc2ccd2f0c904cbe433fb6136f443f0fa86fa6
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/sync_batchnorm/batchnorm.py
@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# File : batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+
+from .comm import SyncMaster
+
+__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
+
+
+def _sum_ft(tensor):
+ """sum over the first and last dimention"""
+ return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+ """add new dementions at the front and the tail"""
+ return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
+ self._is_parallel = False
+ self._parallel_id = None
+ self._slave_pipe = None
+
+ def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ # Resize the input to (B, C, -1).
+ input_shape = input.size()
+ input = input.view(input.size(0), self.num_features, -1)
+
+ # Compute the sum and square-sum.
+ sum_size = input.size(0) * input.size(2)
+ input_sum = _sum_ft(input)
+ input_ssum = _sum_ft(input ** 2)
+
+ # Reduce-and-broadcast the statistics.
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+ # Compute the output.
+ if self.affine:
+ # MJY:: Fuse the multiplication for speed.
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+ else:
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+ # Reshape it.
+ return output.view(input_shape)
+
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
+
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+ # Always using same "device order" makes the ReduceAdd operation faster.
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+ mini-batch.
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of size
+ `batch_size x num_features [x width]`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+ of 3d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+ of 4d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+ or Spatio-temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x depth x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/comm.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b66ec4aea213edf4330beda0a8c8b93d6db77a60
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/sync_batchnorm/comm.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def __getstate__(self):
+ return {'master_callback': self._master_callback}
+
+ def __setstate__(self, state):
+ self.__init__(state['master_callback'])
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/replicate.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b97380d9c5fbe75c4b3583d3668ccd6a2848699
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/sync_batchnorm/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File : replicate.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+ 'CallbackContext',
+ 'execute_replication_callbacks',
+ 'DataParallelWithCallback',
+ 'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+ of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+ """
+ Data Parallel with a replication callback.
+
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+ original `replicate` function.
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ # sync_bn.__data_parallel_replicate__ will be invoked.
+ """
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate
diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/unittest.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/unittest.py
new file mode 100644
index 0000000000000000000000000000000000000000..9716d035495097fb086ec050ab0bc9b76b9d28a0
--- /dev/null
+++ b/sadtalker_audio2pose/src/facerender/sync_batchnorm/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File : unittest.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+
+import numpy as np
+from torch.autograd import Variable
+
+
+def as_numpy(v):
+ if isinstance(v, Variable):
+ v = v.data
+ return v.cpu().numpy()
+
+
+class TorchTestCase(unittest.TestCase):
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
+ npa, npb = as_numpy(a), as_numpy(b)
+ self.assertTrue(
+ np.allclose(npa, npb, atol=atol),
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
+ )
diff --git a/sadtalker_audio2pose/src/generate_batch.py b/sadtalker_audio2pose/src/generate_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcaff51276d489aa76c15e4979864a4d4f74aa4
--- /dev/null
+++ b/sadtalker_audio2pose/src/generate_batch.py
@@ -0,0 +1,120 @@
+import os
+
+from tqdm import tqdm
+import torch
+import numpy as np
+import random
+import scipy.io as scio
+import src.utils.audio as audio
+
+def crop_pad_audio(wav, audio_length):
+ if len(wav) > audio_length:
+ wav = wav[:audio_length]
+ elif len(wav) < audio_length:
+ wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
+ return wav
+
+def parse_audio_length(audio_length, sr, fps):
+ bit_per_frames = sr / fps
+
+ num_frames = int(audio_length / bit_per_frames)
+ audio_length = int(num_frames * bit_per_frames)
+
+ return audio_length, num_frames
+
+def generate_blink_seq(num_frames):
+ ratio = np.zeros((num_frames,1))
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = 80
+ if frame_id+start+9<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
+ frame_id = frame_id+start+9
+ else:
+ break
+ return ratio
+
+def generate_blink_seq_randomly(num_frames):
+ ratio = np.zeros((num_frames,1))
+ if num_frames<=20:
+ return ratio
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
+ if frame_id+start+5<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
+ frame_id = frame_id+start+5
+ else:
+ break
+ return ratio
+
+def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
+
+ syncnet_mel_step_size = 16
+ fps = 25
+
+ pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+
+
+ if idlemode:
+ num_frames = int(length_of_audio * 25)
+ indiv_mels = np.zeros((num_frames, 80, 16))
+ else:
+ wav = audio.load_wav(audio_path, 16000)
+ wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
+ wav = crop_pad_audio(wav, wav_length)
+ orig_mel = audio.melspectrogram(wav).T
+ spec = orig_mel.copy() # nframes 80
+ indiv_mels = []
+
+ for i in tqdm(range(num_frames), 'mel:'):
+ start_frame_num = i-2
+ start_idx = int(80. * (start_frame_num / float(fps)))
+ end_idx = start_idx + syncnet_mel_step_size
+ seq = list(range(start_idx, end_idx))
+ seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
+ m = spec[seq, :]
+ indiv_mels.append(m.T)
+ indiv_mels = np.asarray(indiv_mels) # T 80 16
+
+ ratio = generate_blink_seq_randomly(num_frames) # T
+ source_semantics_path = first_coeff_path
+ source_semantics_dict = scio.loadmat(source_semantics_path)
+ ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
+ ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
+
+ if ref_eyeblink_coeff_path is not None:
+ ratio[:num_frames] = 0
+ refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
+ refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
+ refeyeblink_num_frames = refeyeblink_coeff.shape[0]
+ if refeyeblink_num_frames frame_num:
+ new_degree_list = new_degree_list[:frame_num]
+ elif len(new_degree_list) < frame_num:
+ for _ in range(frame_num-len(new_degree_list)):
+ new_degree_list.append(new_degree_list[-1])
+ print(len(new_degree_list))
+ print(frame_num)
+
+ remainder = frame_num%batch_size
+ if remainder!=0:
+ for _ in range(batch_size-remainder):
+ new_degree_list.append(new_degree_list[-1])
+ new_degree_np = np.array(new_degree_list).reshape(batch_size, -1)
+ return new_degree_np
+
diff --git a/sadtalker_audio2pose/src/gradio_demo.py b/sadtalker_audio2pose/src/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a2399fc44704b544ef39bb908d32a21da9fae17
--- /dev/null
+++ b/sadtalker_audio2pose/src/gradio_demo.py
@@ -0,0 +1,170 @@
+import torch, uuid
+import os, sys, shutil, platform
+from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
+from src.utils.preprocess import CropAndExtract
+from src.test_audio2coeff import Audio2Coeff
+from src.facerender.animate import AnimateFromCoeff
+from src.generate_batch import get_data
+from src.generate_facerender_batch import get_facerender_data
+
+from src.utils.init_path import init_path
+
+from pydub import AudioSegment
+
+
+def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
+ mp3_file = AudioSegment.from_file(file=mp3_filename)
+ mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
+
+
+class SadTalker():
+
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
+
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif platform.system() == 'Darwin': # macos
+ device = "mps"
+ else:
+ device = "cpu"
+
+ self.device = device
+
+ os.environ['TORCH_HOME']= checkpoint_path
+
+ self.checkpoint_path = checkpoint_path
+ self.config_path = config_path
+
+
+ def test(self, source_image, driven_audio, preprocess='crop',
+ still_mode=False, use_enhancer=False, batch_size=1, size=256,
+ pose_style = 0,
+ facerender='facevid2vid',
+ exp_scale=1.0,
+ use_ref_video = False,
+ ref_video = None,
+ ref_info = None,
+ use_idle_mode = False,
+ length_of_audio = 0, use_blink=True,
+ result_dir='./results/'):
+
+ self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
+ print(self.sadtalker_paths)
+
+ self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
+ self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
+
+ if facerender == 'facevid2vid' and self.device != 'mps':
+ self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
+ elif facerender == 'pirender' or self.device == 'mps':
+ self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device)
+ facerender = 'pirender'
+ else:
+ raise(RuntimeError('Unknown model: {}'.format(facerender)))
+
+
+ time_tag = str(uuid.uuid4())
+ save_dir = os.path.join(result_dir, time_tag)
+ os.makedirs(save_dir, exist_ok=True)
+
+ input_dir = os.path.join(save_dir, 'input')
+ os.makedirs(input_dir, exist_ok=True)
+
+ print(source_image)
+ pic_path = os.path.join(input_dir, os.path.basename(source_image))
+ shutil.move(source_image, input_dir)
+
+ if driven_audio is not None and os.path.isfile(driven_audio):
+ audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
+
+ #### mp3 to wav
+ if '.mp3' in audio_path:
+ mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
+ audio_path = audio_path.replace('.mp3', '.wav')
+ else:
+ shutil.move(driven_audio, input_dir)
+
+ elif use_idle_mode:
+ audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
+ from pydub import AudioSegment
+ one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
+ one_sec_segment.export(audio_path, format="wav")
+ else:
+ print(use_ref_video, ref_info)
+ assert use_ref_video == True and ref_info == 'all'
+
+ if use_ref_video and ref_info == 'all': # full ref mode
+ ref_video_videoname = os.path.basename(ref_video)
+ audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
+ print('new audiopath:',audio_path)
+ # if ref_video contains audio, set the audio from ref_video.
+ cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
+ os.system(cmd)
+
+ os.makedirs(save_dir, exist_ok=True)
+
+ #crop image and extract 3dmm from image
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+ os.makedirs(first_frame_dir, exist_ok=True)
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
+
+ if first_coeff_path is None:
+ raise AttributeError("No face is detected")
+
+ if use_ref_video:
+ print('using ref video for genreation')
+ ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
+ ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
+ os.makedirs(ref_video_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing pose')
+ ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
+ else:
+ ref_video_coeff_path = None
+
+ if use_ref_video:
+ if ref_info == 'pose':
+ ref_pose_coeff_path = ref_video_coeff_path
+ ref_eyeblink_coeff_path = None
+ elif ref_info == 'blink':
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = ref_video_coeff_path
+ elif ref_info == 'pose+blink':
+ ref_pose_coeff_path = ref_video_coeff_path
+ ref_eyeblink_coeff_path = ref_video_coeff_path
+ elif ref_info == 'all':
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = None
+ else:
+ raise('error in refinfo')
+ else:
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = None
+
+ #audio2ceoff
+ if use_ref_video and ref_info == 'all':
+ coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+ else:
+ batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \
+ idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+
+ #coeff2video
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \
+ preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender)
+ return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
+ video_name = data['video_name']
+ print(f'The generated video is named {video_name} in {save_dir}')
+
+ del self.preprocess_model
+ del self.audio_to_coeff
+ del self.animate_from_coeff
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ import gc; gc.collect()
+
+ return return_path
+
+
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/test_audio2coeff.py b/sadtalker_audio2pose/src/test_audio2coeff.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0f5ca9195bbc980c93fa3e37c6d06cc32953aee
--- /dev/null
+++ b/sadtalker_audio2pose/src/test_audio2coeff.py
@@ -0,0 +1,123 @@
+import os
+import torch
+import numpy as np
+from scipy.io import savemat, loadmat
+from yacs.config import CfgNode as CN
+from scipy.signal import savgol_filter
+
+import safetensors
+import safetensors.torch
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.audio2exp_models.audio2exp import Audio2Exp
+from src.utils.safetensor_helper import load_x_from_safetensor
+
+def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if model is not None:
+ model.load_state_dict(checkpoint['model'])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+
+ return checkpoint['epoch']
+
+class Audio2Coeff():
+
+ def __init__(self, sadtalker_path, device):
+ #load config
+ fcfg_pose = open(sadtalker_path['audio2pose_yaml_path'])
+ cfg_pose = CN.load_cfg(fcfg_pose)
+ cfg_pose.freeze()
+ fcfg_exp = open(sadtalker_path['audio2exp_yaml_path'])
+ cfg_exp = CN.load_cfg(fcfg_exp)
+ cfg_exp.freeze()
+
+ # load audio2pose_model
+ self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device)
+ self.audio2pose_model = self.audio2pose_model.to(device)
+ self.audio2pose_model.eval()
+ for param in self.audio2pose_model.parameters():
+ param.requires_grad = False
+
+ try:
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose'))
+ else:
+ load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device)
+ except:
+ raise Exception("Failed in loading audio2pose_checkpoint")
+
+ # load audio2exp_model
+ netG = SimpleWrapperV2()
+ netG = netG.to(device)
+ for param in netG.parameters():
+ netG.requires_grad = False
+ netG.eval()
+ try:
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp'))
+ else:
+ load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device)
+ except:
+ raise Exception("Failed in loading audio2exp_checkpoint")
+ self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)
+ self.audio2exp_model = self.audio2exp_model.to(device)
+ for param in self.audio2exp_model.parameters():
+ param.requires_grad = False
+ self.audio2exp_model.eval()
+
+ self.device = device
+
+ def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):
+
+ with torch.no_grad():
+ #test
+ results_dict_exp= self.audio2exp_model.test(batch)
+ exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64
+
+ #for class_id in range(1):
+ #class_id = 0#(i+10)%45
+ #class_id = random.randint(0,46) #46 styles can be selected
+ batch['class'] = torch.LongTensor([pose_style]).to(self.device)
+ results_dict_pose = self.audio2pose_model.test(batch)
+ pose_pred = results_dict_pose['pose_pred'] #bs T 6
+
+ pose_len = pose_pred.shape[1]
+ if pose_len<13:
+ pose_len = int((pose_len-1)/2)*2+1
+ pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device)
+ else:
+ pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device)
+
+ coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70
+
+ coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy()
+
+ if ref_pose_coeff_path is not None:
+ coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)
+
+ savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),
+ {'coeff_3dmm': coeffs_pred_numpy})
+
+ return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))
+
+ def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
+ num_frames = coeffs_pred_numpy.shape[0]
+ refpose_coeff_dict = loadmat(ref_pose_coeff_path)
+ refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70]
+ refpose_num_frames = refpose_coeff.shape[0]
+ if refpose_num_frames= 0
+ if hp.symmetric_mels:
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
+ else:
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
+
+def _denormalize(D):
+ if hp.allow_clipping_in_normalization:
+ if hp.symmetric_mels:
+ return (((np.clip(D, -hp.max_abs_value,
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ + hp.min_level_db)
+ else:
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
+
+ if hp.symmetric_mels:
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
+ else:
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
diff --git a/sadtalker_audio2pose/src/utils/croper.py b/sadtalker_audio2pose/src/utils/croper.py
new file mode 100644
index 0000000000000000000000000000000000000000..578372debdb8d2b99fe93d3d2ba2dfacf7cbb0ad
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/croper.py
@@ -0,0 +1,145 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import scipy
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+
+from src.face3d.extract_kp_videos_safe import KeypointExtractor
+from facexlib.alignment import landmark_98_to_68
+
+import numpy as np
+from PIL import Image
+
+class Preprocesser:
+ def __init__(self, device='cuda'):
+ self.predictor = KeypointExtractor(device)
+
+ def get_landmark(self, img_np):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ with torch.no_grad():
+ dets = self.predictor.det_net.detect_faces(img_np, 0.97)
+
+ if len(dets) == 0:
+ return None
+ det = dets[0]
+
+ img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
+ lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ lm[:,0] += int(det[0])
+ lm[:,1] += int(det[1])
+
+ return lm
+
+ def align_face(self, img, lm, output_size=1024):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference
+ x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点
+ qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍
+
+ # Shrink.
+ # 如果计算出的四边形太大了,就按比例缩小它
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+ else:
+ rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ # img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ # if enable_padding and max(pad) > border - 4:
+ # pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # h, w, _ = img.shape
+ # y, x, _ = np.ogrid[:h, :w, :1]
+ # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+ # blur = qsize * 0.02
+ # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ # quad += pad[:2]
+
+ # Transform.
+ quad = (quad + 0.5).flatten()
+ lx = max(min(quad[0], quad[2]), 0)
+ ly = max(min(quad[1], quad[7]), 0)
+ rx = min(max(quad[4], quad[6]), img.size[0])
+ ry = min(max(quad[3], quad[5]), img.size[0])
+
+ # Save aligned image.
+ return rsize, crop, [lx, ly, rx, ry]
+
+ def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
+ # print(img_np_list)
+ img_np = img_np_list[0]
+ lm = self.get_landmark(img_np)
+
+ if lm is None:
+ raise 'can not detect the landmark from source image'
+ rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ for _i in range(len(img_np_list)):
+ _inp = img_np_list[_i]
+ _inp = cv2.resize(_inp, (rsize[0], rsize[1]))
+ _inp = _inp[cly:cry, clx:crx]
+ if not still:
+ _inp = _inp[ly:ry, lx:rx]
+ img_np_list[_i] = _inp
+ return img_np_list, crop, quad
+
diff --git a/sadtalker_audio2pose/src/utils/face_enhancer.py b/sadtalker_audio2pose/src/utils/face_enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2664560a1d7199e81f1a50093f29d02de91d4bcc
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/face_enhancer.py
@@ -0,0 +1,123 @@
+import os
+import torch
+
+from gfpgan import GFPGANer
+
+from tqdm import tqdm
+
+from src.utils.videoio import load_video_to_cv2
+
+import cv2
+
+
+class GeneratorWithLen(object):
+ """ From https://stackoverflow.com/a/7460929 """
+
+ def __init__(self, gen, length):
+ self.gen = gen
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ return self.gen
+
+def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ return list(gen)
+
+def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
+ """ Provide a generator with a __len__ method so that it can passed to functions that
+ call len()"""
+
+ if os.path.isfile(images): # handle video to images
+ # TODO: Create a generator version of load_video_to_cv2
+ images = load_video_to_cv2(images)
+
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ gen_with_len = GeneratorWithLen(gen, len(images))
+ return gen_with_len
+
+def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
+ """ Provide a generator function so that all of the enhanced images don't need
+ to be stored in memory at the same time. This can save tons of RAM compared to
+ the enhancer function. """
+
+ print('face enhancer....')
+ if not isinstance(images, list) and os.path.isfile(images): # handle video to images
+ images = load_video_to_cv2(images)
+
+ # ------------------------ set up GFPGAN restorer ------------------------
+ if method == 'gfpgan':
+ arch = 'clean'
+ channel_multiplier = 2
+ model_name = 'GFPGANv1.4'
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
+ elif method == 'RestoreFormer':
+ arch = 'RestoreFormer'
+ channel_multiplier = 2
+ model_name = 'RestoreFormer'
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
+ elif method == 'codeformer': # TODO:
+ arch = 'CodeFormer'
+ channel_multiplier = 2
+ model_name = 'CodeFormer'
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+ else:
+ raise ValueError(f'Wrong model version {method}.')
+
+
+ # ------------------------ set up background upsampler ------------------------
+ if bg_upsampler == 'realesrgan':
+ if not torch.cuda.is_available(): # CPU
+ import warnings
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
+ 'If you really want to use it, please modify the corresponding codes.')
+ bg_upsampler = None
+ else:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ bg_upsampler = RealESRGANer(
+ scale=2,
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
+ model=model,
+ tile=400,
+ tile_pad=10,
+ pre_pad=0,
+ half=True) # need to set False in CPU mode
+ else:
+ bg_upsampler = None
+
+ # determine model paths
+ model_path = os.path.join('gfpgan/weights', model_name + '.pth')
+
+ if not os.path.isfile(model_path):
+ model_path = os.path.join('checkpoints', model_name + '.pth')
+
+ if not os.path.isfile(model_path):
+ # download pre-trained models from url
+ model_path = url
+
+ restorer = GFPGANer(
+ model_path=model_path,
+ upscale=2,
+ arch=arch,
+ channel_multiplier=channel_multiplier,
+ bg_upsampler=bg_upsampler)
+
+ # ------------------------ restore ------------------------
+ for idx in tqdm(range(len(images)), 'Face Enhancer:'):
+
+ img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
+
+ # restore faces and background if necessary
+ cropped_faces, restored_faces, r_img = restorer.enhance(
+ img,
+ has_aligned=False,
+ only_center_face=False,
+ paste_back=True)
+
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
+ yield r_img
diff --git a/sadtalker_audio2pose/src/utils/flow_util.py b/sadtalker_audio2pose/src/utils/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f25046bab67cc8fbbb59efd02f48d7b6f22fc580
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/flow_util.py
@@ -0,0 +1,221 @@
+import torch
+import sys
+
+
+def convert_flow_to_deformation(flow):
+ r"""convert flow fields to deformations.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ deformation (tensor): The deformation used for warpping
+ """
+ b,c,h,w = flow.shape
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
+ grid = make_coordinate_grid(flow)
+ # print(grid.shape, flow_norm.shape)
+ deformation = grid + flow_norm.permute(0,2,3,1)
+ return deformation
+
+def make_coordinate_grid(flow):
+ r"""obtain coordinate grid with the same size as the flow filed.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ grid (tensor): The grid with the same size as the input flow
+ """
+ b,c,h,w = flow.shape
+
+ x = torch.arange(w).to(flow)
+ y = torch.arange(h).to(flow)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ meshed = meshed.expand(b, -1, -1, -1)
+ return meshed
+
+
+def warp_image(source_image, deformation):
+ r"""warp the input image according to the deformation
+
+ Args:
+ source_image (tensor): source images to be warpped
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
+ Returns:
+ output (tensor): the warpped images
+ """
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = source_image.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
+ deformation = deformation.permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(source_image, deformation)
+
+
+
+# visualize flow
+import numpy as np
+
+__all__ = ['load_flow', 'save_flow', 'vis_flow']
+
+
+def load_flow(path):
+ with open(path, 'rb') as f:
+ magic = float(np.fromfile(f, np.float32, count=1)[0])
+ if magic == 202021.25:
+ w, h = np.fromfile(f, np.int32, count=1)[0], np.fromfile(f, np.int32, count=1)[0]
+ data = np.fromfile(f, np.float32, count=h * w * 2)
+ data.resize((h, w, 2))
+ return data
+ return None
+
+
+def save_flow(path, flow):
+ magic = np.array([202021.25], np.float32)
+ h, w = flow.shape[:2]
+ h, w = np.array([h], np.int32), np.array([w], np.int32)
+
+ with open(path, 'wb') as f:
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ flow.tofile(f)
+
+
+
+def makeColorwheel():
+ # color encoding scheme
+
+ # adapted from the color circle idea described at
+ # http://members.shaw.ca/quadibloc/other/colint.htm
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3]) # r g b
+
+ col = 0
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY, 1) / RY)
+ col += RY
+
+ # YG
+ colorwheel[col:YG + col, 0] = 255 - np.floor(255 * np.arange(0, YG, 1) / YG)
+ colorwheel[col:YG + col, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:GC + col, 1] = 255
+ colorwheel[col:GC + col, 2] = np.floor(255 * np.arange(0, GC, 1) / GC)
+ col += GC
+
+ # CB
+ colorwheel[col:CB + col, 1] = 255 - np.floor(255 * np.arange(0, CB, 1) / CB)
+ colorwheel[col:CB + col, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:BM + col, 2] = 255
+ colorwheel[col:BM + col, 0] = np.floor(255 * np.arange(0, BM, 1) / BM)
+ col += BM
+
+ # MR
+ colorwheel[col:MR + col, 2] = 255 - np.floor(255 * np.arange(0, MR, 1) / MR)
+ colorwheel[col:MR + col, 0] = 255
+ return colorwheel
+
+
+def computeColor(u, v):
+ colorwheel = makeColorwheel()
+ nan_u = np.isnan(u)
+ nan_v = np.isnan(v)
+ nan_u = np.where(nan_u)
+ nan_v = np.where(nan_v)
+
+ u[nan_u] = 0
+ u[nan_v] = 0
+ v[nan_u] = 0
+ v[nan_v] = 0
+
+ ncols = colorwheel.shape[0]
+ radius = np.sqrt(u ** 2 + v ** 2)
+ a = np.arctan2(-v, -u) / np.pi
+ fk = (a + 1) / 2 * (ncols - 1) # -1~1 maped to 1~ncols
+ k0 = fk.astype(np.uint8) # 1, 2, ..., ncols
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+
+ img = np.empty([k1.shape[0], k1.shape[1], 3])
+ ncolors = colorwheel.shape[1]
+ for i in range(ncolors):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255
+ col1 = tmp[k1] / 255
+ col = (1 - f) * col0 + f * col1
+ idx = radius <= 1
+ col[idx] = 1 - radius[idx] * (1 - col[idx]) # increase saturation with radius
+ col[~idx] *= 0.75 # out of range
+ img[:, :, 2 - i] = np.floor(255 * col).astype(np.uint8)
+
+ return img.astype(np.uint8)
+
+
+def vis_flow(flow):
+ eps = sys.float_info.epsilon
+ UNKNOWN_FLOW_THRESH = 1e9
+ UNKNOWN_FLOW = 1e10
+
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999
+ maxv = -999
+
+ minu = 999
+ minv = 999
+
+ maxrad = -1
+ # fix unknown flow
+ greater_u = np.where(u > UNKNOWN_FLOW_THRESH)
+ greater_v = np.where(v > UNKNOWN_FLOW_THRESH)
+ u[greater_u] = 0
+ u[greater_v] = 0
+ v[greater_u] = 0
+ v[greater_v] = 0
+
+ maxu = max([maxu, np.amax(u)])
+ minu = min([minu, np.amin(u)])
+
+ maxv = max([maxv, np.amax(v)])
+ minv = min([minv, np.amin(v)])
+ rad = np.sqrt(np.multiply(u, u) + np.multiply(v, v))
+ maxrad = max([maxrad, np.amax(rad)])
+ # print('max flow: %.4f flow range: u = %.3f .. %.3f; v = %.3f .. %.3f\n' % (maxrad, minu, maxu, minv, maxv))
+
+ u = u / (maxrad + eps)
+ v = v / (maxrad + eps)
+ img = computeColor(u, v)
+ return img[:, :, [2, 1, 0]]
+
+
+def test_visualize_flow():
+ flow = load_flow('out.flo')
+ img = vis_flow(flow)
+
+ import cv2
+ cv2.imwrite("img.png", img)
diff --git a/sadtalker_audio2pose/src/utils/hparams.py b/sadtalker_audio2pose/src/utils/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..83c312d767c35b9adc988157243efc02129fdb84
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/hparams.py
@@ -0,0 +1,160 @@
+from glob import glob
+import os
+
+class HParams:
+ def __init__(self, **kwargs):
+ self.data = {}
+
+ for key, value in kwargs.items():
+ self.data[key] = value
+
+ def __getattr__(self, key):
+ if key not in self.data:
+ raise AttributeError("'HParams' object has no attribute %s" % key)
+ return self.data[key]
+
+ def set_hparam(self, key, value):
+ self.data[key] = value
+
+
+# Default hyperparameters
+hparams = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=16,
+ initial_learning_rate=1e-4,
+ nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=20,
+ checkpoint_interval=3000,
+ eval_interval=3000,
+ writer_interval=300,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=1000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+
+# Default hyperparameters
+hparamsdebug = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=2,
+ initial_learning_rate=1e-3,
+ nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=0,
+ checkpoint_interval=10000,
+ eval_interval=10,
+ writer_interval=5,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=10000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+def hparams_debug_string():
+ values = hparams.values()
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
+ return "Hyperparameters:\n" + "\n".join(hp)
diff --git a/sadtalker_audio2pose/src/utils/init_path.py b/sadtalker_audio2pose/src/utils/init_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..65239fe3281798b2472f7ca0557a96157d9de930
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/init_path.py
@@ -0,0 +1,49 @@
+import os
+import glob
+
+def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):
+
+ if old_version:
+ #### load all the checkpoint of `pth`
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ use_safetensor = False
+ elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):
+ print('using safetensor as default')
+ sadtalker_paths = {
+ "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),
+ }
+ use_safetensor = True
+ else:
+ print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!")
+ use_safetensor = False
+
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'
+ sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')
+ sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')
+ sadtalker_paths['pirender_yaml_path'] = os.path.join(config_dir, 'facerender_pirender.yaml')
+ sadtalker_paths['pirender_checkpoint'] = os.path.join(checkpoint_dir, 'epoch_00190_iteration_000400000_checkpoint.pt')
+ sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')
+
+ if 'full' in preprocess:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')
+ else:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')
+
+ return sadtalker_paths
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/utils/model2safetensor.py b/sadtalker_audio2pose/src/utils/model2safetensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b76e3d67a06fdbf6646590d44b8c225bc73d79
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/model2safetensor.py
@@ -0,0 +1,141 @@
+import torch
+import yaml
+import os
+
+import safetensors
+from safetensors.torch import save_file
+from yacs.config import CfgNode as CN
+import sys
+
+sys.path.append('/apdcephfs/private_shadowcun/SadTalker')
+
+from src.face3d.models import networks
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.test_audio2coeff import load_cpk
+
+size = 256
+############ face vid2vid
+config_path = os.path.join('src', 'config', 'facerender.yaml')
+current_root_path = '.'
+
+path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
+net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')
+checkpoint = torch.load(path_of_net_recon_model, map_location='cpu')
+net_recon.load_state_dict(checkpoint['net_recon'])
+
+with open(config_path) as f:
+ config = yaml.safe_load(f)
+
+generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+mapping = MappingNet(**config['model_params']['mapping_params'])
+
+def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+
+def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'
+load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+
+wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
+
+audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
+audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
+
+audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
+audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
+
+fcfg_pose = open(audio2pose_yaml_path)
+cfg_pose = CN.load_cfg(fcfg_pose)
+cfg_pose.freeze()
+audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)
+audio2pose_model.eval()
+load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')
+
+# load audio2exp_model
+netG = SimpleWrapperV2()
+netG.eval()
+load_cpk(audio2exp_checkpoint, model=netG, device='cpu')
+
+class SadTalker(torch.nn.Module):
+ def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):
+ super(SadTalker, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.audio2exp = netG
+ self.audio2pose = audio2pose
+ self.face_3drecon = face_3drecon
+
+
+model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)
+
+# here, we want to convert it to safetensor
+save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors")
+
+### test
+load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None)
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/utils/paste_pic.py b/sadtalker_audio2pose/src/utils/paste_pic.py
new file mode 100644
index 0000000000000000000000000000000000000000..4da8952e6933698fec6c7cf35042cb5b1f0dcba5
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/paste_pic.py
@@ -0,0 +1,69 @@
+import cv2, os
+import numpy as np
+from tqdm import tqdm
+import uuid
+
+from src.utils.videoio import save_video_with_watermark
+
+def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ video_stream = cv2.VideoCapture(video_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ crop_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ crop_frames.append(frame)
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_path = str(uuid.uuid4())+'.mp4'
+ out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ for crop_frame in tqdm(crop_frames, 'seamlessClone:'):
+ p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1))
+
+ mask = 255*np.ones(p.shape, p.dtype)
+ location = ((ox1+ox2) // 2, (oy1+oy2) // 2)
+ gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)
+ out_tmp.write(gen_img)
+
+ out_tmp.release()
+
+ save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False)
+ os.remove(tmp_path)
diff --git a/sadtalker_audio2pose/src/utils/preprocess.py b/sadtalker_audio2pose/src/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4956c00d273467f8a0c020312401158b06c4fecd
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/preprocess.py
@@ -0,0 +1,170 @@
+import numpy as np
+import cv2, os, sys, torch
+from tqdm import tqdm
+from PIL import Image
+
+# 3dmm extraction
+import safetensors
+import safetensors.torch
+from src.face3d.util.preprocess import align_img
+from src.face3d.util.load_mats import load_lm3d
+from src.face3d.models import networks
+
+from scipy.io import loadmat, savemat
+from src.utils.croper import Preprocesser
+
+
+import warnings
+
+from src.utils.safetensor_helper import load_x_from_safetensor
+warnings.filterwarnings("ignore")
+
+def split_coeff(coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+
+
+class CropAndExtract():
+ def __init__(self, sadtalker_path, device):
+
+ self.propress = Preprocesser(device)
+ self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
+
+ if sadtalker_path['use_safetensor']:
+ checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))
+ else:
+ checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))
+ self.net_recon.load_state_dict(checkpoint['net_recon'])
+
+ self.net_recon.eval()
+ self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])
+ self.device = device
+
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256):
+
+ pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
+
+ landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
+ coeff_path = os.path.join(save_dir, pic_name+'.mat')
+ png_path = os.path.join(save_dir, pic_name+'.png')
+
+ #load input
+ if not os.path.isfile(input_path):
+ raise ValueError('input_path must be a valid path to video/image file')
+ elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_frames = [cv2.imread(input_path)]
+ fps = 25
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(frame)
+ if source_image_flag:
+ break
+
+ x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
+
+ #### crop images as the
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ else: # resize mode
+ oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1]
+ crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+ frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+ if len(frames_pil) == 0:
+ print('No face is detected in the input file')
+ return None, None
+
+ # save crop info
+ for frame in frames_pil:
+ cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+ # 2. get the landmark according to the detected face.
+ if not os.path.isfile(landmarks_path):
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
+ else:
+ print(' Using saved landmarks.')
+ lm = np.loadtxt(landmarks_path).astype(np.float32)
+ lm = lm.reshape([len(x_full_frames), -1, 2])
+
+ if not os.path.isfile(coeff_path):
+ # load 3dmm paramter generator from Deep3DFaceRecon_pytorch
+ video_coeffs, full_coeffs = [], []
+ for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):
+ frame = frames_pil[idx]
+ W,H = frame.size
+ lm1 = lm[idx].reshape([-1, 2])
+
+ if np.mean(lm1) == -1:
+ lm1 = (self.lm3d_std[:, :2]+1)/2.
+ lm1 = np.concatenate(
+ [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+ )
+ else:
+ lm1[:, -1] = H - 1 - lm1[:, -1]
+
+ trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+
+ trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32)
+ im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)
+
+ with torch.no_grad():
+ full_coeff = self.net_recon(im_t)
+ coeffs = split_coeff(full_coeff)
+
+ pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
+
+ pred_coeff = np.concatenate([
+ pred_coeff['exp'],
+ pred_coeff['angle'],
+ pred_coeff['trans'],
+ trans_params_m[2:][None],
+ ], 1)
+ video_coeffs.append(pred_coeff)
+ full_coeffs.append(full_coeff.cpu().numpy())
+
+ semantic_npy = np.array(video_coeffs)[:,0]
+
+ savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params})
+
+ return coeff_path, png_path, crop_info
diff --git a/sadtalker_audio2pose/src/utils/preprocess_fromvideo.py b/sadtalker_audio2pose/src/utils/preprocess_fromvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c4aad3eef558934d9974c3170b658cff88f568c
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/preprocess_fromvideo.py
@@ -0,0 +1,195 @@
+import numpy as np
+import cv2, os, sys, torch
+from tqdm import tqdm
+from PIL import Image
+
+# 3dmm extraction
+import safetensors
+import safetensors.torch
+from src.face3d.util.preprocess import align_img
+from src.face3d.util.load_mats import load_lm3d
+from src.face3d.models import networks
+
+from scipy.io import loadmat, savemat
+from src.utils.croper import Preprocesser
+
+
+import warnings
+
+from src.utils.safetensor_helper import load_x_from_safetensor
+warnings.filterwarnings("ignore")
+
+
+def smooth_3dmm_params(params, window_size=5):
+ # 创建一个新的数组来存储平滑后的参数
+ smoothed_params = np.zeros_like(params)
+
+ # 对每个参数进行平滑处理
+ for i in range(params.shape[1]):
+
+ # 在参数周围创建一个滑动窗口
+ window = np.ones(int(window_size))/float(window_size)
+ smoothed_param = np.convolve(params[:, i], window, 'same')
+
+ # 将平滑后的参数存储在新数组中
+ smoothed_params[:, i] = smoothed_param
+
+ return smoothed_params
+
+
+
+def split_coeff(coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+
+
+class CropAndExtract():
+ def __init__(self, sadtalker_path, device):
+
+ self.propress = Preprocesser(device)
+ self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
+
+ if sadtalker_path['use_safetensor']:
+ checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))
+ else:
+ checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))
+ self.net_recon.load_state_dict(checkpoint['net_recon'])
+
+ self.net_recon.eval()
+ self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])
+ self.device = device
+
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256, if_smooth=False):
+
+ pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
+
+ landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
+ coeff_path = os.path.join(save_dir, pic_name+'.mat')
+ png_path = os.path.join(save_dir, pic_name+'.png')
+
+ #load input
+ if not os.path.isfile(input_path):
+ raise ValueError('input_path must be a valid path to video/image file')
+ elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_frames = [cv2.imread(input_path)]
+ fps = 25
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(frame)
+ if source_image_flag:
+ break
+
+ x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
+
+ # print(x_full_frames)
+
+ #### crop images as the
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ else: # resize mode
+ oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1]
+ crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+ frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+ if len(frames_pil) == 0:
+ print('No face is detected in the input file')
+ return None, None
+
+ # save crop info
+ for frame in frames_pil:
+ cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+ # 2. get the landmark according to the detected face.
+ if not os.path.isfile(landmarks_path):
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
+ else:
+ print(' Using saved landmarks.')
+ lm = np.loadtxt(landmarks_path).astype(np.float32)
+ lm = lm.reshape([len(x_full_frames), -1, 2])
+
+ if not os.path.isfile(coeff_path):
+ # load 3dmm paramter generator from Deep3DFaceRecon_pytorch
+ video_coeffs, full_coeffs = [], []
+ for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):
+ frame = frames_pil[idx]
+ W,H = frame.size
+ lm1 = lm[idx].reshape([-1, 2])
+
+ if np.mean(lm1) == -1:
+ lm1 = (self.lm3d_std[:, :2]+1)/2.
+ lm1 = np.concatenate(
+ [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+ )
+ else:
+ lm1[:, -1] = H - 1 - lm1[:, -1]
+
+ trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+
+ trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32)
+ im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)
+
+ with torch.no_grad():
+ full_coeff = self.net_recon(im_t)
+ coeffs = split_coeff(full_coeff)
+
+ pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
+
+ pred_coeff = np.concatenate([
+ pred_coeff['exp'],
+ pred_coeff['angle'],
+ pred_coeff['trans'],
+ # trans_params_m[2:][None],
+ ], 1)
+ video_coeffs.append(pred_coeff)
+ full_coeffs.append(full_coeff.cpu().numpy())
+
+ semantic_npy = np.array(video_coeffs)[:,0]
+
+ if if_smooth:
+ # pass
+ semantic_npy[:, -6:] = smooth_3dmm_params(semantic_npy[:, -6:], window_size=10)
+
+ savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params})
+
+ return coeff_path, png_path, crop_info
diff --git a/sadtalker_audio2pose/src/utils/safetensor_helper.py b/sadtalker_audio2pose/src/utils/safetensor_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..164ed9621eba24e0b3050ca663fcb60123517158
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/safetensor_helper.py
@@ -0,0 +1,8 @@
+
+
+def load_x_from_safetensor(checkpoint, key):
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if key in k:
+ x_generator[k.replace(key+'.', '')] = v
+ return x_generator
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/utils/text2speech.py b/sadtalker_audio2pose/src/utils/text2speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0fe21daf74fcd01767b17378b7076c9dd424248
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/text2speech.py
@@ -0,0 +1,20 @@
+import os
+import tempfile
+from TTS.api import TTS
+
+
+class TTSTalker():
+ def __init__(self) -> None:
+ model_name = TTS.list_models()[0]
+ self.tts = TTS(model_name)
+
+ def test(self, text, language='en'):
+
+ tempf = tempfile.NamedTemporaryFile(
+ delete = False,
+ suffix = ('.'+'wav'),
+ )
+
+ self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name)
+
+ return tempf.name
\ No newline at end of file
diff --git a/sadtalker_audio2pose/src/utils/videoio.py b/sadtalker_audio2pose/src/utils/videoio.py
new file mode 100644
index 0000000000000000000000000000000000000000..d604ae5b098006f3e59cf3c0133779ffd1cc9d5a
--- /dev/null
+++ b/sadtalker_audio2pose/src/utils/videoio.py
@@ -0,0 +1,41 @@
+import shutil
+import uuid
+
+import os
+
+import cv2
+
+def load_video_to_cv2(input_path):
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ return full_frames
+
+def save_video_with_watermark(video, audio, save_path, watermark=False):
+ temp_file = str(uuid.uuid4())+'.mp4'
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec mpeg4 "%s"' % (video, audio, temp_file)
+ os.system(cmd)
+
+ if watermark is False:
+ shutil.move(temp_file, save_path)
+ else:
+ # watermark
+ try:
+ ##### check if stable-diffusion-webui
+ import webui
+ from modules import paths
+ watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
+ except:
+ # get the root path of sadtalker.
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
+
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
+ os.system(cmd)
+ os.remove(temp_file)
\ No newline at end of file
diff --git a/sadtalker_video2pose/.DS_Store b/sadtalker_video2pose/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..9b7747e8985cba181c6477ae341433be1ca71030
Binary files /dev/null and b/sadtalker_video2pose/.DS_Store differ
diff --git a/sadtalker_video2pose/inference.py b/sadtalker_video2pose/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..7167fdb44be0261ed35b2fa25f39e327190c62bf
--- /dev/null
+++ b/sadtalker_video2pose/inference.py
@@ -0,0 +1,170 @@
+from glob import glob
+import shutil
+import torch
+from time import strftime
+import os, sys, time
+from argparse import ArgumentParser
+import platform
+import scipy
+import numpy as np
+
+# from src.utils.preprocess import CropAndExtract
+from src.utils.preprocess_fromvideo import CropAndExtract
+from src.test_audio2coeff import Audio2Coeff
+from src.facerender.animate import AnimateFromCoeff
+from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
+from src.generate_batch import get_data
+from src.generate_facerender_batch import get_facerender_data
+from src.utils.init_path import init_path
+
+
+def main(args):
+ #torch.backends.cudnn.enabled = False
+
+
+
+ # args.facerender = 'pirender'
+
+
+
+ pic_path = args.source_image
+ # audio_path = args.driven_audio
+ save_dir = args.result_dir
+ os.makedirs(save_dir, exist_ok=True)
+ pose_style = args.pose_style
+ device = args.device
+ batch_size = args.batch_size
+ input_yaw_list = args.input_yaw
+ input_pitch_list = args.input_pitch
+ input_roll_list = args.input_roll
+ ref_eyeblink = args.ref_eyeblink
+ ref_pose = args.ref_pose
+
+ current_root_path = os.path.split(sys.argv[0])[0]
+
+ sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
+
+ #init model
+ preprocess_model = CropAndExtract(sadtalker_paths, device)
+
+ audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
+
+ if args.facerender == 'facevid2vid':
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
+ elif args.facerender == 'pirender':
+ animate_from_coeff = AnimateFromCoeff_PIRender(sadtalker_paths, device)
+ else:
+ raise(RuntimeError('Unknown model: {}'.format(args.facerender)))
+
+ #crop image and extract 3dmm from image
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+ os.makedirs(first_frame_dir, exist_ok=True)
+ print('3DMM Extraction for source image')
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
+ source_image_flag=True, pic_size=args.size)
+ if first_coeff_path is None:
+ print("Can't get the coeffs of the input")
+ return
+
+ if ref_eyeblink is not None:
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
+ ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing eye blinking')
+ ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
+ else:
+ ref_eyeblink_coeff_path=None
+
+ if ref_pose is not None:
+ if ref_pose == ref_eyeblink:
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
+ else:
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
+ ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing pose')
+ # print(ref_pose)
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False, if_smooth=True)
+ else:
+ ref_pose_coeff_path=None
+
+ # #audio2ceoff
+ # batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
+ # coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+
+ # print(ref_pose_coeff_path)
+ # print(coeff_path)
+
+ # coeff_pred_video = scipy.io.loadmat(ref_pose_coeff_path)['coeff_3dmm']
+ # coeff_pred = scipy.io.loadmat(coeff_path)['coeff_3dmm']
+
+ # print(coeff_pred_video.shape)
+ # print(coeff_pred.shape)
+
+ coeff_path = ref_pose_coeff_path
+ # coeff_path = smooth_3dmm_params(ref_pose_coeff_path, window_size=3)
+
+
+
+ # assert False
+
+ # 3dface render
+ if args.face3dvis:
+ from src.face3d.visualize_fromvideo import gen_composed_video
+ gen_composed_video(args, device, first_coeff_path, coeff_path, \
+ os.path.join(save_dir, '3dface.mp4'), os.path.join(save_dir, 'landmarks.mp4'), crop_info, extended_crop= True if 'ext' in args.preprocess else False )
+ return
+
+
+if __name__ == '__main__':
+
+ parser = ArgumentParser()
+ # parser.add_argument("--driven_audio", default='./sadtalker_video2pose/dummy/bus_chinese.wav', help="path to driven audio")
+ parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image")
+ parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
+ parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
+ parser.add_argument("--checkpoint_dir", default='./ckpts/sad_talker', help="path to output")
+ parser.add_argument("--result_dir", default='./results', help="path to output")
+ parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
+ parser.add_argument("--batch_size", type=int, default=1, help="the batch size of facerender")
+ parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
+ parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
+ parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
+ parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
+ parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
+ parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
+ parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
+ parser.add_argument("--cpu", dest="cpu", action="store_true")
+ parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
+ parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
+ parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
+ parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
+ parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
+ parser.add_argument("--facerender", default='facevid2vid', choices=['pirender', 'facevid2vid'] )
+
+
+ # net structure and parameters
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
+ parser.add_argument('--init_path', type=str, default=None, help='Useless')
+ parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
+ parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talker/BFM_Fitting/')
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+ # default renderer parameters
+ parser.add_argument('--focal', type=float, default=1015.)
+ parser.add_argument('--center', type=float, default=112.)
+ parser.add_argument('--camera_d', type=float, default=10.)
+ parser.add_argument('--z_near', type=float, default=5.)
+ parser.add_argument('--z_far', type=float, default=15.)
+
+ args = parser.parse_args()
+
+ if torch.cuda.is_available() and not args.cpu:
+ args.device = "cuda"
+ elif platform.system() == 'Darwin' and args.facerender == 'pirender': # macos
+ args.device = "mps"
+ else:
+ args.device = "cpu"
+
+ main(args)
+
diff --git a/sadtalker_video2pose/src/.DS_Store b/sadtalker_video2pose/src/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..0f8fa60ef73513c0a5ddb5161310a66031c28262
Binary files /dev/null and b/sadtalker_video2pose/src/.DS_Store differ
diff --git a/sadtalker_video2pose/src/audio2exp_models/audio2exp.py b/sadtalker_video2pose/src/audio2exp_models/audio2exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1062ab6684df01e0b3c48b6b577cc8df0503c91
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2exp_models/audio2exp.py
@@ -0,0 +1,41 @@
+from tqdm import tqdm
+import torch
+from torch import nn
+
+
+class Audio2Exp(nn.Module):
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
+ super(Audio2Exp, self).__init__()
+ self.cfg = cfg
+ self.device = device
+ self.netG = netG.to(device)
+
+ def test(self, batch):
+
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
+ bs = mel_input.shape[0]
+ T = mel_input.shape[1]
+
+ exp_coeff_pred = []
+
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
+
+ current_mel_input = mel_input[:,i:i+10]
+
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
+ ref = batch['ref'][:, :, :64][:, i:i+10]
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
+
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
+
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
+
+ exp_coeff_pred += [curr_exp_coeff_pred]
+
+ # BS x T x 64
+ results_dict = {
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
+ }
+ return results_dict
+
+
diff --git a/sadtalker_video2pose/src/audio2exp_models/networks.py b/sadtalker_video2pose/src/audio2exp_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd77a2f48d7c00ce85fe2eefe3a3e820730fbb74
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2exp_models/networks.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+ self.use_act = use_act
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+
+ if self.use_act:
+ return self.act(out)
+ else:
+ return out
+
+class SimpleWrapperV2(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
+ )
+
+ #### load the pre-trained audio_encoder
+ #self.audio_encoder = self.audio_encoder.to(device)
+ '''
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
+ state_dict = self.audio_encoder.state_dict()
+
+ for k,v in wav2lip_state_dict.items():
+ if 'audio_encoder' in k:
+ print('init:', k)
+ state_dict[k.replace('module.audio_encoder.', '')] = v
+ self.audio_encoder.load_state_dict(state_dict)
+ '''
+
+ self.mapping1 = nn.Linear(512+64+1, 64)
+ #self.mapping2 = nn.Linear(30, 64)
+ #nn.init.constant_(self.mapping1.weight, 0.)
+ nn.init.constant_(self.mapping1.bias, 0.)
+
+ def forward(self, x, ref, ratio):
+ x = self.audio_encoder(x).view(x.size(0), -1)
+ ref_reshape = ref.reshape(x.size(0), -1)
+ ratio = ratio.reshape(x.size(0), -1)
+
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
+ return out
diff --git a/sadtalker_video2pose/src/audio2pose_models/audio2pose.py b/sadtalker_video2pose/src/audio2pose_models/audio2pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..53883adc508037294ba664d05d34e5459f1879f8
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2pose_models/audio2pose.py
@@ -0,0 +1,94 @@
+import torch
+from torch import nn
+from src.audio2pose_models.cvae import CVAE
+from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
+from src.audio2pose_models.audio_encoder import AudioEncoder
+
+class Audio2Pose(nn.Module):
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
+ super().__init__()
+ self.cfg = cfg
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
+ self.device = device
+
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
+ self.audio_encoder.eval()
+ for param in self.audio_encoder.parameters():
+ param.requires_grad = False
+
+ self.netG = CVAE(cfg)
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
+
+
+ def forward(self, x):
+
+ batch = {}
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
+
+ # forward
+ audio_emb_list = []
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
+ batch['audio_emb'] = audio_emb
+ batch = self.netG(batch)
+
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
+
+ batch['pose_pred'] = pose_pred
+ batch['pose_gt'] = pose_gt
+
+ return batch
+
+ def test(self, x):
+
+ batch = {}
+ ref = x['ref'] #bs 1 70
+ batch['ref'] = x['ref'][:,0,-6:]
+ batch['class'] = x['class']
+ bs = ref.shape[0]
+
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
+ num_frames = x['num_frames']
+ num_frames = int(num_frames) - 1
+
+ #
+ div = num_frames//self.seq_len
+ re = num_frames%self.seq_len
+ audio_emb_list = []
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
+ device=batch['ref'].device)]
+
+ for i in range(div):
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
+ batch['z'] = z
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
+ batch['audio_emb'] = audio_emb
+ batch = self.netG.test(batch)
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
+
+ if re != 0:
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
+ batch['z'] = z
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
+ if audio_emb.shape[1] != self.seq_len:
+ pad_dim = self.seq_len-audio_emb.shape[1]
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
+ batch['audio_emb'] = audio_emb
+ batch = self.netG.test(batch)
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
+
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
+ batch['pose_motion_pred'] = pose_motion_pred
+
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
+
+ batch['pose_pred'] = pose_pred
+ return batch
diff --git a/sadtalker_video2pose/src/audio2pose_models/audio_encoder.py b/sadtalker_video2pose/src/audio2pose_models/audio_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0c165afbc25910cb66828d8676973fe727cb3a3
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2pose_models/audio_encoder.py
@@ -0,0 +1,64 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+class AudioEncoder(nn.Module):
+ def __init__(self, wav2lip_checkpoint, device):
+ super(AudioEncoder, self).__init__()
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
+ # state_dict = self.audio_encoder.state_dict()
+
+ # for k,v in wav2lip_state_dict.items():
+ # if 'audio_encoder' in k:
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
+ # self.audio_encoder.load_state_dict(state_dict)
+
+
+ def forward(self, audio_sequences):
+ # audio_sequences = (B, T, 1, 80, 16)
+ B = audio_sequences.size(0)
+
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+ dim = audio_embedding.shape[1]
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
+
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
diff --git a/sadtalker_video2pose/src/audio2pose_models/cvae.py b/sadtalker_video2pose/src/audio2pose_models/cvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..407b78894cde564dd3f2819772a84e8bb1de251d
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2pose_models/cvae.py
@@ -0,0 +1,149 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from src.audio2pose_models.res_unet import ResUnet
+
+def class2onehot(idx, class_num):
+
+ assert torch.max(idx).item() < class_num
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
+ onehot.scatter_(1, idx, 1)
+ return onehot
+
+class CVAE(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
+ num_classes = cfg.DATASET.NUM_CLASSES
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
+
+ self.latent_size = latent_size
+
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len)
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len)
+ def reparameterize(self, mu, logvar):
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return mu + eps * std
+
+ def forward(self, batch):
+ batch = self.encoder(batch)
+ mu = batch['mu']
+ logvar = batch['logvar']
+ z = self.reparameterize(mu, logvar)
+ batch['z'] = z
+ return self.decoder(batch)
+
+ def test(self, batch):
+ '''
+ class_id = batch['class']
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
+ batch['z'] = z
+ '''
+ return self.decoder(batch)
+
+class ENCODER(nn.Module):
+ def __init__(self, layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len):
+ super().__init__()
+
+ self.resunet = ResUnet()
+ self.num_classes = num_classes
+ self.seq_len = seq_len
+
+ self.MLP = nn.Sequential()
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
+ self.MLP.add_module(
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+ def forward(self, batch):
+ class_id = batch['class']
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
+ ref = batch['ref'] #bs 6
+ bs = pose_motion_gt.shape[0]
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
+
+ #pose encode
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
+
+ #audio mapping
+ print(audio_in.shape)
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
+ audio_out = audio_out.reshape(bs, -1)
+
+ class_bias = self.classbias[class_id] #bs latent_size
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
+ x_out = self.MLP(x_in)
+
+ mu = self.linear_means(x_out)
+ logvar = self.linear_means(x_out) #bs latent_size
+
+ batch.update({'mu':mu, 'logvar':logvar})
+ return batch
+
+class DECODER(nn.Module):
+ def __init__(self, layer_sizes, latent_size, num_classes,
+ audio_emb_in_size, audio_emb_out_size, seq_len):
+ super().__init__()
+
+ self.resunet = ResUnet()
+ self.num_classes = num_classes
+ self.seq_len = seq_len
+
+ self.MLP = nn.Sequential()
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
+ self.MLP.add_module(
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+ if i+1 < len(layer_sizes):
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+ else:
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
+
+ self.pose_linear = nn.Linear(6, 6)
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+ def forward(self, batch):
+
+ z = batch['z'] #bs latent_size
+ bs = z.shape[0]
+ class_id = batch['class']
+ ref = batch['ref'] #bs 6
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
+ #print('audio_in: ', audio_in[:, :, :10])
+
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
+ #print('audio_out: ', audio_out[:, :, :10])
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
+ class_bias = self.classbias[class_id] #bs latent_size
+
+ z = z + class_bias
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
+ x_out = x_out.reshape((bs, self.seq_len, -1))
+
+ #print('x_out: ', x_out)
+
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
+
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
+
+ batch.update({'pose_motion_pred':pose_motion_pred})
+ return batch
diff --git a/sadtalker_video2pose/src/audio2pose_models/discriminator.py b/sadtalker_video2pose/src/audio2pose_models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8ed6e36708d4a70227ff90109f56c6f73a17d2
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2pose_models/discriminator.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class ConvNormRelu(nn.Module):
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
+ super().__init__()
+ if kernel_size is None:
+ if downsample:
+ kernel_size, stride, padding = 4, 2, 1
+ else:
+ kernel_size, stride, padding = 3, 1, 1
+
+ if conv_type == '2d':
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=False,
+ )
+ if norm == 'BN':
+ self.norm = nn.BatchNorm2d(out_channels)
+ elif norm == 'IN':
+ self.norm = nn.InstanceNorm2d(out_channels)
+ else:
+ raise NotImplementedError
+ elif conv_type == '1d':
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ bias=False,
+ )
+ if norm == 'BN':
+ self.norm = nn.BatchNorm1d(out_channels)
+ elif norm == 'IN':
+ self.norm = nn.InstanceNorm1d(out_channels)
+ else:
+ raise NotImplementedError
+ nn.init.kaiming_normal_(self.conv.weight)
+
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ if isinstance(self.norm, nn.InstanceNorm1d):
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
+ else:
+ x = self.norm(x)
+ x = self.act(x)
+ return x
+
+
+class PoseSequenceDiscriminator(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
+
+ self.seq = nn.Sequential(
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
+ )
+
+ def forward(self, x):
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
+ x = self.seq(x)
+ x = x.squeeze(1)
+ return x
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/audio2pose_models/networks.py b/sadtalker_video2pose/src/audio2pose_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9212b49836d9221895993d1d490a476707599922
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2pose_models/networks.py
@@ -0,0 +1,140 @@
+import torch.nn as nn
+import torch
+
+
+class ResidualConv(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding):
+ super(ResidualConv, self).__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.BatchNorm2d(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
+ ),
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
+ nn.BatchNorm2d(output_dim),
+ )
+
+ def forward(self, x):
+
+ return self.conv_block(x) + self.conv_skip(x)
+
+
+class Upsample(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel, stride):
+ super(Upsample, self).__init__()
+
+ self.upsample = nn.ConvTranspose2d(
+ input_dim, output_dim, kernel_size=kernel, stride=stride
+ )
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+class Squeeze_Excite_Block(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(Squeeze_Excite_Block, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y.expand_as(x)
+
+
+class ASPP(nn.Module):
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
+ super(ASPP, self).__init__()
+
+ self.aspp_block1 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+ self.aspp_block2 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+ self.aspp_block3 = nn.Sequential(
+ nn.Conv2d(
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
+ ),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(out_dims),
+ )
+
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
+ self._init_weights()
+
+ def forward(self, x):
+ x1 = self.aspp_block1(x)
+ x2 = self.aspp_block2(x)
+ x3 = self.aspp_block3(x)
+ out = torch.cat([x1, x2, x3], dim=1)
+ return self.output(out)
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+class Upsample_(nn.Module):
+ def __init__(self, scale=2):
+ super(Upsample_, self).__init__()
+
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+class AttentionBlock(nn.Module):
+ def __init__(self, input_encoder, input_decoder, output_dim):
+ super(AttentionBlock, self).__init__()
+
+ self.conv_encoder = nn.Sequential(
+ nn.BatchNorm2d(input_encoder),
+ nn.ReLU(),
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
+ nn.MaxPool2d(2, 2),
+ )
+
+ self.conv_decoder = nn.Sequential(
+ nn.BatchNorm2d(input_decoder),
+ nn.ReLU(),
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
+ )
+
+ self.conv_attn = nn.Sequential(
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, 1, 1),
+ )
+
+ def forward(self, x1, x2):
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
+ out = self.conv_attn(out)
+ return out * x2
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/audio2pose_models/res_unet.py b/sadtalker_video2pose/src/audio2pose_models/res_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..280404c2a2804038705f792dd800ddf707b75cf8
--- /dev/null
+++ b/sadtalker_video2pose/src/audio2pose_models/res_unet.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from src.audio2pose_models.networks import ResidualConv, Upsample
+
+
+class ResUnet(nn.Module):
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
+ super(ResUnet, self).__init__()
+
+ self.input_layer = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
+ nn.BatchNorm2d(filters[0]),
+ nn.ReLU(),
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
+ )
+ self.input_skip = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
+ )
+
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
+
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
+
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
+
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
+
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
+
+ self.output_layer = nn.Sequential(
+ nn.Conv2d(filters[0], 1, 1, 1),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ # Encode
+ x1 = self.input_layer(x) + self.input_skip(x)
+ x2 = self.residual_conv_1(x1)
+ x3 = self.residual_conv_2(x2)
+ # Bridge
+ x4 = self.bridge(x3)
+
+ # Decode
+ x4 = self.upsample_1(x4)
+ x5 = torch.cat([x4, x3], dim=1)
+
+ x6 = self.up_residual_conv1(x5)
+
+ x6 = self.upsample_2(x6)
+ x7 = torch.cat([x6, x2], dim=1)
+
+ x8 = self.up_residual_conv2(x7)
+
+ x8 = self.upsample_3(x8)
+ x9 = torch.cat([x8, x1], dim=1)
+
+ x10 = self.up_residual_conv3(x9)
+
+ output = self.output_layer(x10)
+
+ return output
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/config/auido2exp.yaml b/sadtalker_video2pose/src/config/auido2exp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e0e8fbba267158d26a147c8cb2ec5acdd73f432
--- /dev/null
+++ b/sadtalker_video2pose/src/config/auido2exp.yaml
@@ -0,0 +1,58 @@
+DATASET:
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
+ TRAIN_BATCH_SIZE: 32
+ EVAL_BATCH_SIZE: 32
+ EXP: True
+ EXP_DIM: 64
+ FRAME_LEN: 32
+ COEFF_LEN: 73
+ NUM_CLASSES: 46
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+ DEBUG: True
+ NUM_REPEATS: 2
+ T: 40
+
+
+MODEL:
+ FRAMEWORK: V2
+ AUDIOENCODER:
+ LEAKY_RELU: True
+ NORM: 'IN'
+ DISCRIMINATOR:
+ LEAKY_RELU: False
+ INPUT_CHANNELS: 6
+ CVAE:
+ AUDIO_EMB_IN_SIZE: 512
+ AUDIO_EMB_OUT_SIZE: 128
+ SEQ_LEN: 32
+ LATENT_SIZE: 256
+ ENCODER_LAYER_SIZES: [192, 1024]
+ DECODER_LAYER_SIZES: [1024, 192]
+
+
+TRAIN:
+ MAX_EPOCH: 300
+ GENERATOR:
+ LR: 2.0e-5
+ DISCRIMINATOR:
+ LR: 1.0e-5
+ LOSS:
+ W_FEAT: 0
+ W_COEFF_EXP: 2
+ W_LM: 1.0e-2
+ W_LM_MOUTH: 0
+ W_REG: 0
+ W_SYNC: 0
+ W_COLOR: 0
+ W_EXPRESSION: 0
+ W_LIPREADING: 0.01
+ W_LIPREADING_VV: 0
+ W_EYE_BLINK: 4
+
+TAG:
+ NAME: small_dataset
+
+
diff --git a/sadtalker_video2pose/src/config/auido2pose.yaml b/sadtalker_video2pose/src/config/auido2pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7702414b11581ff99aef7a3187f0d0d1388ae3f3
--- /dev/null
+++ b/sadtalker_video2pose/src/config/auido2pose.yaml
@@ -0,0 +1,49 @@
+DATASET:
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
+ TRAIN_BATCH_SIZE: 64
+ EVAL_BATCH_SIZE: 1
+ EXP: True
+ EXP_DIM: 64
+ FRAME_LEN: 32
+ COEFF_LEN: 73
+ NUM_CLASSES: 46
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+ DEBUG: True
+
+
+MODEL:
+ AUDIOENCODER:
+ LEAKY_RELU: True
+ NORM: 'IN'
+ DISCRIMINATOR:
+ LEAKY_RELU: False
+ INPUT_CHANNELS: 6
+ CVAE:
+ AUDIO_EMB_IN_SIZE: 512
+ AUDIO_EMB_OUT_SIZE: 6
+ SEQ_LEN: 32
+ LATENT_SIZE: 64
+ ENCODER_LAYER_SIZES: [192, 128]
+ DECODER_LAYER_SIZES: [128, 192]
+
+
+TRAIN:
+ MAX_EPOCH: 150
+ GENERATOR:
+ LR: 1.0e-4
+ DISCRIMINATOR:
+ LR: 1.0e-4
+ LOSS:
+ LAMBDA_REG: 1
+ LAMBDA_LANDMARKS: 0
+ LAMBDA_VERTICES: 0
+ LAMBDA_GAN_MOTION: 0.7
+ LAMBDA_GAN_COEFF: 0
+ LAMBDA_KL: 1
+
+TAG:
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
+
+
diff --git a/sadtalker_video2pose/src/config/facerender.yaml b/sadtalker_video2pose/src/config/facerender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd1e1ddfe265698e49dac4a6e103cba0aac4f3ce
--- /dev/null
+++ b/sadtalker_video2pose/src/config/facerender.yaml
@@ -0,0 +1,45 @@
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+ mapping_params:
+ coeff_nc: 70
+ descriptor_nc: 1024
+ layer: 3
+ num_kp: 15
+ num_bins: 66
+
diff --git a/sadtalker_video2pose/src/config/facerender_pirender.yaml b/sadtalker_video2pose/src/config/facerender_pirender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f893b5d0a22f0546642c2d2bdafda88740c81138
--- /dev/null
+++ b/sadtalker_video2pose/src/config/facerender_pirender.yaml
@@ -0,0 +1,83 @@
+# How often do you want to log the training stats.
+# network_list:
+# gen: gen_optimizer
+# dis: dis_optimizer
+
+distributed: False
+image_to_tensorboard: True
+snapshot_save_iter: 40000
+snapshot_save_epoch: 20
+snapshot_save_start_iter: 20000
+snapshot_save_start_epoch: 10
+image_save_iter: 1000
+max_epoch: 200
+logging_iter: 100
+results_dir: ./eval_results
+
+gen_optimizer:
+ type: adam
+ lr: 0.0001
+ adam_beta1: 0.5
+ adam_beta2: 0.999
+ lr_policy:
+ iteration_mode: True
+ type: step
+ step_size: 300000
+ gamma: 0.2
+
+trainer:
+ type: trainers.face_trainer::FaceTrainer
+ pretrain_warp_iteration: 200000
+ loss_weight:
+ weight_perceptual_warp: 2.5
+ weight_perceptual_final: 4
+ vgg_param_warp:
+ network: vgg19
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
+ use_style_loss: False
+ num_scales: 4
+ vgg_param_final:
+ network: vgg19
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
+ use_style_loss: True
+ num_scales: 4
+ style_to_perceptual: 250
+ init:
+ type: 'normal'
+ gain: 0.02
+gen:
+ type: generators.face_model::FaceGenerator
+ param:
+ mapping_net:
+ coeff_nc: 73
+ descriptor_nc: 256
+ layer: 3
+ warpping_net:
+ encoder_layer: 5
+ decoder_layer: 3
+ base_nc: 32
+ editing_net:
+ layer: 3
+ num_res_blocks: 2
+ base_nc: 64
+ common:
+ image_nc: 3
+ descriptor_nc: 256
+ max_nc: 256
+ use_spect: False
+
+
+# Data options.
+data:
+ type: data.vox_dataset::VoxDataset
+ path: ./dataset/vox_lmdb
+ resolution: 256
+ semantic_radius: 13
+ train:
+ batch_size: 5
+ distributed: True
+ val:
+ batch_size: 8
+ distributed: True
+
+
diff --git a/sadtalker_video2pose/src/config/facerender_still.yaml b/sadtalker_video2pose/src/config/facerender_still.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6b84181763caf7184a0769e53a7e419e2e3f604
--- /dev/null
+++ b/sadtalker_video2pose/src/config/facerender_still.yaml
@@ -0,0 +1,45 @@
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+ mapping_params:
+ coeff_nc: 73
+ descriptor_nc: 1024
+ layer: 3
+ num_kp: 15
+ num_bins: 66
+
diff --git a/sadtalker_video2pose/src/config/similarity_Lm3D_all.mat b/sadtalker_video2pose/src/config/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..9f5b0bd4ecffb926128a29cb1bbf9d9081c3d4e7
--- /dev/null
+++ b/sadtalker_video2pose/src/config/similarity_Lm3D_all.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a
+size 994
diff --git a/sadtalker_video2pose/src/face3d/data/__init__.py b/sadtalker_video2pose/src/face3d/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be2378c5877af8e749db18d8a67a382f3eb0912b
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/data/__init__.py
@@ -0,0 +1,116 @@
+"""This package includes all the modules related to data loading and preprocessing
+
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
+ You need to implement four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point from data loader.
+ -- : (optionally) add dataset-specific options and set default options.
+
+Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
+See our template dataset class 'template_dataset.py' for more details.
+"""
+import numpy as np
+import importlib
+import torch.utils.data
+from face3d.data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+ """Import the module "data/[dataset_name]_dataset.py".
+
+ In the file, the class called DatasetNameDataset() will
+ be instantiated. It has to be a subclass of BaseDataset,
+ and it is case-insensitive.
+ """
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ dataset = None
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower() \
+ and issubclass(cls, BaseDataset):
+ dataset = cls
+
+ if dataset is None:
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ """Return the static method of the dataset class."""
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt, rank=0):
+ """Create a dataset given the option.
+
+ This function wraps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from data import create_dataset
+ >>> dataset = create_dataset(opt)
+ """
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
+ dataset = data_loader.load_data()
+ return dataset
+
+class CustomDatasetDataLoader():
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+ def __init__(self, opt, rank=0):
+ """Initialize this class
+
+ Step 1: create a dataset instance given the name [dataset_mode]
+ Step 2: create a multi-threaded data loader.
+ """
+ self.opt = opt
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
+ self.dataset = dataset_class(opt)
+ self.sampler = None
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
+ if opt.use_ddp and opt.isTrain:
+ world_size = opt.world_size
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
+ self.dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=not opt.serial_batches
+ )
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ sampler=self.sampler,
+ num_workers=int(opt.num_threads / world_size),
+ batch_size=int(opt.batch_size / world_size),
+ drop_last=True)
+ else:
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ shuffle=(not opt.serial_batches) and opt.isTrain,
+ num_workers=int(opt.num_threads),
+ drop_last=True
+ )
+
+ def set_epoch(self, epoch):
+ self.dataset.current_epoch = epoch
+ if self.sampler is not None:
+ self.sampler.set_epoch(epoch)
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/sadtalker_video2pose/src/face3d/data/base_dataset.py b/sadtalker_video2pose/src/face3d/data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a7ea5024206e6e58c2f404ac6a1bf0987f5fd4
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/data/base_dataset.py
@@ -0,0 +1,125 @@
+"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
+
+It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
+"""
+import random
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+from abc import ABC, abstractmethod
+
+
+class BaseDataset(data.Dataset, ABC):
+ """This class is an abstract base class (ABC) for datasets.
+
+ To create a subclass, you need to implement the following four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point.
+ -- : (optionally) add dataset-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the class; save the options in the class
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ self.opt = opt
+ # self.root = opt.dataroot
+ self.current_epoch = 0
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return 0
+
+ @abstractmethod
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index - - a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
+ """
+ pass
+
+
+def get_transform(grayscale=False):
+ transform_list = []
+ if grayscale:
+ transform_list.append(transforms.Grayscale(1))
+ transform_list += [transforms.ToTensor()]
+ return transforms.Compose(transform_list)
+
+def get_affine_mat(opt, size):
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
+ w, h = size
+
+ if 'shift' in opt.preprocess:
+ shift_pixs = int(opt.shift_pixs)
+ shift_x = random.randint(-shift_pixs, shift_pixs)
+ shift_y = random.randint(-shift_pixs, shift_pixs)
+ if 'scale' in opt.preprocess:
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
+ if 'rot' in opt.preprocess:
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
+ rot_rad = -rot_angle * np.pi/180
+ if 'flip' in opt.preprocess:
+ flip = random.random() > 0.5
+
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
+
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
+ affine_inv = np.linalg.inv(affine)
+ return affine, affine_inv, flip
+
+def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
+
+def apply_lm_affine(landmark, affine, flip, size):
+ _, h = size
+ lm = landmark.copy()
+ lm[:, 1] = h - 1 - lm[:, 1]
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
+ lm = lm @ np.transpose(affine)
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
+ lm = lm[:, :2]
+ lm[:, 1] = h - 1 - lm[:, 1]
+ if flip:
+ lm_ = lm.copy()
+ lm_[:17] = lm[16::-1]
+ lm_[17:22] = lm[26:21:-1]
+ lm_[22:27] = lm[21:16:-1]
+ lm_[31:36] = lm[35:30:-1]
+ lm_[36:40] = lm[45:41:-1]
+ lm_[40:42] = lm[47:45:-1]
+ lm_[42:46] = lm[39:35:-1]
+ lm_[46:48] = lm[41:39:-1]
+ lm_[48:55] = lm[54:47:-1]
+ lm_[55:60] = lm[59:54:-1]
+ lm_[60:65] = lm[64:59:-1]
+ lm_[65:68] = lm[67:64:-1]
+ lm = lm_
+ return lm
diff --git a/sadtalker_video2pose/src/face3d/data/flist_dataset.py b/sadtalker_video2pose/src/face3d/data/flist_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b49caa8020f8e9aedb73a839b7112320cad68a
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/data/flist_dataset.py
@@ -0,0 +1,125 @@
+"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
+"""
+
+import os.path
+from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+import numpy as np
+import json
+import torch
+from scipy.io import loadmat, savemat
+import pickle
+from util.preprocess import align_img, estimate_norm
+from util.load_mats import load_lm3d
+
+
+def default_flist_reader(flist):
+ """
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
+ """
+ imlist = []
+ with open(flist, 'r') as rf:
+ for line in rf.readlines():
+ impath = line.strip()
+ imlist.append(impath)
+
+ return imlist
+
+def jason_flist_reader(flist):
+ with open(flist, 'r') as fp:
+ info = json.load(fp)
+ return info
+
+def parse_label(label):
+ return torch.tensor(np.array(label).astype(np.float32))
+
+
+class FlistDataset(BaseDataset):
+ """
+ It requires one directories to host training images '/path/to/data/train'
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
+
+ msk_names = default_flist_reader(opt.flist)
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
+
+ self.size = len(self.msk_paths)
+ self.opt = opt
+
+ self.name = 'train' if opt.isTrain else 'val'
+ if '_' in opt.flist:
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
+
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ img (tensor) -- an image in the input domain
+ msk (tensor) -- its corresponding attention mask
+ lm (tensor) -- its corresponding 3d landmarks
+ im_paths (str) -- image paths
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
+ """
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
+ img_path = msk_path.replace('mask/', '')
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
+
+ raw_img = Image.open(img_path).convert('RGB')
+ raw_msk = Image.open(msk_path).convert('RGB')
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
+
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
+
+ aug_flag = self.opt.use_aug and self.opt.isTrain
+ if aug_flag:
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
+
+ _, H = img.size
+ M = estimate_norm(lm, H)
+ transform = get_transform()
+ img_tensor = transform(img)
+ msk_tensor = transform(msk)[:1, ...]
+ lm_tensor = parse_label(lm)
+ M_tensor = parse_label(M)
+
+
+ return {'imgs': img_tensor,
+ 'lms': lm_tensor,
+ 'msks': msk_tensor,
+ 'M': M_tensor,
+ 'im_paths': img_path,
+ 'aug_flag': aug_flag,
+ 'dataset': self.name}
+
+ def _augmentation(self, img, lm, opt, msk=None):
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
+ img = apply_img_affine(img, affine_inv)
+ lm = apply_lm_affine(lm, affine, flip, img.size)
+ if msk is not None:
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
+ return img, lm, msk
+
+
+
+
+ def __len__(self):
+ """Return the total number of images in the dataset.
+ """
+ return self.size
diff --git a/sadtalker_video2pose/src/face3d/data/image_folder.py b/sadtalker_video2pose/src/face3d/data/image_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ef069029b0db1fc40b9b5f9a6f52a48c1cd162
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/data/image_folder.py
@@ -0,0 +1,66 @@
+"""A modified image folder class
+
+We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
+so that this class can load images from both current directory and its subdirectories.
+"""
+import numpy as np
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+ '.tif', '.TIF', '.tiff', '.TIFF',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir, max_dataset_size=float("inf")):
+ images = []
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images[:min(max_dataset_size, len(images))]
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/sadtalker_video2pose/src/face3d/data/template_dataset.py b/sadtalker_video2pose/src/face3d/data/template_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..693b6b09085ad424e53f26e0938b61eea30ed644
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/data/template_dataset.py
@@ -0,0 +1,75 @@
+"""Dataset class template
+
+This module provides a template for users to implement custom datasets.
+You can specify '--dataset_mode template' to use this dataset.
+The class name should be consistent with both the filename and its dataset_mode option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+You need to implement the following functions:
+ -- : Add dataset-specific options and rewrite default values for existing options.
+ -- <__init__>: Initialize this dataset class.
+ -- <__getitem__>: Return a data point and its metadata information.
+ -- <__len__>: Return the number of images.
+"""
+from data.base_dataset import BaseDataset, get_transform
+# from data.image_folder import make_dataset
+# from PIL import Image
+
+
+class TemplateDataset(BaseDataset):
+ """A template dataset class for you to implement custom datasets."""
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ A few things can be done here.
+ - save the options (have been done in BaseDataset)
+ - get image paths and meta information of the dataset.
+ - define the image transformation.
+ """
+ # save the option and dataset root
+ BaseDataset.__init__(self, opt)
+ # get the image paths of your dataset;
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
+ # define the default transform function. You can use ; You can also define your custom transform function
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index -- a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
+
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
+ Step 4: return a data point as a dictionary.
+ """
+ path = 'temp' # needs to be a string
+ data_A = None # needs to be a tensor
+ data_B = None # needs to be a tensor
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
+
+ def __len__(self):
+ """Return the total number of images."""
+ return len(self.image_paths)
diff --git a/sadtalker_video2pose/src/face3d/extract_kp_videos.py b/sadtalker_video2pose/src/face3d/extract_kp_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..68dd79badafd406113ee85cde83492b6c7c66a9b
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/extract_kp_videos.py
@@ -0,0 +1,108 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import face_alignment
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from itertools import cycle
+
+from torch.multiprocessing import Pool, Process, set_start_method
+
+class KeypointExtractor():
+ def __init__(self, device):
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
+ device=device)
+
+ def extract_keypoint(self, images, name=None, info=True):
+ if isinstance(images, list):
+ keypoints = []
+ if info:
+ i_range = tqdm(images,desc='landmark Det:')
+ else:
+ i_range = images
+
+ for image in i_range:
+ current_kp = self.extract_keypoint(image)
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+
+def read_video(filename):
+ frames = []
+ cap = cv2.VideoCapture(filename)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ else:
+ break
+ cap.release()
+ return frames
+
+def run(data):
+ filename, opt, device = data
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ kp_extractor = KeypointExtractor()
+ images = read_video(filename)
+ name = filename.split('/')[-2:]
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+ kp_extractor.extract_keypoint(
+ images,
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
+ )
+
+if __name__ == '__main__':
+ set_start_method('spawn')
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+ parser.add_argument('--device_ids', type=str, default='0,1')
+ parser.add_argument('--workers', type=int, default=4)
+
+ opt = parser.parse_args()
+ filenames = list()
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+ extensions = VIDEO_EXTENSIONS
+
+ for ext in extensions:
+ os.listdir(f'{opt.input_dir}')
+ print(f'{opt.input_dir}/*.{ext}')
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+ print('Total number of videos:', len(filenames))
+ pool = Pool(opt.workers)
+ args_list = cycle([opt])
+ device_ids = opt.device_ids.split(",")
+ device_ids = cycle(device_ids)
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+ None
diff --git a/sadtalker_video2pose/src/face3d/extract_kp_videos_safe.py b/sadtalker_video2pose/src/face3d/extract_kp_videos_safe.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc08ec9b5b48d7d7ecb53a018d24065461a4347
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/extract_kp_videos_safe.py
@@ -0,0 +1,145 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+from torch.multiprocessing import Pool, Process, set_start_method
+
+from facexlib.alignment import landmark_98_to_68
+from facexlib.detection import init_detection_model
+
+from facexlib.utils import load_file_from_url
+from facexlib.alignment.awing_arch import FAN
+
+def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
+ if model_name == 'awing_fan':
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
+
+
+class KeypointExtractor():
+ def __init__(self, device='cuda'):
+
+ root_path = './ckpts/gfpgan'
+
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
+
+ def extract_keypoint(self, images, name=None, info=True):
+ if isinstance(images, list):
+ keypoints = []
+ if info:
+ i_range = tqdm(images,desc='landmark Det:')
+ else:
+ i_range = images
+
+ for image in i_range:
+ current_kp = self.extract_keypoint(image)
+ # current_kp = self.detector.get_landmarks(np.array(image))
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ with torch.no_grad():
+ # face detection -> face alignment.
+ img = np.array(images)
+ bboxes = self.det_net.detect_faces(images, 0.97)
+
+ bboxes = bboxes[0]
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
+
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ keypoints[:,0] += int(bboxes[0])
+ keypoints[:,1] += int(bboxes[1])
+
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+
+def read_video(filename):
+ frames = []
+ cap = cv2.VideoCapture(filename)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ else:
+ break
+ cap.release()
+ return frames
+
+def run(data):
+ filename, opt, device = data
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
+ kp_extractor = KeypointExtractor()
+ images = read_video(filename)
+ name = filename.split('/')[-2:]
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+ kp_extractor.extract_keypoint(
+ images,
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
+ )
+
+if __name__ == '__main__':
+ set_start_method('spawn')
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+ parser.add_argument('--device_ids', type=str, default='0,1')
+ parser.add_argument('--workers', type=int, default=4)
+
+ opt = parser.parse_args()
+ filenames = list()
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+ extensions = VIDEO_EXTENSIONS
+
+ for ext in extensions:
+ os.listdir(f'{opt.input_dir}')
+ print(f'{opt.input_dir}/*.{ext}')
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+ print('Total number of videos:', len(filenames))
+ pool = Pool(opt.workers)
+ args_list = cycle([opt])
+ device_ids = opt.device_ids.split(",")
+ device_ids = cycle(device_ids)
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+ None
diff --git a/sadtalker_video2pose/src/face3d/models/__init__.py b/sadtalker_video2pose/src/face3d/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef6b5e399254bd42850f3385878f35d4acf90852
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from src.face3d.models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "face3d.models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/README.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc7f1d45f2f5e4b752c42dc81d3e2879c1459c6e
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/README.md
@@ -0,0 +1,164 @@
+# Distributed Arcface Training in Pytorch
+
+This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
+identity on a single server.
+
+## Requirements
+
+- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
+- `pip install -r requirements.txt`.
+- Download the dataset
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
+ .
+
+## How to Training
+
+To train a model, run `train.py` with the path to the configs:
+
+### 1. Single node, 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+```
+
+### 2. Multiple nodes, each node 8 GPUs:
+
+Node 0:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+Node 1:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+### 3.Training resnet2060 with 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
+```
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.
+- All models can be found in here.
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
+- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
+As the result, we can evaluate the FAIR performance for different algorithms.
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
+
+For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
+Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
+There are totally 13,928 positive pairs and 96,983,824 negative pairs.
+
+| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
+| :---: | :--- | :--- | :--- |:--- |:--- |
+| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
+| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
+| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
+| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
+| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
+| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
+| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
+| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
+| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
+| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
+
+### Performance on IJB-C and Verification Datasets
+
+| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
+| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
+| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
+| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
+| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
+| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
+| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
+| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
+| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
+| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
+| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
+
+[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
+
+
+## [Speed Benchmark](docs/speed_benchmark.md)
+
+**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
+classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
+accuracy with several times faster training performance and smaller GPU memory.
+Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
+sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
+sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
+we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
+training and mixed precision training.
+
+![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
+
+More details see
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|1400000 | **1672** | 3043 | 4738 |
+|5500000 | **-** | **1389** | 3975 |
+|8000000 | **-** | **-** | 3565 |
+|16000000 | **-** | **-** | 2679 |
+|29000000 | **-** | **-** | **1855** |
+
+### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|1400000 | 32252 | 11178 | 6056 |
+|5500000 | **-** | 32188 | 9854 |
+|8000000 | **-** | **-** | 12310 |
+|16000000 | **-** | **-** | 19950 |
+|29000000 | **-** | **-** | 32324 |
+
+## Evaluation ICCV2021-MFR and IJB-C
+
+More details see [eval.md](docs/eval.md) in docs.
+
+## Test
+
+We tested many versions of PyTorch. Please create an issue if you are having trouble.
+
+- [x] torch 1.6.0
+- [x] torch 1.7.1
+- [x] torch 1.8.0
+- [x] torch 1.9.0
+
+## Citation
+
+```
+@inproceedings{deng2019arcface,
+ title={Arcface: Additive angular margin loss for deep face recognition},
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={4690--4699},
+ year={2019}
+}
+@inproceedings{an2020partical_fc,
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
+ Zhang, Debing and Fu Ying},
+ booktitle={Arxiv 2010.05222},
+ year={2020}
+}
+```
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5650187b4fdea84c5a23e0445440901690ab682a
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/__init__.py
@@ -0,0 +1,25 @@
+from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+ # resnet
+ if name == "r18":
+ return iresnet18(False, **kwargs)
+ elif name == "r34":
+ return iresnet34(False, **kwargs)
+ elif name == "r50":
+ return iresnet50(False, **kwargs)
+ elif name == "r100":
+ return iresnet100(False, **kwargs)
+ elif name == "r200":
+ return iresnet200(False, **kwargs)
+ elif name == "r2060":
+ from .iresnet2060 import iresnet2060
+ return iresnet2060(False, **kwargs)
+ elif name == "mbf":
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf(fp16=fp16, num_features=num_features)
+ else:
+ raise ValueError()
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29f5f2bfbd444273717c4bc8aa20ba7edd08f80
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,187 @@
+import torch
+from torch import nn
+
+__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
+ progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
+ progress, **kwargs)
+
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bb4335716b653bd5924e20d616d825ef48339f
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ['iresnet2060']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def checkpoint(self, func, num_seg, x):
+ if self.training:
+ return checkpoint_sequential(func, num_seg, x)
+ else:
+ return func(x)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.checkpoint(self.layer2, 20, x)
+ x = self.checkpoint(self.layer3, 100, x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02c6c1e4fa6a6ddf09f5b01dec96971427cb110
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,130 @@
+'''
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+'''
+
+import torch.nn as nn
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
+import torch
+
+
+class Flatten(Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(ConvBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+ BatchNorm2d(num_features=out_c),
+ PReLU(num_parameters=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class LinearBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(LinearBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
+ BatchNorm2d(num_features=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DepthWise(Module):
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+ super(DepthWise, self).__init__()
+ self.residual = residual
+ self.layers = nn.Sequential(
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
+ )
+
+ def forward(self, x):
+ short_cut = None
+ if self.residual:
+ short_cut = x
+ x = self.layers(x)
+ if self.residual:
+ output = short_cut + x
+ else:
+ output = x
+ return output
+
+
+class Residual(Module):
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+ super(Residual, self).__init__()
+ modules = []
+ for _ in range(num_block):
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+ self.layers = Sequential(*modules)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GDC(Module):
+ def __init__(self, embedding_size):
+ super(GDC, self).__init__()
+ self.layers = nn.Sequential(
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+ Flatten(),
+ Linear(512, embedding_size, bias=False),
+ BatchNorm1d(embedding_size))
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MobileFaceNet(Module):
+ def __init__(self, fp16=False, num_features=512):
+ super(MobileFaceNet, self).__init__()
+ scale = 2
+ self.fp16 = fp16
+ self.layers = nn.Sequential(
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ )
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+ self.features = GDC(num_features)
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.layers(x)
+ x = self.conv_sep(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def get_mbf(fp16, num_features):
+ return MobileFaceNet(fp16, num_features)
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bee7cb4236e8b842a1bd1e8c26de7a11df0bf43
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf7df5f04e2509e5dcc14adebbb9302a18f03f2b
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/base.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f98c62fed44afde276dcbacecd9da0a8f474963c
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/base.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = "ms1mv3_arcface_r50"
+
+config.dataset = "ms1m-retinaface-t1"
+config.embedding_size = 512
+config.sample_rate = 1
+config.fp16 = False
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+if config.dataset == "emore":
+ config.rec = "/train_tmp/faces_emore"
+ config.num_classes = 85742
+ config.num_image = 5822653
+ config.num_epoch = 16
+ config.warmup_epoch = -1
+ config.decay_epoch = [8, 14, ]
+ config.val_targets = ["lfw", ]
+
+elif config.dataset == "ms1m-retinaface-t1":
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
+ config.num_classes = 93431
+ config.num_image = 5179510
+ config.num_epoch = 25
+ config.warmup_epoch = -1
+ config.decay_epoch = [11, 17, 22]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "glint360k":
+ config.rec = "/train_tmp/glint360k"
+ config.num_classes = 360232
+ config.num_image = 17091657
+ config.num_epoch = 20
+ config.warmup_epoch = -1
+ config.decay_epoch = [8, 12, 15, 18]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "webface":
+ config.rec = "/train_tmp/faces_webface_112x112"
+ config.num_classes = 10572
+ config.num_image = "forget"
+ config.num_epoch = 34
+ config.warmup_epoch = -1
+ config.decay_epoch = [20, 28, 32]
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ee5e8d96249d57196df43418f6fda4ab339877
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f8ef745c0efb9d5ea67409edc8c904def8a9d9
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..473b59a954fffcaddca132fb6e0f32cbe70c70f4
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c22ff0c82cc98bbbe81c9a1c26c9b3fc186105
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ecbfda06730e3842e7b347db366e82f0714912f
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c87a99867db55c7f689574c331c14cda23ea96
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 20, 25]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aeb851b05ea22e01da87b3d387812f0253989f8
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..8693e67080dac7e7b84da08a62df326c7b12d465
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r2060"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 64
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bff483db179045c0e3acc8e2975477182b0756
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..de81ffdd84edd6fcea7fcb4d3594db031b9e4e26
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/speed.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/speed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c172f9d44d39b534f2253630471e91cf78e6fba7
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/speed.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 100 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/dataset.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bead250243237c650fa3138f6aa172d4f98535f
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/dataset.py
@@ -0,0 +1,124 @@
+import numbers
+import os
+import queue as Queue
+import threading
+
+import mxnet as mx
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, local_rank, max_prefetch=6):
+ super(BackgroundGenerator, self).__init__()
+ self.queue = Queue.Queue(max_prefetch)
+ self.generator = generator
+ self.local_rank = local_rank
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ torch.cuda.set_device(self.local_rank)
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class DataLoaderX(DataLoader):
+
+ def __init__(self, local_rank, **kwargs):
+ super(DataLoaderX, self).__init__(**kwargs)
+ self.stream = torch.cuda.Stream(local_rank)
+ self.local_rank = local_rank
+
+ def __iter__(self):
+ self.iter = super(DataLoaderX, self).__iter__()
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
+ self.preload()
+ return self
+
+ def preload(self):
+ self.batch = next(self.iter, None)
+ if self.batch is None:
+ return None
+ with torch.cuda.stream(self.stream):
+ for k in range(len(self.batch)):
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+ def __next__(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is None:
+ raise StopIteration
+ self.preload()
+ return batch
+
+
+class MXFaceDataset(Dataset):
+ def __init__(self, root_dir, local_rank):
+ super(MXFaceDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [transforms.ToPILImage(),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+ self.root_dir = root_dir
+ self.local_rank = local_rank
+ path_imgrec = os.path.join(root_dir, 'train.rec')
+ path_imgidx = os.path.join(root_dir, 'train.idx')
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+ s = self.imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ if header.flag > 0:
+ self.header0 = (int(header.label[0]), int(header.label[1]))
+ self.imgidx = np.array(range(1, int(header.label[0])))
+ else:
+ self.imgidx = np.array(list(self.imgrec.keys))
+
+ def __getitem__(self, index):
+ idx = self.imgidx[index]
+ s = self.imgrec.read_idx(idx)
+ header, img = mx.recordio.unpack(s)
+ label = header.label
+ if not isinstance(label, numbers.Number):
+ label = label[0]
+ label = torch.tensor(label, dtype=torch.long)
+ sample = mx.image.imdecode(img).asnumpy()
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample, label
+
+ def __len__(self):
+ return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+ def __init__(self, local_rank):
+ super(SyntheticDataset, self).__init__()
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).squeeze(0).float()
+ img = ((img / 255) - 0.5) / 0.5
+ self.img = img
+ self.label = 1
+
+ def __getitem__(self, index):
+ return self.img, self.label
+
+ def __len__(self):
+ return 1000000
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/eval.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..4d29c855fc6e4245ed264216c1f96ab2efc57248
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/eval.md
@@ -0,0 +1,31 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/install.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1b770a0d93dac1f160185b5bbf4da2f414f21f6
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/install.md
@@ -0,0 +1,51 @@
+## v1.8.0
+### Linux and Windows
+```shell
+# CUDA 11.0
+pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
+
+# CPU only
+pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+```
+
+
+## v1.7.1
+### Linux and Windows
+```shell
+# CUDA 11.0
+pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
+
+# CUDA 10.1
+pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+
+## v1.6.0
+
+### Linux and Windows
+```shell
+# CUDA 10.2
+pip install torch==1.6.0 torchvision==0.7.0
+
+# CUDA 10.1
+pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+```
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/modelzoo.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..d54904587df4e13784dc68d5709b4d7d97490890
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance.
+The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
+[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
+[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
+[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
+[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
+[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
+[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
+[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
+[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
+[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
+[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
+[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
+[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
+[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|250000 | 4047 | 4521 | 4976 |
+|500000 | 3087 | 4013 | 4900 |
+|1000000 | 2090 | 3449 | 4803 |
+|1400000 | 1672 | 3043 | 4738 |
+|2000000 | - | 2593 | 4626 |
+|4000000 | - | 1748 | 4208 |
+|5500000 | - | 1389 | 3975 |
+|8000000 | - | - | 3565 |
+|16000000 | - | - | 2679 |
+|29000000 | - | - | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|250000 | 9940 | 5826 | 5004 |
+|500000 | 14220 | 7114 | 5202 |
+|1000000 | 23708 | 9966 | 5620 |
+|1400000 | 32252 | 11178 | 6056 |
+|2000000 | - | 13978 | 6472 |
+|4000000 | - | 23238 | 8284 |
+|5500000 | - | 32188 | 9854 |
+|8000000 | - | - | 12310 |
+|16000000 | - | - | 19950 |
+|29000000 | - | - | 32324 |
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/verification.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b1f5618184effae64895847af1a65d43d2e4418
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/verification.py
@@ -0,0 +1,407 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset
+"""
+
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+ def __init__(self, n_splits=2, shuffle=False):
+ self.n_splits = n_splits
+ if self.n_splits > 1:
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+ def split(self, indices):
+ if self.n_splits > 1:
+ return self.k_fold.split(indices)
+ else:
+ return [(indices, indices)]
+
+
+def calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ nrof_folds=10,
+ pca=0):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
+ accuracy = np.zeros((nrof_folds))
+ indices = np.arange(nrof_pairs)
+
+ if pca == 0:
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+ if pca > 0:
+ print('doing pca on', fold_idx)
+ embed1_train = embeddings1[train_set]
+ embed2_train = embeddings2[train_set]
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+ pca_model = PCA(n_components=pca)
+ pca_model.fit(_embed_train)
+ embed1 = pca_model.transform(embeddings1)
+ embed2 = pca_model.transform(embeddings2)
+ embed1 = sklearn.preprocessing.normalize(embed1)
+ embed2 = sklearn.preprocessing.normalize(embed2)
+ diff = np.subtract(embed1, embed2)
+ dist = np.sum(np.square(diff), 1)
+
+ # Find the best threshold for the fold
+ acc_train = np.zeros((nrof_thresholds))
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, _, acc_train[threshold_idx] = calculate_accuracy(
+ threshold, dist[train_set], actual_issame[train_set])
+ best_threshold_index = np.argmax(acc_train)
+ for threshold_idx, threshold in enumerate(thresholds):
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+ threshold, dist[test_set],
+ actual_issame[test_set])
+ _, _, accuracy[fold_idx] = calculate_accuracy(
+ thresholds[best_threshold_index], dist[test_set],
+ actual_issame[test_set])
+
+ tpr = np.mean(tprs, 0)
+ fpr = np.mean(fprs, 0)
+ return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ tn = np.sum(
+ np.logical_and(np.logical_not(predict_issame),
+ np.logical_not(actual_issame)))
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+ acc = float(tp + tn) / dist.size
+ return tpr, fpr, acc
+
+
+def calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ far_target,
+ nrof_folds=10):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ val = np.zeros(nrof_folds)
+ far = np.zeros(nrof_folds)
+
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+ indices = np.arange(nrof_pairs)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+ # Find the threshold that gives FAR = far_target
+ far_train = np.zeros(nrof_thresholds)
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, far_train[threshold_idx] = calculate_val_far(
+ threshold, dist[train_set], actual_issame[train_set])
+ if np.max(far_train) >= far_target:
+ f = interpolate.interp1d(far_train, thresholds, kind='slinear')
+ threshold = f(far_target)
+ else:
+ threshold = 0.0
+
+ val[fold_idx], far[fold_idx] = calculate_val_far(
+ threshold, dist[test_set], actual_issame[test_set])
+
+ val_mean = np.mean(val)
+ far_mean = np.mean(far)
+ val_std = np.std(val)
+ return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+ false_accept = np.sum(
+ np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ n_same = np.sum(actual_issame)
+ n_diff = np.sum(np.logical_not(actual_issame))
+ # print(true_accept, false_accept)
+ # print(n_same, n_diff)
+ val = float(true_accept) / float(n_same)
+ far = float(false_accept) / float(n_diff)
+ return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+ # Calculate evaluation metrics
+ thresholds = np.arange(0, 4, 0.01)
+ embeddings1 = embeddings[0::2]
+ embeddings2 = embeddings[1::2]
+ tpr, fpr, accuracy = calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ nrof_folds=nrof_folds,
+ pca=pca)
+ thresholds = np.arange(0, 4, 0.001)
+ val, val_std, far = calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ 1e-3,
+ nrof_folds=nrof_folds)
+ return tpr, fpr, accuracy, val, val_std, far
+
+@torch.no_grad()
+def load_bin(path, image_size):
+ try:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f) # py2
+ except UnicodeDecodeError as e:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
+ data_list = []
+ for flip in [0, 1]:
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+ data_list.append(data)
+ for idx in range(len(issame_list) * 2):
+ _bin = bins[idx]
+ img = mx.image.imdecode(_bin)
+ if img.shape[1] != image_size[0]:
+ img = mx.image.resize_short(img, image_size[0])
+ img = nd.transpose(img, axes=(2, 0, 1))
+ for flip in [0, 1]:
+ if flip == 1:
+ img = mx.ndarray.flip(data=img, axis=2)
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+ if idx % 1000 == 0:
+ print('loading bin', idx)
+ print(data_list[0].shape)
+ return data_list, issame_list
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+ print('testing verification..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+ _data = data[bb - batch_size: bb]
+ time0 = datetime.datetime.now()
+ img = ((_data / 255) - 0.5) / 0.5
+ net_out: torch.Tensor = backbone(img)
+ _embeddings = net_out.detach().cpu().numpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+
+ _xnorm = 0.0
+ _xnorm_cnt = 0
+ for embed in embeddings_list:
+ for i in range(embed.shape[0]):
+ _em = embed[i]
+ _norm = np.linalg.norm(_em)
+ _xnorm += _norm
+ _xnorm_cnt += 1
+ _xnorm /= _xnorm_cnt
+
+ acc1 = 0.0
+ std1 = 0.0
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ print(embeddings.shape)
+ print('infer time', time_consumed)
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set,
+ backbone,
+ batch_size,
+ name='',
+ data_extra=None,
+ label_shape=None):
+ print('dump verification embedding..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+ time0 = datetime.datetime.now()
+ if data_extra is None:
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
+ else:
+ db = mx.io.DataBatch(data=(_data, _data_extra),
+ label=(_label,))
+ model.forward(db, is_train=False)
+ net_out = model.get_outputs()
+ _embeddings = net_out[0].asnumpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ actual_issame = np.asarray(issame_list)
+ outname = os.path.join('temp.bin')
+ with open(outname, 'wb') as f:
+ pickle.dump((embeddings, issame_list),
+ f,
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+# parser = argparse.ArgumentParser(description='do verification')
+# # general
+# parser.add_argument('--data-dir', default='', help='')
+# parser.add_argument('--model',
+# default='../model/softmax,50',
+# help='path to load model.')
+# parser.add_argument('--target',
+# default='lfw,cfp_ff,cfp_fp,agedb_30',
+# help='test targets.')
+# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+# parser.add_argument('--batch-size', default=32, type=int, help='')
+# parser.add_argument('--max', default='', type=str, help='')
+# parser.add_argument('--mode', default=0, type=int, help='')
+# parser.add_argument('--nfolds', default=10, type=int, help='')
+# args = parser.parse_args()
+# image_size = [112, 112]
+# print('image_size', image_size)
+# ctx = mx.gpu(args.gpu)
+# nets = []
+# vec = args.model.split(',')
+# prefix = args.model.split(',')[0]
+# epochs = []
+# if len(vec) == 1:
+# pdir = os.path.dirname(prefix)
+# for fname in os.listdir(pdir):
+# if not fname.endswith('.params'):
+# continue
+# _file = os.path.join(pdir, fname)
+# if _file.startswith(prefix):
+# epoch = int(fname.split('.')[0].split('-')[1])
+# epochs.append(epoch)
+# epochs = sorted(epochs, reverse=True)
+# if len(args.max) > 0:
+# _max = [int(x) for x in args.max.split(',')]
+# assert len(_max) == 2
+# if len(epochs) > _max[1]:
+# epochs = epochs[_max[0]:_max[1]]
+#
+# else:
+# epochs = [int(x) for x in vec[1].split('|')]
+# print('model number', len(epochs))
+# time0 = datetime.datetime.now()
+# for epoch in epochs:
+# print('loading', prefix, epoch)
+# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+# all_layers = sym.get_internals()
+# sym = all_layers['fc1_output']
+# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+# image_size[1]))])
+# model.set_params(arg_params, aux_params)
+# nets.append(model)
+# time_now = datetime.datetime.now()
+# diff = time_now - time0
+# print('model loading time', diff.total_seconds())
+#
+# ver_list = []
+# ver_name_list = []
+# for name in args.target.split(','):
+# path = os.path.join(args.data_dir, name + ".bin")
+# if os.path.exists(path):
+# print('loading.. ', name)
+# data_set = load_bin(path, image_size)
+# ver_list.append(data_set)
+# ver_name_list.append(name)
+#
+# if args.mode == 0:
+# for i in range(len(ver_list)):
+# results = []
+# for model in nets:
+# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+# ver_list[i], model, args.batch_size, args.nfolds)
+# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+# results.append(acc2)
+# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+# elif args.mode == 1:
+# raise ValueError
+# else:
+# model = nets[0]
+# dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval_ijbc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..64844c4723a88b4b160d2fee9a7b626b987981d9
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval_ijbc.py
@@ -0,0 +1,483 @@
+# coding: utf-8
+
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description='do ijb test')
+# general
+parser.add_argument('--model-prefix', default='', help='path to load model.')
+parser.add_argument('--image-path', default='', type=str, help='')
+parser.add_argument('--result-dir', default='.', type=str, help='')
+parser.add_argument('--batch-size', default=128, type=int, help='')
+parser.add_argument('--network', default='iresnet50', type=str, help='')
+parser.add_argument('--job', default='insightface', type=str, help='job name')
+parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True # if Ture, TestMode(N1)
+use_detector_score = True # if Ture, TestMode(D1)
+use_flip_test = True # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+ def __init__(self, prefix, data_shape, batch_size=1):
+ image_size = (112, 112)
+ self.image_size = image_size
+ weight = torch.load(prefix)
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+ resnet.load_state_dict(weight)
+ model = torch.nn.DataParallel(resnet)
+ self.model = model
+ self.model.eval()
+ src = np.array([
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]], dtype=np.float32)
+ src[:, 0] += 8.0
+ self.src = src
+ self.batch_size = batch_size
+ self.data_shape = data_shape
+
+ def get(self, rimg, landmark):
+
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+ assert landmark.shape[1] == 2
+ if landmark.shape[0] == 68:
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
+ landmark5[2] = landmark[30]
+ landmark5[3] = landmark[48]
+ landmark5[4] = landmark[54]
+ else:
+ landmark5 = landmark
+ tform = trans.SimilarityTransform()
+ tform.estimate(landmark5, self.src)
+ M = tform.params[0:2, :]
+ img = cv2.warpAffine(rimg,
+ M, (self.image_size[1], self.image_size[0]),
+ borderValue=0.0)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_flip = np.fliplr(img)
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
+ img_flip = np.transpose(img_flip, (2, 0, 1))
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+ input_blob[0] = img
+ input_blob[1] = img_flip
+ return input_blob
+
+ @torch.no_grad()
+ def forward_db(self, batch_data):
+ imgs = torch.Tensor(batch_data).cuda()
+ imgs.div_(255).sub_(0.5).div_(0.5)
+ feat = self.model(imgs)
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+ return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+ twoList = [[] for i in range(n)]
+ for i, e in enumerate(listTemp):
+ twoList[i % n].append(e)
+ return twoList
+
+
+def read_template_media_list(path):
+ # ijb_meta = np.loadtxt(path, dtype=str)
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+ # pairs = np.loadtxt(path, dtype=str)
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ # print(pairs.shape)
+ # print(pairs[:, 0].astype(np.int))
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+ batch_size = args.batch_size
+ data_shape = (3, 112, 112)
+
+ files = files_list
+ print('files:', len(files))
+ rare_size = len(files) % batch_size
+ faceness_scores = []
+ batch = 0
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, batch_size)
+ for img_index, each_line in enumerate(files[:len(files) - rare_size]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+ if (img_index + 1) % batch_size == 0:
+ print('batch', batch)
+ img_feats[batch * batch_size:batch * batch_size +
+ batch_size][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, rare_size)
+ for img_index, each_line in enumerate(files[len(files) - rare_size:]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+ batch_data[2 * img_index][:] = input_blob[0]
+ batch_data[2 * img_index + 1][:] = input_blob[1]
+ if (img_index + 1) % rare_size == 0:
+ print('batch', batch)
+ img_feats[len(files) -
+ rare_size:][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+ return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ # ==========================================================
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+ # 2. compute media feature.
+ # 3. compute template feature.
+ # ==========================================================
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+ for count_template, uqt in enumerate(unique_templates):
+
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias,
+ return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
+ ]
+ media_norm_feats = np.array(media_norm_feats)
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+ # print(template_norm_feats.shape)
+ return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ # ==========================================================
+ # Compute set-to-set Similarity Score.
+ # ==========================================================
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def read_score(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == 'IJBC' or target == 'IJBB'
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id, mid --> media id
+# format:
+# image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_face_tid_mid.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id, label : 1/0
+# format:
+# tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+# img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = '%s/loose_crop' % image_path
+img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list,
+ model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
+ img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+ # concat --- F1
+ # img_input_feats = img_feats
+ # add --- F2
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] //
+ 2] + img_feats[:, img_feats.shape[1] // 2:]
+else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+if use_norm_score:
+ img_input_feats = img_input_feats
+else:
+ # normalise features to remove norm information
+ img_input_feats = img_input_feats / np.sqrt(
+ np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+ img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+ methods.append(Path(file).stem)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
+print(tpr_fpr_table)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/inference.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..1929d4abb640d040398dda57b491b9bd96deac9d
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/inference.py
@@ -0,0 +1,35 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+ if img is None:
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+ else:
+ img = cv2.imread(img)
+ img = cv2.resize(img, (112, 112))
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+ img.div_(255).sub_(0.5).div_(0.5)
+ net = get_model(name, fp16=False)
+ net.load_state_dict(torch.load(weight))
+ net.eval()
+ feat = net(img).numpy()
+ print(feat)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('--network', type=str, default='r50', help='backbone network')
+ parser.add_argument('--weight', type=str, default='')
+ parser.add_argument('--img', type=str, default=None)
+ args = parser.parse_args()
+ inference(args.weight, args.network, args.img)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/losses.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bfdd8c6b7f6b0d465928f19c554e62340e5ad7b
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/losses.py
@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+
+def get_loss(name):
+ if name == "cosface":
+ return CosFace()
+ elif name == "arcface":
+ return ArcFace()
+ else:
+ raise ValueError()
+
+
+class CosFace(nn.Module):
+ def __init__(self, s=64.0, m=0.40):
+ super(CosFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, cosine, label):
+ index = torch.where(label != -1)[0]
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+ m_hot.scatter_(1, label[index, None], self.m)
+ cosine[index] -= m_hot
+ ret = cosine * self.s
+ return ret
+
+
+class ArcFace(nn.Module):
+ def __init__(self, s=64.0, m=0.5):
+ super(ArcFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, cosine: torch.Tensor, label):
+ index = torch.where(label != -1)[0]
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+ m_hot.scatter_(1, label[index, None], self.m)
+ cosine.acos_()
+ cosine[index] += m_hot
+ cosine.cos_().mul_(self.s)
+ return cosine
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_helper.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a01a46621dc0ea695bd903de5d1e212d424c860
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_helper.py
@@ -0,0 +1,250 @@
+from __future__ import division
+import datetime
+import os
+import os.path as osp
+import glob
+import numpy as np
+import cv2
+import sys
+import onnxruntime
+import onnx
+import argparse
+from onnx import numpy_helper
+from insightface.data import get_image
+
+class ArcFaceORT:
+ def __init__(self, model_path, cpu=False):
+ self.model_path = model_path
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+ self.providers = ['CPUExecutionProvider'] if cpu else None
+
+ #input_size is (w,h), return error message, return None if success
+ def check(self, track='cfat', test_img = None):
+ #default is cfat
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=15
+ if track.startswith('ms1m'):
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=10
+ elif track.startswith('glint'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=20
+ elif track.startswith('cfat'):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ elif track.startswith('unconstrained'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=30
+ else:
+ return "track not found"
+
+ if not os.path.exists(self.model_path):
+ return "model_path not exists"
+ if not os.path.isdir(self.model_path):
+ return "model_path should be directory"
+ onnx_files = []
+ for _file in os.listdir(self.model_path):
+ if _file.endswith('.onnx'):
+ onnx_files.append(osp.join(self.model_path, _file))
+ if len(onnx_files)==0:
+ return "do not have onnx files"
+ self.model_file = sorted(onnx_files)[-1]
+ print('use onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('input-shape:', input_shape)
+ if len(input_shape)!=4:
+ return "length of input_shape should be 4"
+ if not isinstance(input_shape[0], str):
+ #return "input_shape[0] should be str to support batch-inference"
+ print('reset input-shape[0] to None')
+ model = onnx.load(self.model_file)
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
+ onnx.save(model, new_model_file)
+ self.model_file = new_model_file
+ print('use new onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('new-input-shape:', input_shape)
+
+ self.image_size = tuple(input_shape[2:4][::-1])
+ #print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ outputs = session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ #print(o.name, o.shape)
+ if len(output_names)!=1:
+ return "number of output nodes should be 1"
+ self.session = session
+ self.input_name = input_name
+ self.output_names = output_names
+ #print(self.output_names)
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ if len(graph.node)<8:
+ return "too small onnx graph"
+
+ input_size = (112,112)
+ self.crop = None
+ if track=='cfat':
+ crop_file = osp.join(self.model_path, 'crop.txt')
+ if osp.exists(crop_file):
+ lines = open(crop_file,'r').readlines()
+ if len(lines)!=6:
+ return "crop.txt should contain 6 lines"
+ lines = [int(x) for x in lines]
+ self.crop = lines[:4]
+ input_size = tuple(lines[4:6])
+ if input_size!=self.image_size:
+ return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
+
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
+ if self.model_size_mb > max_model_size_mb:
+ return "max model size exceed, given %.3f-MB"%self.model_size_mb
+
+ input_mean = None
+ input_std = None
+ if track=='cfat':
+ pn_file = osp.join(self.model_path, 'pixel_norm.txt')
+ if osp.exists(pn_file):
+ lines = open(pn_file,'r').readlines()
+ if len(lines)!=2:
+ return "pixel_norm.txt should contain 2 lines"
+ input_mean = float(lines[0])
+ input_std = float(lines[1])
+ if input_mean is not None or input_std is not None:
+ if input_mean is None or input_std is None:
+ return "please set input_mean and input_std simultaneously"
+ else:
+ find_sub = False
+ find_mul = False
+ for nid, node in enumerate(graph.node[:8]):
+ print(nid, node.name)
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
+ find_sub = True
+ if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
+ find_mul = True
+ if find_sub and find_mul:
+ print("find sub and mul")
+ #mxnet arcface model
+ input_mean = 0.0
+ input_std = 1.0
+ else:
+ input_mean = 127.5
+ input_std = 127.5
+ self.input_mean = input_mean
+ self.input_std = input_std
+ for initn in graph.initializer:
+ weight_array = numpy_helper.to_array(initn)
+ dt = weight_array.dtype
+ if dt.itemsize<4:
+ return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
+ if test_img is None:
+ test_img = get_image('Tom_Hanks_54745')
+ test_img = cv2.resize(test_img, self.image_size)
+ else:
+ test_img = cv2.resize(test_img, self.image_size)
+ feat, cost = self.benchmark(test_img)
+ batch_result = self.check_batch(test_img)
+ batch_result_sum = float(np.sum(batch_result))
+ if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
+ print(batch_result)
+ print(batch_result_sum)
+ return "batch result output contains NaN!"
+
+ if len(feat.shape) < 2:
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+ if feat.shape[1] > max_feat_dim:
+ return "max feat dim exceed, given %d"%feat.shape[1]
+ self.feat_dim = feat.shape[1]
+ cost_ms = cost*1000
+ if cost_ms>max_time_cost:
+ return "max time cost exceed, given %.4f"%cost_ms
+ self.cost_ms = cost_ms
+ print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
+ return None
+
+ def check_batch(self, img):
+ if not isinstance(img, list):
+ imgs = [img, ] * 32
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+ nimg = cv2.resize(nimg, self.image_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
+ mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+
+ def meta_info(self):
+ return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
+
+
+ def forward(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.image_size
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ return net_out
+
+ def benchmark(self, img):
+ input_size = self.image_size
+ if self.crop is not None:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ img = nimg
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ costs = []
+ for _ in range(50):
+ ta = datetime.datetime.now()
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ tb = datetime.datetime.now()
+ cost = (tb-ta).total_seconds()
+ costs.append(cost)
+ costs = sorted(costs)
+ cost = costs[5]
+ return net_out, cost
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='')
+ # general
+ parser.add_argument('workdir', help='submitted work dir', type=str)
+ parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
+ args = parser.parse_args()
+ handler = ArcFaceORT(args.workdir)
+ err = handler.check(args.track)
+ print('err:', err)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_ijbc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa96b96745e23d4d6642d99f71456c10af5e4e4e
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,267 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+
+from onnx_helper import ArcFaceORT
+
+SRC = np.array(
+ [
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]]
+ , dtype=np.float32)
+SRC[:, 0] += 8.0
+
+
+class AlignedDataSet(mx.gluon.data.Dataset):
+ def __init__(self, root, lines, align=True):
+ self.lines = lines
+ self.root = root
+ self.align = align
+
+ def __len__(self):
+ return len(self.lines)
+
+ def __getitem__(self, idx):
+ each_line = self.lines[idx]
+ name_lmk_score = each_line.strip().split(' ')
+ name = os.path.join(self.root, name_lmk_score[0])
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+ st = skimage.transform.SimilarityTransform()
+ st.estimate(landmark5, SRC)
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+ img_1 = np.expand_dims(img, 0)
+ img_2 = np.expand_dims(np.fliplr(img), 0)
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+ output = np.transpose(output, (0, 3, 1, 2))
+ output = mx.nd.array(output)
+ return output
+
+
+def extract(model_root, dataset):
+ model = ArcFaceORT(model_path=model_root)
+ model.check()
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+ def batchify_fn(data):
+ return mx.nd.concat(*data, dim=0)
+
+ data_loader = mx.gluon.data.DataLoader(
+ dataset, 128, last_batch='keep', num_workers=4,
+ thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
+ num_iter = 0
+ for batch in data_loader:
+ batch = batch.asnumpy()
+ batch = (batch - model.input_mean) / model.input_std
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
+ feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
+ num_iter += 1
+ if num_iter % 50 == 0:
+ print(num_iter)
+ return feat_mat
+
+
+def read_template_media_list(path):
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+def image2template_feature(img_feats=None,
+ templates=None,
+ medias=None):
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+ for count_template, uqt in enumerate(unique_templates):
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
+ media_norm_feats = np.array(media_norm_feats)
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ template_norm_feats = normalize(template_feats)
+ return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),))
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000
+ sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def main(args):
+ use_norm_score = True # if Ture, TestMode(N1)
+ use_detector_score = True # if Ture, TestMode(D1)
+ use_flip_test = True # if Ture, TestMode(F1)
+ assert args.target == 'IJBC' or args.target == 'IJBB'
+
+ start = timeit.default_timer()
+ templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % args.image_path,
+ '%s_template_pair_label.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ img_path = '%s/loose_crop' % args.image_path
+ img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
+ img_list = open(img_list_path)
+ files = img_list.readlines()
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+ img_feats = extract(args.model_root, dataset)
+
+ faceness_scores = []
+ for each_line in files:
+ name_lmk_score = each_line.split()
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
+ start = timeit.default_timer()
+
+ if use_flip_test:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
+ else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+ if use_norm_score:
+ img_input_feats = img_input_feats
+ else:
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+ if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+ else:
+ img_input_feats = img_input_feats
+
+ template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ score = verification(template_norm_feats, unique_templates, p1, p2)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
+ np.save(score_save_file, score)
+ files = [score_save_file]
+ methods = []
+ scores = []
+ for file in files:
+ methods.append(os.path.basename(file))
+ scores.append(np.load(file))
+ methods = np.array(methods)
+ scores = dict(zip(methods, scores))
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+ tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
+ for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr)
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+ print(tpr_fpr_table)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='do ijb test')
+ # general
+ parser.add_argument('--model-root', default='', help='path to load model.')
+ parser.add_argument('--image-path', default='', type=str, help='')
+ parser.add_argument('--result-dir', default='.', type=str, help='')
+ parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+ main(parser.parse_args())
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/partial_fc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0286dd437319c920ecb61f4eb3a32333dcf49eb
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/partial_fc.py
@@ -0,0 +1,222 @@
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from torch.nn import Module
+from torch.nn.functional import normalize, linear
+from torch.nn.parameter import Parameter
+
+
+class PartialFC(Module):
+ """
+ Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
+ Partial FC: Training 10 Million Identities on a Single Machine
+ See the original paper:
+ https://arxiv.org/abs/2010.05222
+ """
+
+ @torch.no_grad()
+ def __init__(self, rank, local_rank, world_size, batch_size, resume,
+ margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
+ """
+ rank: int
+ Unique process(GPU) ID from 0 to world_size - 1.
+ local_rank: int
+ Unique process(GPU) ID within the server from 0 to 7.
+ world_size: int
+ Number of GPU.
+ batch_size: int
+ Batch size on current rank(GPU).
+ resume: bool
+ Select whether to restore the weight of softmax.
+ margin_softmax: callable
+ A function of margin softmax, eg: cosface, arcface.
+ num_classes: int
+ The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
+ required.
+ sample_rate: float
+ The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
+ can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
+ embedding_size: int
+ The feature dimension, default is 512.
+ prefix: str
+ Path for save checkpoint, default is './'.
+ """
+ super(PartialFC, self).__init__()
+ #
+ self.num_classes: int = num_classes
+ self.rank: int = rank
+ self.local_rank: int = local_rank
+ self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
+ self.world_size: int = world_size
+ self.batch_size: int = batch_size
+ self.margin_softmax: callable = margin_softmax
+ self.sample_rate: float = sample_rate
+ self.embedding_size: int = embedding_size
+ self.prefix: str = prefix
+ self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
+ self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+
+ self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
+ self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
+
+ if resume:
+ try:
+ self.weight: torch.Tensor = torch.load(self.weight_name)
+ self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
+ if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
+ raise IndexError
+ logging.info("softmax weight resume successfully!")
+ logging.info("softmax weight mom resume successfully!")
+ except (FileNotFoundError, KeyError, IndexError):
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+ logging.info("softmax weight init!")
+ logging.info("softmax weight mom init!")
+ else:
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+ logging.info("softmax weight init successfully!")
+ logging.info("softmax weight mom init successfully!")
+ self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
+
+ self.index = None
+ if int(self.sample_rate) == 1:
+ self.update = lambda: 0
+ self.sub_weight = Parameter(self.weight)
+ self.sub_weight_mom = self.weight_mom
+ else:
+ self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
+
+ def save_params(self):
+ """ Save softmax weight for each rank on prefix
+ """
+ torch.save(self.weight.data, self.weight_name)
+ torch.save(self.weight_mom, self.weight_mom_name)
+
+ @torch.no_grad()
+ def sample(self, total_label):
+ """
+ Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
+ `num_sample`.
+
+ total_label: tensor
+ Label after all gather, which cross all GPUs.
+ """
+ index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
+ total_label[~index_positive] = -1
+ total_label[index_positive] -= self.class_start
+ if int(self.sample_rate) != 1:
+ positive = torch.unique(total_label[index_positive], sorted=True)
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local], device=self.device)
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1]
+ index = index.sort()[0]
+ else:
+ index = positive
+ self.index = index
+ total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
+ self.sub_weight = Parameter(self.weight[index])
+ self.sub_weight_mom = self.weight_mom[index]
+
+ def forward(self, total_features, norm_weight):
+ """ Partial fc forward, `logits = X * sample(W)`
+ """
+ torch.cuda.current_stream().wait_stream(self.stream)
+ logits = linear(total_features, norm_weight)
+ return logits
+
+ @torch.no_grad()
+ def update(self):
+ """ Set updated weight and weight_mom to memory bank.
+ """
+ self.weight_mom[self.index] = self.sub_weight_mom
+ self.weight[self.index] = self.sub_weight
+
+ def prepare(self, label, optimizer):
+ """
+ get sampled class centers for cal softmax.
+
+ label: tensor
+ Label tensor on each rank.
+ optimizer: opt
+ Optimizer for partial fc, which need to get weight mom.
+ """
+ with torch.cuda.stream(self.stream):
+ total_label = torch.zeros(
+ size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
+ dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
+ self.sample(total_label)
+ optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
+ optimizer.param_groups[-1]['params'][0] = self.sub_weight
+ optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
+ norm_weight = normalize(self.sub_weight)
+ return total_label, norm_weight
+
+ def forward_backward(self, label, features, optimizer):
+ """
+ Partial fc forward and backward with model parallel
+
+ label: tensor
+ Label tensor on each rank(GPU)
+ features: tensor
+ Features tensor on each rank(GPU)
+ optimizer: optimizer
+ Optimizer for partial fc
+
+ Returns:
+ --------
+ x_grad: tensor
+ The gradient of features.
+ loss_v: tensor
+ Loss value for cross entropy.
+ """
+ total_label, norm_weight = self.prepare(label, optimizer)
+ total_features = torch.zeros(
+ size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
+ dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
+ total_features.requires_grad = True
+
+ logits = self.forward(total_features, norm_weight)
+ logits = self.margin_softmax(logits, total_label)
+
+ with torch.no_grad():
+ max_fc = torch.max(logits, dim=1, keepdim=True)[0]
+ dist.all_reduce(max_fc, dist.ReduceOp.MAX)
+
+ # calculate exp(logits) and all-reduce
+ logits_exp = torch.exp(logits - max_fc)
+ logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
+ dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
+
+ # calculate prob
+ logits_exp.div_(logits_sum_exp)
+
+ # get one-hot
+ grad = logits_exp
+ index = torch.where(total_label != -1)[0]
+ one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
+ one_hot.scatter_(1, total_label[index, None], 1)
+
+ # calculate loss
+ loss = torch.zeros(grad.size()[0], 1, device=grad.device)
+ loss[index] = grad[index].gather(1, total_label[index, None])
+ dist.all_reduce(loss, dist.ReduceOp.SUM)
+ loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ # calculate grad
+ grad[index] -= one_hot
+ grad.div_(self.batch_size * self.world_size)
+
+ logits.backward(grad)
+ if total_features.grad is not None:
+ total_features.grad.detach_()
+ x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
+ # feature gradient all-reduce
+ dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
+ x_grad = x_grad * self.world_size
+ # backward backbone
+ return x_grad, loss_v
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/requirement.txt b/sadtalker_video2pose/src/face3d/models/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..99aef673e30b99cbe56ce82a564c1df9df24ba21
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/requirement.txt
@@ -0,0 +1,5 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/run.sh b/sadtalker_video2pose/src/face3d/models/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..67b25fd63ef3921733d81d5be844aacc5a5c84ed
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/run.sh
@@ -0,0 +1,2 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/torch2onnx.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..458660df7cc7f9a567aaf492c45f232e776a9ef0
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/torch2onnx.py
@@ -0,0 +1,59 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+ assert isinstance(net, torch.nn.Module)
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = img.astype(np.float)
+ img = (img / 255. - 0.5) / 0.5 # torch style norm
+ img = img.transpose((2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+
+ weight = torch.load(path_module)
+ net.load_state_dict(weight)
+ net.eval()
+ torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
+ model = onnx.load(output)
+ graph = model.graph
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ if simplify:
+ from onnxsim import simplify
+ model, check = simplify(model)
+ assert check, "Simplified ONNX model could not be validated"
+ onnx.save(model, output)
+
+
+if __name__ == '__main__':
+ import os
+ import argparse
+ from backbones import get_model
+
+ parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
+ parser.add_argument('input', type=str, help='input backbone.pth file or path')
+ parser.add_argument('--output', type=str, default=None, help='output onnx path')
+ parser.add_argument('--network', type=str, default=None, help='backbone network')
+ parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
+ args = parser.parse_args()
+ input_file = args.input
+ if os.path.isdir(input_file):
+ input_file = os.path.join(input_file, "backbone.pth")
+ assert os.path.exists(input_file)
+ model_name = os.path.basename(os.path.dirname(input_file)).lower()
+ params = model_name.split("_")
+ if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+ if args.network is None:
+ args.network = params[2]
+ assert args.network is not None
+ print(args)
+ backbone_onnx = get_model(args.network, dropout=0)
+
+ output_path = args.output
+ if output_path is None:
+ output_path = os.path.join(os.path.dirname(__file__), 'onnx')
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+ assert os.path.isdir(output_path)
+ output_file = os.path.join(output_path, "%s.onnx" % model_name)
+ convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/train.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5491de9af8fc7a2f3d0648c53b89584864f20e
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/train.py
@@ -0,0 +1,141 @@
+import argparse
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torch.utils.data.distributed
+from torch.nn.utils import clip_grad_norm_
+
+import losses
+from backbones import get_model
+from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
+from partial_fc import PartialFC
+from utils.utils_amp import MaxClipGradScaler
+from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
+from utils.utils_config import get_config
+from utils.utils_logging import AverageMeter, init_logging
+
+
+def main(args):
+ cfg = get_config(args.config)
+ try:
+ world_size = int(os.environ['WORLD_SIZE'])
+ rank = int(os.environ['RANK'])
+ dist.init_process_group('nccl')
+ except KeyError:
+ world_size = 1
+ rank = 0
+ dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
+
+ local_rank = args.local_rank
+ torch.cuda.set_device(local_rank)
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ if cfg.rec == "synthetic":
+ train_set = SyntheticDataset(local_rank=local_rank)
+ else:
+ train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
+
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
+ train_loader = DataLoaderX(
+ local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
+ sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
+
+ if cfg.resume:
+ try:
+ backbone_pth = os.path.join(cfg.output, "backbone.pth")
+ backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
+ if rank == 0:
+ logging.info("backbone resume successfully!")
+ except (FileNotFoundError, KeyError, IndexError, RuntimeError):
+ if rank == 0:
+ logging.info("resume fail, backbone init successfully!")
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank])
+ backbone.train()
+ margin_softmax = losses.get_loss(cfg.loss)
+ module_partial_fc = PartialFC(
+ rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
+ batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
+ sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
+
+ opt_backbone = torch.optim.SGD(
+ params=[{'params': backbone.parameters()}],
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
+ momentum=0.9, weight_decay=cfg.weight_decay)
+ opt_pfc = torch.optim.SGD(
+ params=[{'params': module_partial_fc.parameters()}],
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
+ momentum=0.9, weight_decay=cfg.weight_decay)
+
+ num_image = len(train_set)
+ total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
+ cfg.total_step = num_image // total_batch_size * cfg.num_epoch
+
+ def lr_step_func(current_step):
+ cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
+ if current_step < cfg.warmup_step:
+ return current_step / cfg.warmup_step
+ else:
+ return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
+
+ scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
+ optimizer=opt_backbone, lr_lambda=lr_step_func)
+ scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
+ optimizer=opt_pfc, lr_lambda=lr_step_func)
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ val_target = cfg.val_targets
+ callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
+ callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
+ callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
+
+ loss = AverageMeter()
+ start_epoch = 0
+ global_step = 0
+ grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
+ for epoch in range(start_epoch, cfg.num_epoch):
+ train_sampler.set_epoch(epoch)
+ for step, (img, label) in enumerate(train_loader):
+ global_step += 1
+ features = F.normalize(backbone(img))
+ x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
+ if cfg.fp16:
+ features.backward(grad_amp.scale(x_grad))
+ grad_amp.unscale_(opt_backbone)
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+ grad_amp.step(opt_backbone)
+ grad_amp.update()
+ else:
+ features.backward(x_grad)
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+ opt_backbone.step()
+
+ opt_pfc.step()
+ module_partial_fc.update()
+ opt_backbone.zero_grad()
+ opt_pfc.zero_grad()
+ loss.update(loss_v, 1)
+ callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
+ callback_verification(global_step, backbone)
+ scheduler_backbone.step()
+ scheduler_pfc.step()
+ callback_checkpoint(global_step, backbone, module_partial_fc)
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('config', type=str, help='py config file')
+ parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
+ main(parser.parse_args())
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/plot.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fce6cc0ae526d5aebc8e7a1550300ceae3a2034
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/plot.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+
+import os
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import roc_curve, auc
+
+image_path = "/data/anxiang/IJB_release/IJBC"
+files = [
+ "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
+]
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % 'ijbc'))
+
+methods = []
+scores = []
+for file in files:
+ methods.append(file.split('/')[-2])
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_amp.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_amp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d5bcbb540ff8b04535e71c0057e124338df5bd
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_amp.py
@@ -0,0 +1,88 @@
+from typing import Dict, List
+
+import torch
+
+if torch.__version__ < '1.9':
+ Iterable = torch._six.container_abcs.Iterable
+else:
+ import collections
+
+ Iterable = collections.abc.Iterable
+from torch.cuda.amp import GradScaler
+
+
+class _MultiDeviceReplicator(object):
+ """
+ Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
+ """
+
+ def __init__(self, master_tensor: torch.Tensor) -> None:
+ assert master_tensor.is_cuda
+ self.master = master_tensor
+ self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+ def get(self, device) -> torch.Tensor:
+ retval = self._per_device_tensors.get(device, None)
+ if retval is None:
+ retval = self.master.to(device=device, non_blocking=True, copy=True)
+ self._per_device_tensors[device] = retval
+ return retval
+
+
+class MaxClipGradScaler(GradScaler):
+ def __init__(self, init_scale, max_scale: float, growth_interval=100):
+ GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
+ self.max_scale = max_scale
+
+ def scale_clip(self):
+ if self.get_scale() == self.max_scale:
+ self.set_growth_factor(1)
+ elif self.get_scale() < self.max_scale:
+ self.set_growth_factor(2)
+ elif self.get_scale() > self.max_scale:
+ self._scale.fill_(self.max_scale)
+ self.set_growth_factor(1)
+
+ def scale(self, outputs):
+ """
+ Multiplies ('scales') a tensor or list of tensors by the scale factor.
+
+ Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
+ unmodified.
+
+ Arguments:
+ outputs (Tensor or iterable of Tensors): Outputs to scale.
+ """
+ if not self._enabled:
+ return outputs
+ self.scale_clip()
+ # Short-circuit for the common case.
+ if isinstance(outputs, torch.Tensor):
+ assert outputs.is_cuda
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(outputs.device)
+ assert self._scale is not None
+ return outputs * self._scale.to(device=outputs.device, non_blocking=True)
+
+ # Invoke the more complex machinery only if we're treating multiple outputs.
+ stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
+
+ def apply_scale(val):
+ if isinstance(val, torch.Tensor):
+ assert val.is_cuda
+ if len(stash) == 0:
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(val.device)
+ assert self._scale is not None
+ stash.append(_MultiDeviceReplicator(self._scale))
+ return val * stash[0].get(val.device)
+ elif isinstance(val, Iterable):
+ iterable = map(apply_scale, val)
+ if isinstance(val, list) or isinstance(val, tuple):
+ return type(val)(iterable)
+ else:
+ return iterable
+ else:
+ raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+ return apply_scale(outputs)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..748923b36358bd118efa0532a6f512b6ca96ff34
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,117 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+from eval import verification
+from utils.utils_logging import AverageMeter
+
+
+class CallBackVerification(object):
+ def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
+ self.frequent: int = frequent
+ self.rank: int = rank
+ self.highest_acc: float = 0.0
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+ self.ver_list: List[object] = []
+ self.ver_name_list: List[str] = []
+ if self.rank is 0:
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
+ results = []
+ for i in range(len(self.ver_list)):
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
+ self.ver_list[i], backbone, 10, 10)
+ logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
+ logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
+ if acc2 > self.highest_acc_list[i]:
+ self.highest_acc_list[i] = acc2
+ logging.info(
+ '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
+ results.append(acc2)
+
+ def init_dataset(self, val_targets, data_dir, image_size):
+ for name in val_targets:
+ path = os.path.join(data_dir, name + ".bin")
+ if os.path.exists(path):
+ data_set = verification.load_bin(path, image_size)
+ self.ver_list.append(data_set)
+ self.ver_name_list.append(name)
+
+ def __call__(self, num_update, backbone: torch.nn.Module):
+ if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
+ backbone.eval()
+ self.ver_test(backbone, num_update)
+ backbone.train()
+
+
+class CallBackLogging(object):
+ def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
+ self.frequent: int = frequent
+ self.rank: int = rank
+ self.time_start = time.time()
+ self.total_step: int = total_step
+ self.batch_size: int = batch_size
+ self.world_size: int = world_size
+ self.writer = writer
+
+ self.init = False
+ self.tic = 0
+
+ def __call__(self,
+ global_step: int,
+ loss: AverageMeter,
+ epoch: int,
+ fp16: bool,
+ learning_rate: float,
+ grad_scaler: torch.cuda.amp.GradScaler):
+ if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+ if self.init:
+ try:
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+ speed_total = speed * self.world_size
+ except ZeroDivisionError:
+ speed_total = float('inf')
+
+ time_now = (time.time() - self.time_start) / 3600
+ time_total = time_now / ((global_step + 1) / self.total_step)
+ time_for_end = time_total - time_now
+ if self.writer is not None:
+ self.writer.add_scalar('time_for_end', time_for_end, global_step)
+ self.writer.add_scalar('learning_rate', learning_rate, global_step)
+ self.writer.add_scalar('loss', loss.avg, global_step)
+ if fp16:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
+ "Fp16 Grad Scale: %2.f Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step,
+ grad_scaler.get_scale(), time_for_end
+ )
+ else:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
+ "Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
+ )
+ logging.info(msg)
+ loss.reset()
+ self.tic = time.time()
+ else:
+ self.init = True
+ self.tic = time.time()
+
+
+class CallBackModelCheckpoint(object):
+ def __init__(self, rank, output="./"):
+ self.rank: int = rank
+ self.output: str = output
+
+ def __call__(self, global_step, backbone, partial_fc, ):
+ if global_step > 100 and self.rank == 0:
+ path_module = os.path.join(self.output, "backbone.pth")
+ torch.save(backbone.module.state_dict(), path_module)
+ logging.info("Pytorch Model Saved in '{}'".format(path_module))
+
+ if global_step > 100 and partial_fc is not None:
+ partial_fc.save_params()
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_config.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60a1e5a2e860ce5511a2d3863c8b57a4df292d7
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+ assert config_file.startswith('configs/'), 'config file setting must start with configs/'
+ temp_config_name = osp.basename(config_file)
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ config = importlib.import_module("configs.base")
+ cfg = config.config
+ config = importlib.import_module("configs.%s" % temp_module_name)
+ job_cfg = config.config
+ cfg.update(job_cfg)
+ if cfg.output is None:
+ cfg.output = osp.join('work_dirs', temp_module_name)
+ return cfg
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_logging.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b43b851c9e06230abd94c73a1f64cfa1b6f3ac
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,41 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value
+ """
+
+ def __init__(self):
+ self.val = None
+ self.avg = None
+ self.sum = None
+ self.count = None
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+ if rank == 0:
+ log_root = logging.getLogger()
+ log_root.setLevel(logging.INFO)
+ formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+ handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+ handler_stream = logging.StreamHandler(sys.stdout)
+ handler_file.setFormatter(formatter)
+ handler_stream.setFormatter(formatter)
+ log_root.addHandler(handler_file)
+ log_root.addHandler(handler_stream)
+ log_root.info('rank_id: %d' % rank)
diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_os.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_os.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sadtalker_video2pose/src/face3d/models/base_model.py b/sadtalker_video2pose/src/face3d/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b975223f6148febfe32d20d63980583c97b61eb3
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/base_model.py
@@ -0,0 +1,316 @@
+"""This script defines the base network model for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.isTrain = False
+ self.device = torch.device('cpu')
+ self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.parallel_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def dict_grad_hook_factory(add_func=lambda x: x):
+ saved_dict = dict()
+
+ def hook_gen(name):
+ def grad_hook(grad):
+ saved_vals = add_func(grad)
+ saved_dict[name] = saved_vals
+ return grad_hook
+ return hook_gen, saved_dict
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+
+ if not self.isTrain or opt.continue_train:
+ load_suffix = opt.epoch
+ self.load_networks(load_suffix)
+
+
+ # self.print_networks(opt.verbose)
+
+ def parallelize(self, convert_sync_batchnorm=True):
+ if not self.opt.use_ddp:
+ for name in self.parallel_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+ else:
+ for name in self.model_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ if convert_sync_batchnorm:
+ module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
+ setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
+ device_ids=[self.device.index],
+ find_unused_parameters=True, broadcast_buffers=True))
+
+ # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
+ for name in self.parallel_names:
+ if isinstance(name, str) and name not in self.model_names:
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+
+ # put state_dict of optimizer to gpu device
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ for optim in self.optimizers:
+ for state in optim.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(self.device)
+
+ def data_dependent_initialize(self, data):
+ pass
+
+ def train(self):
+ """Make models train mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.train()
+
+ def eval(self):
+ """Make models eval mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self, name='A'):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths if name =='A' else self.image_paths_B
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)[:, :3, ...]
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if not os.path.isdir(self.save_dir):
+ os.makedirs(self.save_dir)
+
+ save_filename = 'epoch_%s.pth' % (epoch)
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ save_dict = {}
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net,
+ torch.nn.parallel.DistributedDataParallel):
+ net = net.module
+ save_dict[name] = net.state_dict()
+
+
+ for i, optim in enumerate(self.optimizers):
+ save_dict['opt_%02d'%i] = optim.state_dict()
+
+ for i, sched in enumerate(self.schedulers):
+ save_dict['sched_%02d'%i] = sched.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+ else:
+ load_dir = self.save_dir
+ load_filename = 'epoch_%s.pth' % (epoch)
+ load_path = os.path.join(load_dir, load_filename)
+ state_dict = torch.load(load_path, map_location=self.device)
+ print('loading the model from %s' % load_path)
+
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ net.load_state_dict(state_dict[name])
+
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ print('loading the optim from %s' % load_path)
+ for i, optim in enumerate(self.optimizers):
+ optim.load_state_dict(state_dict['opt_%02d'%i])
+
+ try:
+ print('loading the sched from %s' % load_path)
+ for i, sched in enumerate(self.schedulers):
+ sched.load_state_dict(state_dict['sched_%02d'%i])
+ except:
+ print('Failed to load schedulers, set schedulers according to epoch count manually')
+ for i, sched in enumerate(self.schedulers):
+ sched.last_epoch = self.opt.epoch_count - 1
+
+
+
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ return {}
diff --git a/sadtalker_video2pose/src/face3d/models/bfm.py b/sadtalker_video2pose/src/face3d/models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cecaf589befac790cf9c124737ba01e27bc29e6
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/bfm.py
@@ -0,0 +1,331 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+from src.face3d.util.load_mats import transferBFM09
+import os
+
+def perspective_projection(focal, center):
+ # return p.T (N, 3) @ (3, 3)
+ return np.array([
+ focal, 0, center,
+ 0, focal, center,
+ 0, 0, 1
+ ]).reshape([3, 3]).astype(np.float32).transpose()
+
+class SH:
+ def __init__(self):
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
+
+
+
+class ParametricFaceModel:
+ def __init__(self,
+ bfm_folder='./BFM',
+ recenter=True,
+ camera_distance=10.,
+ init_lit=np.array([
+ 0.8, 0, 0, 0, 0, 0, 0, 0, 0
+ ]),
+ focal=1015.,
+ center=112.,
+ is_train=True,
+ default_name='BFM_model_front.mat'):
+
+ if not os.path.isfile(os.path.join(bfm_folder, default_name)):
+ transferBFM09(bfm_folder)
+
+ model = loadmat(os.path.join(bfm_folder, default_name))
+ # mean face shape. [3*N,1]
+ self.mean_shape = model['meanshape'].astype(np.float32)
+ # identity basis. [3*N,80]
+ self.id_base = model['idBase'].astype(np.float32)
+ # expression basis. [3*N,64]
+ self.exp_base = model['exBase'].astype(np.float32)
+ # mean face texture. [3*N,1] (0-255)
+ self.mean_tex = model['meantex'].astype(np.float32)
+ # texture basis. [3*N,80]
+ self.tex_base = model['texBase'].astype(np.float32)
+ # face indices for each vertex that lies in. starts from 0. [N,8]
+ self.point_buf = model['point_buf'].astype(np.int64) - 1
+ # vertex indices for each face. starts from 0. [F,3]
+ self.face_buf = model['tri'].astype(np.int64) - 1
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
+ self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
+
+ if is_train:
+ # vertex indices for small face region to compute photometric error. starts from 0.
+ self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
+ # vertex indices for each face from small face region. starts from 0. [f,3]
+ self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
+ # vertex indices for pre-defined skin region to compute reflectance loss
+ self.skin_mask = np.squeeze(model['skinmask'])
+
+ if recenter:
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+
+ self.persc_proj = perspective_projection(focal, center)
+ self.device = 'cpu'
+ self.camera_distance = camera_distance
+ self.SH = SH()
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+
+
+ def to(self, device):
+ self.device = device
+ for key, value in self.__dict__.items():
+ if type(value).__module__ == np.__name__:
+ setattr(self, key, torch.tensor(value).to(device))
+
+
+ def compute_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+
+ def compute_texture(self, tex_coeff, normalize=True):
+ """
+ Return:
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+ Parameters:
+ tex_coeff -- torch.tensor, size (B, 80)
+ """
+ batch_size = tex_coeff.shape[0]
+ face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
+ if normalize:
+ face_texture = face_texture / 255.
+ return face_texture.reshape([batch_size, -1, 3])
+
+
+ def compute_norm(self, face_shape):
+ """
+ Return:
+ vertex_norm -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+
+ v1 = face_shape[:, self.face_buf[:, 0]]
+ v2 = face_shape[:, self.face_buf[:, 1]]
+ v3 = face_shape[:, self.face_buf[:, 2]]
+ e1 = v1 - v2
+ e2 = v2 - v3
+ face_norm = torch.cross(e1, e2, dim=-1)
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+ return vertex_norm
+
+
+ def compute_color(self, face_texture, face_norm, gamma):
+ """
+ Return:
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+ Parameters:
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
+ gamma -- torch.tensor, size (B, 27), SH coeffs
+ """
+ batch_size = gamma.shape[0]
+ v_num = face_texture.shape[1]
+ a, c = self.SH.a, self.SH.c
+ gamma = gamma.reshape([batch_size, 3, 9])
+ gamma = gamma + self.init_lit
+ gamma = gamma.permute(0, 2, 1)
+ Y = torch.cat([
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+ -a[1] * c[1] * face_norm[..., 1:2],
+ a[1] * c[1] * face_norm[..., 2:],
+ -a[1] * c[1] * face_norm[..., :1],
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
+ ], dim=-1)
+ r = Y @ gamma[..., :1]
+ g = Y @ gamma[..., 1:2]
+ b = Y @ gamma[..., 2:]
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
+ return face_color
+
+
+ def compute_rotation(self, angles):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ angles -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = angles.shape[0]
+ ones = torch.ones([batch_size, 1]).to(self.device)
+ zeros = torch.zeros([batch_size, 1]).to(self.device)
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+
+ def to_camera(self, face_shape):
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
+ return face_shape
+
+ def to_image(self, face_shape):
+ """
+ Return:
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+ # to image_plane
+ face_proj = face_shape @ self.persc_proj
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+ return face_proj
+
+
+ def transform(self, face_shape, rot, trans):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ rot -- torch.tensor, size (B, 3, 3)
+ trans -- torch.tensor, size (B, 3)
+ """
+ return face_shape @ rot + trans.unsqueeze(1)
+
+
+ def get_landmarks(self, face_proj):
+ """
+ Return:
+ face_lms -- torch.tensor, size (B, 68, 2)
+
+ Parameters:
+ face_proj -- torch.tensor, size (B, N, 2)
+ """
+ return face_proj[:, self.keypoints]
+
+ def split_coeff(self, coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+ def compute_for_render(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ rotation = self.compute_rotation(coef_dict['angle'])
+
+
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+ def compute_for_render_woRotation(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ #rotation = self.compute_rotation(coef_dict['angle'])
+
+
+ #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm # @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+
+if __name__ == '__main__':
+ transferBFM09()
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/face3d/models/facerecon_model.py b/sadtalker_video2pose/src/face3d/models/facerecon_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..58a836a45a05fa192591cca5cf684783a6fb8533
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/facerecon_model.py
@@ -0,0 +1,220 @@
+"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+from src.face3d.models.base_model import BaseModel
+from src.face3d.models import networks
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
+from src.face3d.util import util
+from src.face3d.util.nvdiffrast import MeshRenderer
+# from src.face3d.util.preprocess import estimate_norm_torch
+
+import trimesh
+from scipy.io import savemat
+
+class FaceReconModel(BaseModel):
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=False):
+ """ Configures options specific for CUT model
+ """
+ # net structure and parameters
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
+ parser.add_argument('--init_path', type=str, default='./ckpts/sad_talkers/init_model/resnet50-0676ba61.pth')
+ parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
+ parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talkers/BFM_Fitting/')
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+ # renderer parameters
+ parser.add_argument('--focal', type=float, default=1015.)
+ parser.add_argument('--center', type=float, default=112.)
+ parser.add_argument('--camera_d', type=float, default=10.)
+ parser.add_argument('--z_near', type=float, default=5.)
+ parser.add_argument('--z_far', type=float, default=15.)
+
+ if is_train:
+ # training parameters
+ parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
+ parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
+ parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
+ parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
+
+
+ # augmentation parameters
+ parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
+ parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
+ parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
+
+ # loss weights
+ parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
+ parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
+ parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
+ parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
+ parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
+ parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
+ parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
+ parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
+ parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
+
+ opt, _ = parser.parse_known_args()
+ parser.set_defaults(
+ focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
+ )
+ if is_train:
+ parser.set_defaults(
+ use_crop_face=True, use_predef_M=False
+ )
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+
+ self.visual_names = ['output_vis']
+ self.model_names = ['net_recon']
+ self.parallel_names = self.model_names + ['renderer']
+
+ self.facemodel = ParametricFaceModel(
+ bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
+ is_train=self.isTrain, default_name=opt.bfm_model
+ )
+
+ fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
+ self.renderer = MeshRenderer(
+ rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
+ )
+
+ if self.isTrain:
+ self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
+
+ self.net_recog = networks.define_net_recog(
+ net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
+ )
+ # loss func name: (compute_%s_loss) % loss_name
+ self.compute_feat_loss = perceptual_loss
+ self.comupte_color_loss = photo_loss
+ self.compute_lm_loss = landmark_loss
+ self.compute_reg_loss = reg_loss
+ self.compute_reflc_loss = reflectance_loss
+
+ self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
+ self.optimizers = [self.optimizer]
+ self.parallel_names += ['net_recog']
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ self.input_img = input['imgs'].to(self.device)
+ self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
+ self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
+ self.trans_m = input['M'].to(self.device) if 'M' in input else None
+ self.image_paths = input['im_paths'] if 'im_paths' in input else None
+
+ def forward(self, output_coeff, device):
+ self.facemodel.to(device)
+ self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
+ self.facemodel.compute_for_render(output_coeff)
+ self.pred_mask, _, self.pred_face = self.renderer(
+ self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
+
+ self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
+
+
+ def compute_losses(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+
+ assert self.net_recog.training == False
+ trans_m = self.trans_m
+ if not self.opt.use_predef_M:
+ trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
+
+ pred_feat = self.net_recog(self.pred_face, trans_m)
+ gt_feat = self.net_recog(self.input_img, self.trans_m)
+ self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
+
+ face_mask = self.pred_mask
+ if self.opt.use_crop_face:
+ face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
+
+ face_mask = face_mask.detach()
+ self.loss_color = self.opt.w_color * self.comupte_color_loss(
+ self.pred_face, self.input_img, self.atten_mask * face_mask)
+
+ loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
+ self.loss_reg = self.opt.w_reg * loss_reg
+ self.loss_gamma = self.opt.w_gamma * loss_gamma
+
+ self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
+
+ self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
+
+ self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+ + self.loss_lm + self.loss_reflc
+
+
+ def optimize_parameters(self, isTrain=True):
+ self.forward()
+ self.compute_losses()
+ """Update network weights; it will be called in every training iteration."""
+ if isTrain:
+ self.optimizer.zero_grad()
+ self.loss_all.backward()
+ self.optimizer.step()
+
+ def compute_visuals(self):
+ with torch.no_grad():
+ input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
+ output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
+
+ if self.gt_lm is not None:
+ gt_lm_numpy = self.gt_lm.cpu().numpy()
+ pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
+
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw, output_vis_numpy), axis=-2)
+ else:
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw), axis=-2)
+
+ self.output_vis = torch.tensor(
+ output_vis_numpy / 255., dtype=torch.float32
+ ).permute(0, 3, 1, 2).to(self.device)
+
+ def save_mesh(self, name):
+
+ recon_shape = self.pred_vertex # get reconstructed shape
+ recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
+ recon_shape = recon_shape.cpu().numpy()[0]
+ recon_color = self.pred_color
+ recon_color = recon_color.cpu().numpy()[0]
+ tri = self.facemodel.face_buf.cpu().numpy()
+ mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
+ mesh.export(name)
+
+ def save_coeff(self,name):
+
+ pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
+ pred_lm = self.pred_lm.cpu().numpy()
+ pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
+ pred_coeffs['lm68'] = pred_lm
+ savemat(name,pred_coeffs)
+
+
+
diff --git a/sadtalker_video2pose/src/face3d/models/losses.py b/sadtalker_video2pose/src/face3d/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d9da84f28d54e772bebd2385ae5a7fedd10f7d
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/losses.py
@@ -0,0 +1,113 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from kornia.geometry import warp_affine
+import torch.nn.functional as F
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+### perceptual level loss
+class PerceptualLoss(nn.Module):
+ def __init__(self, recog_net, input_size=112):
+ super(PerceptualLoss, self).__init__()
+ self.recog_net = recog_net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+ def forward(imageA, imageB, M):
+ """
+ 1 - cosine distance
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
+ imageB --same as imageA
+ """
+
+ imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
+ imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
+
+ # freeze bn
+ self.recog_net.eval()
+
+ id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
+ id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+def perceptual_loss(id_featureA, id_featureB):
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+### image level loss
+def photo_loss(imageA, imageB, mask, eps=1e-6):
+ """
+ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
+ imageB --same as imageA
+ """
+ loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
+ loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
+ return loss
+
+def landmark_loss(predict_lm, gt_lm, weight=None):
+ """
+ weighted mse loss
+ Parameters:
+ predict_lm --torch.tensor (B, 68, 2)
+ gt_lm --torch.tensor (B, 68, 2)
+ weight --numpy.array (1, 68)
+ """
+ if not weight:
+ weight = np.ones([68])
+ weight[28:31] = 20
+ weight[-8:] = 20
+ weight = np.expand_dims(weight, 0)
+ weight = torch.tensor(weight).to(predict_lm.device)
+ loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
+ loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
+ return loss
+
+
+### regulization
+def reg_loss(coeffs_dict, opt=None):
+ """
+ l2 norm without the sqrt, from yu's implementation (mse)
+ tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
+ Parameters:
+ coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
+
+ """
+ # coefficient regularization to ensure plausible 3d faces
+ if opt:
+ w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
+ else:
+ w_id, w_exp, w_tex = 1, 1, 1, 1
+ creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
+ w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
+ w_tex * torch.sum(coeffs_dict['tex'] ** 2)
+ creg_loss = creg_loss / coeffs_dict['id'].shape[0]
+
+ # gamma regularization to ensure a nearly-monochromatic light
+ gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
+ gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
+ gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
+
+ return creg_loss, gamma_loss
+
+def reflectance_loss(texture, mask):
+ """
+ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
+ Parameters:
+ texture --torch.tensor, (B, N, 3)
+ mask --torch.tensor, (N), 1 or 0
+
+ """
+ mask = mask.reshape([1, mask.shape[0], 1])
+ texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
+ loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
+ return loss
+
diff --git a/sadtalker_video2pose/src/face3d/models/networks.py b/sadtalker_video2pose/src/face3d/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e69eba1ade2e6431e7e7fd526ea68b8f63e7152
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/networks.py
@@ -0,0 +1,521 @@
+"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import torch
+from torch import Tensor
+import torch.nn as nn
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+from .arcface_torch.backbones import get_model
+from kornia.geometry import warp_affine
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+def filter_state_dict(state_dict, remove_name='fc'):
+ new_state_dict = {}
+ for key in state_dict:
+ if remove_name in key:
+ continue
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def define_net_recon(net_recon, use_last_fc=False, init_path=None):
+ return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
+
+def define_net_recog(net_recog, pretrained_path=None):
+ net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
+ net.eval()
+ return net
+
+class ReconNetWrapper(nn.Module):
+ fc_dim=257
+ def __init__(self, net_recon, use_last_fc=False, init_path=None):
+ super(ReconNetWrapper, self).__init__()
+ self.use_last_fc = use_last_fc
+ if net_recon not in func_dict:
+ return NotImplementedError('network [%s] is not implemented', net_recon)
+ func, last_dim = func_dict[net_recon]
+ backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
+ if init_path and os.path.isfile(init_path):
+ state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
+ backbone.load_state_dict(state_dict)
+ print("loading init net_recon %s from %s" %(net_recon, init_path))
+ self.backbone = backbone
+ if not use_last_fc:
+ self.final_layers = nn.ModuleList([
+ conv1x1(last_dim, 80, bias=True), # id layer
+ conv1x1(last_dim, 64, bias=True), # exp layer
+ conv1x1(last_dim, 80, bias=True), # tex layer
+ conv1x1(last_dim, 3, bias=True), # angle layer
+ conv1x1(last_dim, 27, bias=True), # gamma layer
+ conv1x1(last_dim, 2, bias=True), # tx, ty
+ conv1x1(last_dim, 1, bias=True) # tz
+ ])
+ for m in self.final_layers:
+ nn.init.constant_(m.weight, 0.)
+ nn.init.constant_(m.bias, 0.)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ if not self.use_last_fc:
+ output = []
+ for layer in self.final_layers:
+ output.append(layer(x))
+ x = torch.flatten(torch.cat(output, dim=1), 1)
+ return x
+
+
+class RecogNetWrapper(nn.Module):
+ def __init__(self, net_recog, pretrained_path=None, input_size=112):
+ super(RecogNetWrapper, self).__init__()
+ net = get_model(name=net_recog, fp16=False)
+ if pretrained_path:
+ state_dict = torch.load(pretrained_path, map_location='cpu')
+ net.load_state_dict(state_dict)
+ print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
+ for param in net.parameters():
+ param.requires_grad = False
+ self.net = net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+
+ def forward(self, image, M):
+ image = self.preprocess(resize_n_crop(image, M, self.input_size))
+ id_feature = F.normalize(self.net(image), dim=-1, p=2)
+ return id_feature
+
+
+# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ use_last_fc: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.use_last_fc = use_last_fc
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ if self.use_last_fc:
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ if self.use_last_fc:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+func_dict = {
+ 'resnet18': (resnet18, 512),
+ 'resnet50': (resnet50, 2048)
+}
diff --git a/sadtalker_video2pose/src/face3d/models/template_model.py b/sadtalker_video2pose/src/face3d/models/template_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..75860272a06312bfa4de382729dce5136a480a7f
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/models/template_model.py
@@ -0,0 +1,100 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+ min_ ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+ : Add model-specific options and rewrite default values for existing options.
+ <__init__>: Initialize this model class.
+ : Unpack input data and perform data pre-processing.
+ : Run forward pass. This will be called by both and .
+ : Update network weights; it will be called in every training iteration.
+"""
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class TemplateModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new model-specific options and rewrite default values for existing options.
+
+ Parameters:
+ parser -- the option parser
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+ if is_train:
+ parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+ self.loss_names = ['loss_G']
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+ self.visual_names = ['data_A', 'data_B', 'output']
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+ self.model_names = ['G']
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+ if self.isTrain: # only defined during training time
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+ self.criterionLoss = torch.nn.L1Loss()
+ # define and initialize optimizers. You can define one optimizer for each network.
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = [self.optimizer]
+
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
+ self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
+ self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
+ self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
+
+ def forward(self):
+ """Run forward pass. This will be called by both functions and ."""
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
+
+ def backward(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # caculate the intermediate results if necessary; here self.output has been computed during function
+ # calculate loss given the input and intermediate results
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
+
+ def optimize_parameters(self):
+ """Update network weights; it will be called in every training iteration."""
+ self.forward() # first call forward to calculate intermediate results
+ self.optimizer.zero_grad() # clear network G's existing gradients
+ self.backward() # calculate gradients for network G
+ self.optimizer.step() # update gradients for network G
diff --git a/sadtalker_video2pose/src/face3d/options/__init__.py b/sadtalker_video2pose/src/face3d/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06559aa558cf178b946c4523b28b098d1dfad606
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/sadtalker_video2pose/src/face3d/options/base_options.py b/sadtalker_video2pose/src/face3d/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6db3f776b11a3946eaed1a41aae732ff3a15d9
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/options/base_options.py
@@ -0,0 +1,169 @@
+"""This script contains base options for Deep3DFaceRecon_pytorch
+"""
+
+import argparse
+import os
+from util import util
+import numpy as np
+import torch
+import face3d.models as models
+import face3d.data as data
+
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self, cmd_line=None):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+ self.cmd_line = None
+ if cmd_line is not None:
+ self.cmd_line = cmd_line.split()
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--checkpoints_dir', type=str, default='./ckpts/sad_talkers', help='models are saved here')
+ parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
+ parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
+ parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
+ parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
+ parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
+ parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
+ parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
+
+ # model parameters
+ parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
+
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args()
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line)
+
+ # set cuda visible devices
+ os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
+
+ # modify dataset-related parser options
+ if opt.dataset_mode:
+ dataset_name = opt.dataset_mode
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ if self.cmd_line is None:
+ return parser.parse_args()
+ else:
+ return parser.parse_args(self.cmd_line)
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ try:
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ except PermissionError as error:
+ print("permission error {}".format(error))
+ pass
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ gpu_ids.append(id)
+ opt.world_size = len(gpu_ids)
+ # if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(gpu_ids[0])
+ if opt.world_size == 1:
+ opt.use_ddp = False
+
+ if opt.phase != 'test':
+ # set continue_train automatically
+ if opt.pretrained_name is None:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ else:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
+ if os.path.isdir(model_dir):
+ model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
+ if os.path.isdir(model_dir) and len(model_pths) != 0:
+ opt.continue_train= True
+
+ # update the latest epoch count
+ if opt.continue_train:
+ if opt.epoch == 'latest':
+ epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
+ if len(epoch_counts) != 0:
+ opt.epoch_count = max(epoch_counts) + 1
+ else:
+ opt.epoch_count = int(opt.epoch) + 1
+
+
+ self.print_options(opt)
+ self.opt = opt
+ return self.opt
diff --git a/sadtalker_video2pose/src/face3d/options/inference_options.py b/sadtalker_video2pose/src/face3d/options/inference_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b9466776e120e0fe3d164217df5071c2114cef
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/options/inference_options.py
@@ -0,0 +1,23 @@
+from face3d.options.base_options import BaseOptions
+
+
+class InferenceOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+ parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
+ parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
+ parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
+ parser.add_argument('--inference_batch_size', type=int, default=8)
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/sadtalker_video2pose/src/face3d/options/test_options.py b/sadtalker_video2pose/src/face3d/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81c0c6eee0549e6fa8762dc4fc4b8573b887fe4
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/options/test_options.py
@@ -0,0 +1,21 @@
+"""This script contains the test options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/sadtalker_video2pose/src/face3d/options/train_options.py b/sadtalker_video2pose/src/face3d/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1100b0e35cc8ef563f41f6b8219510edbef53233
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/options/train_options.py
@@ -0,0 +1,53 @@
+"""This script contains the training options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+from util import util
+
+class TrainOptions(BaseOptions):
+ """This class includes training options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # dataset parameters
+ # for train
+ parser.add_argument('--data_root', type=str, default='./', help='dataset root')
+ parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
+ parser.add_argument('--batch_size', type=int, default=32)
+ parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
+ parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
+
+ # for val
+ parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
+ parser.add_argument('--batch_size_val', type=int, default=32)
+
+
+ # visualization parameters
+ parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+
+ # network saving and loading parameters
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+ parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
+ parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
+
+ # training parameters
+ parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
+ parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
+ parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
+ parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
+
+ self.isTrain = True
+ return parser
diff --git a/sadtalker_video2pose/src/face3d/util/BBRegressorParam_r.mat b/sadtalker_video2pose/src/face3d/util/BBRegressorParam_r.mat
new file mode 100644
index 0000000000000000000000000000000000000000..a0da99af145c400a5216d9f6fb251d9412565921
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/BBRegressorParam_r.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a5a07b8ce75a39d96b918dc0fc6e110a72e090da16f5f056a0ef7bfbc3f4560
+size 22019
diff --git a/sadtalker_video2pose/src/face3d/util/__init__.py b/sadtalker_video2pose/src/face3d/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c67833cc634a2ca310b883ae253b08687665f40
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/__init__.py
@@ -0,0 +1,3 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from src.face3d.util import *
+
diff --git a/sadtalker_video2pose/src/face3d/util/detect_lm68.py b/sadtalker_video2pose/src/face3d/util/detect_lm68.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2cfd22b342de5c872ff07fc1c2a9920c2985b7
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/detect_lm68.py
@@ -0,0 +1,106 @@
+import os
+import cv2
+import numpy as np
+from scipy.io import loadmat
+import tensorflow as tf
+from util.preprocess import align_for_lm
+from shutil import move
+
+mean_face = np.loadtxt('util/test_mean_face.txt')
+mean_face = mean_face.reshape([68, 2])
+
+def save_label(labels, save_path):
+ np.savetxt(save_path, labels)
+
+def draw_landmarks(img, landmark, save_name):
+ landmark = landmark
+ lm_img = np.zeros([img.shape[0], img.shape[1], 3])
+ lm_img[:] = img.astype(np.float32)
+ landmark = np.round(landmark).astype(np.int32)
+
+ for i in range(len(landmark)):
+ for j in range(-1, 1):
+ for k in range(-1, 1):
+ if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
+ img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
+ landmark[i, 0]+k > 0 and \
+ landmark[i, 0]+k < img.shape[1]:
+ lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
+ :] = np.array([0, 0, 255])
+ lm_img = lm_img.astype(np.uint8)
+
+ cv2.imwrite(save_name, lm_img)
+
+
+def load_data(img_name, txt_name):
+ return cv2.imread(img_name), np.loadtxt(txt_name)
+
+# create tensorflow graph for landmark detector
+def load_lm_graph(graph_filename):
+ with tf.gfile.GFile(graph_filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+
+ with tf.Graph().as_default() as graph:
+ tf.import_graph_def(graph_def, name='net')
+ img_224 = graph.get_tensor_by_name('net/input_imgs:0')
+ output_lm = graph.get_tensor_by_name('net/lm:0')
+ lm_sess = tf.Session(graph=graph)
+
+ return lm_sess,img_224,output_lm
+
+# landmark detection
+def detect_68p(img_path,sess,input_op,output_op):
+ print('detecting landmarks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ vis_path = os.path.join(img_path, 'vis')
+ remove_path = os.path.join(img_path, 'remove')
+ save_path = os.path.join(img_path, 'landmarks')
+ if not os.path.isdir(vis_path):
+ os.makedirs(vis_path)
+ if not os.path.isdir(remove_path):
+ os.makedirs(remove_path)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
+ full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
+
+ # if an image does not have detected 5 facial landmarks, remove it from the training list
+ if not os.path.isfile(full_txt_name):
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # load data
+ img, five_points = load_data(full_image_name, full_txt_name)
+ input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
+
+ # if the alignment fails, remove corresponding image from the training list
+ if scale == 0:
+ move(full_txt_name, os.path.join(
+ remove_path, txt_name))
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # detect landmarks
+ input_img = np.reshape(
+ input_img, [1, 224, 224, 3]).astype(np.float32)
+ landmark = sess.run(
+ output_op, feed_dict={input_op: input_img})
+
+ # transform back to original image coordinate
+ landmark = landmark.reshape([68, 2]) + mean_face
+ landmark[:, 1] = 223 - landmark[:, 1]
+ landmark = landmark / scale
+ landmark[:, 0] = landmark[:, 0] + bbox[0]
+ landmark[:, 1] = landmark[:, 1] + bbox[1]
+ landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
+
+ if i % 100 == 0:
+ draw_landmarks(img, landmark, os.path.join(vis_path, name))
+ save_label(landmark, os.path.join(save_path, txt_name))
diff --git a/sadtalker_video2pose/src/face3d/util/generate_list.py b/sadtalker_video2pose/src/face3d/util/generate_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe93fcc5c61fbc79f4cd004a8d1bdd10ece16eb
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/generate_list.py
@@ -0,0 +1,34 @@
+"""This script is to generate training list files for Deep3DFaceRecon_pytorch
+"""
+
+import os
+
+# save path to training data
+def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
+ save_path = os.path.join(save_folder, mode)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+ with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in lms_list])
+
+ with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in imgs_list])
+
+ with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in msks_list])
+
+# check if the path is valid
+def check_list(rlms_list, rimgs_list, rmsks_list):
+ lms_list, imgs_list, msks_list = [], [], []
+ for i in range(len(rlms_list)):
+ flag = 'false'
+ lm_path = rlms_list[i]
+ im_path = rimgs_list[i]
+ msk_path = rmsks_list[i]
+ if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
+ flag = 'true'
+ lms_list.append(rlms_list[i])
+ imgs_list.append(rimgs_list[i])
+ msks_list.append(rmsks_list[i])
+ print(i, rlms_list[i], flag)
+ return lms_list, imgs_list, msks_list
diff --git a/sadtalker_video2pose/src/face3d/util/html.py b/sadtalker_video2pose/src/face3d/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c4e6a66ba5a34e30cee3beb13e21465c72ef38
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__': # we show an example usage here.
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/sadtalker_video2pose/src/face3d/util/load_mats.py b/sadtalker_video2pose/src/face3d/util/load_mats.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ea0a7877e80035883138415c102910d896bb61
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/load_mats.py
@@ -0,0 +1,120 @@
+"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat, savemat
+from array import array
+import os.path as osp
+
+# load expression basis
+def LoadExpBasis(bfm_folder='BFM'):
+ n_vertex = 53215
+ Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
+ exp_dim = array('i')
+ exp_dim.fromfile(Expbin, 1)
+ expMU = array('f')
+ expPC = array('f')
+ expMU.fromfile(Expbin, 3*n_vertex)
+ expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
+ Expbin.close()
+
+ expPC = np.array(expPC)
+ expPC = np.reshape(expPC, [exp_dim[0], -1])
+ expPC = np.transpose(expPC)
+
+ expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
+
+ return expPC, expEV
+
+
+# transfer original BFM09 to our face model
+def transferBFM09(bfm_folder='BFM'):
+ print('Transfer BFM09 to BFM_model_front......')
+ original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
+ shapePC = original_BFM['shapePC'] # shape basis
+ shapeEV = original_BFM['shapeEV'] # corresponding eigen value
+ shapeMU = original_BFM['shapeMU'] # mean face
+ texPC = original_BFM['texPC'] # texture basis
+ texEV = original_BFM['texEV'] # eigen value
+ texMU = original_BFM['texMU'] # mean texture
+
+ expPC, expEV = LoadExpBasis(bfm_folder)
+
+ # transfer BFM09 to our face model
+
+ idBase = shapePC*np.reshape(shapeEV, [-1, 199])
+ idBase = idBase/1e5 # unify the scale to decimeter
+ idBase = idBase[:, :80] # use only first 80 basis
+
+ exBase = expPC*np.reshape(expEV, [-1, 79])
+ exBase = exBase/1e5 # unify the scale to decimeter
+ exBase = exBase[:, :64] # use only first 64 basis
+
+ texBase = texPC*np.reshape(texEV, [-1, 199])
+ texBase = texBase[:, :80] # use only first 80 basis
+
+ # our face model is cropped along face landmarks and contains only 35709 vertex.
+ # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
+ # thus we select corresponding vertex to get our face model.
+
+ index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
+ index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
+
+ index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
+ index_shape = index_shape['trimIndex'].astype(
+ np.int32) - 1 # starts from 0 (to 53490)
+ index_shape = index_shape[index_exp]
+
+ idBase = np.reshape(idBase, [-1, 3, 80])
+ idBase = idBase[index_shape, :, :]
+ idBase = np.reshape(idBase, [-1, 80])
+
+ texBase = np.reshape(texBase, [-1, 3, 80])
+ texBase = texBase[index_shape, :, :]
+ texBase = np.reshape(texBase, [-1, 80])
+
+ exBase = np.reshape(exBase, [-1, 3, 64])
+ exBase = exBase[index_exp, :, :]
+ exBase = np.reshape(exBase, [-1, 64])
+
+ meanshape = np.reshape(shapeMU, [-1, 3])/1e5
+ meanshape = meanshape[index_shape, :]
+ meanshape = np.reshape(meanshape, [1, -1])
+
+ meantex = np.reshape(texMU, [-1, 3])
+ meantex = meantex[index_shape, :]
+ meantex = np.reshape(meantex, [1, -1])
+
+ # other info contains triangles, region used for computing photometric loss,
+ # region used for skin texture regularization, and 68 landmarks index etc.
+ other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
+ frontmask2_idx = other_info['frontmask2_idx']
+ skinmask = other_info['skinmask']
+ keypoints = other_info['keypoints']
+ point_buf = other_info['point_buf']
+ tri = other_info['tri']
+ tri_mask2 = other_info['tri_mask2']
+
+ # save our face model
+ savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
+ 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
+
+
+# load landmarks for standard face, which is used for image preprocessing
+def load_lm3d(bfm_folder):
+
+ Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
+ Lm3D = Lm3D['lm']
+
+ # calculate 5 facial landmarks using 68 landmarks
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
+ Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
+ Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
+
+ return Lm3D
+
+
+if __name__ == '__main__':
+ transferBFM09()
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/face3d/util/nvdiffrast.py b/sadtalker_video2pose/src/face3d/util/nvdiffrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b345db30085de501b6718ad5b49bb5f9144dd29
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/nvdiffrast.py
@@ -0,0 +1,126 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+ Attention, antialiasing step is missing in current version.
+"""
+import pytorch3d.ops
+import torch
+import torch.nn.functional as F
+import kornia
+from kornia.geometry.camera import pixel2cam
+import numpy as np
+from typing import List
+from scipy.io import loadmat
+from torch import nn
+
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import (
+ look_at_view_transform,
+ FoVPerspectiveCameras,
+ DirectionalLights,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesUV,
+)
+
+# def ndc_projection(x=0.1, n=1.0, f=50.0):
+# return np.array([[n/x, 0, 0, 0],
+# [ 0, n/-x, 0, 0],
+# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+# [ 0, 0, -1, 0]]).astype(np.float32)
+
+class MeshRenderer(nn.Module):
+ def __init__(self,
+ rasterize_fov,
+ znear=0.1,
+ zfar=10,
+ rasterize_size=224):
+ super(MeshRenderer, self).__init__()
+
+ # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+ # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+ # torch.diag(torch.tensor([1., -1, -1, 1])))
+ self.rasterize_size = rasterize_size
+ self.fov = rasterize_fov
+ self.znear = znear
+ self.zfar = zfar
+
+ self.rasterizer = None
+
+ def forward(self, vertex, tri, feat=None):
+ """
+ Return:
+ mask -- torch.tensor, size (B, 1, H, W)
+ depth -- torch.tensor, size (B, 1, H, W)
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+ Parameters:
+ vertex -- torch.tensor, size (B, N, 3)
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+ feat(optional) -- torch.tensor, size (B, N ,C), features
+ """
+ device = vertex.device
+ rsize = int(self.rasterize_size)
+ # ndc_proj = self.ndc_proj.to(device)
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+ if vertex.shape[-1] == 3:
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+ vertex[..., 0] = -vertex[..., 0]
+
+
+ # vertex_ndc = vertex @ ndc_proj.t()
+ if self.rasterizer is None:
+ self.rasterizer = MeshRasterizer()
+ print("create rasterizer on device cuda:%d"%device.index)
+
+ # ranges = None
+ # if isinstance(tri, List) or len(tri.shape) == 3:
+ # vum = vertex_ndc.shape[1]
+ # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+ # fstartidx = torch.cumsum(fnum, dim=0) - fnum
+ # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+ # for i in range(tri.shape[0]):
+ # tri[i] = tri[i] + i*vum
+ # vertex_ndc = torch.cat(vertex_ndc, dim=0)
+ # tri = torch.cat(tri, dim=0)
+
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+ tri = tri.type(torch.int32).contiguous()
+
+ # rasterize
+ cameras = FoVPerspectiveCameras(
+ device=device,
+ fov=self.fov,
+ znear=self.znear,
+ zfar=self.zfar,
+ )
+
+ raster_settings = RasterizationSettings(
+ image_size=rsize
+ )
+
+ # print(vertex.shape, tri.shape)
+ mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))
+
+ fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
+ rast_out = fragments.pix_to_face.squeeze(-1)
+ depth = fragments.zbuf
+
+ # render depth
+ depth = depth.permute(0, 3, 1, 2)
+ mask = (rast_out > 0).float().unsqueeze(1)
+ depth = mask * depth
+
+
+ image = None
+ if feat is not None:
+ attributes = feat.reshape(-1,3)[mesh.faces_packed()]
+ image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
+ fragments.bary_coords,
+ attributes)
+ # print(image.shape)
+ image = image.squeeze(-2).permute(0, 3, 1, 2)
+ image = mask * image
+
+ return mask, depth, image
+
diff --git a/sadtalker_video2pose/src/face3d/util/preprocess.py b/sadtalker_video2pose/src/face3d/util/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b36443fe4c84c1ad6366897a8e7d4e8b63b2b6
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/preprocess.py
@@ -0,0 +1,134 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from scipy.io import loadmat
+from PIL import Image
+import cv2
+import os
+from skimage import transform as trans
+import torch
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+ npts = xp.shape[1]
+
+ A = np.zeros([2*npts, 8])
+
+ A[0:2*npts-1:2, 0:3] = x.transpose()
+ A[0:2*npts-1:2, 3] = 1
+
+ A[1:2*npts:2, 4:7] = x.transpose()
+ A[1:2*npts:2, 7] = 1
+
+ b = np.reshape(xp.transpose(), [2*npts, 1])
+
+ k, _, _, _ = np.linalg.lstsq(A, b)
+
+ R1 = k[0:3]
+ R2 = k[4:7]
+ sTx = k[3]
+ sTy = k[7]
+ s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
+ t = np.stack([sTx, sTy], axis=0)
+
+ return t, s
+
+# # resize and crop images for face reconstruction
+# def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+# w0, h0 = img.size
+# w = (w0*s).astype(np.int32)
+# h = (h0*s).astype(np.int32)
+# left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+# right = left + target_size
+# up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+# below = up + target_size
+
+# img = img.resize((w, h), resample=Image.BICUBIC)
+# img = img.crop((left, up, right, below))
+
+# if mask is not None:
+# mask = mask.resize((w, h), resample=Image.BICUBIC)
+# mask = mask.crop((left, up, right, below))
+
+# lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+# t[1] + h0/2], axis=1)*s
+# lm = lm - np.reshape(
+# np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+# return img, lm, mask
+
+
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+ w0, h0 = img.size
+ w = (w0*s).astype(np.int32)
+ h = (h0*s).astype(np.int32)
+ left = np.round(w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+ right = left + target_size
+ up = np.round(h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+ below = up + target_size
+
+ img = img.resize((w, h), resample=Image.BICUBIC)
+ img = img.crop((left, up, right, below))
+ # import pdb; pdb.set_trace()
+ if mask is not None:
+ mask = mask.resize((w, h), resample=Image.BICUBIC)
+ mask = mask.crop((left, up, right, below))
+
+ lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+ t[1] + h0/2], axis=1)*s
+ lm = lm - np.reshape(
+ np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+ # orig_left, orig_up, orig_crop_size = (left,up,target_size)/s
+
+ return img, lm, mask, left, up, target_size
+
+# utils for face reconstruction
+def extract_5p(lm):
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
+ lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
+ lm5p = lm5p[[1, 2, 0, 3, 4], :]
+ return lm5p
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
+ """
+ Return:
+ transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
+ img_new --PIL.Image (target_size, target_size, 3)
+ lm_new --numpy.array (68, 2), y direction is opposite to v direction
+ mask_new --PIL.Image (target_size, target_size)
+
+ Parameters:
+ img --PIL.Image (raw_H, raw_W, 3)
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ lm3D --numpy.array (5, 3)
+ mask --PIL.Image (raw_H, raw_W, 3)
+ """
+
+ w0, h0 = img.size
+ if lm.shape[0] != 5:
+ lm5p = extract_5p(lm)
+ else:
+ lm5p = lm
+
+ # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+ t, s = POS(lm5p.transpose(), lm3D.transpose())
+ s = rescale_factor/s
+
+ # processing the image
+
+ # processing the image
+ img_new, lm_new, mask_new, orig_left, orig_up, orig_crop_size = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ trans_params = np.array([w0, h0, s, t[0], t[1], orig_left, orig_up, orig_crop_size])
+ # img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ # trans_params = np.array([w0, h0, s, t[0], t[1]])
+
+ return trans_params, img_new, lm_new, mask_new
diff --git a/sadtalker_video2pose/src/face3d/util/skin_mask.py b/sadtalker_video2pose/src/face3d/util/skin_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed764759038f77b35d45448b344d4347498ca427
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/skin_mask.py
@@ -0,0 +1,125 @@
+"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
+"""
+
+import math
+import numpy as np
+import os
+import cv2
+
+class GMM:
+ def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
+ self.dim = dim # feature dimension
+ self.num = num # number of Gaussian components
+ self.w = w # weights of Gaussian components (a list of scalars)
+ self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
+ self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
+ self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
+ self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
+
+ self.factor = [0]*num
+ for i in range(self.num):
+ self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
+
+ def likelihood(self, data):
+ assert(data.shape[1] == self.dim)
+ N = data.shape[0]
+ lh = np.zeros(N)
+
+ for i in range(self.num):
+ data_ = data - self.mu[i]
+
+ tmp = np.matmul(data_,self.cov_inv[i]) * data_
+ tmp = np.sum(tmp,axis=1)
+ power = -0.5 * tmp
+
+ p = np.array([math.exp(power[j]) for j in range(N)])
+ p = p/self.factor[i]
+ lh += p*self.w[i]
+
+ return lh
+
+
+def _rgb2ycbcr(rgb):
+ m = np.array([[65.481, 128.553, 24.966],
+ [-37.797, -74.203, 112],
+ [112, -93.786, -18.214]])
+ shape = rgb.shape
+ rgb = rgb.reshape((shape[0] * shape[1], 3))
+ ycbcr = np.dot(rgb, m.transpose() / 255.)
+ ycbcr[:, 0] += 16.
+ ycbcr[:, 1:] += 128.
+ return ycbcr.reshape(shape)
+
+
+def _bgr2ycbcr(bgr):
+ rgb = bgr[..., ::-1]
+ return _rgb2ycbcr(rgb)
+
+
+gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
+gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
+ np.array([150.19858, 105.18467, 155.51428]),
+ np.array([183.92976, 107.62468, 152.71820]),
+ np.array([114.90524, 113.59782, 151.38217])]
+gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
+gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
+ np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
+ np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
+ np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
+
+gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
+
+gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
+gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
+ np.array([110.91392, 125.52969, 130.19237]),
+ np.array([129.75864, 129.96107, 126.96808]),
+ np.array([112.29587, 128.85121, 129.05431])]
+gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
+gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
+ np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
+ np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
+ np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
+
+gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
+
+prior_skin = 0.8
+prior_nonskin = 1 - prior_skin
+
+
+# calculate skin attention mask
+def skinmask(imbgr):
+ im = _bgr2ycbcr(imbgr)
+
+ data = im.reshape((-1,3))
+
+ lh_skin = gmm_skin.likelihood(data)
+ lh_nonskin = gmm_nonskin.likelihood(data)
+
+ tmp1 = prior_skin * lh_skin
+ tmp2 = prior_nonskin * lh_nonskin
+ post_skin = tmp1 / (tmp1+tmp2) # posterior probability
+
+ post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
+
+ post_skin = np.round(post_skin*255)
+ post_skin = post_skin.astype(np.uint8)
+ post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
+
+ return post_skin
+
+
+def get_skin_mask(img_path):
+ print('generating skin masks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ save_path = os.path.join(img_path, 'mask')
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ img = cv2.imread(full_image_name).astype(np.float32)
+ skin_img = skinmask(img)
+ cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
diff --git a/sadtalker_video2pose/src/face3d/util/test_mean_face.txt b/sadtalker_video2pose/src/face3d/util/test_mean_face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1637648acf5a61cbc71b317c845414bb16d0150c
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/test_mean_face.txt
@@ -0,0 +1,136 @@
+-5.228591537475585938e+01
+2.078247070312500000e-01
+-5.064269638061523438e+01
+-1.315765380859375000e+01
+-4.952939224243164062e+01
+-2.592591094970703125e+01
+-4.793047332763671875e+01
+-3.832135772705078125e+01
+-4.512159729003906250e+01
+-5.059623336791992188e+01
+-3.917720794677734375e+01
+-6.043736648559570312e+01
+-2.929953765869140625e+01
+-6.861183166503906250e+01
+-1.719801330566406250e+01
+-7.572736358642578125e+01
+-1.961936950683593750e+00
+-7.862001037597656250e+01
+1.467941284179687500e+01
+-7.607844543457031250e+01
+2.744073486328125000e+01
+-6.915261840820312500e+01
+3.855677795410156250e+01
+-5.950350570678710938e+01
+4.478240966796875000e+01
+-4.867547225952148438e+01
+4.714337158203125000e+01
+-3.800830078125000000e+01
+4.940315246582031250e+01
+-2.496297454833984375e+01
+5.117234802246093750e+01
+-1.241538238525390625e+01
+5.190507507324218750e+01
+8.244247436523437500e-01
+-4.150688934326171875e+01
+2.386329650878906250e+01
+-3.570307159423828125e+01
+3.017010498046875000e+01
+-2.790358734130859375e+01
+3.212951660156250000e+01
+-1.941773223876953125e+01
+3.156523132324218750e+01
+-1.138106536865234375e+01
+2.841992187500000000e+01
+5.993263244628906250e+00
+2.895182800292968750e+01
+1.343590545654296875e+01
+3.189880371093750000e+01
+2.203153991699218750e+01
+3.302221679687500000e+01
+2.992478942871093750e+01
+3.099150085449218750e+01
+3.628388977050781250e+01
+2.765748596191406250e+01
+-1.933914184570312500e+00
+1.405374145507812500e+01
+-2.153038024902343750e+00
+5.772636413574218750e+00
+-2.270050048828125000e+00
+-2.121643066406250000e+00
+-2.218330383300781250e+00
+-1.068978118896484375e+01
+-1.187252044677734375e+01
+-1.997912597656250000e+01
+-6.879402160644531250e+00
+-2.143579864501953125e+01
+-1.227821350097656250e+00
+-2.193494415283203125e+01
+4.623237609863281250e+00
+-2.152721405029296875e+01
+9.721397399902343750e+00
+-1.953671264648437500e+01
+-3.648714447021484375e+01
+9.811126708984375000e+00
+-3.130242919921875000e+01
+1.422447967529296875e+01
+-2.212834930419921875e+01
+1.493019866943359375e+01
+-1.500880432128906250e+01
+1.073588562011718750e+01
+-2.095037078857421875e+01
+9.054298400878906250e+00
+-3.050099182128906250e+01
+8.704177856445312500e+00
+1.173237609863281250e+01
+1.054329681396484375e+01
+1.856353759765625000e+01
+1.535009765625000000e+01
+2.893331909179687500e+01
+1.451992797851562500e+01
+3.452944946289062500e+01
+1.065280151367187500e+01
+2.875990295410156250e+01
+8.654792785644531250e+00
+1.942100524902343750e+01
+9.422447204589843750e+00
+-2.204488372802734375e+01
+-3.983994293212890625e+01
+-1.324458312988281250e+01
+-3.467377471923828125e+01
+-6.749649047851562500e+00
+-3.092894744873046875e+01
+-9.183349609375000000e-01
+-3.196458435058593750e+01
+4.220649719238281250e+00
+-3.090406036376953125e+01
+1.089889526367187500e+01
+-3.497008514404296875e+01
+1.874589538574218750e+01
+-4.065438079833984375e+01
+1.124106597900390625e+01
+-4.438417816162109375e+01
+5.181709289550781250e+00
+-4.649170684814453125e+01
+-1.158607482910156250e+00
+-4.680406951904296875e+01
+-7.918922424316406250e+00
+-4.671575164794921875e+01
+-1.452505493164062500e+01
+-4.416526031494140625e+01
+-2.005007171630859375e+01
+-3.997841644287109375e+01
+-1.054919433593750000e+01
+-3.849683380126953125e+01
+-1.051826477050781250e+00
+-3.794863128662109375e+01
+6.412681579589843750e+00
+-3.804645538330078125e+01
+1.627674865722656250e+01
+-4.039697265625000000e+01
+6.373878479003906250e+00
+-4.087213897705078125e+01
+-8.551712036132812500e-01
+-4.157129669189453125e+01
+-1.014953613281250000e+01
+-4.128469085693359375e+01
diff --git a/sadtalker_video2pose/src/face3d/util/util.py b/sadtalker_video2pose/src/face3d/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c7517ee66c8830a73fa86ab5e5c3513f11d869
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/util.py
@@ -0,0 +1,208 @@
+"""This script contains basic utilities for Deep3DFaceRecon_pytorch
+"""
+from __future__ import print_function
+import numpy as np
+import torch
+from PIL import Image
+import os
+import importlib
+import argparse
+from argparse import Namespace
+import torchvision
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def copyconf(default_opt, **kwargs):
+ conf = Namespace(**vars(default_opt))
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+ return conf
+
+def genvalconf(train_opt, **kwargs):
+ conf = Namespace(**vars(train_opt))
+ attr_dict = train_opt.__dict__
+ for key, value in attr_dict.items():
+ if 'val' in key and key.split('_')[0] in attr_dict:
+ setattr(conf, key.split('_')[0], value)
+
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+
+ return conf
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace('_', '').lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
+
+ return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array, range(0, 1)
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
+ if image_numpy.shape[0] == 1: # grayscale to RGB
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+
+ image_pil = Image.fromarray(image_numpy)
+ h, w, _ = image_numpy.shape
+
+ if aspect_ratio is None:
+ pass
+ elif aspect_ratio > 1.0:
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ elif aspect_ratio < 1.0:
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i, :1]
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+ one_np = one_np[:, :, 0]
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+ resized_t = torch.from_numpy(np.array(one_image)).long()
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.BICUBIC):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i:i + 1]
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+def draw_landmarks(img, landmark, color='r', step=2):
+ """
+ Return:
+ img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
+
+
+ Parameters:
+ img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
+ landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
+ color -- str, 'r' or 'b' (red or blue)
+ """
+ if color =='r':
+ c = np.array([255., 0, 0])
+ else:
+ c = np.array([0, 0, 255.])
+
+ _, H, W, _ = img.shape
+ img, landmark = img.copy(), landmark.copy()
+ landmark[..., 1] = H - 1 - landmark[..., 1]
+ landmark = np.round(landmark).astype(np.int32)
+ for i in range(landmark.shape[1]):
+ x, y = landmark[:, i, 0], landmark[:, i, 1]
+ for j in range(-step, step):
+ for k in range(-step, step):
+ u = np.clip(x + j, 0, W - 1)
+ v = np.clip(y + k, 0, H - 1)
+ for m in range(landmark.shape[0]):
+ img[m, v[m], u[m]] = c
+ return img
diff --git a/sadtalker_video2pose/src/face3d/util/visualizer.py b/sadtalker_video2pose/src/face3d/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a8b755e054a4a34d003962a723ef189726a7a0
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/util/visualizer.py
@@ -0,0 +1,227 @@
+"""This script defines the visualizer for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+from torch.utils.tensorboard import SummaryWriter
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = '%s/%s.png' % (label, name)
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ self.use_html = opt.isTrain and not opt.no_html
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.saved = False
+ if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+
+ def display_current_results(self, visuals, total_iters, epoch, save_result):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ for label, image in visuals.items():
+ self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
+
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ def plot_current_losses(self, total_iters, losses):
+ # G_loss_collection = {}
+ # D_loss_collection = {}
+ # for name, value in losses.items():
+ # if 'G' in name or 'NCE' in name or 'idt' in name:
+ # G_loss_collection[name] = value
+ # else:
+ # D_loss_collection[name] = value
+ # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
+ # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
+ for name, value in losses.items():
+ self.writer.add_scalar(name, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
+
+
+class MyVisualizer:
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the optio
+ self.name = opt.name
+ self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
+
+ if opt.phase != 'test':
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+
+ def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
+ add_image=True):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ dataset (str) - - 'train' or 'val' or 'test'
+ """
+ # if (not add_image) and (not save_results): return
+
+ for label, image in visuals.items():
+ for i in range(image.shape[0]):
+ image_numpy = util.tensor2im(image[i])
+ if add_image:
+ self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
+ image_numpy, total_iters, dataformats='HWC')
+
+ if save_results:
+ save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ if name is not None:
+ img_path = os.path.join(save_path, '%s.png' % name)
+ else:
+ img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
+ util.save_image(image_numpy, img_path)
+
+
+ def plot_current_losses(self, total_iters, losses, dataset='train'):
+ for name, value in losses.items():
+ self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
+ dataset, epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
diff --git a/sadtalker_video2pose/src/face3d/visualize.py b/sadtalker_video2pose/src/face3d/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb8791ec30fb8f748aefc82cf4385444754825a4
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/visualize.py
@@ -0,0 +1,133 @@
+# check the sync of 3dmm feature and the audio
+import shutil
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm
+
+
+def draw_landmarks(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1)
+ cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return image
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False):
+
+ coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+ info = scio.loadmat(first_frame_coeff)['trans_params'][0]
+ print(info)
+
+ coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+ # print(coeff_pred.shape)
+ # print(coeff_pred[1:, 64:].shape)
+
+ if args.still:
+ coeff_pred[1:, 64:] = np.stack([coeff_pred[0, 64:]]*coeff_pred[1:, 64:].shape[0])
+
+ # assert False
+
+ coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+ coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+ coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
+ coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_video_path = '/tmp/face3dtmp.mp4'
+ facemodel = FaceReconModel(args)
+ im0 = cv2.imread(args.source_image)
+
+ video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+ # since we resize the video, we first need to resize the landmark to the cropped size resolution
+ # then, we need to add it back to the original video
+ x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256
+
+ W, H = im0.shape[0], im0.shape[1]
+
+ _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7])
+ orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)]
+
+ landmark_scale = np.array([[x_scale, y_scale]])
+ landmark_shift = np.array([[orig_left, orig_up]])
+ landmark_shift2 = np.array([[ox1, oy1]])
+
+
+ landmarks = []
+
+ for k in tqdm(range(coeff_first.shape[0]), '1st:'):
+ cur_coeff_full = torch.tensor(coeff_first, device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ print(orig_up, orig_left, orig_crop_size, s)
+
+ for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+ cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ rendered_img = facemodel.pred_face
+ rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+ out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+ video.write(np.uint8(out_img[:,:,::-1]))
+
+ video.release()
+
+ # visualize landmarks
+ video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1]))
+
+ for k in tqdm(range(len(landmarks)), 'face3d vis:'):
+ # im = draw_landmarks(im0.copy(), landmarks[k])
+ im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k])
+ video.write(im)
+ video.release()
+
+ shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png'))
+
+ np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks)
+
+ command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+ subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/sadtalker_video2pose/src/face3d/visualize_fromvideo.py b/sadtalker_video2pose/src/face3d/visualize_fromvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d74872695739df70ce9009351b7cd78a8cb779
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/visualize_fromvideo.py
@@ -0,0 +1,133 @@
+# check the sync of 3dmm feature and the audio
+import shutil
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm
+
+
+def draw_landmarks(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1)
+ cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return image
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, save_path, save_lmk_path, crop_info, extended_crop = False):
+
+ coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+ info = scio.loadmat(first_frame_coeff)['trans_params'][0]
+ print(info)
+
+ coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+ # print(coeff_pred.shape)
+ # print(coeff_pred[1:, 64:].shape)
+
+ if args.still:
+ coeff_pred[1:, 64:] = np.stack([coeff_pred[0, 64:]]*coeff_pred[1:, 64:].shape[0])
+
+ # assert False
+
+ coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+ coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+ coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
+ coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ # tmp_video_path = '/tmp/face3dtmp.mp4'
+ facemodel = FaceReconModel(args)
+ im0 = cv2.imread(args.source_image)
+
+ video = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+ # since we resize the video, we first need to resize the landmark to the cropped size resolution
+ # then, we need to add it back to the original video
+ x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256
+
+ W, H = im0.shape[0], im0.shape[1]
+
+ _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7])
+ orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)]
+
+ landmark_scale = np.array([[x_scale, y_scale]])
+ landmark_shift = np.array([[orig_left, orig_up]])
+ landmark_shift2 = np.array([[ox1, oy1]])
+
+
+ landmarks = []
+
+ for k in tqdm(range(coeff_first.shape[0]), '1st:'):
+ cur_coeff_full = torch.tensor(coeff_first, device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ print(orig_up, orig_left, orig_crop_size, s)
+
+ for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+ cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ rendered_img = facemodel.pred_face
+ rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+ out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+ video.write(np.uint8(out_img[:,:,::-1]))
+
+ video.release()
+
+ # visualize landmarks
+ video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1]))
+
+ for k in tqdm(range(len(landmarks)), 'face3d vis:'):
+ # im = draw_landmarks(im0.copy(), landmarks[k])
+ im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k])
+ video.write(im)
+ video.release()
+
+ shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png'))
+
+ np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks)
+
+ # command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+ # subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/sadtalker_video2pose/src/face3d/visualize_old.py b/sadtalker_video2pose/src/face3d/visualize_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4a37b388320344fd96b4778b60679440fe584c3
--- /dev/null
+++ b/sadtalker_video2pose/src/face3d/visualize_old.py
@@ -0,0 +1,110 @@
+# check the sync of 3dmm feature and the audio
+import shutil
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm
+
+
+def draw_landmarks(image, landmarks):
+ for i, point in enumerate(landmarks):
+ cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1)
+ cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return image
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False):
+
+ coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+ info = scio.loadmat(first_frame_coeff)['trans_params'][0]
+ print(info)
+
+ coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+ coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+ coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+ coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
+ coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_video_path = '/tmp/face3dtmp.mp4'
+ facemodel = FaceReconModel(args)
+ im0 = cv2.imread(args.source_image)
+
+ video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+ # since we resize the video, we first need to resize the landmark to the cropped size resolution
+ # then, we need to add it back to the original video
+ x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256
+
+ W, H = im0.shape[0], im0.shape[1]
+
+ _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7])
+ orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)]
+
+ landmark_scale = np.array([[x_scale, y_scale]])
+ landmark_shift = np.array([[orig_left, orig_up]])
+ landmark_shift2 = np.array([[ox1, oy1]])
+
+ landmarks = []
+
+ print(orig_up, orig_left, orig_crop_size, s)
+
+ for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+ cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+ facemodel.forward(cur_coeff_full, device)
+
+ predicted_landmark = facemodel.pred_lm # TODO.
+ predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+ predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1]
+
+ predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2
+
+ landmarks.append(predicted_landmark)
+
+ rendered_img = facemodel.pred_face
+ rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+ out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+ video.write(np.uint8(out_img[:,:,::-1]))
+
+ video.release()
+
+ # visualize landmarks
+ video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1]))
+
+ for k in tqdm(range(len(landmarks)), 'face3d vis:'):
+ # im = draw_landmarks(im0.copy(), landmarks[k])
+ im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k])
+ video.write(im)
+ video.release()
+
+ shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png'))
+
+ np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks)
+
+ command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+ subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/sadtalker_video2pose/src/facerender/animate.py b/sadtalker_video2pose/src/facerender/animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..45fcb45edb4169166b851a066c8aaf08063ed1c6
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/animate.py
@@ -0,0 +1,261 @@
+import os
+import cv2
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+from src.facerender.modules.make_animation import make_animation
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+class AnimateFromCoeff():
+
+ def __init__(self, sadtalker_path, device):
+
+ with open(sadtalker_path['facerender_yaml']) as f:
+ config = yaml.safe_load(f)
+
+ generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+ kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+ he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+ mapping = MappingNet(**config['model_params']['mapping_params'])
+
+ generator.to(device)
+ kp_extractor.to(device)
+ he_estimator.to(device)
+ mapping.to(device)
+ for param in generator.parameters():
+ param.requires_grad = False
+ for param in kp_extractor.parameters():
+ param.requires_grad = False
+ for param in he_estimator.parameters():
+ param.requires_grad = False
+ for param in mapping.parameters():
+ param.requires_grad = False
+
+ if sadtalker_path is not None:
+ if 'checkpoint' in sadtalker_path: # use safe tensor
+ self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
+ else:
+ self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+ else:
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+ if sadtalker_path['mappingnet_checkpoint'] is not None:
+ self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
+ else:
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.he_estimator = he_estimator
+ self.mapping = mapping
+
+ self.kp_extractor.eval()
+ self.generator.eval()
+ self.he_estimator.eval()
+ self.mapping.eval()
+
+ self.device = device
+
+ def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+ def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if mapping is not None:
+ mapping.load_state_dict(checkpoint['mapping'])
+ if discriminator is not None:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ if optimizer_mapping is not None:
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
+ if optimizer_discriminator is not None:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+
+ return checkpoint['epoch']
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ if 'yaw_c_seq' in x:
+ yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
+ yaw_c_seq = x['yaw_c_seq'].to(self.device)
+ else:
+ yaw_c_seq = None
+ if 'pitch_c_seq' in x:
+ pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
+ pitch_c_seq = x['pitch_c_seq'].to(self.device)
+ else:
+ pitch_c_seq = None
+ if 'roll_c_seq' in x:
+ roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
+ roll_c_seq = x['roll_c_seq'].to(self.device)
+ else:
+ roll_c_seq = None
+
+ frame_num = x['frame_num']
+
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
+ self.generator, self.kp_extractor, self.he_estimator, self.mapping,
+ yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
+
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+ predictions_video = predictions_video[:frame_num]
+
+ video = []
+ for idx in range(predictions_video.shape[0]):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ # print(path)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+
+
+ # os.remove(enhanced_path)
+
+ # os.remove(path)
+ # os.remove(new_audio_path)
+
+ return return_path
+
diff --git a/sadtalker_video2pose/src/facerender/modules/dense_motion.py b/sadtalker_video2pose/src/facerender/modules/dense_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c30417870e79bc005ea47a8f383c3aa406df563
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/dense_motion.py
@@ -0,0 +1,121 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+
+class DenseMotionNetwork(nn.Module):
+ """
+ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
+ """
+
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
+ estimate_occlusion_map=False):
+ super(DenseMotionNetwork, self).__init__()
+ # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
+
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
+
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
+ self.norm = BatchNorm3d(compress, affine=True)
+
+ if estimate_occlusion_map:
+ # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
+ else:
+ self.occlusion = None
+
+ self.num_kp = num_kp
+
+
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape
+ identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3)
+ coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
+
+ # if 'jacobian' in kp_driving:
+ if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
+ jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
+ jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
+ jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
+ coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
+ coordinate_grid = coordinate_grid.squeeze(-1)
+
+
+ driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
+
+ #adding background feature
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3
+
+ # sparse_motions = driving_to_source
+
+ return sparse_motions
+
+ def create_deformed_feature(self, feature, sparse_motions):
+ bs, _, d, h, w = feature.shape
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!!
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
+ return sparse_deformed
+
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
+ spatial_size = feature.shape[3:]
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
+ heatmap = gaussian_driving - gaussian_source
+
+ # adding background feature
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
+ heatmap = torch.cat([zeros, heatmap], dim=1)
+ heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+ return heatmap
+
+ def forward(self, feature, kp_driving, kp_source):
+ bs, _, d, h, w = feature.shape
+
+ feature = self.compress(feature)
+ feature = self.norm(feature)
+ feature = F.relu(feature)
+
+ out_dict = dict()
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion)
+
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
+
+ input_ = torch.cat([heatmap, deformed_feature], dim=2)
+ input_ = input_.view(bs, -1, d, h, w)
+
+ # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w)
+
+ prediction = self.hourglass(input_)
+
+
+ mask = self.mask(prediction)
+ mask = F.softmax(mask, dim=1)
+ out_dict['mask'] = mask
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
+
+ zeros_mask = torch.zeros_like(mask)
+ mask = torch.where(mask < 1e-3, zeros_mask, mask)
+
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w)
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
+
+ out_dict['deformation'] = deformation
+
+ if self.occlusion:
+ bs, c, d, h, w = prediction.shape
+ prediction = prediction.view(bs, -1, h, w)
+ occlusion_map = torch.sigmoid(self.occlusion(prediction))
+ out_dict['occlusion_map'] = occlusion_map
+
+ return out_dict
diff --git a/sadtalker_video2pose/src/facerender/modules/discriminator.py b/sadtalker_video2pose/src/facerender/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc0a2b460d2175a958d7b230b7e5233d7d7c7f92
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/discriminator.py
@@ -0,0 +1,90 @@
+from torch import nn
+import torch.nn.functional as F
+from facerender.modules.util import kp2gaussian
+import torch
+
+
+class DownBlock2d(nn.Module):
+ """
+ Simple block for processing video (encoder).
+ """
+
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
+
+ if sn:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ if norm:
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
+ else:
+ self.norm = None
+ self.pool = pool
+
+ def forward(self, x):
+ out = x
+ out = self.conv(out)
+ if self.norm:
+ out = self.norm(out)
+ out = F.leaky_relu(out, 0.2)
+ if self.pool:
+ out = F.avg_pool2d(out, (2, 2))
+ return out
+
+
+class Discriminator(nn.Module):
+ """
+ Discriminator similar to Pix2Pix
+ """
+
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
+ sn=False, **kwargs):
+ super(Discriminator, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(
+ DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
+ if sn:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ def forward(self, x):
+ feature_maps = []
+ out = x
+
+ for down_block in self.down_blocks:
+ feature_maps.append(down_block(out))
+ out = feature_maps[-1]
+ prediction_map = self.conv(out)
+
+ return feature_maps, prediction_map
+
+
+class MultiScaleDiscriminator(nn.Module):
+ """
+ Multi-scale (scale) discriminator
+ """
+
+ def __init__(self, scales=(), **kwargs):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.scales = scales
+ discs = {}
+ for scale in scales:
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
+ self.discs = nn.ModuleDict(discs)
+
+ def forward(self, x):
+ out_dict = {}
+ for scale, disc in self.discs.items():
+ scale = str(scale).replace('-', '.')
+ key = 'prediction_' + scale
+ feature_maps, prediction_map = disc(x[key])
+ out_dict['feature_maps_' + scale] = feature_maps
+ out_dict['prediction_map_' + scale] = prediction_map
+ return out_dict
diff --git a/sadtalker_video2pose/src/facerender/modules/generator.py b/sadtalker_video2pose/src/facerender/modules/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b94dde7a37c5ddf0f74dd0317a5db3507ab0729
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/generator.py
@@ -0,0 +1,255 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
+from src.facerender.modules.dense_motion import DenseMotionNetwork
+
+
+class OcclusionAwareGenerator(nn.Module):
+ """
+ Generator follows NVIDIA architecture.
+ """
+
+ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+ num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+ super(OcclusionAwareGenerator, self).__init__()
+
+ if dense_motion_params is not None:
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params)
+ else:
+ self.dense_motion_network = None
+
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+ out_features = block_expansion * (2 ** (num_down_blocks))
+ self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+ self.resblocks_2d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
+
+ up_blocks = []
+ for i in range(num_down_blocks):
+ in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
+ out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
+ self.estimate_occlusion_map = estimate_occlusion_map
+ self.image_channel = image_channel
+
+ def deform_input(self, inp, deformation):
+ _, d_old, h_old, w_old, _ = deformation.shape
+ _, _, d, h, w = inp.shape
+ if d_old != d or h_old != h or w_old != w:
+ deformation = deformation.permute(0, 4, 1, 2, 3)
+ deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+ deformation = deformation.permute(0, 2, 3, 4, 1)
+ return F.grid_sample(inp, deformation)
+
+ def forward(self, source_image, kp_driving, kp_source):
+ # Encoding (downsampling) part
+ out = self.first(source_image)
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape
+ # print(out.shape)
+ feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
+ feature_3d = self.resblocks_3d(feature_3d)
+
+ # Transforming feature representation according to deformation and occlusion
+ output_dict = {}
+ if self.dense_motion_network is not None:
+ dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+ kp_source=kp_source)
+ output_dict['mask'] = dense_motion['mask']
+
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+ else:
+ occlusion_map = None
+ deformation = dense_motion['deformation']
+ out = self.deform_input(feature_3d, deformation)
+
+ bs, c, d, h, w = out.shape
+ out = out.view(bs, c*d, h, w)
+ out = self.third(out)
+ out = self.fourth(out)
+
+ if occlusion_map is not None:
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ out = out * occlusion_map
+
+ # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image
+
+ # Decoding part
+ out = self.resblocks_2d(out)
+ for i in range(len(self.up_blocks)):
+ out = self.up_blocks[i](out)
+ out = self.final(out)
+ out = F.sigmoid(out)
+
+ output_dict["prediction"] = out
+
+ return output_dict
+
+
+class SPADEDecoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ ic = 256
+ oc = 64
+ norm_G = 'spadespectralinstance'
+ label_nc = 256
+
+ self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
+ self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+ self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
+ self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
+ self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
+ self.up = nn.Upsample(scale_factor=2)
+
+ def forward(self, feature):
+ seg = feature
+ x = self.fc(feature)
+ x = self.G_middle_0(x, seg)
+ x = self.G_middle_1(x, seg)
+ x = self.G_middle_2(x, seg)
+ x = self.G_middle_3(x, seg)
+ x = self.G_middle_4(x, seg)
+ x = self.G_middle_5(x, seg)
+ x = self.up(x)
+ x = self.up_0(x, seg) # 256, 128, 128
+ x = self.up(x)
+ x = self.up_1(x, seg) # 64, 256, 256
+
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
+ # x = torch.tanh(x)
+ x = F.sigmoid(x)
+
+ return x
+
+
+class OcclusionAwareSPADEGenerator(nn.Module):
+
+ def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+ num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+ super(OcclusionAwareSPADEGenerator, self).__init__()
+
+ if dense_motion_params is not None:
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+ estimate_occlusion_map=estimate_occlusion_map,
+ **dense_motion_params)
+ else:
+ self.dense_motion_network = None
+
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
+
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+ self.reshape_channel = reshape_channel
+ self.reshape_depth = reshape_depth
+
+ self.resblocks_3d = torch.nn.Sequential()
+ for i in range(num_resblocks):
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+ out_features = block_expansion * (2 ** (num_down_blocks))
+ self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+ self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+ self.estimate_occlusion_map = estimate_occlusion_map
+ self.image_channel = image_channel
+
+ self.decoder = SPADEDecoder()
+
+ def deform_input(self, inp, deformation):
+ _, d_old, h_old, w_old, _ = deformation.shape
+ _, _, d, h, w = inp.shape
+ if d_old != d or h_old != h or w_old != w:
+ deformation = deformation.permute(0, 4, 1, 2, 3)
+ deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+ deformation = deformation.permute(0, 2, 3, 4, 1)
+ return F.grid_sample(inp, deformation)
+
+ def forward(self, source_image, kp_driving, kp_source):
+ # Encoding (downsampling) part
+ out = self.first(source_image)
+ for i in range(len(self.down_blocks)):
+ out = self.down_blocks[i](out)
+ out = self.second(out)
+ bs, c, h, w = out.shape
+ # print(out.shape)
+ feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
+ feature_3d = self.resblocks_3d(feature_3d)
+
+ # Transforming feature representation according to deformation and occlusion
+ output_dict = {}
+ if self.dense_motion_network is not None:
+ dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+ kp_source=kp_source)
+ output_dict['mask'] = dense_motion['mask']
+
+ # import pdb; pdb.set_trace()
+
+ if 'occlusion_map' in dense_motion:
+ occlusion_map = dense_motion['occlusion_map']
+ output_dict['occlusion_map'] = occlusion_map
+ else:
+ occlusion_map = None
+ deformation = dense_motion['deformation']
+ out = self.deform_input(feature_3d, deformation)
+
+ bs, c, d, h, w = out.shape
+ out = out.view(bs, c*d, h, w)
+ out = self.third(out)
+ out = self.fourth(out)
+
+ # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)
+
+ if occlusion_map is not None:
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+ out = out * occlusion_map
+
+ # Decoding part
+ out = self.decoder(out)
+
+ output_dict["prediction"] = out
+
+ return output_dict
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/modules/keypoint_detector.py b/sadtalker_video2pose/src/facerender/modules/keypoint_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..e56800c7b1e94bb3cbf97200cd3f059ce9d29cf3
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/keypoint_detector.py
@@ -0,0 +1,179 @@
+from torch import nn
+import torch
+import torch.nn.functional as F
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
+
+
+class KPDetector(nn.Module):
+ """
+ Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
+ """
+
+ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
+ num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
+ super(KPDetector, self).__init__()
+
+ self.predictor = KPHourglass(block_expansion, in_features=image_channel,
+ max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
+
+ # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
+ self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
+
+ if estimate_jacobian:
+ self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
+ # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
+ self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
+ '''
+ initial as:
+ [[1 0 0]
+ [0 1 0]
+ [0 0 1]]
+ '''
+ self.jacobian.weight.data.zero_()
+ self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
+ else:
+ self.jacobian = None
+
+ self.temperature = temperature
+ self.scale_factor = scale_factor
+ if self.scale_factor != 1:
+ self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
+
+ def gaussian2kp(self, heatmap):
+ """
+ Extract the mean from a heatmap
+ """
+ shape = heatmap.shape
+ heatmap = heatmap.unsqueeze(-1)
+ grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
+ value = (heatmap * grid).sum(dim=(2, 3, 4))
+ kp = {'value': value}
+
+ return kp
+
+ def forward(self, x):
+ if self.scale_factor != 1:
+ x = self.down(x)
+
+ feature_map = self.predictor(x)
+ prediction = self.kp(feature_map)
+
+ final_shape = prediction.shape
+ heatmap = prediction.view(final_shape[0], final_shape[1], -1)
+ heatmap = F.softmax(heatmap / self.temperature, dim=2)
+ heatmap = heatmap.view(*final_shape)
+
+ out = self.gaussian2kp(heatmap)
+
+ if self.jacobian is not None:
+ jacobian_map = self.jacobian(feature_map)
+ jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
+ final_shape[3], final_shape[4])
+ heatmap = heatmap.unsqueeze(2)
+
+ jacobian = heatmap * jacobian_map
+ jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
+ jacobian = jacobian.sum(dim=-1)
+ jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
+ out['jacobian'] = jacobian
+
+ return out
+
+
+class HEEstimator(nn.Module):
+ """
+ Estimating head pose and expression.
+ """
+
+ def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
+ super(HEEstimator, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
+ self.norm1 = BatchNorm2d(block_expansion, affine=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
+ self.norm2 = BatchNorm2d(256, affine=True)
+
+ self.block1 = nn.Sequential()
+ for i in range(3):
+ self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
+
+ self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
+ self.norm3 = BatchNorm2d(512, affine=True)
+ self.block2 = ResBottleneck(in_features=512, stride=2)
+
+ self.block3 = nn.Sequential()
+ for i in range(3):
+ self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
+
+ self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
+ self.norm4 = BatchNorm2d(1024, affine=True)
+ self.block4 = ResBottleneck(in_features=1024, stride=2)
+
+ self.block5 = nn.Sequential()
+ for i in range(5):
+ self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
+
+ self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
+ self.norm5 = BatchNorm2d(2048, affine=True)
+ self.block6 = ResBottleneck(in_features=2048, stride=2)
+
+ self.block7 = nn.Sequential()
+ for i in range(2):
+ self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
+
+ self.fc_roll = nn.Linear(2048, num_bins)
+ self.fc_pitch = nn.Linear(2048, num_bins)
+ self.fc_yaw = nn.Linear(2048, num_bins)
+
+ self.fc_t = nn.Linear(2048, 3)
+
+ self.fc_exp = nn.Linear(2048, 3*num_kp)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = F.relu(out)
+ out = self.maxpool(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+
+ out = self.block1(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+ out = F.relu(out)
+ out = self.block2(out)
+
+ out = self.block3(out)
+
+ out = self.conv4(out)
+ out = self.norm4(out)
+ out = F.relu(out)
+ out = self.block4(out)
+
+ out = self.block5(out)
+
+ out = self.conv5(out)
+ out = self.norm5(out)
+ out = F.relu(out)
+ out = self.block6(out)
+
+ out = self.block7(out)
+
+ out = F.adaptive_avg_pool2d(out, 1)
+ out = out.view(out.shape[0], -1)
+
+ yaw = self.fc_roll(out)
+ pitch = self.fc_pitch(out)
+ roll = self.fc_yaw(out)
+ t = self.fc_t(out)
+ exp = self.fc_exp(out)
+
+ return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
+
diff --git a/sadtalker_video2pose/src/facerender/modules/make_animation.py b/sadtalker_video2pose/src/facerender/modules/make_animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c8c53dcc04da8354d05c98c2bc0d88bf067fb2
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/make_animation.py
@@ -0,0 +1,170 @@
+from scipy.spatial import ConvexHull
+import torch
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm
+
+def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
+ use_relative_movement=False, use_relative_jacobian=False):
+ if adapt_movement_scale:
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
+ else:
+ adapt_movement_scale = 1
+
+ kp_new = {k: v for k, v in kp_driving.items()}
+
+ if use_relative_movement:
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
+ kp_value_diff *= adapt_movement_scale
+ kp_new['value'] = kp_value_diff + kp_source['value']
+
+ if use_relative_jacobian:
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
+
+ return kp_new
+
+def headpose_pred_to_degree(pred):
+ device = pred.device
+ idx_tensor = [idx for idx in range(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
+ pred = F.softmax(pred)
+ degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+ return degree
+
+def get_rotation_matrix(yaw, pitch, roll):
+ yaw = yaw / 180 * 3.14
+ pitch = pitch / 180 * 3.14
+ roll = roll / 180 * 3.14
+
+ roll = roll.unsqueeze(1)
+ pitch = pitch.unsqueeze(1)
+ yaw = yaw.unsqueeze(1)
+
+ pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
+ torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
+ torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
+ pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+ yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
+ torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
+ -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
+ yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+ roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
+ torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
+ torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
+ roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+ rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
+
+ return rot_mat
+
+def keypoint_transformation(kp_canonical, he, wo_exp=False):
+ kp = kp_canonical['value'] # (bs, k, 3)
+ yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']
+ yaw = headpose_pred_to_degree(yaw)
+ pitch = headpose_pred_to_degree(pitch)
+ roll = headpose_pred_to_degree(roll)
+
+ if 'yaw_in' in he:
+ yaw = he['yaw_in']
+ if 'pitch_in' in he:
+ pitch = he['pitch_in']
+ if 'roll_in' in he:
+ roll = he['roll_in']
+
+ rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ t, exp = he['t'], he['exp']
+ if wo_exp:
+ exp = exp*0
+
+ # keypoint rotation
+ kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+ # keypoint translation
+ t[:, 0] = t[:, 0]*0
+ t[:, 2] = t[:, 2]*0
+ t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.view(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return {'value': kp_transformed}
+
+
+
+def make_animation(source_image, source_semantics, target_semantics,
+ generator, kp_detector, he_estimator, mapping,
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
+ use_exp=True, use_half=False):
+ with torch.no_grad():
+ predictions = []
+
+ kp_canonical = kp_detector(source_image)
+ he_source = mapping(source_semantics)
+ kp_source = keypoint_transformation(kp_canonical, he_source)
+
+ for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
+ # still check the dimension
+ # print(target_semantics.shape, source_semantics.shape)
+ target_semantics_frame = target_semantics[:, frame_idx]
+ he_driving = mapping(target_semantics_frame)
+ if yaw_c_seq is not None:
+ he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
+ if pitch_c_seq is not None:
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
+ if roll_c_seq is not None:
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
+
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
+
+ kp_norm = kp_driving
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
+ '''
+ source_image_new = out['prediction'].squeeze(1)
+ kp_canonical_new = kp_detector(source_image_new)
+ he_source_new = he_estimator(source_image_new)
+ kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
+ kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
+ out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
+ '''
+ predictions.append(out['prediction'])
+ predictions_ts = torch.stack(predictions, dim=1)
+ return predictions_ts
+
+class AnimateModel(torch.nn.Module):
+ """
+ Merge all generator related updates into single model for better multi-gpu usage
+ """
+
+ def __init__(self, generator, kp_extractor, mapping):
+ super(AnimateModel, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.mapping = mapping
+
+ self.kp_extractor.eval()
+ self.generator.eval()
+ self.mapping.eval()
+
+ def forward(self, x):
+
+ source_image = x['source_image']
+ source_semantics = x['source_semantics']
+ target_semantics = x['target_semantics']
+ yaw_c_seq = x['yaw_c_seq']
+ pitch_c_seq = x['pitch_c_seq']
+ roll_c_seq = x['roll_c_seq']
+
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
+ self.generator, self.kp_extractor,
+ self.mapping, use_exp = True,
+ yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)
+
+ return predictions_video
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/modules/mapping.py b/sadtalker_video2pose/src/facerender/modules/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac98dd9e177b949f71f8f47029b66d67ece05b4
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/mapping.py
@@ -0,0 +1,47 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ self.fc_roll = nn.Linear(descriptor_nc, num_bins)
+ self.fc_pitch = nn.Linear(descriptor_nc, num_bins)
+ self.fc_yaw = nn.Linear(descriptor_nc, num_bins)
+ self.fc_t = nn.Linear(descriptor_nc, 3)
+ self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ out = out.view(out.shape[0], -1)
+ #print('out:', out.shape)
+
+ yaw = self.fc_yaw(out)
+ pitch = self.fc_pitch(out)
+ roll = self.fc_roll(out)
+ t = self.fc_t(out)
+ exp = self.fc_exp(out)
+
+ return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/modules/util.py b/sadtalker_video2pose/src/facerender/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bfb1f26427b491f032ca9952db41cdeb793d70
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/modules/util.py
@@ -0,0 +1,564 @@
+from torch import nn
+
+import torch.nn.functional as F
+import torch
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+import torch.nn.utils.spectral_norm as spectral_norm
+
+
+def kp2gaussian(kp, spatial_size, kp_variance):
+ """
+ Transform a keypoint into gaussian like representation
+ """
+ mean = kp['value']
+
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
+ number_of_leading_dimensions = len(mean.shape) - 1
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
+ coordinate_grid = coordinate_grid.view(*shape)
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
+ coordinate_grid = coordinate_grid.repeat(*repeats)
+
+ # Preprocess kp shape
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
+ mean = mean.view(*shape)
+
+ mean_sub = (coordinate_grid - mean)
+
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
+
+ return out
+
+def make_coordinate_grid_2d(spatial_size, type):
+ """
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
+ """
+ h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+
+ return meshed
+
+
+def make_coordinate_grid(spatial_size, type):
+ d, h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+ z = torch.arange(d).type(type)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+ z = (2 * (z / (d - 1)) - 1)
+
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
+
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
+
+ return meshed
+
+
+class ResBottleneck(nn.Module):
+ def __init__(self, in_features, stride):
+ super(ResBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)
+ self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)
+ self.norm1 = BatchNorm2d(in_features//4, affine=True)
+ self.norm2 = BatchNorm2d(in_features//4, affine=True)
+ self.norm3 = BatchNorm2d(in_features, affine=True)
+
+ self.stride = stride
+ if self.stride != 1:
+ self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)
+ self.norm4 = BatchNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv3(out)
+ out = self.norm3(out)
+ if self.stride != 1:
+ x = self.skip(x)
+ x = self.norm4(x)
+ out += x
+ out = F.relu(out)
+ return out
+
+
+class ResBlock2d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm2d(in_features, affine=True)
+ self.norm2 = BatchNorm2d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class ResBlock3d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock3d, self).__init__()
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm3d(in_features, affine=True)
+ self.norm2 = BatchNorm3d(in_features, affine=True)
+
+ def forward(self, x):
+ out = self.norm1(x)
+ out = F.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = F.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+
+class UpBlock2d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock2d, self).__init__()
+
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+
+ def forward(self, x):
+ out = F.interpolate(x, scale_factor=2)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+class UpBlock3d(nn.Module):
+ """
+ Upsampling block for use in decoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock3d, self).__init__()
+
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm3d(out_features, affine=True)
+
+ def forward(self, x):
+ # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class DownBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class DownBlock3d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock3d, self).__init__()
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups, stride=(1, 2, 2))
+ '''
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm3d(out_features, affine=True)
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = F.relu(out)
+ out = self.pool(out)
+ return out
+
+
+class SameBlock2d(nn.Module):
+ """
+ Simple block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+ kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features, affine=True)
+ if lrelu:
+ self.ac = nn.LeakyReLU()
+ else:
+ self.ac = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = self.ac(out)
+ return out
+
+
+class Encoder(nn.Module):
+ """
+ Hourglass Encoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Encoder, self).__init__()
+
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ return outs
+
+
+class Decoder(nn.Module):
+ """
+ Hourglass Decoder
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Decoder, self).__init__()
+
+ up_blocks = []
+
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+ # self.out_filters = block_expansion
+ self.out_filters = block_expansion + in_features
+
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
+ self.norm = BatchNorm3d(self.out_filters, affine=True)
+
+ def forward(self, x):
+ out = x.pop()
+ # for up_block in self.up_blocks[:-1]:
+ for up_block in self.up_blocks:
+ out = up_block(out)
+ skip = x.pop()
+ out = torch.cat([out, skip], dim=1)
+ # out = self.up_blocks[-1](out)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+
+class Hourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_filters = self.decoder.out_filters
+
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
+
+
+class KPHourglass(nn.Module):
+ """
+ Hourglass architecture.
+ """
+
+ def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):
+ super(KPHourglass, self).__init__()
+
+ self.down_blocks = nn.Sequential()
+ for i in range(num_blocks):
+ self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+
+ in_filters = min(max_features, block_expansion * (2 ** num_blocks))
+ self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)
+
+ self.up_blocks = nn.Sequential()
+ for i in range(num_blocks):
+ in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))
+ out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
+ self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+ self.reshape_depth = reshape_depth
+ self.out_filters = out_filters
+
+ def forward(self, x):
+ out = self.down_blocks(x)
+ out = self.conv(out)
+ bs, c, h, w = out.shape
+ out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)
+ out = self.up_blocks(out)
+
+ return out
+
+
+
+class AntiAliasInterpolation2d(nn.Module):
+ """
+ Band-limited downsampling, for better preservation of the input signal.
+ """
+ def __init__(self, channels, scale):
+ super(AntiAliasInterpolation2d, self).__init__()
+ sigma = (1 / scale - 1) / 2
+ kernel_size = 2 * round(sigma * 4) + 1
+ self.ka = kernel_size // 2
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+ kernel_size = [kernel_size, kernel_size]
+ sigma = [sigma, sigma]
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+ self.scale = scale
+ inv_scale = 1 / scale
+ self.int_inv_scale = int(inv_scale)
+
+ def forward(self, input):
+ if self.scale == 1.0:
+ return input
+
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
+
+ return out
+
+
+class SPADE(nn.Module):
+ def __init__(self, norm_nc, label_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
+ nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+
+ def forward(self, x, segmap):
+ normalized = self.param_free_norm(x)
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+ self.use_se = use_se
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+ # apply spectral norm if specified
+ if 'spectral' in norm_G:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+ # define normalization layers
+ self.norm_0 = SPADE(fin, label_nc)
+ self.norm_1 = SPADE(fmiddle, label_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(fin, label_nc)
+
+ def forward(self, x, seg1):
+ x_s = self.shortcut(x, seg1)
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg1):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg1))
+ else:
+ x_s = x
+ return x_s
+
+ def actvn(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+class audio2image(nn.Module):
+ def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params):
+ super().__init__()
+ # Attributes
+ self.generator = generator
+ self.kp_extractor = kp_extractor
+ self.he_estimator_video = he_estimator_video
+ self.he_estimator_audio = he_estimator_audio
+ self.train_params = train_params
+
+ def headpose_pred_to_degree(self, pred):
+ device = pred.device
+ idx_tensor = [idx for idx in range(66)]
+ idx_tensor = torch.FloatTensor(idx_tensor).to(device)
+ pred = F.softmax(pred)
+ degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+
+ return degree
+
+ def get_rotation_matrix(self, yaw, pitch, roll):
+ yaw = yaw / 180 * 3.14
+ pitch = pitch / 180 * 3.14
+ roll = roll / 180 * 3.14
+
+ roll = roll.unsqueeze(1)
+ pitch = pitch.unsqueeze(1)
+ yaw = yaw.unsqueeze(1)
+
+ roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll),
+ torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
+ torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
+ roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+ pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch),
+ torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
+ -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
+ pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+ yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),
+ torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
+ torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
+ yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+ rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
+
+ return rot_mat
+
+ def keypoint_transformation(self, kp_canonical, he):
+ kp = kp_canonical['value'] # (bs, k, 3)
+ yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
+ t, exp = he['t'], he['exp']
+
+ yaw = self.headpose_pred_to_degree(yaw)
+ pitch = self.headpose_pred_to_degree(pitch)
+ roll = self.headpose_pred_to_degree(roll)
+
+ rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ # keypoint rotation
+ kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+
+
+ # keypoint translation
+ t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.view(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return {'value': kp_transformed}
+
+ def forward(self, source_image, target_audio):
+ pose_source = self.he_estimator_video(source_image)
+ pose_generated = self.he_estimator_audio(target_audio)
+ kp_canonical = self.kp_extractor(source_image)
+ kp_source = self.keypoint_transformation(kp_canonical, pose_source)
+ kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated)
+ generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated)
+ return generated
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/pirender/base_function.py b/sadtalker_video2pose/src/facerender/pirender/base_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..650fb7de1b95fc34e4b7c17b2526c1f450a577a0
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/pirender/base_function.py
@@ -0,0 +1,368 @@
+import sys
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
+
+
+class LayerNorm2d(nn.Module):
+ def __init__(self, n_out, affine=True):
+ super(LayerNorm2d, self).__init__()
+ self.n_out = n_out
+ self.affine = affine
+
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
+
+ def forward(self, x):
+ normalized_shape = x.size()[1:]
+ if self.affine:
+ return F.layer_norm(x, normalized_shape, \
+ self.weight.expand(normalized_shape),
+ self.bias.expand(normalized_shape))
+
+ else:
+ return F.layer_norm(x, normalized_shape)
+
+class ADAINHourglass(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
+ super(ADAINHourglass, self).__init__()
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
+ self.output_nc = self.decoder.output_nc
+
+ def forward(self, x, z):
+ return self.decoder(self.encoder(x, z), z)
+
+
+
+class ADAINEncoder(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoder, self).__init__()
+ self.layers = layers
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
+ for i in range(layers):
+ in_channels = min(ngf * (2**i), img_f)
+ out_channels = min(ngf *(2**(i+1)), img_f)
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
+ setattr(self, 'encoder' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = self.input_layer(x)
+ out_list = [out]
+ for i in range(self.layers):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out, z)
+ out_list.append(out)
+ return out_list
+
+class ADAINDecoder(nn.Module):
+ """docstring for ADAINDecoder"""
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
+
+ super(ADAINDecoder, self).__init__()
+ self.encoder_layers = encoder_layers
+ self.decoder_layers = decoder_layers
+ self.skip_connect = skip_connect
+ use_transpose = True
+
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
+ in_channels = min(ngf * (2**(i+1)), img_f)
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
+ out_channels = min(ngf * (2**i), img_f)
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
+ setattr(self, 'decoder' + str(i), model)
+
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
+
+ def forward(self, x, z):
+ out = x.pop() if self.skip_connect else x
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
+ model = getattr(self, 'decoder' + str(i))
+ out = model(out, z)
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
+ return out
+
+class ADAINEncoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoderBlock, self).__init__()
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
+
+
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(output_nc, feature_nc)
+ self.actvn = nonlinearity
+
+ def forward(self, x, z):
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
+ return x
+
+class ADAINDecoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINDecoderBlock, self).__init__()
+ # Attributes
+ self.actvn = nonlinearity
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
+
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
+ if use_transpose:
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
+ else:
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
+
+ # create conv layers
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
+ if use_transpose:
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
+ else:
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ # define normalization layers
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
+ self.norm_s = ADAIN(input_nc, feature_nc)
+
+ def forward(self, x, z):
+ x_s = self.shortcut(x, z)
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, z):
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
+ return x_s
+
+
+def spectral_norm(module, use_spect=True):
+ """use spectral normal layer to stable the training process"""
+ if use_spect:
+ return SpectralNorm(module)
+ else:
+ return module
+
+
+class ADAIN(nn.Module):
+ def __init__(self, norm_nc, feature_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+
+ nhidden = 128
+ use_bias=True
+
+ self.mlp_shared = nn.Sequential(
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
+ nn.ReLU()
+ )
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
+
+ def forward(self, x, feature):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on feature
+ feature = feature.view(feature.size(0), -1)
+ actv = self.mlp_shared(feature)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
+ beta = beta.view(*beta.size()[:2], 1,1)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class FineEncoder(nn.Module):
+ """docstring for Encoder"""
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineEncoder, self).__init__()
+ self.layers = layers
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
+ for i in range(layers):
+ in_channels = min(ngf*(2**i), img_f)
+ out_channels = min(ngf*(2**(i+1)), img_f)
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'down' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x):
+ x = self.first(x)
+ out=[x]
+ for i in range(self.layers):
+ model = getattr(self, 'down'+str(i))
+ x = model(x)
+ out.append(x)
+ return out
+
+class FineDecoder(nn.Module):
+ """docstring for FineDecoder"""
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineDecoder, self).__init__()
+ self.layers = layers
+ for i in range(layers)[::-1]:
+ in_channels = min(ngf*(2**(i+1)), img_f)
+ out_channels = min(ngf*(2**i), img_f)
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
+
+ setattr(self, 'up' + str(i), up)
+ setattr(self, 'res' + str(i), res)
+ setattr(self, 'jump' + str(i), jump)
+
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
+
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = x.pop()
+ for i in range(self.layers)[::-1]:
+ res_model = getattr(self, 'res' + str(i))
+ up_model = getattr(self, 'up' + str(i))
+ jump_model = getattr(self, 'jump' + str(i))
+ out = res_model(out, z)
+ out = up_model(out)
+ out = jump_model(x.pop()) + out
+ out_image = self.final(out)
+ return out_image
+
+class FirstBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FirstBlock2d, self).__init__()
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class DownBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(DownBlock2d, self).__init__()
+
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity, pool)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class UpBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(UpBlock2d, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(F.interpolate(x, scale_factor=2))
+ return out
+
+class FineADAINResBlocks(nn.Module):
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlocks, self).__init__()
+ self.num_block = num_block
+ for i in range(num_block):
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'res'+str(i), model)
+
+ def forward(self, x, z):
+ for i in range(self.num_block):
+ model = getattr(self, 'res'+str(i))
+ x = model(x, z)
+ return x
+
+class Jump(nn.Module):
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(Jump, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class FineADAINResBlock2d(nn.Module):
+ """
+ Define an Residual block for different types
+ """
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.norm1 = ADAIN(input_nc, feature_nc)
+ self.norm2 = ADAIN(input_nc, feature_nc)
+
+ self.actvn = nonlinearity
+
+
+ def forward(self, x, z):
+ dx = self.actvn(self.norm1(self.conv1(x), z))
+ dx = self.norm2(self.conv2(x), z)
+ out = dx + x
+ return out
+
+class FinalBlock2d(nn.Module):
+ """
+ Define the output layer
+ """
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
+ super(FinalBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if tanh_or_sigmoid == 'sigmoid':
+ out_nonlinearity = nn.Sigmoid()
+ else:
+ out_nonlinearity = nn.Tanh()
+
+ self.model = nn.Sequential(conv, out_nonlinearity)
+ def forward(self, x):
+ out = self.model(x)
+ return out
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/pirender/config.py b/sadtalker_video2pose/src/facerender/pirender/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..29dc2d1b9008dbf2dc3c0a307212471621bae8da
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/pirender/config.py
@@ -0,0 +1,211 @@
+import collections
+import functools
+import os
+import re
+
+import yaml
+
+class AttrDict(dict):
+ """Dict as attribute trick."""
+
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+ for key, value in self.__dict__.items():
+ if isinstance(value, dict):
+ self.__dict__[key] = AttrDict(value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ self.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ self.__dict__[key] = value
+
+ def yaml(self):
+ """Convert object to yaml dict and return."""
+ yaml_dict = {}
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ yaml_dict[key] = value.yaml()
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ new_l = []
+ for item in value:
+ new_l.append(item.yaml())
+ yaml_dict[key] = new_l
+ else:
+ yaml_dict[key] = value
+ else:
+ yaml_dict[key] = value
+ return yaml_dict
+
+ def __repr__(self):
+ """Print all variables."""
+ ret_str = []
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ ret_str.append('{}:'.format(key))
+ child_ret_str = value.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ ret_str.append('{}:'.format(key))
+ for item in value:
+ # Treat as AttrDict above.
+ child_ret_str = item.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ return '\n'.join(ret_str)
+
+
+class Config(AttrDict):
+ r"""Configuration class. This should include every human specifiable
+ hyperparameter values for your training."""
+
+ def __init__(self, filename=None, args=None, verbose=False, is_train=True):
+ super(Config, self).__init__()
+ # Set default parameters.
+ # Logging.
+
+ large_number = 1000000000
+ self.snapshot_save_iter = large_number
+ self.snapshot_save_epoch = large_number
+ self.snapshot_save_start_iter = 0
+ self.snapshot_save_start_epoch = 0
+ self.image_save_iter = large_number
+ self.eval_epoch = large_number
+ self.start_eval_epoch = large_number
+ self.eval_epoch = large_number
+ self.max_epoch = large_number
+ self.max_iter = large_number
+ self.logging_iter = 100
+ self.image_to_tensorboard=False
+ self.which_iter = 0 # args.which_iter
+ self.resume = False
+
+ self.checkpoints_dir = '/Users/shadowcun/Downloads/'
+ self.name = 'face'
+ self.phase = 'train' if is_train else 'test'
+
+ # Networks.
+ self.gen = AttrDict(type='generators.dummy')
+ self.dis = AttrDict(type='discriminators.dummy')
+
+ # Optimizers.
+ self.gen_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ self.dis_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ # Data.
+ self.data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0)
+ self.test_data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0,
+ test=AttrDict(is_lmdb=False,
+ roots='',
+ batch_size=1))
+ self.trainer = AttrDict(
+ model_average=False,
+ model_average_beta=0.9999,
+ model_average_start_iteration=1000,
+ model_average_batch_norm_estimation_iteration=30,
+ model_average_remove_sn=True,
+ image_to_tensorboard=False,
+ hparam_to_tensorboard=False,
+ distributed_data_parallel='pytorch',
+ delay_allreduce=True,
+ gan_relativistic=False,
+ gen_step=1,
+ dis_step=1)
+
+ # # Cudnn.
+ self.cudnn = AttrDict(deterministic=False,
+ benchmark=True)
+
+ # Others.
+ self.pretrained_weight = ''
+ self.inference_args = AttrDict()
+
+
+ # Update with given configurations.
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
+ loader = yaml.SafeLoader
+ loader.add_implicit_resolver(
+ u'tag:yaml.org,2002:float',
+ re.compile(u'''^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$''', re.X),
+ list(u'-+0123456789.'))
+ try:
+ with open(filename, 'r') as f:
+ cfg_dict = yaml.load(f, Loader=loader)
+ except EnvironmentError:
+ print('Please check the file with name of "%s"', filename)
+ recursive_update(self, cfg_dict)
+
+ # Put common opts in both gen and dis.
+ if 'common' in cfg_dict:
+ self.common = AttrDict(**cfg_dict['common'])
+ self.gen.common = self.common
+ self.dis.common = self.common
+
+
+ if verbose:
+ print(' config '.center(80, '-'))
+ print(self.__repr__())
+ print(''.center(80, '-'))
+
+
+def rsetattr(obj, attr, val):
+ """Recursively find object and set value"""
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+ """Recursively find object and return value"""
+
+ def _getattr(obj, attr):
+ r"""Get attribute."""
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def recursive_update(d, u):
+ """Recursively update AttrDict d with AttrDict u"""
+ for key, value in u.items():
+ if isinstance(value, collections.abc.Mapping):
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ d.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ d.__dict__[key] = value
+ else:
+ d.__dict__[key] = value
+ return d
diff --git a/sadtalker_video2pose/src/facerender/pirender/face_model.py b/sadtalker_video2pose/src/facerender/pirender/face_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f83e2fc5d8c66cf9bd2e2c5549773e11e0f8a44
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/pirender/face_model.py
@@ -0,0 +1,178 @@
+import functools
+import torch
+import torch.nn as nn
+from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
+
+def convert_flow_to_deformation(flow):
+ r"""convert flow fields to deformations.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ deformation (tensor): The deformation used for warpping
+ """
+ b,c,h,w = flow.shape
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
+ grid = make_coordinate_grid(flow)
+ deformation = grid + flow_norm.permute(0,2,3,1)
+ return deformation
+
+def make_coordinate_grid(flow):
+ r"""obtain coordinate grid with the same size as the flow filed.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ grid (tensor): The grid with the same size as the input flow
+ """
+ b,c,h,w = flow.shape
+
+ x = torch.arange(w).to(flow)
+ y = torch.arange(h).to(flow)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ meshed = meshed.expand(b, -1, -1, -1)
+ return meshed
+
+
+def warp_image(source_image, deformation):
+ r"""warp the input image according to the deformation
+
+ Args:
+ source_image (tensor): source images to be warpped
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
+ Returns:
+ output (tensor): the warpped images
+ """
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = source_image.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
+ deformation = deformation.permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(source_image, deformation)
+
+
+class FaceGenerator(nn.Module):
+ def __init__(
+ self,
+ mapping_net,
+ warpping_net,
+ editing_net,
+ common
+ ):
+ super(FaceGenerator, self).__init__()
+ self.mapping_net = MappingNet(**mapping_net)
+ self.warpping_net = WarpingNet(**warpping_net, **common)
+ self.editing_net = EditingNet(**editing_net, **common)
+
+ def forward(
+ self,
+ input_image,
+ driving_source,
+ stage=None
+ ):
+ if stage == 'warp':
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ else:
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
+ return output
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ return out
+
+class WarpingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ base_nc,
+ max_nc,
+ encoder_layer,
+ decoder_layer,
+ use_spect
+ ):
+ super( WarpingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
+
+ self.descriptor_nc = descriptor_nc
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
+ max_nc, encoder_layer, decoder_layer, **kwargs)
+
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
+ nonlinearity,
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
+
+ self.pool = nn.AdaptiveAvgPool2d(1)
+
+ def forward(self, input_image, descriptor):
+ final_output={}
+ output = self.hourglass(input_image, descriptor)
+ final_output['flow_field'] = self.flow_out(output)
+
+ deformation = convert_flow_to_deformation(final_output['flow_field'])
+ final_output['warp_image'] = warp_image(input_image, deformation)
+ return final_output
+
+
+class EditingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ layer,
+ base_nc,
+ max_nc,
+ num_res_blocks,
+ use_spect):
+ super(EditingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
+ self.descriptor_nc = descriptor_nc
+
+ # encoder part
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
+
+ def forward(self, input_image, warp_image, descriptor):
+ x = torch.cat([input_image, warp_image], 1)
+ x = self.encoder(x)
+ gen_image = self.decoder(x, descriptor)
+ return gen_image
diff --git a/sadtalker_video2pose/src/facerender/pirender_animate.py b/sadtalker_video2pose/src/facerender/pirender_animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d4ccf0918f09dcfa422a85694bd17bf42d11ff
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/pirender_animate.py
@@ -0,0 +1,266 @@
+import os
+import uuid
+import cv2
+from tqdm import tqdm
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+from src.facerender.pirender.config import Config
+from src.facerender.pirender.face_model import FaceGenerator
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+from src.utils.flow_util import vis_flow
+from scipy.io import savemat,loadmat
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+expession = loadmat('expression.mat')
+control_dict = {}
+for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']:
+ control_dict[item] = torch.tensor(expession[item])[0]
+
+class AnimateFromCoeff_PIRender():
+
+ def __init__(self, sadtalker_path, device):
+
+ opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False)
+ opt.device = device
+ self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device)
+ checkpoint_path = sadtalker_path['pirender_checkpoint']
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
+ self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False)
+ print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path))
+ self.net_G = self.net_G_ema.eval()
+ self.device = device
+
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+
+ num = 16
+
+ # import pdb; pdb.set_trace()
+ # target_semantics_
+ current = target_semantics[0, 0, :64, 0]
+ for control_k in range(len(control_dict.keys())):
+ listx = list(control_dict.keys())
+ control_v = control_dict[listx[control_k]]
+ for i in range(num):
+ expression = (control_v-current)*i/(num-1)+current
+ target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None]
+
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
+
+ def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+
+
+ num = 16
+
+ current = target_semantics[0, 0, :64, 0]
+ for control_k in range(len(control_dict.keys())):
+ listx = list(control_dict.keys())
+ control_v = control_dict[listx[control_k]]
+ for i in range(num):
+ expression = (control_v-current)*i/(num-1)+current
+ target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None]
+
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+
+ results = np.stack(video, axis=0)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ # original_size = crop_info[0]
+ # if original_size:
+ # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+ # results = np.stack(result, axis=0)
+
+ x_name = os.path.basename(pic_path)
+ save_name = os.path.join(video_save_dir, x_name + '.flo')
+ save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4')
+
+ flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False)
+
+ flow_viz = []
+ for kk in range(flow_full.shape[0]):
+ tmp = vis_flow(flow_full[kk])
+ flow_viz.append(tmp)
+ flow_viz = np.stack(flow_viz)
+
+ torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'})
+
+ return save_name_flow_vis
+
+
+def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ # full images, we only use it as reference for zero init image.
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ # template = np.zeros((frame_h, frame_w, 2)) # full flows
+ out_tmp = []
+ for crop_frame in tqdm(flows, 'seamlessClone:'):
+ p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4)
+
+ gen_img = np.zeros((frame_h, frame_w, 2))
+ # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE)
+ gen_img[oy1:oy2,ox1:ox2] = p
+ out_tmp.append(gen_img)
+
+ np.save(save_name, np.stack(out_tmp))
+ return np.stack(out_tmp)
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/pirender_animate_control.py b/sadtalker_video2pose/src/facerender/pirender_animate_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c357f35577816c8d6731627afd505c6dd8efdca
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/pirender_animate_control.py
@@ -0,0 +1,251 @@
+import os
+import uuid
+import cv2
+from tqdm import tqdm
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+import safetensors
+import safetensors.torch
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+from src.facerender.pirender.config import Config
+from src.facerender.pirender.face_model import FaceGenerator
+
+from pydub import AudioSegment
+from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+from src.utils.flow_util import vis_flow
+
+from scipy.io import savemat,loadmat
+
+try:
+ import webui # in webui
+ in_webui = True
+except:
+ in_webui = False
+
+expession = loadmat('expression.mat')
+control_dict = {}
+for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']:
+ control_dict[item] = torch.tensor(expession[item])[0]
+
+class AnimateFromCoeff_PIRender():
+
+ def __init__(self, sadtalker_path, device):
+
+ opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False)
+ opt.device = device
+ self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device)
+ checkpoint_path = sadtalker_path['pirender_checkpoint']
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
+ self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False)
+ print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path))
+ self.net_G = self.net_G_ema.eval()
+ self.device = device
+
+
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ num = 10
+
+ # target_semantics_
+ current = target_semantics['target_semantics_list'][0, :64, 0]
+ for control in control_dict:
+ for i in range(num):
+ expression = (control_dict[control]-current)*i/(num-1)+current
+ target_semantics['target_semantics_list'][:, :64, :] = expression[None, :, None]
+
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+
+ imageio.mimsave(path, result, fps=float(25))
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+ start_time = 0
+ # cog will not keep the .mp3 filename
+ sound = AudioSegment.from_file(audio_path)
+ frames = frame_num
+ end_time = start_time + frames*1/25*1000
+ word1=sound.set_frame_rate(16000)
+ word = word1[start_time:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ #### paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ except:
+ enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
+
+ def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+
+ source_image=x['source_image'].type(torch.FloatTensor)
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
+ source_image=source_image.to(self.device)
+ source_semantics=source_semantics.to(self.device)
+ target_semantics=target_semantics.to(self.device)
+ frame_num = x['frame_num']
+
+ with torch.no_grad():
+ predictions_video = []
+ for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'):
+ predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field'])
+
+ predictions_video = torch.stack(predictions_video, dim=1)
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+
+ video = []
+ for idx in range(len(predictions_video)):
+ image = predictions_video[idx]
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+ video.append(image)
+
+ results = np.stack(video, axis=0)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ # original_size = crop_info[0]
+ # if original_size:
+ # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
+ # results = np.stack(result, axis=0)
+
+ x_name = os.path.basename(pic_path)
+ save_name = os.path.join(video_save_dir, x_name + '.flo')
+ save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4')
+
+ flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False)
+
+ flow_viz = []
+ for kk in range(flow_full.shape[0]):
+ tmp = vis_flow(flow_full[kk])
+ flow_viz.append(tmp)
+ flow_viz = np.stack(flow_viz)
+
+ torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'})
+
+ return save_name_flow_vis
+
+
+def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ # full images, we only use it as reference for zero init image.
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ # template = np.zeros((frame_h, frame_w, 2)) # full flows
+ out_tmp = []
+ for crop_frame in tqdm(flows, 'seamlessClone:'):
+ p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4)
+
+ gen_img = np.zeros((frame_h, frame_w, 2))
+ # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE)
+ gen_img[oy1:oy2,ox1:ox2] = p
+ out_tmp.append(gen_img)
+
+ np.save(save_name, np.stack(out_tmp))
+ return np.stack(out_tmp)
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/__init__.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48871cdcdc882c903501ecc6d70fcb1b50bd7e9f
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/sync_batchnorm/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# File : __init__.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/batchnorm.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4cc2ccd2f0c904cbe433fb6136f443f0fa86fa6
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/sync_batchnorm/batchnorm.py
@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# File : batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+
+from .comm import SyncMaster
+
+__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
+
+
+def _sum_ft(tensor):
+ """sum over the first and last dimention"""
+ return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+ """add new dementions at the front and the tail"""
+ return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
+ self._is_parallel = False
+ self._parallel_id = None
+ self._slave_pipe = None
+
+ def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ # Resize the input to (B, C, -1).
+ input_shape = input.size()
+ input = input.view(input.size(0), self.num_features, -1)
+
+ # Compute the sum and square-sum.
+ sum_size = input.size(0) * input.size(2)
+ input_sum = _sum_ft(input)
+ input_ssum = _sum_ft(input ** 2)
+
+ # Reduce-and-broadcast the statistics.
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+ # Compute the output.
+ if self.affine:
+ # MJY:: Fuse the multiplication for speed.
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+ else:
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+ # Reshape it.
+ return output.view(input_shape)
+
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
+
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+ # Always using same "device order" makes the ReduceAdd operation faster.
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+ mini-batch.
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of size
+ `batch_size x num_features [x width]`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+ of 3d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+ of 4d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+ or Spatio-temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x depth x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/comm.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b66ec4aea213edf4330beda0a8c8b93d6db77a60
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/sync_batchnorm/comm.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def __getstate__(self):
+ return {'master_callback': self._master_callback}
+
+ def __setstate__(self, state):
+ self.__init__(state['master_callback'])
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/replicate.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b97380d9c5fbe75c4b3583d3668ccd6a2848699
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/sync_batchnorm/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File : replicate.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+ 'CallbackContext',
+ 'execute_replication_callbacks',
+ 'DataParallelWithCallback',
+ 'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+ of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+ """
+ Data Parallel with a replication callback.
+
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+ original `replicate` function.
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ # sync_bn.__data_parallel_replicate__ will be invoked.
+ """
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate
diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/unittest.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/unittest.py
new file mode 100644
index 0000000000000000000000000000000000000000..9716d035495097fb086ec050ab0bc9b76b9d28a0
--- /dev/null
+++ b/sadtalker_video2pose/src/facerender/sync_batchnorm/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File : unittest.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+
+import numpy as np
+from torch.autograd import Variable
+
+
+def as_numpy(v):
+ if isinstance(v, Variable):
+ v = v.data
+ return v.cpu().numpy()
+
+
+class TorchTestCase(unittest.TestCase):
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
+ npa, npb = as_numpy(a), as_numpy(b)
+ self.assertTrue(
+ np.allclose(npa, npb, atol=atol),
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
+ )
diff --git a/sadtalker_video2pose/src/generate_batch.py b/sadtalker_video2pose/src/generate_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcaff51276d489aa76c15e4979864a4d4f74aa4
--- /dev/null
+++ b/sadtalker_video2pose/src/generate_batch.py
@@ -0,0 +1,120 @@
+import os
+
+from tqdm import tqdm
+import torch
+import numpy as np
+import random
+import scipy.io as scio
+import src.utils.audio as audio
+
+def crop_pad_audio(wav, audio_length):
+ if len(wav) > audio_length:
+ wav = wav[:audio_length]
+ elif len(wav) < audio_length:
+ wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
+ return wav
+
+def parse_audio_length(audio_length, sr, fps):
+ bit_per_frames = sr / fps
+
+ num_frames = int(audio_length / bit_per_frames)
+ audio_length = int(num_frames * bit_per_frames)
+
+ return audio_length, num_frames
+
+def generate_blink_seq(num_frames):
+ ratio = np.zeros((num_frames,1))
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = 80
+ if frame_id+start+9<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
+ frame_id = frame_id+start+9
+ else:
+ break
+ return ratio
+
+def generate_blink_seq_randomly(num_frames):
+ ratio = np.zeros((num_frames,1))
+ if num_frames<=20:
+ return ratio
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
+ if frame_id+start+5<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
+ frame_id = frame_id+start+5
+ else:
+ break
+ return ratio
+
+def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
+
+ syncnet_mel_step_size = 16
+ fps = 25
+
+ pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+
+
+ if idlemode:
+ num_frames = int(length_of_audio * 25)
+ indiv_mels = np.zeros((num_frames, 80, 16))
+ else:
+ wav = audio.load_wav(audio_path, 16000)
+ wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
+ wav = crop_pad_audio(wav, wav_length)
+ orig_mel = audio.melspectrogram(wav).T
+ spec = orig_mel.copy() # nframes 80
+ indiv_mels = []
+
+ for i in tqdm(range(num_frames), 'mel:'):
+ start_frame_num = i-2
+ start_idx = int(80. * (start_frame_num / float(fps)))
+ end_idx = start_idx + syncnet_mel_step_size
+ seq = list(range(start_idx, end_idx))
+ seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
+ m = spec[seq, :]
+ indiv_mels.append(m.T)
+ indiv_mels = np.asarray(indiv_mels) # T 80 16
+
+ ratio = generate_blink_seq_randomly(num_frames) # T
+ source_semantics_path = first_coeff_path
+ source_semantics_dict = scio.loadmat(source_semantics_path)
+ ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
+ ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
+
+ if ref_eyeblink_coeff_path is not None:
+ ratio[:num_frames] = 0
+ refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
+ refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
+ refeyeblink_num_frames = refeyeblink_coeff.shape[0]
+ if refeyeblink_num_frames frame_num:
+ new_degree_list = new_degree_list[:frame_num]
+ elif len(new_degree_list) < frame_num:
+ for _ in range(frame_num-len(new_degree_list)):
+ new_degree_list.append(new_degree_list[-1])
+ print(len(new_degree_list))
+ print(frame_num)
+
+ remainder = frame_num%batch_size
+ if remainder!=0:
+ for _ in range(batch_size-remainder):
+ new_degree_list.append(new_degree_list[-1])
+ new_degree_np = np.array(new_degree_list).reshape(batch_size, -1)
+ return new_degree_np
+
diff --git a/sadtalker_video2pose/src/gradio_demo.py b/sadtalker_video2pose/src/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a2399fc44704b544ef39bb908d32a21da9fae17
--- /dev/null
+++ b/sadtalker_video2pose/src/gradio_demo.py
@@ -0,0 +1,170 @@
+import torch, uuid
+import os, sys, shutil, platform
+from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
+from src.utils.preprocess import CropAndExtract
+from src.test_audio2coeff import Audio2Coeff
+from src.facerender.animate import AnimateFromCoeff
+from src.generate_batch import get_data
+from src.generate_facerender_batch import get_facerender_data
+
+from src.utils.init_path import init_path
+
+from pydub import AudioSegment
+
+
+def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
+ mp3_file = AudioSegment.from_file(file=mp3_filename)
+ mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
+
+
+class SadTalker():
+
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
+
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif platform.system() == 'Darwin': # macos
+ device = "mps"
+ else:
+ device = "cpu"
+
+ self.device = device
+
+ os.environ['TORCH_HOME']= checkpoint_path
+
+ self.checkpoint_path = checkpoint_path
+ self.config_path = config_path
+
+
+ def test(self, source_image, driven_audio, preprocess='crop',
+ still_mode=False, use_enhancer=False, batch_size=1, size=256,
+ pose_style = 0,
+ facerender='facevid2vid',
+ exp_scale=1.0,
+ use_ref_video = False,
+ ref_video = None,
+ ref_info = None,
+ use_idle_mode = False,
+ length_of_audio = 0, use_blink=True,
+ result_dir='./results/'):
+
+ self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
+ print(self.sadtalker_paths)
+
+ self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
+ self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
+
+ if facerender == 'facevid2vid' and self.device != 'mps':
+ self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
+ elif facerender == 'pirender' or self.device == 'mps':
+ self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device)
+ facerender = 'pirender'
+ else:
+ raise(RuntimeError('Unknown model: {}'.format(facerender)))
+
+
+ time_tag = str(uuid.uuid4())
+ save_dir = os.path.join(result_dir, time_tag)
+ os.makedirs(save_dir, exist_ok=True)
+
+ input_dir = os.path.join(save_dir, 'input')
+ os.makedirs(input_dir, exist_ok=True)
+
+ print(source_image)
+ pic_path = os.path.join(input_dir, os.path.basename(source_image))
+ shutil.move(source_image, input_dir)
+
+ if driven_audio is not None and os.path.isfile(driven_audio):
+ audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
+
+ #### mp3 to wav
+ if '.mp3' in audio_path:
+ mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
+ audio_path = audio_path.replace('.mp3', '.wav')
+ else:
+ shutil.move(driven_audio, input_dir)
+
+ elif use_idle_mode:
+ audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
+ from pydub import AudioSegment
+ one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
+ one_sec_segment.export(audio_path, format="wav")
+ else:
+ print(use_ref_video, ref_info)
+ assert use_ref_video == True and ref_info == 'all'
+
+ if use_ref_video and ref_info == 'all': # full ref mode
+ ref_video_videoname = os.path.basename(ref_video)
+ audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
+ print('new audiopath:',audio_path)
+ # if ref_video contains audio, set the audio from ref_video.
+ cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
+ os.system(cmd)
+
+ os.makedirs(save_dir, exist_ok=True)
+
+ #crop image and extract 3dmm from image
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+ os.makedirs(first_frame_dir, exist_ok=True)
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
+
+ if first_coeff_path is None:
+ raise AttributeError("No face is detected")
+
+ if use_ref_video:
+ print('using ref video for genreation')
+ ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
+ ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
+ os.makedirs(ref_video_frame_dir, exist_ok=True)
+ print('3DMM Extraction for the reference video providing pose')
+ ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
+ else:
+ ref_video_coeff_path = None
+
+ if use_ref_video:
+ if ref_info == 'pose':
+ ref_pose_coeff_path = ref_video_coeff_path
+ ref_eyeblink_coeff_path = None
+ elif ref_info == 'blink':
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = ref_video_coeff_path
+ elif ref_info == 'pose+blink':
+ ref_pose_coeff_path = ref_video_coeff_path
+ ref_eyeblink_coeff_path = ref_video_coeff_path
+ elif ref_info == 'all':
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = None
+ else:
+ raise('error in refinfo')
+ else:
+ ref_pose_coeff_path = None
+ ref_eyeblink_coeff_path = None
+
+ #audio2ceoff
+ if use_ref_video and ref_info == 'all':
+ coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+ else:
+ batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \
+ idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
+
+ #coeff2video
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \
+ preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender)
+ return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
+ video_name = data['video_name']
+ print(f'The generated video is named {video_name} in {save_dir}')
+
+ del self.preprocess_model
+ del self.audio_to_coeff
+ del self.animate_from_coeff
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ import gc; gc.collect()
+
+ return return_path
+
+
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/test_audio2coeff.py b/sadtalker_video2pose/src/test_audio2coeff.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0f5ca9195bbc980c93fa3e37c6d06cc32953aee
--- /dev/null
+++ b/sadtalker_video2pose/src/test_audio2coeff.py
@@ -0,0 +1,123 @@
+import os
+import torch
+import numpy as np
+from scipy.io import savemat, loadmat
+from yacs.config import CfgNode as CN
+from scipy.signal import savgol_filter
+
+import safetensors
+import safetensors.torch
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.audio2exp_models.audio2exp import Audio2Exp
+from src.utils.safetensor_helper import load_x_from_safetensor
+
+def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if model is not None:
+ model.load_state_dict(checkpoint['model'])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+
+ return checkpoint['epoch']
+
+class Audio2Coeff():
+
+ def __init__(self, sadtalker_path, device):
+ #load config
+ fcfg_pose = open(sadtalker_path['audio2pose_yaml_path'])
+ cfg_pose = CN.load_cfg(fcfg_pose)
+ cfg_pose.freeze()
+ fcfg_exp = open(sadtalker_path['audio2exp_yaml_path'])
+ cfg_exp = CN.load_cfg(fcfg_exp)
+ cfg_exp.freeze()
+
+ # load audio2pose_model
+ self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device)
+ self.audio2pose_model = self.audio2pose_model.to(device)
+ self.audio2pose_model.eval()
+ for param in self.audio2pose_model.parameters():
+ param.requires_grad = False
+
+ try:
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose'))
+ else:
+ load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device)
+ except:
+ raise Exception("Failed in loading audio2pose_checkpoint")
+
+ # load audio2exp_model
+ netG = SimpleWrapperV2()
+ netG = netG.to(device)
+ for param in netG.parameters():
+ netG.requires_grad = False
+ netG.eval()
+ try:
+ if sadtalker_path['use_safetensor']:
+ checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp'))
+ else:
+ load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device)
+ except:
+ raise Exception("Failed in loading audio2exp_checkpoint")
+ self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)
+ self.audio2exp_model = self.audio2exp_model.to(device)
+ for param in self.audio2exp_model.parameters():
+ param.requires_grad = False
+ self.audio2exp_model.eval()
+
+ self.device = device
+
+ def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):
+
+ with torch.no_grad():
+ #test
+ results_dict_exp= self.audio2exp_model.test(batch)
+ exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64
+
+ #for class_id in range(1):
+ #class_id = 0#(i+10)%45
+ #class_id = random.randint(0,46) #46 styles can be selected
+ batch['class'] = torch.LongTensor([pose_style]).to(self.device)
+ results_dict_pose = self.audio2pose_model.test(batch)
+ pose_pred = results_dict_pose['pose_pred'] #bs T 6
+
+ pose_len = pose_pred.shape[1]
+ if pose_len<13:
+ pose_len = int((pose_len-1)/2)*2+1
+ pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device)
+ else:
+ pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device)
+
+ coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70
+
+ coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy()
+
+ if ref_pose_coeff_path is not None:
+ coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)
+
+ savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),
+ {'coeff_3dmm': coeffs_pred_numpy})
+
+ return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))
+
+ def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
+ num_frames = coeffs_pred_numpy.shape[0]
+ refpose_coeff_dict = loadmat(ref_pose_coeff_path)
+ refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70]
+ refpose_num_frames = refpose_coeff.shape[0]
+ if refpose_num_frames= 0
+ if hp.symmetric_mels:
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
+ else:
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
+
+def _denormalize(D):
+ if hp.allow_clipping_in_normalization:
+ if hp.symmetric_mels:
+ return (((np.clip(D, -hp.max_abs_value,
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ + hp.min_level_db)
+ else:
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
+
+ if hp.symmetric_mels:
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
+ else:
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
diff --git a/sadtalker_video2pose/src/utils/croper.py b/sadtalker_video2pose/src/utils/croper.py
new file mode 100644
index 0000000000000000000000000000000000000000..578372debdb8d2b99fe93d3d2ba2dfacf7cbb0ad
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/croper.py
@@ -0,0 +1,145 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import scipy
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+
+from src.face3d.extract_kp_videos_safe import KeypointExtractor
+from facexlib.alignment import landmark_98_to_68
+
+import numpy as np
+from PIL import Image
+
+class Preprocesser:
+ def __init__(self, device='cuda'):
+ self.predictor = KeypointExtractor(device)
+
+ def get_landmark(self, img_np):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ with torch.no_grad():
+ dets = self.predictor.det_net.detect_faces(img_np, 0.97)
+
+ if len(dets) == 0:
+ return None
+ det = dets[0]
+
+ img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
+ lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0]
+
+ #### keypoints to the original location
+ lm[:,0] += int(det[0])
+ lm[:,1] += int(det[1])
+
+ return lm
+
+ def align_face(self, img, lm, output_size=1024):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference
+ x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点
+ qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍
+
+ # Shrink.
+ # 如果计算出的四边形太大了,就按比例缩小它
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+ else:
+ rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ # img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ # if enable_padding and max(pad) > border - 4:
+ # pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # h, w, _ = img.shape
+ # y, x, _ = np.ogrid[:h, :w, :1]
+ # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+ # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+ # blur = qsize * 0.02
+ # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ # quad += pad[:2]
+
+ # Transform.
+ quad = (quad + 0.5).flatten()
+ lx = max(min(quad[0], quad[2]), 0)
+ ly = max(min(quad[1], quad[7]), 0)
+ rx = min(max(quad[4], quad[6]), img.size[0])
+ ry = min(max(quad[3], quad[5]), img.size[0])
+
+ # Save aligned image.
+ return rsize, crop, [lx, ly, rx, ry]
+
+ def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
+ # print(img_np_list)
+ img_np = img_np_list[0]
+ lm = self.get_landmark(img_np)
+
+ if lm is None:
+ raise 'can not detect the landmark from source image'
+ rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ for _i in range(len(img_np_list)):
+ _inp = img_np_list[_i]
+ _inp = cv2.resize(_inp, (rsize[0], rsize[1]))
+ _inp = _inp[cly:cry, clx:crx]
+ if not still:
+ _inp = _inp[ly:ry, lx:rx]
+ img_np_list[_i] = _inp
+ return img_np_list, crop, quad
+
diff --git a/sadtalker_video2pose/src/utils/face_enhancer.py b/sadtalker_video2pose/src/utils/face_enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2664560a1d7199e81f1a50093f29d02de91d4bcc
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/face_enhancer.py
@@ -0,0 +1,123 @@
+import os
+import torch
+
+from gfpgan import GFPGANer
+
+from tqdm import tqdm
+
+from src.utils.videoio import load_video_to_cv2
+
+import cv2
+
+
+class GeneratorWithLen(object):
+ """ From https://stackoverflow.com/a/7460929 """
+
+ def __init__(self, gen, length):
+ self.gen = gen
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ return self.gen
+
+def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ return list(gen)
+
+def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
+ """ Provide a generator with a __len__ method so that it can passed to functions that
+ call len()"""
+
+ if os.path.isfile(images): # handle video to images
+ # TODO: Create a generator version of load_video_to_cv2
+ images = load_video_to_cv2(images)
+
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ gen_with_len = GeneratorWithLen(gen, len(images))
+ return gen_with_len
+
+def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
+ """ Provide a generator function so that all of the enhanced images don't need
+ to be stored in memory at the same time. This can save tons of RAM compared to
+ the enhancer function. """
+
+ print('face enhancer....')
+ if not isinstance(images, list) and os.path.isfile(images): # handle video to images
+ images = load_video_to_cv2(images)
+
+ # ------------------------ set up GFPGAN restorer ------------------------
+ if method == 'gfpgan':
+ arch = 'clean'
+ channel_multiplier = 2
+ model_name = 'GFPGANv1.4'
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
+ elif method == 'RestoreFormer':
+ arch = 'RestoreFormer'
+ channel_multiplier = 2
+ model_name = 'RestoreFormer'
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
+ elif method == 'codeformer': # TODO:
+ arch = 'CodeFormer'
+ channel_multiplier = 2
+ model_name = 'CodeFormer'
+ url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+ else:
+ raise ValueError(f'Wrong model version {method}.')
+
+
+ # ------------------------ set up background upsampler ------------------------
+ if bg_upsampler == 'realesrgan':
+ if not torch.cuda.is_available(): # CPU
+ import warnings
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
+ 'If you really want to use it, please modify the corresponding codes.')
+ bg_upsampler = None
+ else:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ bg_upsampler = RealESRGANer(
+ scale=2,
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
+ model=model,
+ tile=400,
+ tile_pad=10,
+ pre_pad=0,
+ half=True) # need to set False in CPU mode
+ else:
+ bg_upsampler = None
+
+ # determine model paths
+ model_path = os.path.join('gfpgan/weights', model_name + '.pth')
+
+ if not os.path.isfile(model_path):
+ model_path = os.path.join('checkpoints', model_name + '.pth')
+
+ if not os.path.isfile(model_path):
+ # download pre-trained models from url
+ model_path = url
+
+ restorer = GFPGANer(
+ model_path=model_path,
+ upscale=2,
+ arch=arch,
+ channel_multiplier=channel_multiplier,
+ bg_upsampler=bg_upsampler)
+
+ # ------------------------ restore ------------------------
+ for idx in tqdm(range(len(images)), 'Face Enhancer:'):
+
+ img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
+
+ # restore faces and background if necessary
+ cropped_faces, restored_faces, r_img = restorer.enhance(
+ img,
+ has_aligned=False,
+ only_center_face=False,
+ paste_back=True)
+
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
+ yield r_img
diff --git a/sadtalker_video2pose/src/utils/flow_util.py b/sadtalker_video2pose/src/utils/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f25046bab67cc8fbbb59efd02f48d7b6f22fc580
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/flow_util.py
@@ -0,0 +1,221 @@
+import torch
+import sys
+
+
+def convert_flow_to_deformation(flow):
+ r"""convert flow fields to deformations.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ deformation (tensor): The deformation used for warpping
+ """
+ b,c,h,w = flow.shape
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
+ grid = make_coordinate_grid(flow)
+ # print(grid.shape, flow_norm.shape)
+ deformation = grid + flow_norm.permute(0,2,3,1)
+ return deformation
+
+def make_coordinate_grid(flow):
+ r"""obtain coordinate grid with the same size as the flow filed.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ grid (tensor): The grid with the same size as the input flow
+ """
+ b,c,h,w = flow.shape
+
+ x = torch.arange(w).to(flow)
+ y = torch.arange(h).to(flow)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ meshed = meshed.expand(b, -1, -1, -1)
+ return meshed
+
+
+def warp_image(source_image, deformation):
+ r"""warp the input image according to the deformation
+
+ Args:
+ source_image (tensor): source images to be warpped
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
+ Returns:
+ output (tensor): the warpped images
+ """
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = source_image.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
+ deformation = deformation.permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(source_image, deformation)
+
+
+
+# visualize flow
+import numpy as np
+
+__all__ = ['load_flow', 'save_flow', 'vis_flow']
+
+
+def load_flow(path):
+ with open(path, 'rb') as f:
+ magic = float(np.fromfile(f, np.float32, count=1)[0])
+ if magic == 202021.25:
+ w, h = np.fromfile(f, np.int32, count=1)[0], np.fromfile(f, np.int32, count=1)[0]
+ data = np.fromfile(f, np.float32, count=h * w * 2)
+ data.resize((h, w, 2))
+ return data
+ return None
+
+
+def save_flow(path, flow):
+ magic = np.array([202021.25], np.float32)
+ h, w = flow.shape[:2]
+ h, w = np.array([h], np.int32), np.array([w], np.int32)
+
+ with open(path, 'wb') as f:
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ flow.tofile(f)
+
+
+
+def makeColorwheel():
+ # color encoding scheme
+
+ # adapted from the color circle idea described at
+ # http://members.shaw.ca/quadibloc/other/colint.htm
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3]) # r g b
+
+ col = 0
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY, 1) / RY)
+ col += RY
+
+ # YG
+ colorwheel[col:YG + col, 0] = 255 - np.floor(255 * np.arange(0, YG, 1) / YG)
+ colorwheel[col:YG + col, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:GC + col, 1] = 255
+ colorwheel[col:GC + col, 2] = np.floor(255 * np.arange(0, GC, 1) / GC)
+ col += GC
+
+ # CB
+ colorwheel[col:CB + col, 1] = 255 - np.floor(255 * np.arange(0, CB, 1) / CB)
+ colorwheel[col:CB + col, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:BM + col, 2] = 255
+ colorwheel[col:BM + col, 0] = np.floor(255 * np.arange(0, BM, 1) / BM)
+ col += BM
+
+ # MR
+ colorwheel[col:MR + col, 2] = 255 - np.floor(255 * np.arange(0, MR, 1) / MR)
+ colorwheel[col:MR + col, 0] = 255
+ return colorwheel
+
+
+def computeColor(u, v):
+ colorwheel = makeColorwheel()
+ nan_u = np.isnan(u)
+ nan_v = np.isnan(v)
+ nan_u = np.where(nan_u)
+ nan_v = np.where(nan_v)
+
+ u[nan_u] = 0
+ u[nan_v] = 0
+ v[nan_u] = 0
+ v[nan_v] = 0
+
+ ncols = colorwheel.shape[0]
+ radius = np.sqrt(u ** 2 + v ** 2)
+ a = np.arctan2(-v, -u) / np.pi
+ fk = (a + 1) / 2 * (ncols - 1) # -1~1 maped to 1~ncols
+ k0 = fk.astype(np.uint8) # 1, 2, ..., ncols
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+
+ img = np.empty([k1.shape[0], k1.shape[1], 3])
+ ncolors = colorwheel.shape[1]
+ for i in range(ncolors):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255
+ col1 = tmp[k1] / 255
+ col = (1 - f) * col0 + f * col1
+ idx = radius <= 1
+ col[idx] = 1 - radius[idx] * (1 - col[idx]) # increase saturation with radius
+ col[~idx] *= 0.75 # out of range
+ img[:, :, 2 - i] = np.floor(255 * col).astype(np.uint8)
+
+ return img.astype(np.uint8)
+
+
+def vis_flow(flow):
+ eps = sys.float_info.epsilon
+ UNKNOWN_FLOW_THRESH = 1e9
+ UNKNOWN_FLOW = 1e10
+
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999
+ maxv = -999
+
+ minu = 999
+ minv = 999
+
+ maxrad = -1
+ # fix unknown flow
+ greater_u = np.where(u > UNKNOWN_FLOW_THRESH)
+ greater_v = np.where(v > UNKNOWN_FLOW_THRESH)
+ u[greater_u] = 0
+ u[greater_v] = 0
+ v[greater_u] = 0
+ v[greater_v] = 0
+
+ maxu = max([maxu, np.amax(u)])
+ minu = min([minu, np.amin(u)])
+
+ maxv = max([maxv, np.amax(v)])
+ minv = min([minv, np.amin(v)])
+ rad = np.sqrt(np.multiply(u, u) + np.multiply(v, v))
+ maxrad = max([maxrad, np.amax(rad)])
+ # print('max flow: %.4f flow range: u = %.3f .. %.3f; v = %.3f .. %.3f\n' % (maxrad, minu, maxu, minv, maxv))
+
+ u = u / (maxrad + eps)
+ v = v / (maxrad + eps)
+ img = computeColor(u, v)
+ return img[:, :, [2, 1, 0]]
+
+
+def test_visualize_flow():
+ flow = load_flow('out.flo')
+ img = vis_flow(flow)
+
+ import cv2
+ cv2.imwrite("img.png", img)
diff --git a/sadtalker_video2pose/src/utils/hparams.py b/sadtalker_video2pose/src/utils/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..83c312d767c35b9adc988157243efc02129fdb84
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/hparams.py
@@ -0,0 +1,160 @@
+from glob import glob
+import os
+
+class HParams:
+ def __init__(self, **kwargs):
+ self.data = {}
+
+ for key, value in kwargs.items():
+ self.data[key] = value
+
+ def __getattr__(self, key):
+ if key not in self.data:
+ raise AttributeError("'HParams' object has no attribute %s" % key)
+ return self.data[key]
+
+ def set_hparam(self, key, value):
+ self.data[key] = value
+
+
+# Default hyperparameters
+hparams = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=16,
+ initial_learning_rate=1e-4,
+ nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=20,
+ checkpoint_interval=3000,
+ eval_interval=3000,
+ writer_interval=300,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=1000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+
+# Default hyperparameters
+hparamsdebug = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=2,
+ initial_learning_rate=1e-3,
+ nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=0,
+ checkpoint_interval=10000,
+ eval_interval=10,
+ writer_interval=5,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=10000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+def hparams_debug_string():
+ values = hparams.values()
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
+ return "Hyperparameters:\n" + "\n".join(hp)
diff --git a/sadtalker_video2pose/src/utils/init_path.py b/sadtalker_video2pose/src/utils/init_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..65239fe3281798b2472f7ca0557a96157d9de930
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/init_path.py
@@ -0,0 +1,49 @@
+import os
+import glob
+
+def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):
+
+ if old_version:
+ #### load all the checkpoint of `pth`
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ use_safetensor = False
+ elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):
+ print('using safetensor as default')
+ sadtalker_paths = {
+ "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),
+ }
+ use_safetensor = True
+ else:
+ print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!")
+ use_safetensor = False
+
+ sadtalker_paths = {
+ 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
+ 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
+ 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
+ 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
+ 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
+ }
+
+ sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'
+ sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')
+ sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')
+ sadtalker_paths['pirender_yaml_path'] = os.path.join(config_dir, 'facerender_pirender.yaml')
+ sadtalker_paths['pirender_checkpoint'] = os.path.join(checkpoint_dir, 'epoch_00190_iteration_000400000_checkpoint.pt')
+ sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')
+
+ if 'full' in preprocess:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')
+ else:
+ sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')
+ sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')
+
+ return sadtalker_paths
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/utils/model2safetensor.py b/sadtalker_video2pose/src/utils/model2safetensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b76e3d67a06fdbf6646590d44b8c225bc73d79
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/model2safetensor.py
@@ -0,0 +1,141 @@
+import torch
+import yaml
+import os
+
+import safetensors
+from safetensors.torch import save_file
+from yacs.config import CfgNode as CN
+import sys
+
+sys.path.append('/apdcephfs/private_shadowcun/SadTalker')
+
+from src.face3d.models import networks
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2
+from src.test_audio2coeff import load_cpk
+
+size = 256
+############ face vid2vid
+config_path = os.path.join('src', 'config', 'facerender.yaml')
+current_root_path = '.'
+
+path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
+net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')
+checkpoint = torch.load(path_of_net_recon_model, map_location='cpu')
+net_recon.load_state_dict(checkpoint['net_recon'])
+
+with open(config_path) as f:
+ config = yaml.safe_load(f)
+
+generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+ **config['model_params']['common_params'])
+kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+ **config['model_params']['common_params'])
+he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+ **config['model_params']['common_params'])
+mapping = MappingNet(**config['model_params']['mapping_params'])
+
+def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None,
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
+ optimizer_discriminator=None, optimizer_kp_detector=None,
+ optimizer_he_estimator=None, device="cpu"):
+
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+ if generator is not None:
+ generator.load_state_dict(checkpoint['generator'])
+ if kp_detector is not None:
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
+ if he_estimator is not None:
+ he_estimator.load_state_dict(checkpoint['he_estimator'])
+ if discriminator is not None:
+ try:
+ discriminator.load_state_dict(checkpoint['discriminator'])
+ except:
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+ if optimizer_generator is not None:
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+ if optimizer_discriminator is not None:
+ try:
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+ except RuntimeError as e:
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+ if optimizer_kp_detector is not None:
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+ if optimizer_he_estimator is not None:
+ optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+ return checkpoint['epoch']
+
+
+def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
+ kp_detector=None, he_estimator=None,
+ device="cpu"):
+
+ checkpoint = safetensors.torch.load_file(checkpoint_path)
+
+ if generator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'generator' in k:
+ x_generator[k.replace('generator.', '')] = v
+ generator.load_state_dict(x_generator)
+ if kp_detector is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'kp_extractor' in k:
+ x_generator[k.replace('kp_extractor.', '')] = v
+ kp_detector.load_state_dict(x_generator)
+ if he_estimator is not None:
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if 'he_estimator' in k:
+ x_generator[k.replace('he_estimator.', '')] = v
+ he_estimator.load_state_dict(x_generator)
+
+ return None
+
+free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'
+load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+
+wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
+
+audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
+audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
+
+audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
+audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
+
+fcfg_pose = open(audio2pose_yaml_path)
+cfg_pose = CN.load_cfg(fcfg_pose)
+cfg_pose.freeze()
+audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)
+audio2pose_model.eval()
+load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')
+
+# load audio2exp_model
+netG = SimpleWrapperV2()
+netG.eval()
+load_cpk(audio2exp_checkpoint, model=netG, device='cpu')
+
+class SadTalker(torch.nn.Module):
+ def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):
+ super(SadTalker, self).__init__()
+ self.kp_extractor = kp_extractor
+ self.generator = generator
+ self.audio2exp = netG
+ self.audio2pose = audio2pose
+ self.face_3drecon = face_3drecon
+
+
+model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)
+
+# here, we want to convert it to safetensor
+save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors")
+
+### test
+load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None)
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/utils/paste_pic.py b/sadtalker_video2pose/src/utils/paste_pic.py
new file mode 100644
index 0000000000000000000000000000000000000000..4da8952e6933698fec6c7cf35042cb5b1f0dcba5
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/paste_pic.py
@@ -0,0 +1,69 @@
+import cv2, os
+import numpy as np
+from tqdm import tqdm
+import uuid
+
+from src.utils.videoio import save_video_with_watermark
+
+def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = cv2.imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ video_stream = cv2.VideoCapture(video_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ crop_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ crop_frames.append(frame)
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_path = str(uuid.uuid4())+'.mp4'
+ out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+ for crop_frame in tqdm(crop_frames, 'seamlessClone:'):
+ p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1))
+
+ mask = 255*np.ones(p.shape, p.dtype)
+ location = ((ox1+ox2) // 2, (oy1+oy2) // 2)
+ gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)
+ out_tmp.write(gen_img)
+
+ out_tmp.release()
+
+ save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False)
+ os.remove(tmp_path)
diff --git a/sadtalker_video2pose/src/utils/preprocess.py b/sadtalker_video2pose/src/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4956c00d273467f8a0c020312401158b06c4fecd
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/preprocess.py
@@ -0,0 +1,170 @@
+import numpy as np
+import cv2, os, sys, torch
+from tqdm import tqdm
+from PIL import Image
+
+# 3dmm extraction
+import safetensors
+import safetensors.torch
+from src.face3d.util.preprocess import align_img
+from src.face3d.util.load_mats import load_lm3d
+from src.face3d.models import networks
+
+from scipy.io import loadmat, savemat
+from src.utils.croper import Preprocesser
+
+
+import warnings
+
+from src.utils.safetensor_helper import load_x_from_safetensor
+warnings.filterwarnings("ignore")
+
+def split_coeff(coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+
+
+class CropAndExtract():
+ def __init__(self, sadtalker_path, device):
+
+ self.propress = Preprocesser(device)
+ self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
+
+ if sadtalker_path['use_safetensor']:
+ checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))
+ else:
+ checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))
+ self.net_recon.load_state_dict(checkpoint['net_recon'])
+
+ self.net_recon.eval()
+ self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])
+ self.device = device
+
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256):
+
+ pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
+
+ landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
+ coeff_path = os.path.join(save_dir, pic_name+'.mat')
+ png_path = os.path.join(save_dir, pic_name+'.png')
+
+ #load input
+ if not os.path.isfile(input_path):
+ raise ValueError('input_path must be a valid path to video/image file')
+ elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_frames = [cv2.imread(input_path)]
+ fps = 25
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(frame)
+ if source_image_flag:
+ break
+
+ x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
+
+ #### crop images as the
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ else: # resize mode
+ oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1]
+ crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+ frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+ if len(frames_pil) == 0:
+ print('No face is detected in the input file')
+ return None, None
+
+ # save crop info
+ for frame in frames_pil:
+ cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+ # 2. get the landmark according to the detected face.
+ if not os.path.isfile(landmarks_path):
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
+ else:
+ print(' Using saved landmarks.')
+ lm = np.loadtxt(landmarks_path).astype(np.float32)
+ lm = lm.reshape([len(x_full_frames), -1, 2])
+
+ if not os.path.isfile(coeff_path):
+ # load 3dmm paramter generator from Deep3DFaceRecon_pytorch
+ video_coeffs, full_coeffs = [], []
+ for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):
+ frame = frames_pil[idx]
+ W,H = frame.size
+ lm1 = lm[idx].reshape([-1, 2])
+
+ if np.mean(lm1) == -1:
+ lm1 = (self.lm3d_std[:, :2]+1)/2.
+ lm1 = np.concatenate(
+ [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+ )
+ else:
+ lm1[:, -1] = H - 1 - lm1[:, -1]
+
+ trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+
+ trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32)
+ im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)
+
+ with torch.no_grad():
+ full_coeff = self.net_recon(im_t)
+ coeffs = split_coeff(full_coeff)
+
+ pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
+
+ pred_coeff = np.concatenate([
+ pred_coeff['exp'],
+ pred_coeff['angle'],
+ pred_coeff['trans'],
+ trans_params_m[2:][None],
+ ], 1)
+ video_coeffs.append(pred_coeff)
+ full_coeffs.append(full_coeff.cpu().numpy())
+
+ semantic_npy = np.array(video_coeffs)[:,0]
+
+ savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params})
+
+ return coeff_path, png_path, crop_info
diff --git a/sadtalker_video2pose/src/utils/preprocess_fromvideo.py b/sadtalker_video2pose/src/utils/preprocess_fromvideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e6c34055e557b6b39c5c8c1a5fd08842d17f57
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/preprocess_fromvideo.py
@@ -0,0 +1,195 @@
+import numpy as np
+import cv2, os, sys, torch
+from tqdm import tqdm
+from PIL import Image
+
+# 3dmm extraction
+import safetensors
+import safetensors.torch
+from src.face3d.util.preprocess import align_img
+from src.face3d.util.load_mats import load_lm3d
+from src.face3d.models import networks
+
+from scipy.io import loadmat, savemat
+from src.utils.croper import Preprocesser
+
+
+import warnings
+
+from src.utils.safetensor_helper import load_x_from_safetensor
+warnings.filterwarnings("ignore")
+
+
+def smooth_3dmm_params(params, window_size=5):
+ # 创建一个新的数组来存储平滑后的参数
+ smoothed_params = np.zeros_like(params)
+
+ # 对每个参数进行平滑处理
+ for i in range(params.shape[1]):
+
+ # 在参数周围创建一个滑动窗口
+ window = np.ones(int(window_size))/float(window_size)
+ smoothed_param = np.convolve(params[:, i], window, 'same')
+
+ # 将平滑后的参数存储在新数组中
+ smoothed_params[:, i] = smoothed_param
+
+ return smoothed_params
+
+
+
+def split_coeff(coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+
+
+class CropAndExtract():
+ def __init__(self, sadtalker_path, device):
+
+ self.propress = Preprocesser(device)
+ self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
+
+ if sadtalker_path['use_safetensor']:
+ checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint'])
+ self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon'))
+ else:
+ checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device))
+ self.net_recon.load_state_dict(checkpoint['net_recon'])
+
+ self.net_recon.eval()
+ self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting'])
+ self.device = device
+
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256, if_smooth=False):
+
+ pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
+
+ landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
+ coeff_path = os.path.join(save_dir, pic_name+'.mat')
+ png_path = os.path.join(save_dir, pic_name+'.png')
+
+ #load input
+ if not os.path.isfile(input_path):
+ raise ValueError('input_path must be a valid path to video/image file')
+ elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_frames = [cv2.imread(input_path)]
+ fps = 25
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(frame)
+ if source_image_flag:
+ break
+
+ x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
+
+ # print(x_full_frames)
+
+ #### crop images as the
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ else: # resize mode
+ oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1]
+ crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+ frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+ if len(frames_pil) == 0:
+ print('No face is detected in the input file')
+ return None, None
+
+ # save crop info
+ for frame in frames_pil:
+ cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+ # 2. get the landmark according to the detected face.
+ if not os.path.isfile(landmarks_path):
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
+ else:
+ print(' Using saved landmarks.')
+ lm = np.loadtxt(landmarks_path).astype(np.float32)
+ lm = lm.reshape([len(x_full_frames), -1, 2])
+
+ if not os.path.isfile(coeff_path):
+ # load 3dmm paramter generator from Deep3DFaceRecon_pytorch
+ video_coeffs, full_coeffs = [], []
+ for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):
+ frame = frames_pil[idx]
+ W,H = frame.size
+ lm1 = lm[idx].reshape([-1, 2])
+
+ if np.mean(lm1) == -1:
+ lm1 = (self.lm3d_std[:, :2]+1)/2.
+ lm1 = np.concatenate(
+ [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+ )
+ else:
+ lm1[:, -1] = H - 1 - lm1[:, -1]
+
+ trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+
+ trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32)
+ im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)
+
+ with torch.no_grad():
+ full_coeff = self.net_recon(im_t)
+ coeffs = split_coeff(full_coeff)
+
+ pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
+
+ pred_coeff = np.concatenate([
+ pred_coeff['exp'],
+ pred_coeff['angle'],
+ pred_coeff['trans'],
+ # trans_params_m[2:][None],
+ ], 1)
+ video_coeffs.append(pred_coeff)
+ full_coeffs.append(full_coeff.cpu().numpy())
+
+ semantic_npy = np.array(video_coeffs)[:,0]
+
+ if if_smooth:
+ # pass
+ semantic_npy[:, -6:] = smooth_3dmm_params(semantic_npy[:, -6:], window_size=3)
+
+ savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params})
+
+ return coeff_path, png_path, crop_info
diff --git a/sadtalker_video2pose/src/utils/safetensor_helper.py b/sadtalker_video2pose/src/utils/safetensor_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..164ed9621eba24e0b3050ca663fcb60123517158
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/safetensor_helper.py
@@ -0,0 +1,8 @@
+
+
+def load_x_from_safetensor(checkpoint, key):
+ x_generator = {}
+ for k,v in checkpoint.items():
+ if key in k:
+ x_generator[k.replace(key+'.', '')] = v
+ return x_generator
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/utils/text2speech.py b/sadtalker_video2pose/src/utils/text2speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0fe21daf74fcd01767b17378b7076c9dd424248
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/text2speech.py
@@ -0,0 +1,20 @@
+import os
+import tempfile
+from TTS.api import TTS
+
+
+class TTSTalker():
+ def __init__(self) -> None:
+ model_name = TTS.list_models()[0]
+ self.tts = TTS(model_name)
+
+ def test(self, text, language='en'):
+
+ tempf = tempfile.NamedTemporaryFile(
+ delete = False,
+ suffix = ('.'+'wav'),
+ )
+
+ self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name)
+
+ return tempf.name
\ No newline at end of file
diff --git a/sadtalker_video2pose/src/utils/videoio.py b/sadtalker_video2pose/src/utils/videoio.py
new file mode 100644
index 0000000000000000000000000000000000000000..d604ae5b098006f3e59cf3c0133779ffd1cc9d5a
--- /dev/null
+++ b/sadtalker_video2pose/src/utils/videoio.py
@@ -0,0 +1,41 @@
+import shutil
+import uuid
+
+import os
+
+import cv2
+
+def load_video_to_cv2(input_path):
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ return full_frames
+
+def save_video_with_watermark(video, audio, save_path, watermark=False):
+ temp_file = str(uuid.uuid4())+'.mp4'
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec mpeg4 "%s"' % (video, audio, temp_file)
+ os.system(cmd)
+
+ if watermark is False:
+ shutil.move(temp_file, save_path)
+ else:
+ # watermark
+ try:
+ ##### check if stable-diffusion-webui
+ import webui
+ from modules import paths
+ watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
+ except:
+ # get the root path of sadtalker.
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
+
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
+ os.system(cmd)
+ os.remove(temp_file)
\ No newline at end of file
diff --git a/utils/flow_viz.py b/utils/flow_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c0a357d91e785127b2b9513b2a6951f4ceaf1e
--- /dev/null
+++ b/utils/flow_viz.py
@@ -0,0 +1,291 @@
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from PIL import Image
+import torch
+
+
+def make_colorwheel():
+ '''
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ '''
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
+
+
+def flow_compute_color(u, v, convert_to_bgr=False):
+ '''
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param u: np.ndarray, input horizontal flow
+ :param v: np.ndarray, input vertical flow
+ :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
+ :return:
+ '''
+
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 1
+ f = fk - k0
+
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range?
+
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
+
+ return flow_image
+
+
+def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
+ '''
+ Expects a two dimensional flow image of shape [H,W,2]
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param flow_uv: np.ndarray of shape [H,W,2]
+ :param clip_flow: float, maximum clipping value for flow
+ :return:
+ '''
+
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+
+ return flow_compute_color(u, v, convert_to_bgr)
+
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
+ colorwheel[col:col + YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
+ colorwheel[col:col + CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col + MR, 0] = 255
+
+ return colorwheel
+
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols + 1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel, 1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0 - 1] / 255
+ col1 = tmp[k1 - 1] / 255
+ col = (1 - f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
+
+ return img
+
+
+# from https://github.com/gengshan-y/VCN
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ # maxu = -999.
+ # maxv = -999.
+ # minu = 999.
+ # minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ # maxu = max(maxu, np.max(u))
+ # minu = min(minu, np.min(u))
+
+ # maxv = max(maxv, np.max(v))
+ # minv = min(minv, np.min(v))
+
+ rad = torch.sqrt(u ** 2 + v ** 2)
+ maxrad = max(-1, torch.max(rad).cpu().numpy())
+
+ u = u / (maxrad + np.finfo(float).eps)
+ v = v / (maxrad + np.finfo(float).eps)
+
+ img = compute_color(u.cpu().numpy(), v.cpu().numpy())
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis].cpu().numpy(), 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+
+def save_vis_flow_tofile(flow, output_path):
+ vis_flow = flow_to_image(flow)
+ Image.fromarray(vis_flow).save(output_path)
+
+
+def flow_tensor_to_image(flow):
+ """Used for tensorboard visualization"""
+ flow = flow.permute(1, 2, 0) # [H, W, 2]
+ flow = flow.detach().cpu().numpy()
+ flow = flow_to_image(flow) # [H, W, 3]
+ flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
+
+ return flow
diff --git a/utils/scheduling_euler_discrete_karras_fix.py b/utils/scheduling_euler_discrete_karras_fix.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de68461afb061e2bc5efb3efeb8e54c81b09ca6
--- /dev/null
+++ b/utils/scheduling_euler_discrete_karras_fix.py
@@ -0,0 +1,556 @@
+# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+import torch.nn.functional as F
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
+class EulerDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ interpolation_type(`str`, defaults to `"linear"`, *optional*):
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
+ `"linear"` or `"log_linear"`.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
+ the sigmas are determined according to a sequence of noise levels {σi}.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ interpolation_type: str = "linear",
+ use_karras_sigmas: Optional[bool] = False,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ timestep_spacing: str = "linspace",
+ timestep_type: str = "discrete", # can be "discrete" or "continuous"
+ steps_offset: int = 0,
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ if rescale_betas_zero_snr:
+ # Close to 0 without being 0 so first sigma is not inf
+ # FP16 smallest positive subnormal works well here
+ self.alphas_cumprod[-1] = 2**-24
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+
+ sigmas = sigmas[::-1].copy()
+
+ if self.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_train_timesteps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ # setable values
+ self.num_inference_steps = None
+
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
+ if timestep_type == "continuous" and prediction_type == "v_prediction":
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
+ else:
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32))
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ self.is_scale_input_called = False
+ self.use_karras_sigmas = use_karras_sigmas
+
+ self._step_index = None
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return max_sigma
+
+ return (max_sigma**2 + 1) ** 0.5
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increae 1 after each scheduler step.
+ """
+ return self._step_index
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ sigma = self.sigmas[self.step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
+ ::-1
+ ].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+
+ if self.config.interpolation_type == "linear":
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ elif self.config.interpolation_type == "log_linear":
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
+ else:
+ raise ValueError(
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
+ " 'linear' or 'log_linear'"
+ )
+
+ if self.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
+
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
+ if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
+ else:
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ self._step_index = None
+
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def _init_step_index(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ index_candidates = (self.timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(index_candidates) > 1:
+ step_index = index_candidates[1]
+ else:
+ step_index = index_candidates[0]
+
+ self._step_index = step_index.item()
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ s_churn: float = 0.0,
+ s_tmin: float = 0.0,
+ s_tmax: float = float("inf"),
+ s_noise: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ s_churn (`float`):
+ s_tmin (`float`):
+ s_tmax (`float`):
+ s_noise (`float`, defaults to 1.0):
+ Scaling factor for noise added to the sample.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ sigma = self.sigmas[self.step_index]
+
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
+
+ noise = randn_tensor(
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
+ )
+
+ eps = noise * s_noise
+ sigma_hat = sigma * (gamma + 1)
+
+ if gamma > 0:
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
+ # backwards compatibility
+ if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma_hat * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # denoised = model_output * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma_hat
+
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
+
+ prev_sample = sample + derivative * dt
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b296648f28598cd5f8d0fd9b0613b9173e1b9aad
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,269 @@
+# -*- coding:utf-8 -*-
+import os
+import sys
+import shutil
+import logging
+import colorlog
+from tqdm import tqdm
+import time
+import yaml
+import random
+import importlib
+from PIL import Image
+from warnings import simplefilter
+import imageio
+import math
+import collections
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.optim import Adam
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
+from einops import rearrange, repeat
+import torch.distributed as dist
+from torchvision import datasets, transforms, utils
+
+logging.getLogger().setLevel(logging.WARNING)
+simplefilter(action='ignore', category=FutureWarning)
+
+def get_logger(filename=None):
+ """
+ examples:
+ logger = get_logger('try_logging.txt')
+
+ logger.debug("Do something.")
+ logger.info("Start print log.")
+ logger.warning("Something maybe fail.")
+ try:
+ raise ValueError()
+ except ValueError:
+ logger.error("Error", exc_info=True)
+
+ tips:
+ DO NOT logger.inf(some big tensors since color may not helpful.)
+ """
+ logger = logging.getLogger('utils')
+ level = logging.DEBUG
+ logger.setLevel(level=level)
+ # Use propagate to avoid multiple loggings.
+ logger.propagate = False
+ # Remove %(levelname)s since we have colorlog to represent levelname.
+ format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s'
+
+ streamHandler = logging.StreamHandler()
+ streamHandler.setLevel(level)
+ coloredFormatter = colorlog.ColoredFormatter(
+ '%(log_color)s' + format_str,
+ datefmt='%Y-%m-%d %H:%M:%S',
+ reset=True,
+ log_colors={
+ 'DEBUG': 'cyan',
+ # 'INFO': 'white',
+ 'WARNING': 'yellow',
+ 'ERROR': 'red',
+ 'CRITICAL': 'reg,bg_white',
+ }
+ )
+
+ streamHandler.setFormatter(coloredFormatter)
+ logger.addHandler(streamHandler)
+
+ if filename:
+ fileHandler = logging.FileHandler(filename)
+ fileHandler.setLevel(level)
+ formatter = logging.Formatter(format_str)
+ fileHandler.setFormatter(formatter)
+ logger.addHandler(fileHandler)
+
+ # Fix multiple logging for torch.distributed
+ try:
+ class UniqueLogger:
+ def __init__(self, logger):
+ self.logger = logger
+ self.local_rank = torch.distributed.get_rank()
+
+ def info(self, msg, *args, **kwargs):
+ if self.local_rank == 0:
+ return self.logger.info(msg, *args, **kwargs)
+
+ def warning(self, msg, *args, **kwargs):
+ if self.local_rank == 0:
+ return self.logger.warning(msg, *args, **kwargs)
+
+ logger = UniqueLogger(logger)
+ # AssertionError for gpu with no distributed
+ # AttributeError for no gpu.
+ except Exception:
+ pass
+ return logger
+
+
+logger = get_logger()
+
+def split_filename(filename):
+ absname = os.path.abspath(filename)
+ dirname, basename = os.path.split(absname)
+ split_tmp = basename.rsplit('.', maxsplit=1)
+ if len(split_tmp) == 2:
+ rootname, extname = split_tmp
+ elif len(split_tmp) == 1:
+ rootname = split_tmp[0]
+ extname = None
+ else:
+ raise ValueError("programming error!")
+ return dirname, rootname, extname
+
+def data2file(data, filename, type=None, override=False, printable=False, **kwargs):
+ dirname, rootname, extname = split_filename(filename)
+ print_did_not_save_flag = True
+ if type:
+ extname = type
+ if not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+
+ if not os.path.exists(filename) or override:
+ if extname in ['jpg', 'png', 'jpeg']:
+ utils.save_image(data, filename, **kwargs)
+ elif extname == 'gif':
+ imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0)
+ elif extname == 'txt':
+ if kwargs is None:
+ kwargs = {}
+ max_step = kwargs.get('max_step')
+ if max_step is None:
+ max_step = np.Infinity
+
+ with open(filename, 'w', encoding='utf-8') as f:
+ for i, e in enumerate(data):
+ if i < max_step:
+ f.write(str(e) + '\n')
+ else:
+ break
+ else:
+ raise ValueError('Do not support this type')
+ if printable: logger.info('Saved data to %s' % os.path.abspath(filename))
+ else:
+ if print_did_not_save_flag: logger.info(
+ 'Did not save data to %s because file exists and override is False' % os.path.abspath(
+ filename))
+
+
+def file2data(filename, type=None, printable=True, **kwargs):
+ dirname, rootname, extname = split_filename(filename)
+ print_load_flag = True
+ if type:
+ extname = type
+
+ if extname in ['pth', 'ckpt']:
+ data = torch.load(filename, map_location=kwargs.get('map_location'))
+ elif extname == 'txt':
+ top = kwargs.get('top', None)
+ with open(filename, encoding='utf-8') as f:
+ if top:
+ data = [f.readline() for _ in range(top)]
+ else:
+ data = [e for e in f.read().split('\n') if e]
+ elif extname == 'yaml':
+ with open(filename, 'r') as f:
+ data = yaml.load(f)
+ else:
+ raise ValueError('type can only support h5, npy, json, txt')
+ if printable:
+ if print_load_flag:
+ logger.info('Loaded data from %s' % os.path.abspath(filename))
+ return data
+
+
+def ensure_dirname(dirname, override=False):
+ if os.path.exists(dirname) and override:
+ logger.info('Removing dirname: %s' % os.path.abspath(dirname))
+ try:
+ shutil.rmtree(dirname)
+ except OSError as e:
+ raise ValueError('Failed to delete %s because %s' % (dirname, e))
+
+ if not os.path.exists(dirname):
+ logger.info('Making dirname: %s' % os.path.abspath(dirname))
+ os.makedirs(dirname, exist_ok=True)
+
+
+def import_filename(filename):
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def adaptively_load_state_dict(target, state_dict):
+ target_dict = target.state_dict()
+
+ try:
+ common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()}
+ except Exception as e:
+ logger.warning('load error %s', e)
+ common_dict = {k: v for k, v in state_dict.items() if k in target_dict}
+
+ if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \
+ target.state_dict()['param_groups'][0]['params']:
+ logger.warning('Detected mismatch params, auto adapte state_dict to current')
+ common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params']
+ target_dict.update(common_dict)
+ target.load_state_dict(target_dict)
+
+ missing_keys = [k for k in target_dict.keys() if k not in common_dict]
+ unexpected_keys = [k for k in state_dict.keys() if k not in common_dict]
+
+ if len(unexpected_keys) != 0:
+ logger.warning(
+ f"Some weights of state_dict were not used in target: {unexpected_keys}"
+ )
+ if len(missing_keys) != 0:
+ logger.warning(
+ f"Some weights of state_dict are missing used in target {missing_keys}"
+ )
+ if len(unexpected_keys) == 0 and len(missing_keys) == 0:
+ logger.warning("Strictly Loaded state_dict.")
+
+def set_seed(seed=42):
+ random.seed(seed)
+ os.environ['PYHTONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+def image2pil(filename):
+ return Image.open(filename)
+
+
+def image2arr(filename):
+ pil = image2pil(filename)
+ return pil2arr(pil)
+
+
+# 格式转换
+def pil2arr(pil):
+ if isinstance(pil, list):
+ arr = np.array(
+ [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil])
+ else:
+ arr = np.array(pil)
+ return arr
+
+
+def arr2pil(arr):
+ if arr.ndim == 3:
+ return Image.fromarray(arr.astype('uint8'), 'RGB')
+ elif arr.ndim == 4:
+ return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)]
+ else:
+ raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim)
+
+def notebook_show(*images):
+ from IPython.display import Image
+ from IPython.display import display
+ display(*[Image(e) for e in images])
\ No newline at end of file