|
from collections import OrderedDict |
|
from functools import partial |
|
import os |
|
from copy import copy |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from einops import einsum, rearrange, repeat |
|
|
|
from torch import Tensor, nn |
|
|
|
from config import AutoConfig |
|
from point_pe import point_position_encoding |
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, max_steps=1000, features=32, periods=10000): |
|
super().__init__() |
|
self.pe = partial( |
|
point_position_encoding, |
|
max_steps=max_steps, |
|
features=features, |
|
periods=periods, |
|
) |
|
|
|
@torch.no_grad() |
|
def forward(self, x): |
|
return self.pe(x) |
|
|
|
|
|
def coords_mlp( |
|
in_dim, |
|
out_dim, |
|
hidden_dim=256, |
|
depth=3, |
|
act_fn=nn.GELU, |
|
max_steps=100, |
|
features=32, |
|
periods=100, |
|
fi_act_fn=nn.Identity, |
|
): |
|
assert depth >= 2 |
|
modules = [] |
|
modules.append( |
|
PositionalEncoding(max_steps=max_steps, features=features, periods=periods) |
|
) |
|
in_dim = in_dim * features * 2 |
|
for i in range(depth - 1): |
|
modules.append(nn.Linear(in_dim if i == 0 else hidden_dim, hidden_dim)) |
|
modules.append(act_fn()) |
|
modules.append(nn.Linear(hidden_dim, out_dim)) |
|
modules.append(fi_act_fn()) |
|
return nn.Sequential(*modules) |
|
|
|
|
|
class CachedCoordsMLP(nn.Module): |
|
|
|
def __init__(self, in_dim, out_dim, hidden_dim=256, depth=3, act_fn=nn.Identity): |
|
super().__init__() |
|
self.mlp = coords_mlp( |
|
in_dim, out_dim, hidden_dim=hidden_dim, depth=depth, fi_act_fn=act_fn |
|
) |
|
self.cache = None |
|
|
|
def forward(self, coords, voxel_indices): |
|
if self.training and self.is_req_grad: |
|
self.cache = None |
|
return self.mlp(coords[voxel_indices]) |
|
else: |
|
with torch.no_grad(): |
|
if self.cache is None: |
|
self.cache = self.mlp(coords) |
|
return self.cache[voxel_indices] |
|
|
|
@property |
|
def is_req_grad(self): |
|
return next(self.parameters()).requires_grad |
|
|
|
def build_coords_mlp( |
|
cfg: AutoConfig, in_dim, out_dim, act_fn=partial(nn.Softmax, dim=-1) |
|
): |
|
return CachedCoordsMLP( |
|
in_dim, |
|
out_dim, |
|
hidden_dim=cfg.MODEL.COORDS_MLP.WIDTH, |
|
depth=cfg.MODEL.COORDS_MLP.DEPTH, |
|
act_fn=act_fn, |
|
) |
|
|
|
class CoordsFreeWeights(nn.Module): |
|
def __init__(self, out_dim, n): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.zeros(n, out_dim)) |
|
|
|
def forward(self, coords, voxel_indices=..., *args, **kwargs): |
|
w = self.weight[voxel_indices] |
|
return w |
|
|
|
@property |
|
def is_req_grad(self): |
|
return next(self.parameters()).requires_grad |
|
|
|
def build_coords_free_weights(cfg: AutoConfig, out_dim, n): |
|
return CoordsFreeWeights(out_dim, n) |
|
|
|
|
|
class VoxelNonShareLinearWeight(nn.Module): |
|
def __init__(self, d_model, n_voxels, **kwargs): |
|
super().__init__() |
|
dummy = nn.Linear(d_model, n_voxels) |
|
self.weight = nn.Parameter(dummy.weight) |
|
self.bias = nn.Parameter(dummy.bias) |
|
|
|
def forward(self, coords, voxel_indices=..., *args, **kwargs): |
|
w = self.weight[voxel_indices] |
|
b = self.bias[voxel_indices] |
|
return w, b |
|
|
|
|
|
class CoordsMLPLinearWeight(nn.Module): |
|
def __init__(self, d_model, n_voxels, in_dim=3, hidden_dim=256, depth=3, **kwargs): |
|
super().__init__() |
|
self.w_mlp = CachedCoordsMLP( |
|
in_dim, d_model, hidden_dim=hidden_dim, depth=depth |
|
) |
|
self.b = nn.Parameter(torch.zeros(n_voxels)) |
|
|
|
def forward(self, coords, voxel_indices=..., *args, **kwargs): |
|
w = self.w_mlp(coords, voxel_indices) |
|
b = self.b[voxel_indices] |
|
return w, b |
|
|
|
|
|
def build_voxelouts_weight(cfg: AutoConfig, n_voxels, d_model): |
|
kwargs = { |
|
"d_model": d_model, |
|
"n_voxels": n_voxels, |
|
"in_dim": cfg.POSITION_ENCODING.IN_DIM, |
|
"hidden_dim": cfg.MODEL.COORDS_MLP.WIDTH, |
|
"depth": cfg.MODEL.COORDS_MLP.DEPTH, |
|
} |
|
if cfg.MODEL.VOXEL_OUTS.SHARED.USE: |
|
kwargs["hidden_dim"] = cfg.MODEL.VOXEL_OUTS.SHARED.MLP.WIDTH |
|
kwargs["depth"] = cfg.MODEL.VOXEL_OUTS.SHARED.MLP.DEPTH |
|
return CoordsMLPLinearWeight(**kwargs) |
|
else: |
|
return VoxelNonShareLinearWeight(**kwargs) |
|
|
|
|
|
class LinearBlock(nn.Module): |
|
def __init__(self, in_planes, n): |
|
super(LinearBlock, self).__init__() |
|
dummy = nn.Linear(in_planes, n) |
|
self.weight = nn.Parameter(dummy.weight.unsqueeze(0)) |
|
self.bias = nn.Parameter(dummy.bias.unsqueeze(0)) |
|
|
|
def forward(self, x, voxel_indices=None): |
|
voxel_indices = ... if voxel_indices is None else voxel_indices |
|
out = (x * self.weight[:, voxel_indices, :]).mean(dim=-1) |
|
out += self.bias[:, voxel_indices] |
|
return out |
|
|
|
|
|
class VoxelOutBlock(nn.Module): |
|
|
|
def __init__(self, in_planes, n, planes=32, depth=1): |
|
super(VoxelOutBlock, self).__init__() |
|
planes = in_planes if planes is None else planes |
|
self.weight = nn.ParameterList() |
|
self.bias = nn.ParameterList() |
|
self.act = nn.GELU() |
|
|
|
self.depth = depth |
|
for i in range(depth): |
|
o = planes if i < depth - 1 else 1 |
|
weight = [] |
|
bias = [] |
|
for j in range(o): |
|
dummy = nn.Linear( |
|
in_planes if i == 0 else planes, |
|
n, |
|
) |
|
weight.append(dummy.weight.unsqueeze(0).clone()) |
|
bias.append(dummy.bias.unsqueeze(0).clone()) |
|
weight = torch.cat(weight, dim=0) |
|
bias = torch.cat(bias, dim=0) |
|
weight = rearrange(weight, "o n i -> n i o", n=n, o=o) |
|
bias = rearrange(bias, "o n -> n o", n=n, o=o) |
|
self.weight.append(nn.Parameter(weight)) |
|
self.bias.append(nn.Parameter(bias)) |
|
|
|
def forward(self, x, voxel_indices=None): |
|
voxel_indices = ... if voxel_indices is None else voxel_indices |
|
for ww, bb in zip(self.weight, self.bias): |
|
w = ww[voxel_indices] |
|
b = bb[voxel_indices] |
|
x = einsum(x, w, "b n i, n i o -> b n o") |
|
x /= w.shape[1] |
|
x += b[None, ...] |
|
if x.shape[-1] != 1: |
|
x = self.act(x) |
|
x = x.squeeze(-1) |
|
return x |
|
|
|
|
|
class NeuronProjector(nn.Module): |
|
def __init__( |
|
self, |
|
cfg: AutoConfig, |
|
layer_list: List[str], |
|
neuron_coords: Tensor, |
|
act_fn=nn.GELU, |
|
): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.layer_list = layer_list |
|
self.neuron_coords = neuron_coords |
|
self.neuron_coords.requires_grad = False |
|
self.act_fn = act_fn |
|
|
|
self.projectors = nn.ModuleDict() |
|
self.eye_shifters = nn.ModuleDict() |
|
|
|
if self.cfg.MODEL.NEURON_PROJECTOR.SEPARATE_LAYERS: |
|
for layer in self.layer_list: |
|
k = layer.replace(".", "_") |
|
self.projectors[k] = self.build_neuron_projector( |
|
neuron_coords.shape[-1] |
|
) |
|
self.eye_shifters[k] = self.build_eye_shifter() |
|
else: |
|
shared_projector = self.build_neuron_projector(neuron_coords.shape[-1]) |
|
shared_eye_shifter = self.build_eye_shifter() |
|
for layer in self.layer_list: |
|
k = layer.replace(".", "_") |
|
self.projectors[k] = shared_projector |
|
self.eye_shifters[k] = shared_eye_shifter |
|
|
|
self.layer_gate = self.build_layer_gate( |
|
neuron_coords.shape[-1], len(layer_list) |
|
) |
|
|
|
def forward(self, batch_size, eye_coords=None, voxel_indices=None): |
|
if next(self.projectors.parameters()).requires_grad: |
|
grids, coord_inp, (reg_mu1, reg_mu2, reg_mu3) = self._forward( |
|
batch_size, eye_coords, voxel_indices |
|
) |
|
else: |
|
with torch.no_grad(): |
|
grids, coord_inp, (reg_mu1, reg_mu2, reg_mu3) = self._forward( |
|
batch_size, eye_coords, voxel_indices |
|
) |
|
|
|
if next(self.layer_gate.parameters()).requires_grad: |
|
gate = self.layer_gate(coord_inp) |
|
else: |
|
with torch.no_grad(): |
|
gate = self.layer_gate(coord_inp) |
|
|
|
return grids, gate, (reg_mu1, reg_mu2, reg_mu3) |
|
|
|
def _forward( |
|
self, |
|
batch_size, |
|
eye_coords=None, |
|
voxel_indices=None, |
|
): |
|
if self.neuron_coords.device != self.device: |
|
self.neuron_coords = self.neuron_coords.to(self.device) |
|
|
|
voxel_indices = ... if voxel_indices is None else voxel_indices |
|
coord_inp = self.neuron_coords[voxel_indices] |
|
|
|
|
|
|
|
|
|
grids = {} |
|
for layer in self.layer_list: |
|
k = layer.replace(".", "_") |
|
|
|
mu = self.projectors[k](coord_inp) |
|
|
|
if self.training and next(self.projectors.parameters()).requires_grad: |
|
reg_mu1 = torch.cdist(mu, mu, p=2) |
|
reg_mu1 = 1.0 / (reg_mu1 + 1e-3) |
|
reg_mu1 = reg_mu1.mean() |
|
reg_mu2 = torch.sqrt((mu**2).sum(dim=-1)).mean() |
|
reg_mu3 = mu[:, 0].mean() ** 2 + mu[:, 1].mean() ** 2 |
|
else: |
|
reg_mu1 = torch.tensor(0.0) |
|
reg_mu2 = torch.tensor(0.0) |
|
reg_mu3 = torch.tensor(0.0) |
|
|
|
mu = repeat(mu, "n c -> b n c", b=batch_size) |
|
|
|
if self.training: |
|
norm = torch.normal( |
|
0, |
|
torch.ones_like(mu) * self.cfg.MODEL.NEURON_PROJECTOR.SIGMA_SCALE, |
|
) |
|
mu = mu + norm |
|
|
|
if eye_coords is not None: |
|
shift = self.eye_shifters[k](eye_coords) |
|
shift = repeat(shift, "b c -> b n c", n=mu.shape[1]) |
|
mu += shift |
|
|
|
grid = rearrange(mu, "b n (d c) -> b n d c", d=1, c=2) |
|
|
|
grids[layer] = grid |
|
|
|
return grids, coord_inp, (reg_mu1, reg_mu2, reg_mu3) |
|
|
|
def build_layer_gate(self, location_dim, num_layers): |
|
depth = self.cfg.MODEL.LAYER_GATE.DEPTH |
|
width = self.cfg.MODEL.LAYER_GATE.WIDTH |
|
assert depth >= 2 |
|
modules = [] |
|
for i in range(depth - 1): |
|
modules.append(nn.Linear(location_dim if i == 0 else width, width)) |
|
modules.append(self.act_fn()) |
|
output_dim = num_layers |
|
modules.append(nn.Linear(width, output_dim)) |
|
modules.append(nn.Softmax(dim=-1)) |
|
return nn.Sequential(*modules) |
|
|
|
def build_neuron_projector(self, location_dim, output_dim=None, final_act=nn.Tanh): |
|
depth = self.cfg.MODEL.NEURON_PROJECTOR.DEPTH |
|
width = self.cfg.MODEL.NEURON_PROJECTOR.WIDTH |
|
assert depth >= 2 |
|
modules = [] |
|
for i in range(depth - 1): |
|
modules.append(nn.Linear(location_dim if i == 0 else width, width)) |
|
modules.append(self.act_fn()) |
|
output_dim = 2 if output_dim is None else output_dim |
|
modules.append(nn.Linear(width, output_dim)) |
|
modules.append(final_act()) |
|
return nn.Sequential(*modules) |
|
|
|
def build_eye_shifter(self): |
|
return nn.Sequential(nn.Linear(2, 8), nn.SiLU(), nn.Linear(8, 2), nn.Tanh()) |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
|
|
class TopyNeck(nn.Module): |
|
def __init__( |
|
self, |
|
cfg: AutoConfig, |
|
in_c_dict: Dict[str, int], |
|
num_voxel_dict: Dict[str, int], |
|
neuron_coords_dict: Dict[str, Tensor], |
|
act_fn=nn.SiLU, |
|
): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.in_c_dict = in_c_dict |
|
self.layer_list = list(self.in_c_dict.keys()) |
|
self.act_fn = act_fn |
|
self.num_voxel_dict = num_voxel_dict |
|
self.neuron_coords_dict = neuron_coords_dict |
|
for k in self.neuron_coords_dict.keys(): |
|
self.neuron_coords_dict[k].requires_grad = False |
|
self.num_neuron_latent = self.cfg.MODEL.NEURON_PROJECTOR.NUM_NEURON_LATENT |
|
assert self.num_neuron_latent == 1 |
|
self.subject_list = list(self.num_voxel_dict.keys()) |
|
|
|
self.planes = self.cfg.MODEL.NECK.CONV_HEAD.WIDTH |
|
|
|
self.neuron_projectors = nn.ModuleDict() |
|
self.layer_gates = nn.ModuleDict() |
|
self.mean_method = self.cfg.MODEL.LAYER_GATE.MEAN |
|
|
|
self.voxel_outs = nn.ModuleDict() |
|
|
|
for subject in self.subject_list: |
|
self.add_subject(subject, self.neuron_coords_dict[subject], overwrite=True) |
|
|
|
self.previous_layer_requires_grad = False |
|
|
|
def add_subject( |
|
self, |
|
subject, |
|
neuron_coords, |
|
overwrite=False, |
|
use_linear=True, |
|
nonlinear_depth=3, |
|
nonlinear_planes=32, |
|
): |
|
if subject in self.subject_list and not overwrite: |
|
return |
|
if subject not in self.subject_list: |
|
self.subject_list.append(subject) |
|
|
|
neuron_coords.requires_grad = False |
|
num_voxels = neuron_coords.shape[0] |
|
num_layers = len(self.layer_list) |
|
self.num_voxel_dict[subject] = num_voxels |
|
self.neuron_coords_dict[subject] = neuron_coords |
|
|
|
self.neuron_projectors[subject] = NeuronProjector( |
|
self.cfg, self.layer_list, neuron_coords |
|
) |
|
|
|
if use_linear: |
|
self.voxel_outs[subject] = VoxelOutBlock( |
|
|
|
self.planes, |
|
self.num_voxel_dict[subject], |
|
depth=1, |
|
) |
|
|
|
|
|
|
|
|
|
else: |
|
self.voxel_outs[subject] = VoxelOutBlock( |
|
|
|
self.planes, |
|
self.num_voxel_dict[subject], |
|
depth=nonlinear_depth, |
|
planes=nonlinear_planes, |
|
) |
|
|
|
def _forward_i( |
|
self, |
|
x, |
|
x_shift, |
|
indices, |
|
subject_id, |
|
session_id, |
|
eye_coords, |
|
voxel_indices=None, |
|
chuck_size=8000, |
|
): |
|
|
|
eye_coords = eye_coords[indices] if eye_coords is not None else None |
|
|
|
b = len(indices) |
|
d = self.num_neuron_latent |
|
|
|
def _grid_y(voxel_indices): |
|
grids, gate, reg_mu = self.neuron_projectors[subject_id]( |
|
b, eye_coords, voxel_indices |
|
) |
|
|
|
out_ys = None |
|
|
|
for i, (k, v) in enumerate(x.items()): |
|
w = gate[:, i] |
|
w = rearrange(w, "n -> 1 1 n 1") |
|
grid = grids[k] |
|
out_y = F.grid_sample( |
|
v[indices], |
|
grid, |
|
mode="bilinear", |
|
padding_mode="zeros", |
|
align_corners=False, |
|
) |
|
|
|
if self.mean_method == "mean": |
|
if ( |
|
not self.cfg.MODEL.LAYER_GATE.SKIP |
|
and self.cfg.OPTIMIZER.GATE_REGULARIZER < 100 |
|
): |
|
out_y = out_y * w |
|
if out_ys is None: |
|
out_ys = out_y |
|
else: |
|
out_ys += out_y |
|
elif self.mean_method == "geometric_mean": |
|
raise NotImplementedError("don't use geometric mean") |
|
out_y = out_y**w |
|
if out_ys is None: |
|
out_ys = out_y |
|
else: |
|
out_ys *= out_y |
|
else: |
|
raise NotImplementedError |
|
|
|
out_ys = out_ys * (1 / len(x)) |
|
return out_ys, gate, reg_mu |
|
|
|
def divide_chunks(l, n): |
|
chunks = [] |
|
for i in range(0, len(l), n): |
|
chunks.append(l[i : i + n]) |
|
return chunks |
|
|
|
def forward_one_chuck(voxel_indices, grad_flag): |
|
if grad_flag: |
|
y, gate_weights, reg_mu = _grid_y(voxel_indices) |
|
else: |
|
with torch.no_grad(): |
|
y, gate_weights, reg_mu = _grid_y(voxel_indices) |
|
y = rearrange(y, "b c n d -> b n (c d)") |
|
out = self.voxel_outs[subject_id](y, voxel_indices) |
|
return out, gate_weights, reg_mu |
|
|
|
if voxel_indices == ... or voxel_indices is None: |
|
voxel_indices = torch.arange( |
|
self.num_voxel_dict[subject_id], device=x[list(x.keys())[0]].device |
|
) |
|
|
|
voxel_index_chunks = divide_chunks(voxel_indices, chuck_size) |
|
|
|
grad_flag = self.training and ( |
|
next( |
|
self.neuron_projectors[subject_id].projectors.parameters() |
|
).requires_grad |
|
or next( |
|
self.neuron_projectors[subject_id].layer_gate.parameters() |
|
).requires_grad |
|
or next(self.voxel_outs[subject_id].parameters()).requires_grad |
|
or self.previous_layer_requires_grad |
|
) |
|
if not grad_flag: |
|
outs = [] |
|
for vi in voxel_index_chunks: |
|
out, gate_weights, reg_mu = forward_one_chuck(vi, grad_flag) |
|
outs.append(out) |
|
out = ( |
|
torch.cat(outs, dim=1) |
|
if len(outs) > 0 |
|
else torch.tensor([0 for _ in range(b)]) |
|
) |
|
reg_gate = torch.tensor(0.0) |
|
reg_mu = (torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)) |
|
else: |
|
outs = [] |
|
gate_weights = [] |
|
reg_mus = [] |
|
for vi in voxel_index_chunks: |
|
out, gate_weight, reg_mu = forward_one_chuck(vi, grad_flag) |
|
outs.append(out) |
|
gate_weights.append(gate_weight) |
|
reg_mus.append(reg_mu) |
|
out = ( |
|
torch.cat(outs, dim=1) |
|
if len(outs) > 0 |
|
else torch.tensor([0 for _ in range(b)]) |
|
) |
|
gate_weights = torch.cat(gate_weights, dim=0) |
|
|
|
def entropy(x): |
|
return (x * x.log()).sum(dim=1).mean() |
|
|
|
reg_gate = entropy(gate_weights) |
|
|
|
reg_mu1 = torch.stack([x[0] for x in reg_mus], dim=0).mean() |
|
reg_mu2 = torch.stack([x[1] for x in reg_mus], dim=0).mean() |
|
reg_mu3 = torch.stack([x[2] for x in reg_mus], dim=0).mean() |
|
reg_mu = (reg_mu1, reg_mu2, reg_mu3) |
|
|
|
reg_p_mu_shift = [0.0] * b |
|
|
|
return out, reg_gate, reg_mu, reg_p_mu_shift |
|
|
|
def forward( |
|
self, |
|
x: Dict[str, Tensor], |
|
subject_ids: List[str], |
|
session_ids: List[str] = None, |
|
eye_coords: List[Tensor] = None, |
|
voxel_indices_dict: Dict[str, Tensor] = None, |
|
x_shift=None, |
|
) -> List[Tensor]: |
|
|
|
for k, v in x.items(): |
|
if v.shape[-1] != v.shape[-2]: |
|
x[k] = rearrange(v, "b h w c -> b c h w") |
|
|
|
if isinstance(subject_ids, list): |
|
subject_ids = np.array(subject_ids) |
|
if isinstance(session_ids, list): |
|
session_ids = np.array(session_ids) |
|
|
|
out = [None for _ in range(len(subject_ids))] |
|
reg = [0.0 for _ in range(len(subject_ids))] |
|
unique_subject_ids = np.unique(subject_ids) |
|
for i_sub in unique_subject_ids: |
|
indices1 = subject_ids == i_sub |
|
indices1 = np.where(indices1)[0] |
|
unique_session_ids = np.unique(session_ids[indices1]) |
|
for i_sess in unique_session_ids: |
|
indices2 = session_ids[indices1] == i_sess |
|
indices2 = np.where(indices2)[0] |
|
indices = indices1[indices2] |
|
i_out, i_reg_gate, i_reg_mu, reg_p_mu_shift = self._forward_i( |
|
x, |
|
x_shift, |
|
indices, |
|
i_sub, |
|
i_sess, |
|
eye_coords, |
|
voxel_indices=voxel_indices_dict[i_sub] |
|
if voxel_indices_dict is not None |
|
else None, |
|
) |
|
for i, idx in enumerate(indices): |
|
out[idx] = i_out[i] |
|
i_reg = ( |
|
i_reg_gate * self.cfg.OPTIMIZER.GATE_REGULARIZER |
|
if self.cfg.OPTIMIZER.GATE_REGULARIZER < 100 |
|
else 0.0 |
|
+ i_reg_mu[0] * self.cfg.OPTIMIZER.MU_REGULARIZER_PDIST |
|
+ i_reg_mu[1] * self.cfg.OPTIMIZER.MU_REGULARIZER_PCENTER |
|
+ i_reg_mu[2] * self.cfg.OPTIMIZER.MU_REGULARIZER_MCENTER |
|
|
|
|
|
+ reg_p_mu_shift[i] * self.cfg.OPTIMIZER.P_MU_SHIFT_REGULARIZER |
|
) |
|
reg[idx] = i_reg |
|
return out, reg, x_shift |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|