show / SHOW /cvt_exp_dim_tool.py
camenduru's picture
thanks to show ❤
3bbb319
raw
history blame
No virus
7.14 kB
from dataclasses import dataclass
import tyro
import trimesh
import mmcv
from easydict import EasyDict
from loguru import logger
from tqdm import tqdm
import torch.nn as nn
import numpy as np
from pathlib import Path
import torch
import smplx
DEFAULT_SMPLX_CONFIG = dict(
create_global_orient=True,
create_body_pose=True,
create_betas=True,
create_left_hand_pose=True,
create_right_hand_pose=True,
create_expression=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=True,
)
def load_smplx_model(device='cuda', **kwargs):
body_model = smplx.create(
**DEFAULT_SMPLX_CONFIG,
**kwargs).to(device=device)
return body_model
@dataclass
class Config:
lr = 0.1
device = 'cuda'
target_exp_dim = 75
pkl_path = 'C:/Users/lithiumice/Downloads/mica_all.pkl'
model_path='../models/smplx/SMPLX_MALE_shape2019_exp2020.npz'
args = tyro.cli(Config)
device = torch.device(args.device)
dtype = torch.float32
comm_cfg = EasyDict(
dtype=dtype,
model_path=args.model_path,
use_face_contour=True,
flat_hand_mean=False,
use_hands=True,
use_face=True,
)
all_var = mmcv.load(args.pkl_path)
if isinstance(all_var, list):
all_var = all_var[0]
all_var = EasyDict(all_var)
comm_cfg.use_pca=True
batch_size = comm_cfg.batch_size=all_var.batch_size
num_pca_comps=all_var.left_hand_pose.shape[1]
comm_cfg.num_pca_comps=num_pca_comps
st_et_list = [(0, all_var.batch_size)]
save_data = np.zeros((all_var.shape[0], 165+args.target_exp_dim), dtype=np.float32)
for opt_idx, (start_frame, end_frame) in enumerate(st_et_list):
print(f'opt_idx:{opt_idx}, start_frame:{start_frame}, end_frame:{end_frame})')
comm_cfg.batch_size = batch_size = end_frame-start_frame
all_var = {k:v[start_frame:end_frame] for k,v in all_var.items() if isinstance(v, np.ndarray)}
all_var['betas'] = np.zeros((batch_size, 100), dtype=np.float32)
all_var['transl'] = np.zeros((batch_size, 3), dtype=np.float32)
all_var = EasyDict(all_var)
smlx2flame_idx = torch.from_numpy(
np.load(
'../data/SMPL-X__FLAME_vertex_ids.npy'
)
).to(dtype=torch.long, device=device)
with open('../data/generic_model.pkl', 'rb') as f:
import pickle
class Struct(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
ss = pickle.load(f, encoding='latin1')
ff = Struct(**ss)
faces = ff.f
body_model = load_smplx_model(**comm_cfg,
num_betas=all_var.betas.shape[1],
num_expression_coeffs=all_var.expression.shape[1],
).to(device)
target_body_model = load_smplx_model(**comm_cfg,
num_betas=all_var.betas.shape[1],
num_expression_coeffs=args.target_exp_dim,
).to(device)
smplx_params = dict(
body_pose=all_var.body_pose_axis,
betas=all_var.betas,
global_orient=all_var.global_orient,
transl=all_var.transl,
left_hand_pose=all_var.left_hand_pose,
right_hand_pose=all_var.right_hand_pose,
jaw_pose=all_var.jaw_pose,
leye_pose=all_var.leye_pose,
reye_pose=all_var.reye_pose,
expression=all_var.expression,
)
for key, val in smplx_params.items():
smplx_params[key] = torch.from_numpy(smplx_params[key]).to(device)
model_output = body_model(return_verts=True, **smplx_params)
target_vertices_ = model_output.vertices.detach()#important!
target_vertices_ = target_vertices_[:, smlx2flame_idx]
smplx_params['expression']=nn.Parameter(torch.zeros(batch_size,args.target_exp_dim).detach().to(device).type(dtype))
opt_params=[smplx_params['expression']]
optimizer = torch.optim.LBFGS(opt_params,
lr=args.lr, max_iter=20,
history_size=10, tolerance_grad=1e-05)
loss_fn = torch.nn.MSELoss()
max_iter = 50
loss_history = []
pbar = tqdm(range(max_iter), desc='Optimizing expression parameters')
def closure():
optimizer.zero_grad()
model_output = target_body_model(return_verts=True, **smplx_params)
vertices_ = model_output.vertices
vertices_ = vertices_[:, smlx2flame_idx]
loss = 1000 * 1000 * loss_fn(vertices_, target_vertices_)
ss = (f'Iteration {i}, loss {loss.item()}')
pbar.set_description(ss)
loss_history.append(loss.item())
loss.backward()
return loss
def rel_change(prev_val, curr_val):
return (prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1])
prev_loss = None
from SHOW.utils.metric import MeterBuffer
meter = MeterBuffer(window_size=6)
for i in pbar:
loss = optimizer.step(closure)
if i > 1 and prev_loss is not None:
loss_rel_change = rel_change(
prev_loss, loss.item())
meter.update({'rel': loss_rel_change})
if meter['rel'].avg <= 1e-09:
logger.warning('rel exit')
break
if all([torch.abs(var.grad.view(-1).max()).item() < 1e-06
for var in opt_params if var.grad is not None]):
logger.warning('small grad')
break
prev_loss = loss.item()
# smplx_params****
# jaw_pose=i[:,0:3],
# leye_pose=i[:,3:6],
# reye_pose=i[:,6:9],
# global_orient=i[:,9:12],
# body_pose_axis=i[:,12:75],
# left_hand_pose=i[:,75:120],
# right_hand_pose=i[:,120:165],
# expression=i[:,165:]
save_data[start_frame:end_frame] = np.concatenate([smplx_params['jaw_pose'].detach().cpu().numpy(),
smplx_params['leye_pose'].detach().cpu().numpy(),
smplx_params['reye_pose'].detach().cpu().numpy(),
smplx_params['global_orient'].detach().cpu().numpy(),
smplx_params['body_pose'].detach().cpu().numpy(),
smplx_params['left_hand_pose'].detach().cpu().numpy(),
smplx_params['right_hand_pose'].detach().cpu().numpy(),
smplx_params['expression'].detach().cpu().numpy()],axis=1)
print(f'save_data.shape:{save_data.shape}')
model_output = target_body_model(return_verts=True, **smplx_params)
vertices_ = model_output.vertices
vertices_ = vertices_[:, smlx2flame_idx]
target_mesh = trimesh.Trimesh(vertices_[0].detach().cpu().numpy(), faces)
_ = target_mesh.export(f'target_mesh_{args.target_exp_dim}.obj')
ref_mesh = trimesh.Trimesh(target_vertices_[0].detach().cpu().numpy(), faces)
_ = ref_mesh.export(f'ref_mesh_{args.target_exp_dim}.obj')
np.save(f'rich_expdim{args.target_exp_dim}.npy',save_data)