from functools import partial import logging from torch import nn, Tensor from einops import rearrange, repeat from typing import Dict, Optional, Union import torch import torch.nn.functional as F from config import AutoConfig from backbone import ( SubjectTimeEmbed, build_backbone, AdaLNLoRADiNOv2ViT, build_backbone_prev, build_time_emd, ) from blocks import ( PreviousFeatureMLPs, SubjectPreviousFrameCompress, build_class_token_mlp_prev, build_conv_blocks, build_class_token_mlp, DictConvBlocks, ClassTokenMLPs, build_ftr_compress, build_prev_compress, build_prev_feat_mlp, ) from behav_embed import build_behavior_embed, SubjectBehaviorEmbed from config_utils import load_from_yaml from topyneck import ( CoordsFreeWeights, build_coords_free_weights, build_coords_mlp, CachedCoordsMLP, build_voxelouts_weight, CoordsMLPLinearWeight, VoxelNonShareLinearWeight, ) import numpy as np def get_coords(): import nilearn from nilearn import datasets, surface fsaverage = nilearn.datasets.fetch_surf_fsaverage("fsaverage7") lh_coords, lh_faces = nilearn.surface.load_surf_mesh(fsaverage["sphere_left"]) rh_coords, rh_faces = nilearn.surface.load_surf_mesh(fsaverage["sphere_right"]) lh_xmin, lh_xmax = np.min(lh_coords[:, 0]), np.max(lh_coords[:, 0]) lh_xmax = lh_xmin + (lh_xmax - lh_xmin) * 1.5 rh_xmin, rh_xmax = np.min(rh_coords[:, 0]), np.max(rh_coords[:, 0]) if rh_xmin < lh_xmax: rh_coords[:, 0] += lh_xmax - rh_xmin coords = np.concatenate((lh_coords, rh_coords), axis=0) coords = torch.tensor(coords) return coords # %% class BrainEncodingModel(nn.Module): def __init__( self, cfg: AutoConfig, ): super().__init__() n_voxel_dict = {'subj01': 327684, 'subj02': 327684, 'subj03': 327684, 'subj04': 327684, 'subj05': 327684, 'subj06': 327684, 'subj07': 327684, 'subj08': 327684} self.subject_list = list(n_voxel_dict.keys()) self.layers = cfg.MODEL.BACKBONE.LAYERS self.layers_small = cfg.MODEL.BACKBONE_SMALL.LAYERS self.n_layers = len(self.layers) r = cfg.MODEL.WIDTH_RATIO cfg.MODEL.CONV_HEAD.WIDTH = int(cfg.MODEL.CONV_HEAD.WIDTH * r) self.cfg = cfg self.behav_embed: SubjectBehaviorEmbed = build_behavior_embed(cfg) # behavior is not used, just a placeholder self.backbone: AdaLNLoRADiNOv2ViT = build_backbone(cfg) self.conv_blocks: DictConvBlocks = build_conv_blocks(cfg) self.cls_blocks: ClassTokenMLPs = build_class_token_mlp(cfg) def build_each_subject(fn, subject_list): return nn.ModuleDict({subject: fn() for subject in subject_list}) self.coords = get_coords() # [327684, 3], for layer selector and retina mapper self.coords = nn.Parameter(self.coords, requires_grad=False) self.layer_selector: Dict[str, CachedCoordsMLP] = build_each_subject( partial( build_coords_mlp, cfg=cfg, in_dim=cfg.POSITION_ENCODING.IN_DIM, out_dim=self.n_layers, act_fn=partial(nn.Softmax, dim=-1), ), self.subject_list, ) self.retina_mapper: Dict[str, CachedCoordsMLP] = build_each_subject( partial( build_coords_mlp, cfg=cfg, in_dim=cfg.POSITION_ENCODING.IN_DIM, out_dim=2, act_fn=nn.Tanh, ), self.subject_list, ) self.mu_sigma = cfg.MODEL.RETINA_MAPPER.CONSTANT_SIGMA # voxel-wise output d_model = self.cfg.MODEL.CONV_HEAD.WIDTH self.n_voxel_dict = n_voxel_dict self.d_model = d_model self.voxel_outs_weight: Dict[ str, Union[VoxelNonShareLinearWeight, CoordsMLPLinearWeight] ] = nn.ModuleDict( { subject: build_voxelouts_weight(cfg, self.n_voxel_dict[subject], self.d_model) for subject in self.subject_list } ) def forward( self, x: Tensor, # [B, C, H, W] subject: str, voxel_indices: Optional[Tensor] = None, chunk_size=4096, ): bsz = x.shape[0] device = x.device dtype = x.dtype # bhv is not used, just a placeholder bhv = torch.zeros((bsz, self.cfg.MODEL.COND.IN_DIM), device=device, dtype=dtype) # [B, D_B=35] c = self.behav_embed(bhv, subject=subject) # [B, D_C] x_retina_grid, x_cls_dict = self.backbone.get_intermediate_layers( x, n=self.layers, c=c ) x_retina_grid = self.conv_blocks(x_retina_grid) x_cls_dict = self.cls_blocks(x_cls_dict) x_cls = torch.stack(list(x_cls_dict.values()), dim=-1) # [B, D, 4] ############################# ### voxel-wise prediction ### ############################# # divide voxels into chunks to avoid OOM coords = self.coords n_voxels = coords.shape[0] if voxel_indices is None or voxel_indices == ...: voxel_indices = torch.arange(n_voxels, device=coords.device) voxel_indices_chunks = torch.split(voxel_indices, chunk_size) out_ys, reg_layers = [], [] for voxel_indices_chunk in voxel_indices_chunks: out_y, reg_layer = self._forward_voxels( x_retina_grid, x_cls, subject, coords, voxel_indices_chunk, bsz, device, dtype ) out_ys.append(out_y) reg_layers.append(reg_layer) out_y = torch.cat(out_ys, dim=1) # [B, N] reg_layer = torch.cat(reg_layers, dim=0).mean() # [1] # if self.training: # return out_y, reg_layer # else: return out_y def _forward_voxels( self, x_retina_grid: Dict[str, Tensor], # {layer: [B, D, H/k, W/k], ...} x_cls: Tensor, # [B, D, 4] subject: str, coords: Tensor, voxel_indices: Tensor, bsz, device, dtype, ): N = len(voxel_indices) ## Layer Selector w_layer = self.layer_selector[subject](coords, voxel_indices) # [N, 4] # regularization def entropy(x): return (x * x.log()).sum(dim=1) if self.training and next(self.layer_selector.parameters()).requires_grad: reg_layer = entropy(w_layer) # [N] else: reg_layer = torch.zeros_like(w_layer[:, 0]) # [N] x_cls = repeat(x_cls, "b d l -> b n d l", n=1) _w_layer = repeat(w_layer, "n l -> b n d l", b=1, d=1) x_cls = (x_cls * _w_layer).sum(dim=-1) # [B, N, D] ## Retina Mapper mu = self.retina_mapper[subject](coords, voxel_indices) # [N, 2] mu = mu * (1 - self.mu_sigma) if self.training: norm = torch.normal(0, torch.ones_like(mu) * self.mu_sigma) mu = mu + norm bsz = x_cls.shape[0] mu = repeat(mu, "n d -> b n d", b=bsz) mu = rearrange(mu, "b n (d c) -> b n d c", d=1, c=2) if self.cfg.EXPERIMENTAL.USE_LAYER_SELECTOR: _w_layer = repeat(w_layer, "n l -> b n l", b=1) x_retina = None # [B, N, D] for i, layer in zip(range(self.n_layers), self.layers): x = x_retina_grid[str(layer)] _x_retina = F.grid_sample( x, mu, mode="bilinear", padding_mode="zeros", align_corners=False, ) # [B, C, N, D] (C=D_model, D=1, N=N_voxels) _x_retina = rearrange(_x_retina, "b c n d -> b n (c d)") if self.cfg.EXPERIMENTAL.USE_LAYER_SELECTOR: _x_retina = _x_retina * _w_layer[:, :, i : i + 1] if x_retina is None: x_retina = _x_retina else: x_retina += _x_retina # x_retina: [B, N, D] x_y = x_retina + x_cls # [B, N, D] # T=0 w, b = self.voxel_outs_weight[subject](coords, voxel_indices) # [N, DDD], [N] out_y = (x_y * w.unsqueeze(0)).mean(-1) + b.unsqueeze(0) # [B, N] return out_y, reg_layer # [B, N], [N]