Fabrice-TIERCELIN commited on
Commit
49f01ea
·
verified ·
1 Parent(s): c6ac070

Upload test_attention.py

Browse files
Files changed (1) hide show
  1. tests/test_attention.py +180 -180
tests/test_attention.py CHANGED
@@ -1,180 +1,180 @@
1
- import torch
2
- import sys
3
- import os
4
- current_dir = os.path.dirname(os.path.abspath(__file__))
5
- project_root = os.path.dirname(current_dir)
6
- sys.path.append(project_root)
7
-
8
- from hyvideo.modules.attenion import attention
9
- from xfuser.core.long_ctx_attention import xFuserLongContextAttention
10
- from xfuser.core.distributed import (
11
- init_distributed_environment,
12
- initialize_model_parallel,
13
- # initialize_runtime_state,
14
- )
15
-
16
- def init_dist(backend="nccl"):
17
- local_rank = int(os.environ["LOCAL_RANK"])
18
- rank = int(os.environ["RANK"])
19
- world_size = int(os.environ["WORLD_SIZE"])
20
-
21
- print(
22
- f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
23
- )
24
-
25
- torch.cuda.set_device(local_rank)
26
- init_distributed_environment(rank=rank, world_size=world_size)
27
- # dist.init_process_group(backend=backend)
28
- # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)
29
-
30
- if world_size > 1:
31
- ring_degree = world_size // 2
32
- ulysses_degree = 2
33
- else:
34
- ring_degree = 1
35
- ulysses_degree = 1
36
- initialize_model_parallel(
37
- sequence_parallel_degree=world_size,
38
- ring_degree=ring_degree,
39
- ulysses_degree=ulysses_degree,
40
- )
41
-
42
- return rank, world_size
43
-
44
- def test_mm_double_stream_block_attention(rank, world_size):
45
- device = torch.device(f"cuda:{rank}")
46
- dtype = torch.bfloat16
47
- batch_size = 1
48
- seq_len_img = 118800
49
- seq_len_txt = 256
50
- heads_num = 24
51
- head_dim = 128
52
-
53
- img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
54
- img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
55
- img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
56
- txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
57
- txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
58
- txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
59
-
60
- with torch.no_grad():
61
- torch.distributed.broadcast(img_q, src=0)
62
- torch.distributed.broadcast(img_k, src=0)
63
- torch.distributed.broadcast(img_v, src=0)
64
- torch.distributed.broadcast(txt_q, src=0)
65
- torch.distributed.broadcast(txt_k, src=0)
66
- torch.distributed.broadcast(txt_v, src=0)
67
- q = torch.cat((img_q, txt_q), dim=1)
68
- k = torch.cat((img_k, txt_k), dim=1)
69
- v = torch.cat((img_v, txt_v), dim=1)
70
-
71
-
72
- cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
73
- cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
74
- max_seqlen_q = 119056
75
- max_seqlen_kv = 119056
76
- mode = "torch" # "torch", "vanilla", "flash"
77
-
78
- original_output = attention(
79
- q,
80
- k,
81
- v,
82
- mode=mode,
83
- cu_seqlens_q=cu_seqlens_q,
84
- cu_seqlens_kv=cu_seqlens_kv,
85
- max_seqlen_q=max_seqlen_q,
86
- max_seqlen_kv=max_seqlen_kv,
87
- batch_size=batch_size
88
- )
89
-
90
- hybrid_seq_parallel_attn = xFuserLongContextAttention()
91
- hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
92
- None,
93
- img_q,
94
- img_k,
95
- img_v,
96
- dropout_p=0.0,
97
- causal=False,
98
- joint_tensor_query=txt_q,
99
- joint_tensor_key=txt_k,
100
- joint_tensor_value=txt_v,
101
- joint_strategy="rear",
102
- )
103
-
104
- b, s, a, d = hybrid_seq_parallel_output.shape
105
- hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
106
-
107
- assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
108
-
109
- torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
110
- print("test_mm_double_stream_block_attention Passed")
111
-
112
- def test_mm_single_stream_block_attention(rank, world_size):
113
- device = torch.device(f"cuda:{rank}")
114
- dtype = torch.bfloat16
115
- txt_len = 256
116
- batch_size = 1
117
- seq_len_img = 118800
118
- seq_len_txt = 256
119
- heads_num = 24
120
- head_dim = 128
121
-
122
- with torch.no_grad():
123
- img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
124
- img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
125
- txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
126
- txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
127
- v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
128
-
129
- torch.distributed.broadcast(img_q, src=0)
130
- torch.distributed.broadcast(img_k, src=0)
131
- torch.distributed.broadcast(txt_q, src=0)
132
- torch.distributed.broadcast(txt_k, src=0)
133
- torch.distributed.broadcast(v, src=0)
134
-
135
- q = torch.cat((img_q, txt_q), dim=1)
136
- k = torch.cat((img_k, txt_k), dim=1)
137
-
138
- cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
139
- cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
140
- max_seqlen_q = 119056
141
- max_seqlen_kv = 119056
142
- mode = "torch" # "torch", "vanilla", "flash"
143
-
144
- original_output = attention(
145
- q,
146
- k,
147
- v,
148
- mode=mode,
149
- cu_seqlens_q=cu_seqlens_q,
150
- cu_seqlens_kv=cu_seqlens_kv,
151
- max_seqlen_q=max_seqlen_q,
152
- max_seqlen_kv=max_seqlen_kv,
153
- batch_size=batch_size
154
- )
155
-
156
- hybrid_seq_parallel_attn = xFuserLongContextAttention()
157
- hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
158
- None,
159
- q[:, :-txt_len, :, :],
160
- k[:, :-txt_len, :, :],
161
- v[:, :-txt_len, :, :],
162
- dropout_p=0.0,
163
- causal=False,
164
- joint_tensor_query=q[:, -txt_len:, :, :],
165
- joint_tensor_key=k[:, -txt_len:, :, :],
166
- joint_tensor_value=v[:, -txt_len:, :, :],
167
- joint_strategy="rear",
168
- )
169
- b, s, a, d = hybrid_seq_parallel_output.shape
170
- hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
171
-
172
- assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
173
-
174
- torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
175
- print("test_mm_single_stream_block_attention Passed")
176
-
177
- if __name__ == "__main__":
178
- rank, world_size = init_dist()
179
- test_mm_double_stream_block_attention(rank, world_size)
180
- test_mm_single_stream_block_attention(rank, world_size)
 
