Spaces:
Configuration error
Configuration error
File size: 5,025 Bytes
1ba539f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch.nn as nn
import torch
from lib.config import cfg
from .embedder import get_embedder
import torch.nn.functional as F
class Nerf(nn.Module):
def __init__(self,
D=8,
W=256,
input_ch=3,
input_ch_views=3,
skips=[4],
use_viewdirs=False):
"""
"""
super(Nerf, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList([nn.Linear(input_ch, W)] + [
nn.Linear(W, W) if i not in
self.skips else nn.Linear(W + input_ch, W) for i in range(D - 1)
])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList(
[nn.Linear(input_ch_views + W, W // 2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if self.use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W // 2, 3)
def forward(self, x):
input_pts = x
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
alpha = self.alpha_linear(h)
return alpha
def load_weights_from_keras(self, weights):
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
# Load pts_linears
for i in range(self.D):
idx_pts_linears = 2 * i
self.pts_linears[i].weight.data = torch.from_numpy(
np.transpose(weights[idx_pts_linears]))
self.pts_linears[i].bias.data = torch.from_numpy(
np.transpose(weights[idx_pts_linears + 1]))
# Load feature_linear
idx_feature_linear = 2 * self.D
self.feature_linear.weight.data = torch.from_numpy(
np.transpose(weights[idx_feature_linear]))
self.feature_linear.bias.data = torch.from_numpy(
np.transpose(weights[idx_feature_linear + 1]))
# Load views_linears
idx_views_linears = 2 * self.D + 2
self.views_linears[0].weight.data = torch.from_numpy(
np.transpose(weights[idx_views_linears]))
self.views_linears[0].bias.data = torch.from_numpy(
np.transpose(weights[idx_views_linears + 1]))
# Load rgb_linear
idx_rbg_linear = 2 * self.D + 4
self.rgb_linear.weight.data = torch.from_numpy(
np.transpose(weights[idx_rbg_linear]))
self.rgb_linear.bias.data = torch.from_numpy(
np.transpose(weights[idx_rbg_linear + 1]))
# Load alpha_linear
idx_alpha_linear = 2 * self.D + 6
self.alpha_linear.weight.data = torch.from_numpy(
np.transpose(weights[idx_alpha_linear]))
self.alpha_linear.bias.data = torch.from_numpy(
np.transpose(weights[idx_alpha_linear + 1]))
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.embed_fn, input_ch = get_embedder(cfg.xyz_res)
self.embeddirs_fn, input_ch_views = get_embedder(cfg.view_res)
skips = [4]
self.model = Nerf(D=cfg.netdepth,
W=cfg.netwidth,
input_ch=input_ch,
skips=skips,
input_ch_views=input_ch_views,
use_viewdirs=cfg.use_viewdirs)
# self.model_fine = Nerf(D=cfg.netdepth_fine,
# W=cfg.netwidth_fine,
# input_ch=input_ch,
# skips=skips,
# input_ch_views=input_ch_views,
# use_viewdirs=cfg.use_viewdirs)
def batchify(self, fn, chunk):
"""Constructs a version of 'fn' that applies to smaller batches.
"""
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
def forward(self, inputs, model=''):
"""Prepares inputs and applies network 'fn'.
"""
if model == 'fine':
fn = self.model_fine
else:
fn = self.model
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = self.embed_fn(inputs_flat)
outputs_flat = self.batchify(fn, cfg.netchunk)(embedded)
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
|