Spaces:
Configuration error
Configuration error
import torch.nn as nn | |
import torch | |
from lib.config import cfg | |
from .embedder import get_embedder | |
import torch.nn.functional as F | |
class Nerf(nn.Module): | |
def __init__(self, | |
D=8, | |
W=256, | |
input_ch=3, | |
input_ch_views=3, | |
skips=[4], | |
use_viewdirs=False): | |
""" | |
""" | |
super(Nerf, self).__init__() | |
self.D = D | |
self.W = W | |
self.input_ch = input_ch | |
self.input_ch_views = input_ch_views | |
self.skips = skips | |
self.use_viewdirs = use_viewdirs | |
self.pts_linears = nn.ModuleList([nn.Linear(input_ch, W)] + [ | |
nn.Linear(W, W) if i not in | |
self.skips else nn.Linear(W + input_ch, W) for i in range(D - 1) | |
]) | |
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) | |
self.views_linears = nn.ModuleList( | |
[nn.Linear(input_ch_views + W, W // 2)]) | |
### Implementation according to the paper | |
# self.views_linears = nn.ModuleList( | |
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) | |
if self.use_viewdirs: | |
self.feature_linear = nn.Linear(W, W) | |
self.alpha_linear = nn.Linear(W, 1) | |
self.rgb_linear = nn.Linear(W // 2, 3) | |
def forward(self, x): | |
input_pts, input_views = torch.split( | |
x, [self.input_ch, self.input_ch_views], dim=-1) | |
h = input_pts | |
for i, l in enumerate(self.pts_linears): | |
h = self.pts_linears[i](h) | |
h = F.relu(h) | |
if i in self.skips: | |
h = torch.cat([input_pts, h], -1) | |
if self.use_viewdirs: | |
alpha = self.alpha_linear(h) | |
feature = self.feature_linear(h) | |
h = torch.cat([feature, input_views], -1) | |
for i, l in enumerate(self.views_linears): | |
h = self.views_linears[i](h) | |
h = F.relu(h) | |
rgb = self.rgb_linear(h) | |
outputs = torch.cat([rgb, alpha], -1) | |
else: | |
outputs = self.output_linear(h) | |
return outputs | |
def load_weights_from_keras(self, weights): | |
assert self.use_viewdirs, "Not implemented if use_viewdirs=False" | |
# Load pts_linears | |
for i in range(self.D): | |
idx_pts_linears = 2 * i | |
self.pts_linears[i].weight.data = torch.from_numpy( | |
np.transpose(weights[idx_pts_linears])) | |
self.pts_linears[i].bias.data = torch.from_numpy( | |
np.transpose(weights[idx_pts_linears + 1])) | |
# Load feature_linear | |
idx_feature_linear = 2 * self.D | |
self.feature_linear.weight.data = torch.from_numpy( | |
np.transpose(weights[idx_feature_linear])) | |
self.feature_linear.bias.data = torch.from_numpy( | |
np.transpose(weights[idx_feature_linear + 1])) | |
# Load views_linears | |
idx_views_linears = 2 * self.D + 2 | |
self.views_linears[0].weight.data = torch.from_numpy( | |
np.transpose(weights[idx_views_linears])) | |
self.views_linears[0].bias.data = torch.from_numpy( | |
np.transpose(weights[idx_views_linears + 1])) | |
# Load rgb_linear | |
idx_rbg_linear = 2 * self.D + 4 | |
self.rgb_linear.weight.data = torch.from_numpy( | |
np.transpose(weights[idx_rbg_linear])) | |
self.rgb_linear.bias.data = torch.from_numpy( | |
np.transpose(weights[idx_rbg_linear + 1])) | |
# Load alpha_linear | |
idx_alpha_linear = 2 * self.D + 6 | |
self.alpha_linear.weight.data = torch.from_numpy( | |
np.transpose(weights[idx_alpha_linear])) | |
self.alpha_linear.bias.data = torch.from_numpy( | |
np.transpose(weights[idx_alpha_linear + 1])) | |
class Network(nn.Module): | |
def __init__(self): | |
super(Network, self).__init__() | |
self.embed_fn, input_ch = get_embedder(cfg.xyz_res) | |
self.embeddirs_fn, input_ch_views = get_embedder(cfg.view_res) | |
skips = [4] | |
self.model = Nerf(D=cfg.netdepth, | |
W=cfg.netwidth, | |
input_ch=input_ch, | |
skips=skips, | |
input_ch_views=input_ch_views, | |
use_viewdirs=cfg.use_viewdirs) | |
self.model_fine = Nerf(D=cfg.netdepth_fine, | |
W=cfg.netwidth_fine, | |
input_ch=input_ch, | |
skips=skips, | |
input_ch_views=input_ch_views, | |
use_viewdirs=cfg.use_viewdirs) | |
def batchify(self, fn, chunk): | |
"""Constructs a version of 'fn' that applies to smaller batches. | |
""" | |
def ret(inputs): | |
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) | |
return ret | |
def forward(self, inputs, viewdirs, model=''): | |
"""Prepares inputs and applies network 'fn'. | |
""" | |
if model == 'fine': | |
fn = self.model_fine | |
else: | |
fn = self.model | |
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) | |
embedded = self.embed_fn(inputs_flat) | |
input_dirs = viewdirs[:,None].expand(inputs.shape) | |
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) | |
embedded_dirs = self.embeddirs_fn(input_dirs_flat) | |
embedded = torch.cat([embedded, embedded_dirs], -1) | |
outputs_flat = self.batchify(fn, cfg.netchunk)(embedded) | |
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) | |
return outputs | |