Spaces:
Build error
Build error
import torch | |
import random | |
from utils.commons.base_task import BaseTask | |
from utils.commons.dataset_utils import data_loader | |
from utils.commons.hparams import hparams | |
from utils.commons.tensor_utils import tensors_to_scalars | |
from utils.nn.schedulers import CosineSchedule, NoneSchedule | |
from utils.nn.model_utils import print_arch, num_params | |
from utils.commons.ckpt_utils import load_ckpt | |
from modules.syncnet.models import LandmarkHubertSyncNet | |
from tasks.os_avatar.dataset_utils.syncnet_dataset import SyncNet_Dataset | |
from data_util.face3d_helper import Face3DHelper | |
class ScheduleForSyncNet(NoneSchedule): | |
def __init__(self, optimizer, lr): | |
self.optimizer = optimizer | |
self.constant_lr = self.lr = lr | |
self.step(0) | |
def step(self, num_updates): | |
constant_lr = self.constant_lr | |
self.lr = constant_lr | |
lr = self.lr * hparams['lr_decay_rate'] ** (num_updates // hparams['lr_decay_interval']) | |
# lr = max(lr, 5e-6) | |
lr = max(lr, 5e-5) | |
self.optimizer.param_groups[0]['lr'] = lr | |
return self.lr | |
class SyncNetTask(BaseTask): | |
def __init__(self, hparams_=None): | |
global hparams | |
if hparams_ is not None: | |
hparams = hparams_ | |
self.hparams = hparams | |
super().__init__() | |
self.dataset_cls = SyncNet_Dataset | |
def on_train_start(self): | |
for n, m in self.model.named_children(): | |
num_params(m, model_name=n) | |
def build_model(self): | |
if self.hparams is not None: | |
hparams = self.hparams | |
# lm_dim = 468*3 # lip part in idexp_lm3d | |
self.face3d_helper = Face3DHelper(use_gpu=False, keypoint_mode='lm68') | |
if hparams.get('syncnet_keypoint_mode', 'lip') == 'lip': | |
lm_dim = 20*3 # lip part in idexp_lm3d | |
elif hparams['syncnet_keypoint_mode'] == 'lm68': | |
lm_dim = 68*3 # lip part in idexp_lm3d | |
elif hparams['syncnet_keypoint_mode'] == 'centered_lip': | |
lm_dim = 20*3 # lip part in idexp_lm3d | |
elif hparams['syncnet_keypoint_mode'] == 'centered_lip2d': | |
lm_dim = 20*2 # lip part in idexp_lm3d | |
elif hparams['syncnet_keypoint_mode'] == 'lm468': | |
lm_dim = 468*3 # lip part in idexp_lm3d | |
self.face3d_helper = Face3DHelper(use_gpu=False, keypoint_mode='mediapipe') | |
if hparams['audio_type'] == 'hubert': | |
audio_dim = 1024 # hubert | |
elif hparams['audio_type'] == 'mfcc': | |
audio_dim = 13 # hubert | |
elif hparams['audio_type'] == 'mel': | |
audio_dim = 80 # hubert | |
self.model = LandmarkHubertSyncNet(lm_dim, audio_dim, num_layers_per_block=hparams['syncnet_num_layers_per_block'], base_hid_size=hparams['syncnet_base_hid_size'], out_dim=hparams['syncnet_out_hid_size']) | |
print_arch(self.model) | |
if hparams.get('init_from_ckpt', '') != '': | |
ckpt_dir = hparams.get('init_from_ckpt', '') | |
load_ckpt(self.model, ckpt_dir, model_name='model', strict=False) | |
return self.model | |
def build_optimizer(self, model): | |
if self.hparams is not None: | |
hparams = self.hparams | |
self.optimizer = optimizer = torch.optim.Adam( | |
model.parameters(), | |
lr=hparams['lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) | |
return optimizer | |
# def build_scheduler(self, optimizer): | |
# return CosineSchedule(optimizer, hparams['lr'], warmup_updates=0, total_updates=40_0000) | |
def build_scheduler(self, optimizer): | |
return ScheduleForSyncNet(optimizer, hparams['lr']) | |
def train_dataloader(self): | |
train_dataset = self.dataset_cls(prefix='train') | |
self.train_dl = train_dataset.get_dataloader() | |
return self.train_dl | |
def val_dataloader(self): | |
val_dataset = self.dataset_cls(prefix='val') | |
self.val_dl = val_dataset.get_dataloader() | |
return self.val_dl | |
def test_dataloader(self): | |
val_dataset = self.dataset_cls(prefix='val') | |
self.val_dl = val_dataset.get_dataloader() | |
return self.val_dl | |
########################## | |
# training and validation | |
########################## | |
def run_model(self, sample, infer=False, batch_size=1024): | |
""" | |
render or train on a single-frame | |
:param sample: a batch of data | |
:param infer: bool, run in infer mode | |
:return: | |
if not infer: | |
return losses, model_out | |
if infer: | |
return model_out | |
""" | |
if self.hparams is not None: | |
hparams = self.hparams | |
if sample is None or len(sample) == 0: | |
return None | |
model_out = {} | |
if 'idexp_lm3d' not in sample: | |
with torch.no_grad(): | |
b,t,_ = sample['exp'].shape | |
idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(sample['id'], sample['exp']).reshape([b,t,-1,3]) | |
else: | |
b,t,*_ = sample['idexp_lm3d'].shape | |
idexp_lm3d = sample['idexp_lm3d'] | |
if hparams.get('syncnet_keypoint_mode', 'lip') == 'lip': | |
mouth_lm = idexp_lm3d[:,:, 48:68,:].reshape([b, t, 20*3]) # [b, t, 60] | |
elif hparams.get('syncnet_keypoint_mode', 'lip') == 'centered_lip': | |
mouth_lm = idexp_lm3d[:,:, 48:68, ].reshape([b, t, 20, 3]) # [b, t, 60] | |
mean_mouth_lm = self.face3d_helper.key_mean_shape[48:68] | |
mouth_lm = mouth_lm / 10 + mean_mouth_lm.reshape([1, 1, 20, 3]) - mean_mouth_lm.reshape([1, 1, 20, 3]).mean(dim=-2) # to center | |
mouth_lm = mouth_lm.reshape([b, t, 20*3]) * 10 | |
elif hparams.get('syncnet_keypoint_mode', 'lip') == 'centered_lip2d': | |
mouth_lm = idexp_lm3d[:,:, 48:68, ].reshape([b, t, 20, 3]) # [b, t, 60] | |
mean_mouth_lm = self.face3d_helper.key_mean_shape[48:68] | |
mouth_lm = mouth_lm / 10 + mean_mouth_lm.reshape([1, 1, 20, 3]) - mean_mouth_lm.reshape([1, 1, 20, 3]).mean(dim=-2) # to center | |
mouth_lm = mouth_lm[..., :2] | |
mouth_lm = mouth_lm.reshape([b, t, 20*2]) * 10 | |
elif hparams['syncnet_keypoint_mode'] == 'lm68': | |
mouth_lm = idexp_lm3d.reshape([b, t, 68*3]) | |
elif hparams['syncnet_keypoint_mode'] == 'lm468': | |
mouth_lm = idexp_lm3d.reshape([b, t, 468*3]) | |
if hparams['audio_type'] == 'hubert': | |
mel = sample['hubert'] # [b, 2t, 1024] | |
elif hparams['audio_type'] == 'mfcc': | |
mel = sample['mfcc'] / 100 # [b, 2t, 1024] | |
elif hparams['audio_type'] == 'mel': | |
mel = sample['mfcc'] # [b, 2t, 1024] | |
y_mask = sample['y_mask'] | |
y_len = y_mask.sum(dim=1).min().item() # [B, T] | |
len_mouth_slice = 5 # 5 frames denotes 0.2s, which is a appropriate length for sync check | |
len_mel_slice = len_mouth_slice * 2 | |
if infer: | |
phase_ratio_dict = { | |
'pos' : 1.0, | |
} | |
else: | |
phase_ratio_dict = { | |
'pos' : 0.4, | |
'neg_same_people_small_offset_ratio' : 0.3, | |
'neg_same_people_large_offset_ratio' : 0.2, | |
'neg_diff_people_random_offset_ratio': 0.1 | |
} | |
mouth_lst, mel_lst, label_lst = [], [], [] | |
for phase_key, phase_ratio in phase_ratio_dict.items(): | |
num_samples = int(batch_size * phase_ratio) | |
if phase_key == 'pos': | |
phase_mel_lst = [] | |
phase_mouth_lst = [] | |
num_iters = max(1, num_samples // len(mouth_lm)) | |
for i in range(num_iters): | |
t_start = random.randint(0, y_len-len_mouth_slice-1) | |
phase_mouth = mouth_lm[:, t_start: t_start+len_mouth_slice] | |
assert phase_mouth.shape[1] == len_mouth_slice | |
phase_mel = mel[:, t_start*2 : t_start*2+len_mel_slice] | |
phase_mouth_lst.append(phase_mouth) | |
phase_mel_lst.append(phase_mel) | |
phase_mouth = torch.cat(phase_mouth_lst) | |
phase_mel = torch.cat(phase_mel_lst) | |
mouth_lst.append(phase_mouth) | |
mel_lst.append(phase_mel) | |
label_lst.append(torch.ones([len(phase_mel)])) # 1 denotes pos samples | |
elif phase_key in ['neg_same_people_small_offset_ratio', 'neg_same_people_large_offset_ratio']: | |
phase_mel_lst = [] | |
phase_mouth_lst = [] | |
num_iters = max(1, num_samples // len(mouth_lm)) | |
for i in range(num_iters): | |
if phase_key == 'neg_same_people_small_offset_ratio': | |
offset = random.choice([random.randint(-5,-2), random.randint(2,5)]) | |
elif phase_key == 'neg_same_people_large_offset_ratio': | |
offset = random.choice([random.randint(-10,-5), random.randint(5,10)]) | |
else: ValueError() | |
if offset < 0: | |
t_start = random.randint(-offset, y_len-len_mouth_slice-1) | |
else: | |
t_start = random.randint(0, y_len-len_mouth_slice-1-offset) | |
phase_mouth = mouth_lm[:, t_start: t_start+len_mouth_slice] | |
assert phase_mouth.shape[1] == len_mouth_slice | |
phase_mel = mel[:, (t_start+offset)*2:(t_start+offset)*2+len_mel_slice] | |
phase_mouth_lst.append(phase_mouth) | |
phase_mel_lst.append(phase_mel) | |
phase_mouth = torch.cat(phase_mouth_lst) | |
phase_mel = torch.cat(phase_mel_lst) | |
mouth_lst.append(phase_mouth) | |
mel_lst.append(phase_mel) | |
label_lst.append(torch.zeros([len(phase_mel)])) # 0 denotes neg samples | |
elif phase_key == 'neg_diff_people_random_offset_ratio': | |
phase_mel_lst = [] | |
phase_mouth_lst = [] | |
num_iters = max(1, num_samples // len(mouth_lm)) | |
for i in range(num_iters): | |
offset = random.randint(-10, 10) | |
if offset < 0: | |
t_start = random.randint(-offset, y_len-len_mouth_slice-1) | |
else: | |
t_start = random.randint(0, y_len-len_mouth_slice-1-offset) | |
phase_mouth = mouth_lm[:, t_start: t_start+len_mouth_slice] | |
assert phase_mouth.shape[1] == len_mouth_slice | |
sample_idx = list(range(len(mouth_lm))) | |
random.shuffle(sample_idx) | |
phase_mel = mel[sample_idx, (t_start+offset)*2:(t_start+offset)*2+len_mel_slice] | |
phase_mouth_lst.append(phase_mouth) | |
phase_mel_lst.append(phase_mel) | |
phase_mouth = torch.cat(phase_mouth_lst) | |
phase_mel = torch.cat(phase_mel_lst) | |
mouth_lst.append(phase_mouth) | |
mel_lst.append(phase_mel) | |
label_lst.append(torch.zeros([len(phase_mel)])) # 0 denotes neg samples | |
mel_clips = torch.cat(mel_lst) | |
mouth_clips = torch.cat(mouth_lst) | |
labels = torch.cat(label_lst).float().to(mel_clips.device) | |
audio_embedding, mouth_embedding = self.model(mel_clips, mouth_clips) | |
sync_loss, cosine_sim = self.model.cal_sync_loss(audio_embedding, mouth_embedding, labels, reduction='mean') | |
if not infer: | |
losses_out = {} | |
model_out = {} | |
losses_out['sync_loss'] = sync_loss | |
losses_out['batch_size'] = len(mel_clips) | |
model_out['cosine_sim'] = cosine_sim | |
return losses_out, model_out | |
else: | |
model_out['sync_loss'] = sync_loss | |
model_out['batch_size'] = len(mel_clips) | |
return model_out | |
def _training_step(self, sample, batch_idx, optimizer_idx): | |
ret = self.run_model(sample, infer=False, batch_size=hparams['syncnet_num_clip_pairs']) | |
if ret is None: | |
return None | |
loss_output, model_out = ret | |
loss_weights = {} | |
total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
return total_loss, loss_output | |
def validation_start(self): | |
pass | |
def validation_step(self, sample, batch_idx): | |
outputs = {} | |
outputs['losses'] = {} | |
outputs['losses'], model_out = self.run_model(sample, infer=False, batch_size=8000) | |
outputs = tensors_to_scalars(outputs) | |
return outputs | |
def validation_end(self, outputs): | |
return super().validation_end(outputs) | |
##################### | |
# Testing | |
##################### | |
def test_start(self): | |
pass | |
def test_step(self, sample, batch_idx): | |
""" | |
:param sample: | |
:param batch_idx: | |
:return: | |
""" | |
pass | |
def test_end(self, outputs): | |
pass | |