import torch.nn as nn from .normalize import Normalize from .conv import CausalConv3d import torch import numpy as np from einops import rearrange from .block import Block from .ops import video_to_image class LinearAttention(Block): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange( qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 ) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) out = rearrange( out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w ) return self.to_out(out) class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock3D(Block): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, t, h, w = q.shape q = q.reshape(b * t, c, h * w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b * t, c, h * w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b * t, c, h * w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, t, h, w) h_ = self.proj_out(h_) return x + h_ class AttnBlock(Block): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) @video_to_image def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b, c, h * w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) return x + h_ class TemporalAttnBlock(Block): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv3d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, t, h, w = q.shape q = rearrange(q, "b c t h w -> (b h w) t c") k = rearrange(k, "b c t h w -> (b h w) c t") v = rearrange(v, "b c t h w -> (b h w) c t") w_ = torch.bmm(q, k) w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values w_ = w_.permute(0, 2, 1) h_ = torch.bmm(v, w_) h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) h_ = self.proj_out(h_) return x + h_ def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(attn_type) if attn_type == "vanilla": return AttnBlock(in_channels) elif attn_type == "vanilla3D": return AttnBlock3D(in_channels) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels)