import torch from torch import nn from functools import partial from einops.layers.torch import Rearrange, Reduce from einops import rearrange pair = lambda x: x if isinstance(x, tuple) else (x, x) class PreNormResidual(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x): return self.fn(self.norm(x)) + x def FeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear): inner_dim = int(dim * expansion_factor) return nn.Sequential( dense(dim, inner_dim), nn.GELU(), nn.Dropout(dropout), dense(inner_dim, dim), nn.Dropout(dropout) ) class MappingSub2W(nn.Module): def __init__(self, N=8, dim=512, depth=6, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1): super(MappingSub2W, self).__init__() num_patches = N * 34 chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear self.layer = nn.Sequential( Rearrange('b c h w -> b (c h) w'), *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim), Rearrange('b c h -> b h c'), nn.Linear(34 * N, 34 * N), nn.LayerNorm(34 * N), nn.GELU(), nn.Linear(34 * N, N), Rearrange('b h c -> b c h') ) def forward(self, x): return self.layer(x) class MappingW2Sub(nn.Module): def __init__(self, N=8, dim=512, depth=8, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1): super(MappingW2Sub, self).__init__() self.N = N num_patches = N * 34 chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear self.layer = nn.Sequential( Rearrange('b c h -> b h c'), nn.Linear(N, num_patches), Rearrange('b h c -> b c h'), *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim) ) self.mu_fc = nn.Sequential( *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) ) for _ in range(2)], nn.LayerNorm(dim), nn.Tanh(), Rearrange('b c h -> b (c h)') ) self.var_fc = nn.Sequential( *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) ) for _ in range(2)], nn.LayerNorm(dim), nn.Tanh(), Rearrange('b c h -> b (c h)') ) def reparameterize(self, mu, logvar): """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, x): f = self.layer(x) mu = self.mu_fc(f) var = self.var_fc(f) z = self.reparameterize(mu, var) z = rearrange(z, 'a (b c d) -> a b c d', b=self.N, c=34) return rearrange(mu, 'a (b c d) -> a b c d', b=self.N, c=34), rearrange(var, 'a (b c d) -> a b c d', b=self.N, c=34), z class HeadEncoder(nn.Module): def __init__(self, N=8, dim=512, depth=2, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1): super(HeadEncoder, self).__init__() channels = [32, 64, 64, 64] self.N = N num_patches = N chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear self.s1 = nn.Sequential( nn.Conv2d(channels[0], channels[1], kernel_size=5, padding=2, stride=2), nn.BatchNorm2d(channels[1]), nn.LeakyReLU(), nn.Conv2d(channels[1], channels[2], kernel_size=5, padding=2, stride=2), nn.BatchNorm2d(channels[2]), nn.LeakyReLU(), nn.Conv2d(channels[2], channels[3], kernel_size=5, padding=2, stride=2), nn.BatchNorm2d(channels[3]), nn.LeakyReLU()) self.mlp1 = nn.Linear(channels[3] * 8 * 8, 512) self.up_N = nn.Linear(1, N) self.mu_fc = nn.Sequential( *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim), nn.Tanh() ) self.var_fc = nn.Sequential( *[nn.Sequential( PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim), nn.Tanh() ) def reparameterize(self, mu, logvar): """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, x): feature = self.s1(x) s2 = torch.flatten(feature, start_dim=1) s2 = self.mlp1(s2).unsqueeze(2) s2 = self.up_N(s2) s2 = rearrange(s2, 'b h c -> b c h') mu = self.mu_fc(s2) var = self.var_fc(s2) z = self.reparameterize(mu, var) return mu, var, z class RegionEncoder(nn.Module): def __init__(self, N=8): super(RegionEncoder, self).__init__() channels = [8, 16, 32, 32, 64, 64] self.s1 = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, stride=2) self.s2 = nn.Sequential( nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, stride=2), nn.BatchNorm2d(channels[1]), nn.LeakyReLU(), nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1, stride=2), nn.BatchNorm2d(channels[2]), nn.LeakyReLU() ) self.heads = nn.ModuleList() for i in range(34): self.heads.append(HeadEncoder(N=N)) def forward(self, x, all_mask=None): s1 = self.s1(x) s2 = self.s2(s1) result = [] mus = [] log_vars = [] for i, head in enumerate(self.heads): m = all_mask[:, i, :].unsqueeze(1) mu, var, z = head(s2 * m) result.append(z.unsqueeze(2)) mus.append(mu.unsqueeze(2)) log_vars.append(var.unsqueeze(2)) return torch.cat(mus, dim=2), torch.cat(log_vars, dim=2), torch.cat(result, dim=2)