StableVideo / stablevideo /implicit_neural_networks.py
wchai's picture
init
4c9c42b
raw
history blame
2.84 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# code taken from https://github.com/ykasten/layered-neural-atlases
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def positionalEncoding_vec(in_tensor, b):
proj = torch.einsum("ij, k -> ijk", in_tensor, b) # shape (batch, in_tensor.size(1), freqNum)
mapped_coords = torch.cat((torch.sin(proj), torch.cos(proj)), dim=1) # shape (batch, 2*in_tensor.size(1), freqNum)
output = mapped_coords.transpose(2, 1).contiguous().view(mapped_coords.size(0), -1)
return output
class IMLP(nn.Module):
def __init__(
self,
input_dim,
output_dim,
hidden_dim=256,
use_positional=True,
positional_dim=10,
skip_layers=[4, 6],
num_layers=8, # includes the output layer
verbose=True,
use_tanh=True,
apply_softmax=False,
):
super(IMLP, self).__init__()
self.verbose = verbose
self.use_tanh = use_tanh
self.apply_softmax = apply_softmax
if apply_softmax:
self.softmax = nn.Softmax()
if use_positional:
encoding_dimensions = 2 * input_dim * positional_dim
self.b = torch.tensor([(2 ** j) * np.pi for j in range(positional_dim)], requires_grad=False)
else:
encoding_dimensions = input_dim
self.hidden = nn.ModuleList()
for i in range(num_layers):
if i == 0:
input_dims = encoding_dimensions
elif i in skip_layers:
input_dims = hidden_dim + encoding_dimensions
else:
input_dims = hidden_dim
if i == num_layers - 1:
# last layer
self.hidden.append(nn.Linear(input_dims, output_dim, bias=True))
else:
self.hidden.append(nn.Linear(input_dims, hidden_dim, bias=True))
self.skip_layers = skip_layers
self.num_layers = num_layers
self.positional_dim = positional_dim
self.use_positional = use_positional
if self.verbose:
print(f"Model has {count_parameters(self)} params")
def forward(self, x):
if self.use_positional:
if self.b.device != x.device:
self.b = self.b.to(x.device)
pos = positionalEncoding_vec(x, self.b)
x = pos
input = x.detach().clone()
for i, layer in enumerate(self.hidden):
if i > 0:
x = F.relu(x)
if i in self.skip_layers:
x = torch.cat((x, input), 1)
x = layer(x)
if self.use_tanh:
x = torch.tanh(x)
if self.apply_softmax:
x = self.softmax(x)
return x