1
+ import torch
2
+ import sys
3
+ import os
4
+ current_dir = os.path.dirname(os.path.abspath(__file__))
5
+ project_root = os.path.dirname(current_dir)
6
+ sys.path.append(project_root)
7
+
8
+ from hyvideo.modules.attenion import attention
9
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
10
+ from xfuser.core.distributed import (
11
+ init_distributed_environment,
12
+ initialize_model_parallel,
13
+ # initialize_runtime_state,
14
+ )
15
+
16
+ def init_dist(backend="nccl"):
17
+ local_rank = int(os.environ["LOCAL_RANK"])
18
+ rank = int(os.environ["RANK"])
19
+ world_size = int(os.environ["WORLD_SIZE"])
20
+
21
+ print(
22
+ f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
23
+ )
24
+
25
+ torch.cuda.set_device(local_rank)
26
+ init_distributed_environment(rank=rank, world_size=world_size)
27
+ # dist.init_process_group(backend=backend)
28
+ # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)
29
+
30
+ if world_size > 1:
31
+ ring_degree = world_size // 2
32
+ ulysses_degree = 2
33
+ else:
34
+ ring_degree = 1
35
+ ulysses_degree = 1
36
+ initialize_model_parallel(
37
+ sequence_parallel_degree=world_size,
38
+ ring_degree=ring_degree,
39
+ ulysses_degree=ulysses_degree,
40
+ )
41
+
42
+ return rank, world_size
43
+
44
+ def test_mm_double_stream_block_attention(rank, world_size):
45
+ device = torch.device(f"cuda:{rank}")
46
+ dtype = torch.bfloat16
47
+ batch_size = 1
48
+ seq_len_img = 118800
49
+ seq_len_txt = 256
50
+ heads_num = 24
51
+ head_dim = 128
52
+
53
+ img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
54
+ img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
55
+ img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
56
+ txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
57
+ txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
58
+ txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
59
+
60
+ with torch.no_grad():
61
+ torch.distributed.broadcast(img_q, src=0)
62
+ torch.distributed.broadcast(img_k, src=0)
63
+ torch.distributed.broadcast(img_v, src=0)
64
+ torch.distributed.broadcast(txt_q, src=0)
65
+ torch.distributed.broadcast(txt_k, src=0)
66
+ torch.distributed.broadcast(txt_v, src=0)
67
+ q = torch.cat((img_q, txt_q), dim=1)
68
+ k = torch.cat((img_k, txt_k), dim=1)
69
+ v = torch.cat((img_v, txt_v), dim=1)
70
+
71
+
72
+ cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
73
+ cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
74
+ max_seqlen_q = 119056
75
+ max_seqlen_kv = 119056
76
+ mode = "torch" # "torch", "vanilla", "flash"
77
+
78
+ original_output = attention(
79
+ q,
80
+ k,
81
+ v,
82
+ mode=mode,
83
+ cu_seqlens_q=cu_seqlens_q,
84
+ cu_seqlens_kv=cu_seqlens_kv,
85
+ max_seqlen_q=max_seqlen_q,
86
+ max_seqlen_kv=max_seqlen_kv,
87
+ batch_size=batch_size
88
+ )
89
+
90
+ hybrid_seq_parallel_attn = xFuserLongContextAttention()
91
+ hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
92
+ None,
93
+ img_q,
94
+ img_k,
95
+ img_v,
96
+ dropout_p=0.0,
97
+ causal=False,
98
+ joint_tensor_query=txt_q,
99
+ joint_tensor_key=txt_k,
100
+ joint_tensor_value=txt_v,
101
+ joint_strategy="rear",
102
+ )
103
+
104
+ b, s, a, d = hybrid_seq_parallel_output.shape
105
+ hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
106
+
107
+ assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
108
+
109
+ torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
110
+ print("test_mm_double_stream_block_attention Passed")
111
+
112
+ def test_mm_single_stream_block_attention(rank, world_size):
113
+ device = torch.device(f"cuda:{rank}")
114
+ dtype = torch.bfloat16
115
+ txt_len = 256
116
+ batch_size = 1
117
+ seq_len_img = 118800
118
+ seq_len_txt = 256
119
+ heads_num = 24
120
+ head_dim = 128
121
+
122
+ with torch.no_grad():
123
+ img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
124
+ img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
125
+ txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
126
+ txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
127
+ v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
128
+
129
+ torch.distributed.broadcast(img_q, src=0)
130
+ torch.distributed.broadcast(img_k, src=0)
131
+ torch.distributed.broadcast(txt_q, src=0)
132
+ torch.distributed.broadcast(txt_k, src=0)
133
+ torch.distributed.broadcast(v, src=0)
134
+
135
+ q = torch.cat((img_q, txt_q), dim=1)
136
+ k = torch.cat((img_k, txt_k), dim=1)
137
+
138
+ cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
139
+ cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
140
+ max_seqlen_q = 119056
141
+ max_seqlen_kv = 119056
142
+ mode = "torch" # "torch", "vanilla", "flash"
143
+
144
+ original_output = attention(
145
+ q,
146
+ k,
147
+ v,
148
+ mode=mode,
149
+ cu_seqlens_q=cu_seqlens_q,
150
+ cu_seqlens_kv=cu_seqlens_kv,
151
+ max_seqlen_q=max_seqlen_q,
152
+ max_seqlen_kv=max_seqlen_kv,
153
+ batch_size=batch_size
154
+ )
155
+
156
+ hybrid_seq_parallel_attn = xFuserLongContextAttention()
157
+ hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
158
+ None,
159
+ q[:, :-txt_len, :, :],
160
+ k[:, :-txt_len, :, :],
161
+ v[:, :-txt_len, :, :],
162
+ dropout_p=0.0,
163
+ causal=False,
164
+ joint_tensor_query=q[:, -txt_len:, :, :],
165
+ joint_tensor_key=k[:, -txt_len:, :, :],
166
+ joint_tensor_value=v[:, -txt_len:, :, :],
167
+ joint_strategy="rear",
168
+ )
169
+ b, s, a, d = hybrid_seq_parallel_output.shape
170
+ hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
171
+
172
+ assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
173
+
174
+ torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
175
+ print("test_mm_single_stream_block_attention Passed")
176
+
177
+ if __name__ == "__main__":
178
+ rank, world_size = init_dist()
179
+ test_mm_double_stream_block_attention(rank, world_size)
180
+ test_mm_single_stream_block_attention(rank, world_size)