File size: 3,958 Bytes
34097e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
'''
This file is modified from the TiledVAE attn.py, so that the StableSR can save much VRAM.
'''
import math
import torch
from modules import shared, sd_hijack
from modules.sd_hijack_optimizations import get_available_vram, get_xformers_flash_attention_op, sub_quad_attention
try:
import xformers
import xformers.ops
except ImportError:
pass
def get_attn_func():
method = sd_hijack.model_hijack.optimization_method
if method is None:
return attn_forward
method = method.lower()
# The method should be one of the following:
# ['none', 'sdp-no-mem', 'sdp', 'xformers', ''sub-quadratic', 'v1', 'invokeai', 'doggettx']
if method not in ['none', 'sdp-no-mem', 'sdp', 'xformers', 'sub-quadratic', 'v1', 'invokeai', 'doggettx']:
print(f"[StableSR] Warning: Unknown attention optimization method {method}. Please try to update the extension.")
return attn_forward
if method == 'none':
return attn_forward
elif method == 'xformers':
return xformers_attnblock_forward
elif method == 'sdp-no-mem':
return sdp_no_mem_attnblock_forward
elif method == 'sdp':
return sdp_attnblock_forward
elif method == 'sub-quadratic':
return sub_quad_attnblock_forward
elif method == 'doggettx':
return cross_attention_attnblock_forward
return attn_forward
# The following functions are all copied from modules.sd_hijack_optimizations
# However, the residual & normalization are removed and computed separately.
def attn_forward(q, k, v):
# compute attention
# q: b,hw,c
k = k.permute(0, 2, 1) # b,c,hw
c = k.shape[1]
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.permute(0, 2, 1) # b,c,hw
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = torch.bmm(v, w_)
return h_.permute(0, 2, 1)
def xformers_attnblock_forward(q, k, v):
return xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
def cross_attention_attnblock_forward(q, k, v):
# compute attention
k = k.permute(0, 2, 1)# b,c,hw
v = v.permute(0, 2, 1)# b,c,hw
c = k.shape[1]
h_ = torch.zeros_like(k, device=q.device)
mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del w4
return h_.permute(0, 2, 1)
def sdp_no_mem_attnblock_forward(q, k, v):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(q, k, v)
def sdp_attnblock_forward(q, k, v):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
def sub_quad_attnblock_forward(q, k, v):
return sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=True) |