LinB203
m
a220803
raw
history blame
5.81 kB
import torch.nn as nn
from .normalize import Normalize
from .conv import CausalConv3d
import torch
import numpy as np
from einops import rearrange
from .block import Block
from .ops import video_to_image
class LinearAttention(Block):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock3D(Block):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, t, h, w = q.shape
q = q.reshape(b * t, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b * t, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b * t, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, t, h, w)
h_ = self.proj_out(h_)
return x + h_
class AttnBlock(Block):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
@video_to_image
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class TemporalAttnBlock(Block):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv3d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, t, h, w = q.shape
q = rearrange(q, "b c t h w -> (b h w) t c")
k = rearrange(k, "b c t h w -> (b h w) c t")
v = rearrange(v, "b c t h w -> (b h w) c t")
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
print(attn_type)
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "vanilla3D":
return AttnBlock3D(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)