Huage001 commited on
Commit
202bdbf
1 Parent(s): db87f0c

Create src/linfusion/attention.py

Browse files
Files changed (1) hide show
  1. src/linfusion/attention.py +94 -0
src/linfusion/attention.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models.attention_processor import Attention
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ from fla.ops.linear_attn import chunk_linear_attn
7
+ FLA_ENABLE = True
8
+ except ImportError:
9
+ print("Warning: FLA is not installed, falling back to default attention.")
10
+ FLA_ENABLE = False
11
+
12
+
13
+ def get_none_linear_projection(query_dim, mid_dim=None):
14
+ # If mid_dim is None, then the mid_dim is the same as query_dim
15
+ # If mid_dim is -1, then no non-linear projection is used, and the identity is returned
16
+ return (
17
+ torch.nn.Sequential(
18
+ torch.nn.Linear(query_dim, mid_dim or query_dim),
19
+ torch.nn.LayerNorm(mid_dim or query_dim),
20
+ torch.nn.LeakyReLU(inplace=True),
21
+ torch.nn.Linear(mid_dim or query_dim, query_dim),
22
+ )
23
+ if mid_dim != -1
24
+ else torch.nn.Identity()
25
+ )
26
+
27
+
28
+ class GeneralizedLinearAttention(Attention):
29
+ def __init__(self, *args, projection_mid_dim=None, **kwargs):
30
+ """
31
+ Args:
32
+ query_dim: the dimension of the query.
33
+ out_dim: the dimension of the output.
34
+ dim_head: the dimension of the head. (dim_head * num_heads = query_dim)
35
+ projection_mid_dim: the dimension of the intermediate layer in the non-linear projection.
36
+ If `None`, then the dimension is the same as the query dimension.
37
+ If `-1`, then no non-linear projection is used, and the identity is returned.
38
+ """
39
+ super().__init__(*args, **kwargs)
40
+ self.add_non_linear_model(projection_mid_dim)
41
+
42
+ def from_attention_instance(attention_instance, projection_mid_dim=None):
43
+ assert isinstance(attention_instance, Attention)
44
+ new_instance = GeneralizedLinearAttention(128)
45
+ new_instance.__dict__ = attention_instance.__dict__
46
+ new_instance.add_non_linear_model(mid_dim = projection_mid_dim)
47
+ return new_instance
48
+
49
+ def add_non_linear_model(self, mid_dim=None, **kwargs):
50
+ query_dim = self.to_q.weight.shape[0]
51
+ self.to_q_ = get_none_linear_projection(query_dim, mid_dim, **kwargs)
52
+ self.to_k_ = get_none_linear_projection(query_dim, mid_dim, **kwargs)
53
+
54
+ def forward(
55
+ self,
56
+ hidden_states,
57
+ encoder_hidden_states=None,
58
+ attention_mask=None,
59
+ **kwargs,
60
+ ):
61
+ if encoder_hidden_states is None:
62
+ encoder_hidden_states = hidden_states
63
+
64
+ _, sequence_length, _ = hidden_states.shape
65
+
66
+ query = self.to_q(hidden_states + self.to_q_(hidden_states))
67
+ key = self.to_k(encoder_hidden_states + self.to_k_(encoder_hidden_states))
68
+ value = self.to_v(encoder_hidden_states)
69
+
70
+ query = self.head_to_batch_dim(query)
71
+ key = self.head_to_batch_dim(key)
72
+ value = self.head_to_batch_dim(value)
73
+
74
+ query = F.elu(query) + 1.0
75
+ key = F.elu(key) + 1.0
76
+
77
+ if FLA_ENABLE and False:
78
+ # TODO: there is a bug in the FLA implementation
79
+ raise NotImplementedError
80
+ else:
81
+ z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-4
82
+ kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
83
+ value * (sequence_length**-0.5)
84
+ )
85
+ hidden_states = query @ kv / z
86
+
87
+ hidden_states = self.batch_to_head_dim(hidden_states)
88
+
89
+ # linear proj
90
+ hidden_states = self.to_out[0](hidden_states)
91
+ # dropout
92
+ hidden_states = self.to_out[1](hidden_states)
93
+
94
+ return hidden_states