Spaces:
Paused
Paused
# CrossAttn precision handling | |
import os | |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") | |
import torch | |
from torch import nn | |
from torch import einsum | |
from einops import rearrange, repeat | |
import torch | |
from torch import nn | |
from typing import Optional, Any | |
from ...patches import router | |
class CrossAttention(nn.Module): | |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = context_dim or query_dim | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) | |
) | |
def forward(self, x, context=None, mask=None): | |
h = self.heads | |
q = self.to_q(x) | |
context = x if context is None else context | |
k = self.to_k(context) | |
v = self.to_v(context) | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) | |
# force cast to fp32 to avoid overflowing | |
if _ATTN_PRECISION == "fp32": | |
with torch.autocast(enabled=False, device_type="cuda"): | |
q, k = q.float(), k.float() | |
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale | |
else: | |
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale | |
del q, k | |
if mask is not None: | |
mask = rearrange(mask, "b ... -> b (...)") | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = repeat(mask, "b j -> (b h) () j", h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
# attention, what we cannot get enough of | |
sim = sim.softmax(dim=-1) | |
out = einsum("b i j, b j d -> b i d", sim, v) | |
out = rearrange(out, "(b h) n d -> b n (h d)", h=h) | |
return self.to_out(out) | |
class PatchedCrossAttention(nn.Module): | |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 | |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = context_dim or query_dim | |
self.heads = heads | |
self.dim_head = dim_head | |
self.scale = dim_head**-0.5 | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) | |
self.attention_op: Optional[Any] = None | |
def forward(self, x, context=None, mask=None): | |
return router.attention_forward(self, x, context, mask) |