Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from encoding import get_encoder | |
from .renderer import NeRFRenderer | |
class Conv2d(nn.Module): | |
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, leakyReLU=False, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.conv_block = nn.Sequential( | |
nn.Conv2d(cin, cout, kernel_size, stride, padding), | |
nn.BatchNorm2d(cout) | |
) | |
if leakyReLU: | |
self.act = nn.LeakyReLU(0.02) | |
else: | |
self.act = nn.ReLU() | |
self.residual = residual | |
def forward(self, x): | |
out = self.conv_block(x) | |
if self.residual: | |
out += x | |
return self.act(out) | |
class AudioEncoder(nn.Module): | |
def __init__(self): | |
super(AudioEncoder, self).__init__() | |
self.audio_encoder = nn.Sequential( | |
Conv2d(1, 32, kernel_size=3, stride=1, padding=1), | |
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), | |
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(64, 128, kernel_size=3, stride=3, padding=1), | |
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), | |
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(256, 512, kernel_size=3, stride=1, padding=0), | |
Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ) | |
def forward(self, x): | |
out = self.audio_encoder(x) | |
out = out.squeeze(2).squeeze(2) | |
return out | |
# Audio feature extractor | |
class AudioAttNet(nn.Module): | |
def __init__(self, dim_aud=64, seq_len=8): | |
super(AudioAttNet, self).__init__() | |
self.seq_len = seq_len | |
self.dim_aud = dim_aud | |
self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len | |
nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), | |
nn.LeakyReLU(0.02, True) | |
) | |
self.attentionNet = nn.Sequential( | |
nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), | |
nn.Softmax(dim=1) | |
) | |
def forward(self, x): | |
# x: [1, seq_len, dim_aud] | |
y = x.permute(0, 2, 1) # [1, dim_aud, seq_len] | |
y = self.attentionConvNet(y) | |
y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1) | |
return torch.sum(y * x, dim=1) # [1, dim_aud] | |
# Audio feature extractor | |
class AudioNet(nn.Module): | |
def __init__(self, dim_in=29, dim_aud=64, win_size=16): | |
super(AudioNet, self).__init__() | |
self.win_size = win_size | |
self.dim_aud = dim_aud | |
self.encoder_conv = nn.Sequential( # n x 29 x 16 | |
nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8 | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4 | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2 | |
nn.LeakyReLU(0.02, True), | |
nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1 | |
nn.LeakyReLU(0.02, True), | |
) | |
self.encoder_fc1 = nn.Sequential( | |
nn.Linear(64, 64), | |
nn.LeakyReLU(0.02, True), | |
nn.Linear(64, dim_aud), | |
) | |
def forward(self, x): | |
half_w = int(self.win_size/2) | |
x = x[:, :, 8-half_w:8+half_w] | |
x = self.encoder_conv(x).squeeze(-1) | |
x = self.encoder_fc1(x) | |
return x | |
class MLP(nn.Module): | |
def __init__(self, dim_in, dim_out, dim_hidden, num_layers): | |
super().__init__() | |
self.dim_in = dim_in | |
self.dim_out = dim_out | |
self.dim_hidden = dim_hidden | |
self.num_layers = num_layers | |
net = [] | |
for l in range(num_layers): | |
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) | |
self.net = nn.ModuleList(net) | |
def forward(self, x): | |
for l in range(self.num_layers): | |
x = self.net[l](x) | |
if l != self.num_layers - 1: | |
x = F.relu(x, inplace=True) | |
# x = F.dropout(x, p=0.1, training=self.training) | |
return x | |
# Audio feature extractor | |
class AudioNet_ave(nn.Module): | |
def __init__(self, dim_in=29, dim_aud=64, win_size=16): | |
super(AudioNet_ave, self).__init__() | |
self.win_size = win_size | |
self.dim_aud = dim_aud | |
self.encoder_fc1 = nn.Sequential( | |
nn.Linear(512, 256), | |
nn.LeakyReLU(0.02, True), | |
nn.Linear(256, 128), | |
nn.LeakyReLU(0.02, True), | |
nn.Linear(128, dim_aud), | |
) | |
def forward(self, x): | |
# half_w = int(self.win_size/2) | |
# x = x[:, :, 8-half_w:8+half_w] | |
# x = self.encoder_conv(x).squeeze(-1) | |
x = self.encoder_fc1(x).permute(1,0,2).squeeze(0) | |
return x | |
class NeRFNetwork(NeRFRenderer): | |
def __init__(self, | |
opt, | |
audio_dim = 32, | |
# torso net (hard coded for now) | |
): | |
super().__init__(opt) | |
# audio embedding | |
self.emb = self.opt.emb | |
if 'esperanto' in self.opt.asr_model: | |
self.audio_in_dim = 44 | |
elif 'deepspeech' in self.opt.asr_model: | |
self.audio_in_dim = 29 | |
elif 'hubert' in self.opt.asr_model: | |
self.audio_in_dim = 1024 | |
else: | |
self.audio_in_dim = 32 | |
if self.emb: | |
self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim) | |
# audio network | |
self.audio_dim = audio_dim | |
# self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) | |
# audio network | |
self.audio_dim = audio_dim | |
if self.opt.asr_model == 'ave': | |
self.audio_net = AudioNet_ave(self.audio_in_dim, self.audio_dim) | |
else: | |
self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) | |
self.att = self.opt.att | |
if self.att > 0: | |
self.audio_att_net = AudioAttNet(self.audio_dim) | |
# DYNAMIC PART | |
self.num_levels = 12 | |
self.level_dim = 1 | |
self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) | |
self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) | |
self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) | |
self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz | |
## sigma network | |
self.num_layers = 3 | |
self.hidden_dim = 64 | |
self.geo_feat_dim = 64 | |
self.eye_att_net = MLP(self.in_dim, 1, 16, 2) | |
self.eye_dim = 1 if self.exp_eye else 0 | |
self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers) | |
## color network | |
self.num_layers_color = 2 | |
self.hidden_dim_color = 64 | |
self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics') | |
self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color) | |
self.unc_net = MLP(self.in_dim, 1, 32, 2) | |
self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 64, 2) | |
self.testing = False | |
if self.torso: | |
# torso deform network | |
self.register_parameter('anchor_points', | |
nn.Parameter(torch.tensor([[0.01, 0.01, 0.1, 1], [-0.1, -0.1, 0.1, 1], [0.1, -0.1, 0.1, 1]]))) | |
self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=8) | |
# self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=1, base_resolution=16, log2_hashmap_size=16, desired_resolution=512) | |
self.anchor_encoder, self.anchor_in_dim = get_encoder('frequency', input_dim=6, multires=3) | |
self.torso_deform_net = MLP(self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 2, 32, 3) | |
# torso color network | |
self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048) | |
self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 4, 32, 3) | |
def forward_torso(self, x, poses, c=None): | |
# x: [N, 2] in [-1, 1] | |
# head poses: [1, 4, 4] | |
# c: [1, ind_dim], individual code | |
# test: shrink x | |
x = x * self.opt.torso_shrink | |
# deformation-based | |
wrapped_anchor = self.anchor_points[None, ...] @ poses.permute(0, 2, 1).inverse() | |
wrapped_anchor = (wrapped_anchor[:, :, :2] / wrapped_anchor[:, :, 3, None] / wrapped_anchor[:, :, 2, None]).view(1, -1) | |
# print(wrapped_anchor) | |
# enc_pose = self.pose_encoder(poses) | |
enc_anchor = self.anchor_encoder(wrapped_anchor) | |
enc_x = self.torso_deform_encoder(x) | |
if c is not None: | |
h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1) | |
else: | |
h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1)], dim=-1) | |
dx = self.torso_deform_net(h) | |
x = (x + dx).clamp(-1, 1) | |
x = self.torso_encoder(x, bound=1) | |
# h = torch.cat([x, h, enc_a.repeat(x.shape[0], 1)], dim=-1) | |
h = torch.cat([x, h], dim=-1) | |
h = self.torso_net(h) | |
alpha = torch.sigmoid(h[..., :1])*(1 + 2*0.001) - 0.001 | |
color = torch.sigmoid(h[..., 1:])*(1 + 2*0.001) - 0.001 | |
return alpha, color, dx | |
def split_xyz(x): | |
xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) | |
return xy, yz, xz | |
def encode_x(self, xyz, bound): | |
# x: [N, 3], in [-bound, bound] | |
N, M = xyz.shape | |
xy, yz, xz = self.split_xyz(xyz) | |
feat_xy = self.encoder_xy(xy, bound=bound) | |
feat_yz = self.encoder_yz(yz, bound=bound) | |
feat_xz = self.encoder_xz(xz, bound=bound) | |
return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) | |
def encode_audio(self, a): | |
# a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech | |
# if emb, a should be: [1, 16] or [8, 16] | |
# fix audio traininig | |
if a is None: return None | |
if self.emb: | |
a = self.embedding(a).transpose(-1, -2).contiguous() # [1/8, 29, 16] | |
enc_a = self.audio_net(a) # [1/8, 64] | |
if self.att > 0: | |
enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64] | |
return enc_a | |
def predict_uncertainty(self, unc_inp): | |
if self.testing or not self.opt.unc_loss: | |
unc = torch.zeros_like(unc_inp) | |
else: | |
unc = self.unc_net(unc_inp.detach()) | |
return unc | |
def forward(self, x, d, enc_a, c, e=None): | |
# x: [N, 3], in [-bound, bound] | |
# d: [N, 3], nomalized in [-1, 1] | |
# enc_a: [1, aud_dim] | |
# c: [1, ind_dim], individual code | |
# e: [1, 1], eye feature | |
enc_x = self.encode_x(x, bound=self.bound) | |
sigma_result = self.density(x, enc_a, e, enc_x) | |
sigma = sigma_result['sigma'] | |
geo_feat = sigma_result['geo_feat'] | |
aud_ch_att = sigma_result['ambient_aud'] | |
eye_att = sigma_result['ambient_eye'] | |
# color | |
enc_d = self.encoder_dir(d) | |
if c is not None: | |
h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1) | |
else: | |
h = torch.cat([enc_d, geo_feat], dim=-1) | |
h_color = self.color_net(h) | |
color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001 | |
uncertainty = self.predict_uncertainty(enc_x) | |
uncertainty = torch.log(1 + torch.exp(uncertainty)) | |
return sigma, color, aud_ch_att, eye_att, uncertainty[..., None] | |
def density(self, x, enc_a, e=None, enc_x=None): | |
# x: [N, 3], in [-bound, bound] | |
if enc_x is None: | |
enc_x = self.encode_x(x, bound=self.bound) | |
enc_a = enc_a.repeat(enc_x.shape[0], 1) | |
aud_ch_att = self.aud_ch_att_net(enc_x) | |
enc_w = enc_a * aud_ch_att | |
if e is not None: | |
# e = self.encoder_eye(e) | |
eye_att = torch.sigmoid(self.eye_att_net(enc_x)) | |
e = e * eye_att | |
# e = e.repeat(enc_x.shape[0], 1) | |
h = torch.cat([enc_x, enc_w, e], dim=-1) | |
else: | |
h = torch.cat([enc_x, enc_w], dim=-1) | |
h = self.sigma_net(h) | |
sigma = torch.exp(h[..., 0]) | |
geo_feat = h[..., 1:] | |
return { | |
'sigma': sigma, | |
'geo_feat': geo_feat, | |
'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True), | |
'ambient_eye' : eye_att, | |
} | |
# optimizer utils | |
def get_params(self, lr, lr_net, wd=0): | |
# ONLY train torso | |
if self.torso: | |
params = [ | |
{'params': self.torso_encoder.parameters(), 'lr': lr}, | |
{'params': self.torso_deform_encoder.parameters(), 'lr': lr, 'weight_decay': wd}, | |
{'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, | |
{'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, | |
{'params': self.anchor_points, 'lr': lr_net, 'weight_decay': wd} | |
] | |
if self.individual_dim_torso > 0: | |
params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd}) | |
return params | |
params = [ | |
{'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, | |
{'params': self.encoder_xy.parameters(), 'lr': lr}, | |
{'params': self.encoder_yz.parameters(), 'lr': lr}, | |
{'params': self.encoder_xz.parameters(), 'lr': lr}, | |
# {'params': self.encoder_xyz.parameters(), 'lr': lr}, | |
{'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, | |
{'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, | |
] | |
if self.att > 0: | |
params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) | |
if self.emb: | |
params.append({'params': self.embedding.parameters(), 'lr': lr}) | |
if self.individual_dim > 0: | |
params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) | |
if self.train_camera: | |
params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0}) | |
params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0}) | |
params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) | |
params.append({'params': self.unc_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) | |
params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) | |
return params |