Xidong commited on
Commit
21bc090
·
verified ·
1 Parent(s): 673755f

Upload ./cache_autogptq_cuda_kernel_256.cu with huggingface_hub

Browse files
Files changed (1) hide show
  1. cache_autogptq_cuda_kernel_256.cu +1708 -0
cache_autogptq_cuda_kernel_256.cu ADDED
@@ -0,0 +1,1708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define _CRT_SECURE_NO_WARNINGS
2
+ #include <torch/all.h>
3
+ #include <torch/python.h>
4
+ #include <cuda.h>
5
+ #include <cuda_runtime.h>
6
+ #include <cuda_fp16.h>
7
+ #include <stdint.h>
8
+
9
+ #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
10
+ // adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu
11
+ __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
12
+ unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
13
+ unsigned int old = *address_as_ui;
14
+ unsigned int assumed;
15
+
16
+ do {
17
+ assumed = old;
18
+ unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
19
+ hsum += val;
20
+ old = reinterpret_cast<size_t>(address) & 2
21
+ ? (old & 0xffff) | (hsum << 16)
22
+ : (old & 0xffff0000) | hsum;
23
+ old = atomicCAS(address_as_ui, assumed, old);
24
+
25
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
26
+ } while (assumed != old);
27
+ }
28
+ __device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) {
29
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
30
+ unsigned int old = *address_as_ui;
31
+ unsigned int assumed;
32
+
33
+ do {
34
+ assumed = old;
35
+ __half_raw hsum;
36
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
37
+ half tmpres = __hadd(hsum, val);
38
+ hsum = __half_raw(tmpres);
39
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
40
+ old = atomicCAS(address_as_ui, assumed, old);
41
+ } while (assumed != old);
42
+ }
43
+ #endif
44
+
45
+ template <typename scalar_t>
46
+ __global__ void VecQuant8MatMulKernel(
47
+ const scalar_t* __restrict__ vec,
48
+ const int* __restrict__ mat,
49
+ scalar_t* __restrict__ mul,
50
+ const scalar_t* __restrict__ scales,
51
+ const int* __restrict__ zeros,
52
+ const int* __restrict__ g_idx,
53
+ int batch,
54
+ int vec_height,
55
+ int height,
56
+ int width,
57
+ int zero_width
58
+ );
59
+
60
+ template <typename scalar_t>
61
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel(
62
+ const scalar_t* __restrict__ vec,
63
+ const int* __restrict__ mat,
64
+ scalar_t* __restrict__ mul,
65
+ const scalar_t* __restrict__ scales,
66
+ const int* __restrict__ zeros,
67
+ int batch,
68
+ int heads,
69
+ int vec_row,
70
+ int height,
71
+ int width
72
+ );
73
+
74
+ template <typename scalar_t>
75
+ __global__ void VecQuant4BatchMatMulColumnCompressionKernel(
76
+ const scalar_t* __restrict__ vec,
77
+ const int* __restrict__ mat,
78
+ scalar_t* __restrict__ mul,
79
+ const scalar_t* __restrict__ scales,
80
+ const int* __restrict__ zeros,
81
+ int batch,
82
+ int heads,
83
+ int vec_row,
84
+ int height,
85
+ int width
86
+ );
87
+
88
+ template <typename scalar_t>
89
+ __global__ void VecQuant8BatchMatMulKernel(
90
+ const scalar_t* __restrict__ vec,
91
+ const int* __restrict__ mat,
92
+ scalar_t* __restrict__ mul,
93
+ const scalar_t* __restrict__ scales,
94
+ const int* __restrict__ zeros,
95
+ int batch,
96
+ int heads,
97
+ int vec_row,
98
+ int vec_height,
99
+ int height,
100
+ int width,
101
+ int zero_width
102
+ );
103
+
104
+ template <typename scalar_t>
105
+ __global__ void VecQuant4BatchMatMulKernel(
106
+ const scalar_t* __restrict__ vec,
107
+ const int* __restrict__ mat,
108
+ scalar_t* __restrict__ mul,
109
+ const scalar_t* __restrict__ scales,
110
+ const int* __restrict__ zeros,
111
+ int batch,
112
+ int heads,
113
+ int vec_row,
114
+ int vec_height,
115
+ int height,
116
+ int width,
117
+ int zero_width
118
+ );
119
+
120
+
121
+
122
+ template <typename scalar_t>
123
+ __global__ void VecQuant8BatchMatMulKernel_old(
124
+ const scalar_t* __restrict__ vec,
125
+ const uint8_t* __restrict__ mat,
126
+ scalar_t* __restrict__ mul,
127
+ const scalar_t* __restrict__ scales,
128
+ const scalar_t* __restrict__ zeros,
129
+ int batch,
130
+ int heads,
131
+ int vec_row,
132
+ int vec_height,
133
+ int height,
134
+ int width,
135
+ int zero_width
136
+ );
137
+
138
+ __global__ void VecQuant8BatchMatMulKernel_faster(
139
+ const half* __restrict__ vec,
140
+ const uint8_t* __restrict__ mat,
141
+ half* __restrict__ mul,
142
+ const half* __restrict__ scales,
143
+ const half* __restrict__ zeros,
144
+ int batch,
145
+ int heads,
146
+ int vec_row,
147
+ int vec_height,
148
+ int height,
149
+ int width,
150
+ int zero_width
151
+ );
152
+
153
+
154
+
155
+ __global__ void VecQuant8BatchMatMulKernel_faster_old(
156
+ const half* __restrict__ vec,
157
+ const uint8_t* __restrict__ mat,
158
+ half* __restrict__ mul,
159
+ const half* __restrict__ scales,
160
+ const half* __restrict__ zeros,
161
+ int batch,
162
+ int heads,
163
+ int vec_row,
164
+ int vec_height,
165
+ int height,
166
+ int width
167
+ );
168
+
169
+
170
+ template <typename scalar_t>
171
+ __global__ void VecQuant4BatchMatMulKernel_old(
172
+ const scalar_t* __restrict__ vec,
173
+ const uint8_t* __restrict__ mat,
174
+ scalar_t* __restrict__ mul,
175
+ const scalar_t* __restrict__ scales,
176
+ const scalar_t* __restrict__ zeros,
177
+ int batch,
178
+ int heads,
179
+ int vec_row,
180
+ int vec_height,
181
+ int height,
182
+ int width,
183
+ int zero_width
184
+ );
185
+
186
+
187
+ template <typename scalar_t>
188
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_old(
189
+ const scalar_t* __restrict__ vec,
190
+ const uint8_t* __restrict__ mat,
191
+ scalar_t* __restrict__ mul,
192
+ const scalar_t* __restrict__ scales,
193
+ const scalar_t* __restrict__ zeros,
194
+ int batch,
195
+ int heads,
196
+ int vec_row,
197
+ int height,
198
+ int width
199
+ );
200
+
201
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
202
+ const half* __restrict__ vec,
203
+ const uint8_t* __restrict__ mat,
204
+ half* __restrict__ mul,
205
+ const half* __restrict__ scales,
206
+ const half* __restrict__ zeros,
207
+ int batch,
208
+ int heads,
209
+ int vec_row,
210
+ int height,
211
+ int width
212
+ );
213
+
214
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old(
215
+ const half* __restrict__ vec,
216
+ const uint8_t* __restrict__ mat,
217
+ half* __restrict__ mul,
218
+ const half* __restrict__ scales,
219
+ const half* __restrict__ zeros,
220
+ int batch,
221
+ int heads,
222
+ int vec_row,
223
+ int height,
224
+ int width
225
+ );
226
+
227
+
228
+ template <typename scalar_t>
229
+ __global__ void VecQuant4BatchMatMulColumnCompressionKernel_old(
230
+ const scalar_t* __restrict__ vec,
231
+ const uint8_t* __restrict__ mat,
232
+ scalar_t* __restrict__ mul,
233
+ const scalar_t* __restrict__ scales,
234
+ const scalar_t* __restrict__ zeros,
235
+ int batch,
236
+ int heads,
237
+ int vec_row,
238
+ int height,
239
+ int width
240
+ );
241
+
242
+
243
+ __global__ void VecQuant8BatchMatMulKernel_faster(
244
+ const half* __restrict__ vec,
245
+ const uint8_t* __restrict__ mat,
246
+ half* __restrict__ mul,
247
+ const half* __restrict__ scales,
248
+ const half* __restrict__ zeros,
249
+ int batch,
250
+ int heads,
251
+ int vec_row,
252
+ int vec_height,
253
+ int height,
254
+ int width
255
+ );
256
+
257
+
258
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
259
+ const half* __restrict__ vec,
260
+ const uint8_t* __restrict__ mat,
261
+ half* __restrict__ mul,
262
+ const half* __restrict__ scales,
263
+ const half* __restrict__ zeros,
264
+ int batch,
265
+ int heads,
266
+ int vec_row,
267
+ int height,
268
+ int width
269
+ );
270
+
271
+ const int BLOCKWIDTH = 128;
272
+ const int BLOCKHEIGHT8 = 32;
273
+ const int BLOCKHEIGHT4 = 16;
274
+ const int BLOCKHEIGHT_OLD4 = 128;
275
+ //const int BLOCKHEIGHT_OLD8 = 128;
276
+
277
+ __device__ inline unsigned int as_unsigned(int i) {
278
+ return *reinterpret_cast<unsigned int*>(&i);
279
+ }
280
+
281
+ __device__ inline int as_int(int i) {
282
+ return *reinterpret_cast<int*>(&i);
283
+ }
284
+
285
+ void vecquant8matmul_batched_column_compression_cuda(
286
+ torch::Tensor vec,
287
+ torch::Tensor mat,
288
+ torch::Tensor mul,
289
+ torch::Tensor scales,
290
+ torch::Tensor zeros
291
+ ) {
292
+ int batch = vec.size(0);
293
+ int heads = vec.size(1);
294
+ int vec_row = vec.size(2);
295
+ int height = vec.size(3);
296
+ int width = mat.size(3) * 4;
297
+
298
+ dim3 blocks(
299
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
300
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
301
+ );
302
+ dim3 threads(BLOCKWIDTH);
303
+
304
+ AT_DISPATCH_FLOATING_TYPES(
305
+ vec.type(), "vecquant8matmul_batched_cuda", ([&] {
306
+ VecQuant8BatchMatMulColumnCompressionKernel<<<blocks, threads>>>(
307
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
308
+ scales.data<scalar_t>(), zeros.data<int>(),
309
+ batch, heads, vec_row, height, width
310
+ );
311
+ })
312
+ );
313
+
314
+ }
315
+
316
+ template <typename scalar_t>
317
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel(
318
+ const scalar_t* __restrict__ vec,
319
+ const int* __restrict__ mat,
320
+ scalar_t* __restrict__ mul,
321
+ const scalar_t* __restrict__ scales,
322
+ const int* __restrict__ zeros,
323
+ int batch,
324
+ int heads,
325
+ int vec_row,
326
+ int height,
327
+ int width
328
+ ) {
329
+ int weight_total = batch * heads * height * width / 4;
330
+ int input_total = batch * heads * vec_row * height;
331
+ int out_total = batch * heads * vec_row * width;
332
+ int tid = threadIdx.x;
333
+ // h is index of height with step being BLOCKWIDTH
334
+ int h = BLOCKWIDTH * blockIdx.x;
335
+ // w is index of width with step being 1
336
+ int w = BLOCKWIDTH * blockIdx.y + tid;
337
+ if (w >= width && tid >= height) {
338
+ return;
339
+ }
340
+
341
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
342
+ int k;
343
+ scalar_t w_tmp;
344
+
345
+ float weight[BLOCKWIDTH];
346
+
347
+ for (int b = 0; b < batch; ++b){
348
+ for (int head = 0; head < heads; ++head){
349
+ int batch_shift = b * heads + head;
350
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
351
+ int i_w = (w / 4);
352
+ int w_bit = (w % 4) * 8;
353
+
354
+ int w_index = (batch_shift * height + h + k) * width / 4 + i_w;
355
+ if (w_index >= weight_total || w >= width) {
356
+ weight[k] = 0;
357
+ } else {
358
+ scalar_t scale = scales[batch_shift * height + h + k];
359
+ scalar_t zero = zeros[batch_shift * height + h + k];
360
+ w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xFF);
361
+ weight[k] = scale * (w_tmp - zero);
362
+ }
363
+ }
364
+
365
+ scalar_t res;
366
+ for (int vr = 0; vr < vec_row; ++vr){
367
+ res = 0;
368
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
369
+ if (vec_index < input_total) {
370
+ blockvec[tid] = vec[vec_index];
371
+ } else {
372
+ blockvec[tid] = 0;
373
+ }
374
+
375
+ __syncthreads();
376
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
377
+ // res is the dot product of BLOCKWIDTH elements (part of width)
378
+ res += weight[k] * blockvec[k];
379
+ }
380
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
381
+ int out_index = (batch_shift * vec_row + vr) * width + w;
382
+ if (out_index < out_total) {
383
+ atomicAdd(&mul[out_index], res);
384
+ }
385
+ __syncthreads();
386
+ }
387
+ }
388
+ }
389
+ }
390
+
391
+ void vecquant8matmul_batched_cuda(
392
+ torch::Tensor vec,
393
+ torch::Tensor mat,
394
+ torch::Tensor mul,
395
+ torch::Tensor scales,
396
+ torch::Tensor zeros
397
+ ) {
398
+ int batch = vec.size(0);
399
+ int heads = vec.size(1);
400
+ int vec_row = vec.size(2);
401
+ int vec_height = vec.size(3);
402
+ int height = mat.size(2);
403
+ int width = mat.size(3);
404
+ int zero_width = zeros.size(2);
405
+
406
+ dim3 blocks(
407
+ (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
408
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
409
+ );
410
+ dim3 threads(BLOCKWIDTH);
411
+
412
+ AT_DISPATCH_FLOATING_TYPES(
413
+ vec.type(), "vecquant8matmul_batched_cuda", ([&] {
414
+ VecQuant8BatchMatMulKernel<<<blocks, threads>>>(
415
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
416
+ scales.data<scalar_t>(), zeros.data<int>(),
417
+ batch, heads, vec_row, vec_height, height, width, zero_width
418
+ );
419
+ })
420
+ );
421
+
422
+ }
423
+
424
+ template <typename scalar_t>
425
+ __global__ void VecQuant8BatchMatMulKernel(
426
+ const scalar_t* __restrict__ vec,
427
+ const int* __restrict__ mat,
428
+ scalar_t* __restrict__ mul,
429
+ const scalar_t* __restrict__ scales,
430
+ const int* __restrict__ zeros,
431
+ int batch,
432
+ int heads,
433
+ int vec_row,
434
+ int vec_height,
435
+ int height,
436
+ int width,
437
+ int zero_width
438
+ ) {
439
+ int weight_total = batch * heads * height * width;
440
+ int input_total = batch * heads * vec_row * vec_height;
441
+ int out_total = batch * heads * vec_row * width;
442
+ int tid = threadIdx.x;
443
+ // h is index of height with step being BLOCKHEIGHT8
444
+ int h = BLOCKHEIGHT8 * blockIdx.x;
445
+ // w is index of width with step being 1
446
+ int w = BLOCKWIDTH * blockIdx.y + tid;
447
+ if (w >= width && tid >= vec_height) {
448
+ return;
449
+ }
450
+
451
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
452
+ // i is index of mat of block first row
453
+ int i = width * h + w;
454
+ // if (i >= width * height) {
455
+ // return;
456
+ // }
457
+ int k;
458
+ scalar_t w_tmp;
459
+
460
+ int z_w = w / 4;
461
+ int z_mod = (w % 4) * 8;
462
+
463
+ float weight[BLOCKWIDTH];
464
+
465
+ for (int b = 0; b < batch; ++b){
466
+ for (int head = 0; head < heads; ++head){
467
+ int batch_shift = b * heads + head;
468
+ for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){
469
+ int k_w = (k / 4);
470
+ int k_bit = (k % 4) * 8;
471
+
472
+ int w_index = batch_shift * height * width + i + (k_w * width);
473
+ if (w_index >= weight_total || w >= width) {
474
+ weight[k] = 0;
475
+ } else {
476
+ scalar_t scale = scales[batch_shift * width + w];
477
+ scalar_t zero;
478
+ if (zero_width == width) {
479
+ zero = zeros[batch_shift * width + w];
480
+ } else {
481
+ zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
482
+ }
483
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xFF);
484
+ weight[k] = scale * (w_tmp - zero);
485
+ }
486
+ }
487
+
488
+ scalar_t res;
489
+ for (int vr = 0; vr < vec_row; ++vr){
490
+ res = 0;
491
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
492
+ if (vec_index < input_total) {
493
+ blockvec[tid] = vec[vec_index];
494
+ } else {
495
+ blockvec[tid] = 0;
496
+ }
497
+
498
+ __syncthreads();
499
+ for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){
500
+ // res is the dot product of BLOCKWIDTH elements (part of width)
501
+ res += weight[k] * blockvec[k];
502
+ }
503
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
504
+ int out_index = (batch_shift * vec_row + vr) * width + w;
505
+ if (out_index < out_total) {
506
+ atomicAdd(&mul[out_index], res);
507
+ }
508
+ __syncthreads();
509
+ }
510
+ }
511
+ }
512
+ }
513
+
514
+
515
+ void vecquant8matmul_cuda(
516
+ torch::Tensor vec,
517
+ torch::Tensor mat,
518
+ torch::Tensor mul,
519
+ torch::Tensor scales,
520
+ torch::Tensor zeros,
521
+ torch::Tensor g_idx
522
+ ) {
523
+ int batch = vec.size(0);
524
+ int vec_height = vec.size(1);
525
+ int height = mat.size(0);
526
+ int width = mat.size(1);
527
+ int zero_width = zeros.size(1);
528
+
529
+ dim3 blocks(
530
+ (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
531
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
532
+ );
533
+ dim3 threads(BLOCKWIDTH);
534
+
535
+ AT_DISPATCH_FLOATING_TYPES(
536
+ vec.type(), "vecquant8matmul_cuda", ([&] {
537
+ VecQuant8MatMulKernel<<<blocks, threads>>>(
538
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
539
+ scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
540
+ batch, vec_height, height, width, zero_width
541
+ );
542
+ })
543
+ );
544
+ }
545
+
546
+ template <typename scalar_t>
547
+ __global__ void VecQuant8MatMulKernel(
548
+ const scalar_t* __restrict__ vec,
549
+ const int* __restrict__ mat,
550
+ scalar_t* __restrict__ mul,
551
+ const scalar_t* __restrict__ scales,
552
+ const int* __restrict__ zeros,
553
+ const int* __restrict__ g_idx,
554
+ int batch,
555
+ int vec_height,
556
+ int height,
557
+ int width,
558
+ int zero_width
559
+ ) {
560
+ int h = BLOCKHEIGHT8 * blockIdx.x;
561
+ int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
562
+
563
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
564
+ int i = width * h + w;
565
+ int g_h = h * 4;
566
+ int k;
567
+ unsigned int g;
568
+ scalar_t w_tmp;
569
+
570
+ int z_w = w / 4;
571
+ int z_mod = (w % 4) * 8;
572
+
573
+ float weight[BLOCKWIDTH];
574
+
575
+ for (k = 0; k < BLOCKWIDTH; ++k){
576
+ int k_w = (k / 4);
577
+ int k_bit = (k % 4) * 8;
578
+
579
+ g = as_int(g_idx[g_h + k]);
580
+ scalar_t scale = scales[g * width + w];
581
+ scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
582
+
583
+ w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
584
+
585
+ weight[k] = scale * (w_tmp - zero);
586
+ }
587
+
588
+
589
+ scalar_t res;
590
+ for (int b = 0; b < batch; ++b){
591
+ res = 0;
592
+ blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
593
+ __syncthreads();
594
+ for (k = 0; k < BLOCKWIDTH; ++k){
595
+ res += weight[k] * blockvec[k];
596
+ }
597
+ atomicAdd(&mul[b * width + w], res);
598
+ __syncthreads();
599
+ }
600
+ }
601
+
602
+
603
+
604
+ void vecquant4matmul_batched_cuda(
605
+ torch::Tensor vec,
606
+ torch::Tensor mat,
607
+ torch::Tensor mul,
608
+ torch::Tensor scales,
609
+ torch::Tensor zeros
610
+ ) {
611
+ int batch = vec.size(0);
612
+ int heads = vec.size(1);
613
+ int vec_row = vec.size(2);
614
+ int vec_height = vec.size(3);
615
+ int height = mat.size(2);
616
+ int width = mat.size(3);
617
+ int zero_width = zeros.size(2);
618
+
619
+ dim3 blocks(
620
+ (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
621
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
622
+ );
623
+ dim3 threads(BLOCKWIDTH);
624
+
625
+ AT_DISPATCH_FLOATING_TYPES(
626
+ vec.type(), "vecquant4matmul_batched_cuda", ([&] {
627
+ VecQuant4BatchMatMulKernel<<<blocks, threads>>>(
628
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
629
+ scales.data<scalar_t>(), zeros.data<int>(),
630
+ batch, heads, vec_row, vec_height, height, width, zero_width
631
+ );
632
+ })
633
+ );
634
+
635
+ }
636
+
637
+ template <typename scalar_t>
638
+ __global__ void VecQuant4BatchMatMulKernel(
639
+ const scalar_t* __restrict__ vec,
640
+ const int* __restrict__ mat,
641
+ scalar_t* __restrict__ mul,
642
+ const scalar_t* __restrict__ scales,
643
+ const int* __restrict__ zeros,
644
+ int batch,
645
+ int heads,
646
+ int vec_row,
647
+ int vec_height,
648
+ int height,
649
+ int width,
650
+ int zero_width
651
+ ) {
652
+ int weight_total = batch * heads * height * width;
653
+ int input_total = batch * heads * vec_row * vec_height;
654
+ int out_total = batch * heads * vec_row * width;
655
+ int tid = threadIdx.x;
656
+ // h is index of height with step being BLOCKHEIGHT4
657
+ int h = BLOCKHEIGHT4 * blockIdx.x;
658
+ // w is index of width with step being 1
659
+ int w = BLOCKWIDTH * blockIdx.y + tid;
660
+ if (w >= width && tid >= vec_height) {
661
+ return;
662
+ }
663
+
664
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
665
+ // i is index of mat of block first row
666
+ int i = width * h + w;
667
+ int k;
668
+ scalar_t w_tmp;
669
+
670
+ int z_w = w / 8;
671
+ int z_mod = (w % 8) * 4;
672
+
673
+ float weight[BLOCKWIDTH];
674
+
675
+ for (int b = 0; b < batch; ++b){
676
+ for (int head = 0; head < heads; ++head){
677
+ int batch_shift = b * heads + head;
678
+ for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){
679
+ int k_w = (k / 8);
680
+ int k_bit = (k % 8) * 4;
681
+
682
+ int w_index = batch_shift * height * width + i + (k_w * width);
683
+ if (w_index >= weight_total || w >= width) {
684
+ weight[k] = 0;
685
+ } else {
686
+ scalar_t scale = scales[batch_shift * width + w];
687
+ scalar_t zero;
688
+ if (zero_width == width) {
689
+ zero = zeros[batch_shift * width + w];
690
+ } else {
691
+ zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xF));
692
+ }
693
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
694
+ weight[k] = scale * (w_tmp - zero);
695
+ }
696
+ }
697
+
698
+ scalar_t res;
699
+ for (int vr = 0; vr < vec_row; ++vr){
700
+ res = 0;
701
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
702
+ if (vec_index < input_total) {
703
+ blockvec[tid] = vec[vec_index];
704
+ } else {
705
+ blockvec[tid] = 0;
706
+ }
707
+
708
+ __syncthreads();
709
+ for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){
710
+ // res is the dot product of BLOCKWIDTH elements (part of width)
711
+ res += weight[k] * blockvec[k];
712
+ }
713
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
714
+ int out_index = (batch_shift * vec_row + vr) * width + w;
715
+ if (out_index < out_total) {
716
+ atomicAdd(&mul[out_index], res);
717
+ }
718
+ __syncthreads();
719
+ }
720
+ }
721
+ }
722
+ }
723
+
724
+
725
+
726
+ void vecquant4matmul_batched_column_compression_cuda(
727
+ torch::Tensor vec,
728
+ torch::Tensor mat,
729
+ torch::Tensor mul,
730
+ torch::Tensor scales,
731
+ torch::Tensor zeros
732
+ ) {
733
+ int batch = vec.size(0);
734
+ int heads = vec.size(1);
735
+ int vec_row = vec.size(2);
736
+ int height = vec.size(3);
737
+ int width = mat.size(3) * 8;
738
+
739
+ dim3 blocks(
740
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
741
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
742
+ );
743
+ dim3 threads(BLOCKWIDTH);
744
+
745
+ AT_DISPATCH_FLOATING_TYPES(
746
+ vec.type(), "vecquant4matmul_batched_cuda", ([&] {
747
+ VecQuant4BatchMatMulColumnCompressionKernel<<<blocks, threads>>>(
748
+ vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
749
+ scales.data<scalar_t>(), zeros.data<int>(),
750
+ batch, heads, vec_row, height, width
751
+ );
752
+ })
753
+ );
754
+
755
+ }
756
+
757
+ template <typename scalar_t>
758
+ __global__ void VecQuant4BatchMatMulColumnCompressionKernel(
759
+ const scalar_t* __restrict__ vec,
760
+ const int* __restrict__ mat,
761
+ scalar_t* __restrict__ mul,
762
+ const scalar_t* __restrict__ scales,
763
+ const int* __restrict__ zeros,
764
+ int batch,
765
+ int heads,
766
+ int vec_row,
767
+ int height,
768
+ int width
769
+ ) {
770
+ int weight_total = batch * heads * height * width / 8;
771
+ int input_total = batch * heads * vec_row * height;
772
+ int out_total = batch * heads * vec_row * width;
773
+ int tid = threadIdx.x;
774
+ // h is index of height with step being BLOCKWIDTH
775
+ int h = BLOCKWIDTH * blockIdx.x;
776
+ // w is index of width with step being 1
777
+ int w = BLOCKWIDTH * blockIdx.y + tid;
778
+ if (w >= width && tid >= height) {
779
+ return;
780
+ }
781
+
782
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
783
+ int k;
784
+ scalar_t w_tmp;
785
+
786
+ float weight[BLOCKWIDTH];
787
+
788
+ for (int b = 0; b < batch; ++b){
789
+ for (int head = 0; head < heads; ++head){
790
+ int batch_shift = b * heads + head;
791
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
792
+ int i_w = (w / 8);
793
+ int w_bit = (w % 8) * 4;
794
+
795
+ int w_index = (batch_shift * height + h + k) * width / 8 + i_w;
796
+ if (w_index >= weight_total || w >= width) {
797
+ weight[k] = 0;
798
+ } else {
799
+ scalar_t scale = scales[batch_shift * height + h + k];
800
+ scalar_t zero = zeros[batch_shift * height + h + k];
801
+ w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xF);
802
+ weight[k] = scale * (w_tmp - zero);
803
+ }
804
+ }
805
+
806
+ scalar_t res;
807
+ for (int vr = 0; vr < vec_row; ++vr){
808
+ res = 0;
809
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
810
+ if (vec_index < input_total) {
811
+ blockvec[tid] = vec[vec_index];
812
+ } else {
813
+ blockvec[tid] = 0;
814
+ }
815
+
816
+ __syncthreads();
817
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
818
+ // res is the dot product of BLOCKWIDTH elements (part of width)
819
+ res += weight[k] * blockvec[k];
820
+ }
821
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
822
+ int out_index = (batch_shift * vec_row + vr) * width + w;
823
+ if (out_index < out_total) {
824
+ atomicAdd(&mul[out_index], res);
825
+ }
826
+ __syncthreads();
827
+ }
828
+ }
829
+ }
830
+ }
831
+
832
+
833
+ void vecquant8matmul_batched_old_cuda(
834
+ torch::Tensor vec,
835
+ torch::Tensor mat,
836
+ torch::Tensor mul,
837
+ torch::Tensor scales,
838
+ torch::Tensor zeros
839
+ ) {
840
+ int batch = vec.size(0);
841
+ int heads = vec.size(1);
842
+ int vec_row = vec.size(2);
843
+ int vec_height = vec.size(3);
844
+ int height = mat.size(2);
845
+ int width = mat.size(3);
846
+ int zero_width = zeros.size(2);
847
+
848
+ dim3 blocks(
849
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
850
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
851
+ );
852
+ dim3 threads(BLOCKWIDTH);
853
+
854
+ AT_DISPATCH_FLOATING_TYPES(
855
+ vec.type(), "vecquant8matmul_batched_old_cuda", ([&] {
856
+ VecQuant8BatchMatMulKernel_old<<<blocks, threads>>>(
857
+ vec.data<scalar_t>(), mat.data<uint8_t>(), mul.data<scalar_t>(),
858
+ scales.data<scalar_t>(), zeros.data<scalar_t>(),
859
+ batch, heads, vec_row, vec_height, height, width, zero_width
860
+ );
861
+ })
862
+ );
863
+ }
864
+
865
+
866
+ template <typename scalar_t>
867
+ __global__ void VecQuant8BatchMatMulKernel_old(
868
+ const scalar_t* __restrict__ vec,
869
+ const uint8_t* __restrict__ mat,
870
+ scalar_t* __restrict__ mul,
871
+ const scalar_t* __restrict__ scales,
872
+ const scalar_t* __restrict__ zeros,
873
+ int batch,
874
+ int heads,
875
+ int vec_row,
876
+ int vec_height,
877
+ int height,
878
+ int width,
879
+ int zero_width
880
+ ) {
881
+ int weight_total = batch * heads * height * width;
882
+ int input_total = batch * heads * vec_row * vec_height;
883
+ int out_total = batch * heads * vec_row * width;
884
+ int tid = threadIdx.x;
885
+ // h is index of height with step being BLOCKHEIGHT8
886
+ int h = BLOCKWIDTH * blockIdx.x;
887
+ // w is index of width with step being 1
888
+ int w = BLOCKWIDTH * blockIdx.y + tid;
889
+ if (w >= width && tid >= vec_height) {
890
+ return;
891
+ }
892
+
893
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
894
+ // i is index of mat of block first row
895
+ int i = width * h + w;
896
+ int k;
897
+ scalar_t w_tmp;
898
+
899
+ float weight[BLOCKWIDTH];
900
+ for (int b = 0; b < batch; ++b){
901
+ for (int head = 0; head < heads; ++head){
902
+ int batch_shift = b * heads + head;
903
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
904
+ int k_w = k;
905
+ int w_index = batch_shift * height * width + i + (k_w * width);
906
+ if (w_index >= weight_total || w >= width) {
907
+ weight[k] = 0;
908
+ } else {
909
+ scalar_t scale = scales[batch_shift * width + w];
910
+ scalar_t zero = zeros[batch_shift * width + w];
911
+ w_tmp = as_unsigned(mat[w_index]);
912
+ weight[k] = scale * (w_tmp - zero);
913
+ }
914
+ }
915
+
916
+ scalar_t res;
917
+ for (int vr = 0; vr < vec_row; ++vr){
918
+ res = 0;
919
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
920
+ if (vec_index < input_total) {
921
+ blockvec[tid] = vec[vec_index];
922
+ } else {
923
+ blockvec[tid] = 0;
924
+ }
925
+
926
+ __syncthreads();
927
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
928
+ // res is the dot product of BLOCKWIDTH elements (part of width)
929
+ res += weight[k] * blockvec[k];
930
+ }
931
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
932
+ int out_index = (batch_shift * vec_row + vr) * width + w;
933
+ if (out_index < out_total) {
934
+ atomicAdd(&mul[out_index], res);
935
+ }
936
+ __syncthreads();
937
+ }
938
+ }
939
+ }
940
+ }
941
+
942
+
943
+
944
+ void vecquant8matmul_batched_faster_cuda(
945
+ torch::Tensor vec,
946
+ torch::Tensor mat,
947
+ torch::Tensor mul,
948
+ torch::Tensor scales,
949
+ torch::Tensor zeros
950
+ ) {
951
+ int batch = vec.size(0);
952
+ int heads = vec.size(1);
953
+ int vec_row = vec.size(2);
954
+ int vec_height = vec.size(3);
955
+ int height = mat.size(2);
956
+ int width = mat.size(3);
957
+ int zero_width = zeros.size(2);
958
+
959
+ dim3 blocks(
960
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
961
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
962
+ );
963
+ dim3 threads(BLOCKWIDTH);
964
+
965
+ VecQuant8BatchMatMulKernel_faster<<<blocks, threads>>>(
966
+ (half*) vec.data_ptr(),
967
+ (uint8_t*) mat.data_ptr(),
968
+ (half*) mul.data_ptr(),
969
+ (half*) scales.data_ptr(),
970
+ (half*) zeros.data_ptr(),
971
+ batch, heads, vec_row, vec_height, height, width, zero_width
972
+ );
973
+ }
974
+
975
+
976
+
977
+ __global__ void VecQuant8BatchMatMulKernel_faster(
978
+ const half* __restrict__ vec,
979
+ const uint8_t* __restrict__ mat,
980
+ half* __restrict__ mul,
981
+ const half* __restrict__ scales,
982
+ const half* __restrict__ zeros,
983
+ int batch,
984
+ int heads,
985
+ int vec_row,
986
+ int vec_height,
987
+ int height,
988
+ int width,
989
+ int zero_width
990
+ ) {
991
+ //int weight_total = batch * heads * height * width;
992
+ int input_total = batch * heads * vec_row * vec_height;
993
+ int out_total = batch * heads * vec_row * width;
994
+ int tid = threadIdx.x;
995
+ int h = BLOCKWIDTH * blockIdx.x;
996
+ int w = BLOCKWIDTH * blockIdx.y + tid;
997
+ if (w >= width && tid >= height) {
998
+ return;
999
+ }
1000
+
1001
+ __shared__ float blockvec[BLOCKWIDTH];
1002
+ int i = width * h + w;
1003
+ int k;
1004
+ float w_tmp;
1005
+
1006
+ float weight[BLOCKWIDTH];
1007
+ for (int b = 0; b < batch; ++b){
1008
+ for (int head = 0; head < heads; ++head){
1009
+ int batch_shift = b * heads + head;
1010
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
1011
+ int k_w = k;
1012
+ int w_index = batch_shift * height * width + i + (k_w * width);
1013
+ float scale = __half2float(scales[batch_shift * width + w]);
1014
+ float zero = __half2float(zeros[batch_shift * width + w]);
1015
+ w_tmp = as_unsigned(mat[w_index]);
1016
+ weight[k] = scale *(w_tmp-zero);
1017
+ }
1018
+
1019
+ float res;
1020
+ for (int vr = 0; vr < vec_row; ++vr){
1021
+ res = 0;
1022
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
1023
+ if (vec_index < input_total) {
1024
+ blockvec[tid] = __half2float(vec[vec_index]);
1025
+ } else {
1026
+ blockvec[tid] = 0;
1027
+ }
1028
+ __syncthreads();
1029
+ for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){
1030
+ float temp_res = weight[k]*blockvec[k];
1031
+ res += temp_res;
1032
+ }
1033
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1034
+ if (out_index < out_total) {
1035
+ atomicAdd(&mul[out_index], __float2half(res));
1036
+ }
1037
+ __syncthreads();
1038
+ }
1039
+ }
1040
+ }
1041
+ }
1042
+
1043
+
1044
+
1045
+
1046
+ void vecquant8matmul_batched_column_compression_faster_cuda(
1047
+ torch::Tensor vec,
1048
+ torch::Tensor mat,
1049
+ torch::Tensor mul,
1050
+ torch::Tensor scales,
1051
+ torch::Tensor zeros
1052
+ ) {
1053
+ int batch = vec.size(0);
1054
+ int heads = vec.size(1);
1055
+ int vec_row = vec.size(2);
1056
+ int height = vec.size(3);
1057
+ int width = mat.size(3);
1058
+
1059
+ dim3 blocks(
1060
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
1061
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
1062
+ );
1063
+ dim3 threads(BLOCKWIDTH);
1064
+
1065
+ VecQuant8BatchMatMulColumnCompressionKernel_faster<<<blocks, threads>>>(
1066
+ (half*) vec.data_ptr(),
1067
+ (uint8_t*) mat.data_ptr(),
1068
+ (half*) mul.data_ptr(),
1069
+ (half*) scales.data_ptr(),
1070
+ (half*) zeros.data_ptr(),
1071
+ batch, heads, vec_row, height, width
1072
+ );
1073
+
1074
+ }
1075
+
1076
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster(
1077
+ const half* __restrict__ vec,
1078
+ const uint8_t* __restrict__ mat,
1079
+ half* __restrict__ mul,
1080
+ const half* __restrict__ scales,
1081
+ const half* __restrict__ zeros,
1082
+ int batch,
1083
+ int heads,
1084
+ int vec_row,
1085
+ int height,
1086
+ int width
1087
+ ) {
1088
+ //int weight_total = batch * heads * height * width;
1089
+ int input_total = batch * heads * vec_row * height;
1090
+ int out_total = batch * heads * vec_row * width;
1091
+ int tid = threadIdx.x;
1092
+ int h = BLOCKWIDTH * blockIdx.x;
1093
+ int w = BLOCKWIDTH * blockIdx.y + tid;
1094
+ if (w >= width && tid >= height) {
1095
+ return;
1096
+ }
1097
+
1098
+ __shared__ float blockvec[BLOCKWIDTH];
1099
+ int k;
1100
+ float w_tmp;
1101
+ float weight[BLOCKWIDTH];
1102
+
1103
+ for (int b = 0; b < batch; ++b){
1104
+ for (int head = 0; head < heads; ++head){
1105
+ int batch_shift = b * heads + head;
1106
+ for (k = 0; k < BLOCKWIDTH; ++k){
1107
+ int w_index = (batch_shift * height + h + k) * width + w;
1108
+ float scale = __half2float(scales[batch_shift * height + h + k]);
1109
+ float zero = __half2float(zeros[batch_shift * height + h + k]);
1110
+ w_tmp = mat[w_index];
1111
+ weight[k] = scale * (w_tmp-zero);
1112
+ }
1113
+
1114
+ float res;
1115
+ for (int vr = 0; vr < vec_row; ++vr){
1116
+ res = 0;
1117
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
1118
+ if (vec_index < input_total) {
1119
+ blockvec[tid] = __half2float(vec[vec_index]);
1120
+ } else {
1121
+ blockvec[tid] = 0;
1122
+ }
1123
+ __syncthreads();
1124
+ for (k = 0; k < BLOCKWIDTH; ++k){
1125
+ res += weight[k]*blockvec[k];
1126
+ }
1127
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1128
+ if (out_index < out_total) {
1129
+ atomicAdd(&mul[out_index], __float2half(res));
1130
+ }
1131
+ __syncthreads();
1132
+ }
1133
+ }
1134
+ }
1135
+ }
1136
+
1137
+
1138
+
1139
+ void vecquant8matmul_batched_column_compression_old_cuda(
1140
+ torch::Tensor vec,
1141
+ torch::Tensor mat,
1142
+ torch::Tensor mul,
1143
+ torch::Tensor scales,
1144
+ torch::Tensor zeros
1145
+ ) {
1146
+ int batch = vec.size(0);
1147
+ int heads = vec.size(1);
1148
+ int vec_row = vec.size(2);
1149
+ int height = vec.size(3);
1150
+ int width = mat.size(3);
1151
+
1152
+ dim3 blocks(
1153
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
1154
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
1155
+ );
1156
+ dim3 threads(BLOCKWIDTH);
1157
+
1158
+ AT_DISPATCH_FLOATING_TYPES(
1159
+ vec.type(), "vecquant8matmul_batched_column_compression_old_cuda", ([&] {
1160
+ VecQuant8BatchMatMulColumnCompressionKernel_old<<<blocks, threads>>>(
1161
+ vec.data<scalar_t>(), mat.data<uint8_t>(), mul.data<scalar_t>(),
1162
+ scales.data<scalar_t>(), zeros.data<scalar_t>(),
1163
+ batch, heads, vec_row, height, width
1164
+ );
1165
+ })
1166
+ );
1167
+
1168
+ }
1169
+
1170
+ template <typename scalar_t>
1171
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_old(
1172
+ const scalar_t* __restrict__ vec,
1173
+ const uint8_t* __restrict__ mat,
1174
+ scalar_t* __restrict__ mul,
1175
+ const scalar_t* __restrict__ scales,
1176
+ const scalar_t* __restrict__ zeros,
1177
+ int batch,
1178
+ int heads,
1179
+ int vec_row,
1180
+ int height,
1181
+ int width
1182
+ ) {
1183
+ int weight_total = batch * heads * height * width;
1184
+ int input_total = batch * heads * vec_row * height;
1185
+ int out_total = batch * heads * vec_row * width;
1186
+ int tid = threadIdx.x;
1187
+ // h is index of height with step being BLOCKWIDTH
1188
+ int h = BLOCKWIDTH * blockIdx.x;
1189
+ // w is index of width with step being 1
1190
+ int w = BLOCKWIDTH * blockIdx.y + tid;
1191
+ if (w >= width && tid >= height) {
1192
+ return;
1193
+ }
1194
+
1195
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
1196
+ int k;
1197
+ scalar_t w_tmp;
1198
+
1199
+ float weight[BLOCKWIDTH];
1200
+
1201
+ for (int b = 0; b < batch; ++b){
1202
+ for (int head = 0; head < heads; ++head){
1203
+ int batch_shift = b * heads + head;
1204
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
1205
+ int w_index = (batch_shift * height + h + k) * width + w;
1206
+ if (w_index >= weight_total || w >= width) {
1207
+ weight[k] = 0;
1208
+ } else {
1209
+ scalar_t scale = scales[batch_shift * height + h + k];
1210
+ scalar_t zero = zeros[batch_shift * height + h + k];
1211
+ w_tmp = mat[w_index];
1212
+ weight[k] = scale * (w_tmp - zero);
1213
+ }
1214
+ }
1215
+
1216
+ scalar_t res;
1217
+ for (int vr = 0; vr < vec_row; ++vr){
1218
+ res = 0;
1219
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
1220
+ if (vec_index < input_total) {
1221
+ blockvec[tid] = vec[vec_index];
1222
+ } else {
1223
+ blockvec[tid] = 0;
1224
+ }
1225
+
1226
+ __syncthreads();
1227
+ for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){
1228
+ // res is the dot product of BLOCKWIDTH elements (part of width)
1229
+ res += weight[k] * blockvec[k];
1230
+ }
1231
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
1232
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1233
+ if (out_index < out_total) {
1234
+ atomicAdd(&mul[out_index], res);
1235
+ }
1236
+ __syncthreads();
1237
+ }
1238
+ }
1239
+ }
1240
+ }
1241
+
1242
+
1243
+ void vecquant4matmul_batched_old_cuda(
1244
+ torch::Tensor vec,
1245
+ torch::Tensor mat,
1246
+ torch::Tensor mul,
1247
+ torch::Tensor scales,
1248
+ torch::Tensor zeros
1249
+ ) {
1250
+ int batch = vec.size(0);
1251
+ int heads = vec.size(1);
1252
+ int vec_row = vec.size(2);
1253
+ int vec_height = vec.size(3);
1254
+ int height = mat.size(2);
1255
+ int width = mat.size(3);
1256
+ int zero_width = zeros.size(2);
1257
+
1258
+ dim3 blocks(
1259
+ (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4,
1260
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
1261
+ );
1262
+ dim3 threads(BLOCKWIDTH);
1263
+
1264
+ AT_DISPATCH_FLOATING_TYPES(
1265
+ vec.type(), "vecquant4matmul_batched_old_cuda", ([&] {
1266
+ VecQuant4BatchMatMulKernel_old<<<blocks, threads>>>(
1267
+ vec.data<scalar_t>(), mat.data<uint8_t>(), mul.data<scalar_t>(),
1268
+ scales.data<scalar_t>(), zeros.data<scalar_t>(),
1269
+ batch, heads, vec_row, vec_height, height, width, zero_width
1270
+ );
1271
+ })
1272
+ );
1273
+
1274
+ }
1275
+
1276
+ template <typename scalar_t>
1277
+ __global__ void VecQuant4BatchMatMulKernel_old(
1278
+ const scalar_t* __restrict__ vec,
1279
+ const uint8_t* __restrict__ mat,
1280
+ scalar_t* __restrict__ mul,
1281
+ const scalar_t* __restrict__ scales,
1282
+ const scalar_t* __restrict__ zeros,
1283
+ int batch,
1284
+ int heads,
1285
+ int vec_row,
1286
+ int vec_height,
1287
+ int height,
1288
+ int width,
1289
+ int zero_width
1290
+ ) {
1291
+ int weight_total = batch * heads * height * width;
1292
+ int input_total = batch * heads * vec_row * vec_height;
1293
+ int out_total = batch * heads * vec_row * width;
1294
+ int tid = threadIdx.x;
1295
+ // h is index of height with step being BLOCKHEIGHT_OLD4
1296
+ int h = BLOCKHEIGHT_OLD4 * blockIdx.x;
1297
+ // w is index of width with step being 1
1298
+ int w = BLOCKWIDTH * blockIdx.y + tid;
1299
+ if (w >= width && tid >= vec_height) {
1300
+ return;
1301
+ }
1302
+
1303
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
1304
+ // i is index of mat of block first row
1305
+ int i = width * h + w;
1306
+ int k;
1307
+ scalar_t w_tmp;
1308
+
1309
+ float weight[BLOCKWIDTH];
1310
+ for (int b = 0; b < batch; ++b){
1311
+ for (int head = 0; head < heads; ++head){
1312
+ int batch_shift = b * heads + head;
1313
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){
1314
+ int k_w = (k / 2);
1315
+ int k_bit = (k % 2) * 4;
1316
+ int w_index = batch_shift * height * width + i + (k_w * width);
1317
+ if (w_index >= weight_total || w >= width) {
1318
+ weight[k] = 0;
1319
+ } else {
1320
+ scalar_t scale = scales[batch_shift * width + w];
1321
+ scalar_t zero = zeros[batch_shift * width + w];
1322
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
1323
+ weight[k] = scale * (w_tmp - zero);
1324
+ }
1325
+ }
1326
+
1327
+ scalar_t res;
1328
+ for (int vr = 0; vr < vec_row; ++vr){
1329
+ res = 0;
1330
+ int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid;
1331
+ if (vec_index < input_total) {
1332
+ blockvec[tid] = vec[vec_index];
1333
+ } else {
1334
+ blockvec[tid] = 0;
1335
+ }
1336
+
1337
+ __syncthreads();
1338
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){
1339
+ // res is the dot product of BLOCKWIDTH elements (part of width)
1340
+ res += weight[k] * blockvec[k];
1341
+ }
1342
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
1343
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1344
+ if (out_index < out_total) {
1345
+ atomicAdd(&mul[out_index], res);
1346
+ }
1347
+ __syncthreads();
1348
+ }
1349
+ }
1350
+ }
1351
+ }
1352
+
1353
+
1354
+
1355
+
1356
+
1357
+ void vecquant4matmul_batched_column_compression_old_cuda(
1358
+ torch::Tensor vec,
1359
+ torch::Tensor mat,
1360
+ torch::Tensor mul,
1361
+ torch::Tensor scales,
1362
+ torch::Tensor zeros
1363
+ ) {
1364
+ int batch = vec.size(0);
1365
+ int heads = vec.size(1);
1366
+ int vec_row = vec.size(2);
1367
+ int height = vec.size(3);
1368
+ int width = mat.size(3);
1369
+
1370
+ dim3 blocks(
1371
+ (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4,
1372
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
1373
+ );
1374
+ dim3 threads(BLOCKWIDTH);
1375
+
1376
+ AT_DISPATCH_FLOATING_TYPES(
1377
+ vec.type(), "vecquant4matmul_batched_column_compression_old_cuda", ([&] {
1378
+ VecQuant4BatchMatMulColumnCompressionKernel_old<<<blocks, threads>>>(
1379
+ vec.data<scalar_t>(), mat.data<uint8_t>(), mul.data<scalar_t>(),
1380
+ scales.data<scalar_t>(), zeros.data<scalar_t>(),
1381
+ batch, heads, vec_row, height, width
1382
+ );
1383
+ })
1384
+ );
1385
+
1386
+ }
1387
+
1388
+ template <typename scalar_t>
1389
+ __global__ void VecQuant4BatchMatMulColumnCompressionKernel_old(
1390
+ const scalar_t* __restrict__ vec,
1391
+ const uint8_t* __restrict__ mat,
1392
+ scalar_t* __restrict__ mul,
1393
+ const scalar_t* __restrict__ scales,
1394
+ const scalar_t* __restrict__ zeros,
1395
+ int batch,
1396
+ int heads,
1397
+ int vec_row,
1398
+ int height,
1399
+ int width
1400
+ ) {
1401
+ int weight_total = batch * heads * height * width;
1402
+ int input_total = batch * heads * vec_row * height;
1403
+ int out_total = batch * heads * vec_row * width;
1404
+ int tid = threadIdx.x;
1405
+ // h is index of height with step being BLOCKWIDTH
1406
+ int h = BLOCKHEIGHT_OLD4 * blockIdx.x;
1407
+ // w is index of width with step being 1
1408
+ int w = BLOCKWIDTH * blockIdx.y + tid;
1409
+ if (w >= width && tid >= height) {
1410
+ return;
1411
+ }
1412
+
1413
+ __shared__ scalar_t blockvec[BLOCKWIDTH];
1414
+ int k;
1415
+ scalar_t w_tmp;
1416
+
1417
+ float weight[BLOCKWIDTH];
1418
+
1419
+ for (int b = 0; b < batch; ++b){
1420
+ for (int head = 0; head < heads; ++head){
1421
+ int batch_shift = b * heads + head;
1422
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){
1423
+ int k_w = (k / 2);
1424
+ int k_bit = (k % 2) * 4;
1425
+ int w_index = (batch_shift * height + h + k) * width + k_w;
1426
+ if (w_index >= weight_total || w >= width) {
1427
+ weight[k] = 0;
1428
+ } else {
1429
+ scalar_t scale = scales[batch_shift * height + h + k];
1430
+ scalar_t zero = zeros[batch_shift * height + h + k];
1431
+ w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF);
1432
+ weight[k] = scale * (w_tmp - zero);
1433
+ }
1434
+ }
1435
+
1436
+ scalar_t res;
1437
+ for (int vr = 0; vr < vec_row; ++vr){
1438
+ res = 0;
1439
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
1440
+ if (vec_index < input_total) {
1441
+ blockvec[tid] = vec[vec_index];
1442
+ } else {
1443
+ blockvec[tid] = 0;
1444
+ }
1445
+
1446
+ __syncthreads();
1447
+ for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){
1448
+ // res is the dot product of BLOCKWIDTH elements (part of width)
1449
+ res += weight[k] * blockvec[k];
1450
+ }
1451
+ // add res to the final result, final matrix shape: (batch, vec_row, width)
1452
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1453
+ if (out_index < out_total) {
1454
+ atomicAdd(&mul[out_index], res);
1455
+ }
1456
+ __syncthreads();
1457
+ }
1458
+ }
1459
+ }
1460
+ }
1461
+
1462
+
1463
+
1464
+
1465
+
1466
+ void vecquant8matmul_batched_faster_old_cuda(
1467
+ torch::Tensor vec,
1468
+ torch::Tensor mat,
1469
+ torch::Tensor mul,
1470
+ torch::Tensor scales,
1471
+ torch::Tensor zeros
1472
+ ) {
1473
+ int batch = vec.size(0);
1474
+ int heads = vec.size(1);
1475
+ int vec_row = vec.size(2);
1476
+ int vec_height = vec.size(3);
1477
+ int height = mat.size(2);
1478
+ int width = mat.size(3);
1479
+
1480
+ dim3 blocks(
1481
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
1482
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
1483
+ );
1484
+ dim3 threads(BLOCKWIDTH);
1485
+
1486
+ VecQuant8BatchMatMulKernel_faster_old<<<blocks, threads>>>(
1487
+ (half*) vec.data_ptr(),
1488
+ (uint8_t*) mat.data_ptr(),
1489
+ (half*) mul.data_ptr(),
1490
+ (half*) scales.data_ptr(),
1491
+ (half*) zeros.data_ptr(),
1492
+ batch, heads, vec_row, vec_height, height, width
1493
+ );
1494
+ }
1495
+
1496
+
1497
+ __global__ void VecQuant8BatchMatMulKernel_faster_old(
1498
+ const half* __restrict__ vec,
1499
+ const uint8_t* __restrict__ mat,
1500
+ half* __restrict__ mul,
1501
+ const half* __restrict__ scales,
1502
+ const half* __restrict__ zeros,
1503
+ int batch,
1504
+ int heads,
1505
+ int vec_row,
1506
+ int vec_height,
1507
+ int height,
1508
+ int width
1509
+ ) {
1510
+ int weight_total = batch * heads * height * width;
1511
+ int input_total = batch * heads * vec_row * vec_height;
1512
+ int out_total = batch * heads * vec_row * width;
1513
+ int tid = threadIdx.x;
1514
+ const int BLOCKWIDTH_half = BLOCKWIDTH/2;
1515
+
1516
+ int h = BLOCKWIDTH * blockIdx.x; //head_dim, dim=-1
1517
+ int w = BLOCKWIDTH * blockIdx.y + tid; //seq-len, +0-256 ,dim=-2
1518
+ /*
1519
+ if (w >= width && tid >= vec_height) {
1520
+ return;
1521
+ }
1522
+ */
1523
+ __shared__ half blockvec[BLOCKWIDTH]; //256
1524
+ int i = width * h + w;
1525
+ int k;
1526
+
1527
+ half w_tmp1 = __float2half(0);
1528
+ half w_tmp2 = __float2half(0);
1529
+
1530
+ half2 weight[BLOCKWIDTH_half];
1531
+ for (int b = 0; b < batch; ++b){
1532
+ for (int head = 0; head < heads; ++head){
1533
+ int batch_shift = b * heads + head;
1534
+ //int zero_index = batch_shift;
1535
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
1536
+ int w_index1 = batch_shift * height * width + i + (2 * k * width); // [batch,head,h+k, w]
1537
+ int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width);
1538
+ int zero_index = batch_shift * width + w; // [batch,head, w]
1539
+ if (w_index1 >= weight_total || w >= width || (2 * k + h) >= height) {
1540
+ weight[k] = __float2half2_rn(0);
1541
+ } else {
1542
+ float zero_f=__half2float(zeros[zero_index]);
1543
+ float scale_f= __half2float(scales[zero_index]);
1544
+ if (w_index2 >= weight_total){
1545
+ w_tmp1 = __float2half((as_unsigned(mat[w_index1]) -zero_f)*scale_f);
1546
+ w_tmp2 = __float2half(0);
1547
+ weight[k] = __halves2half2(w_tmp1,w_tmp2);
1548
+ //printf("zero_index is %d w is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,w,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k]));
1549
+ }else{
1550
+ w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1]));
1551
+ w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2]));
1552
+
1553
+ //weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero,zero)),__halves2half2(scale,scale));
1554
+ weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f)));
1555
+ //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k]));
1556
+ }
1557
+ }
1558
+ }
1559
+
1560
+
1561
+ for (int vr = 0; vr < vec_row; ++vr){
1562
+ float res=0;
1563
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
1564
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1565
+ if (vec_index < input_total) {
1566
+ //blockvec[tid] = __half2float(vec[vec_index]);// [batch, head, vr, tid(seq_len dim+)]
1567
+ blockvec[tid] = vec[vec_index];
1568
+ //printf("width is %d height is %d h is %d w is %d vec_index is %d out_index is %d vec_row is %d vec_height is %d,vr is %d tid is %d blockvec is %f\n",width,height, h,w,vec_index,out_index,vec_row,vec_height,vr,tid,blockvec[tid]);
1569
+ } else {
1570
+ blockvec[tid] = __float2half(0);
1571
+ }
1572
+ __syncthreads();
1573
+ if (out_index < out_total) {
1574
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
1575
+ half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1]));
1576
+ res += __low2float(res2) + __high2float(res2);
1577
+ }
1578
+ atomicAdd(&mul[out_index], __float2half(res));
1579
+ }
1580
+ __syncthreads();
1581
+ }
1582
+ }
1583
+ }
1584
+ }
1585
+
1586
+
1587
+ void vecquant8matmul_batched_column_compression_faster_old_cuda(
1588
+ torch::Tensor vec, // [batch,heads, seq_q, seq_v]
1589
+ torch::Tensor mat, // [batch,heads, seq_v, head_dim]
1590
+ torch::Tensor mul, // [batch,heads, seq_q,head_dim]
1591
+ torch::Tensor scales, // [batch,heads, head_dim]
1592
+ torch::Tensor zeros
1593
+ ) {
1594
+ int batch = vec.size(0);
1595
+ int heads = vec.size(1);
1596
+ int vec_row = vec.size(2); //ql
1597
+ int height = mat.size(2); //vl
1598
+ int width = mat.size(3); //head_dim
1599
+
1600
+ dim3 blocks(
1601
+ (height + BLOCKWIDTH - 1) / BLOCKWIDTH,
1602
+ (width + BLOCKWIDTH - 1) / BLOCKWIDTH
1603
+ );
1604
+ dim3 threads(BLOCKWIDTH);
1605
+
1606
+ VecQuant8BatchMatMulColumnCompressionKernel_faster_old<<<blocks, threads>>>(
1607
+ (half*) vec.data_ptr(),
1608
+ (uint8_t*) mat.data_ptr(),
1609
+ (half*) mul.data_ptr(),
1610
+ (half*) scales.data_ptr(),
1611
+ (half*) zeros.data_ptr(),
1612
+ batch, heads, vec_row, height, width
1613
+ );
1614
+
1615
+ }
1616
+
1617
+
1618
+ __global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old(
1619
+ const half* __restrict__ vec, // [batch,heads, seq_q, seq_v]
1620
+ const uint8_t* __restrict__ mat, // [batch,heads, seq_v, head_dim]
1621
+ half* __restrict__ mul, // [batch,heads, seq_q,head_dim]
1622
+ const half* __restrict__ scales, // [batch,heads, seq_v]
1623
+ const half* __restrict__ zeros,
1624
+ int batch,
1625
+ int heads,
1626
+ int vec_row, //seq_q
1627
+ int height, //seq_v
1628
+ int width //head_dim
1629
+ ) {
1630
+ int weight_total = batch * heads * height * width;
1631
+ int input_total = batch * heads * vec_row * height;
1632
+ int out_total = batch * heads * vec_row * width;
1633
+ int tid = threadIdx.x;
1634
+ int h = BLOCKWIDTH * blockIdx.x; // vl
1635
+ int w = BLOCKWIDTH * blockIdx.y + tid; //head_dim + block
1636
+ if (w >= width && tid >= height) {
1637
+ return;
1638
+ }
1639
+ __shared__ half blockvec[BLOCKWIDTH];
1640
+ int k;
1641
+ half w_tmp1 = __float2half(0);
1642
+ half w_tmp2 = __float2half(0);
1643
+ int i = width * h + w;
1644
+ const int BLOCKWIDTH_half = BLOCKWIDTH/2;
1645
+ half2 weight[BLOCKWIDTH_half];
1646
+
1647
+ for (int b = 0; b < batch; ++b){
1648
+ for (int head = 0; head < heads; ++head){
1649
+ int batch_shift = b * heads + head;
1650
+ //int zero_index = batch_shift;
1651
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
1652
+ int w_index1 = batch_shift * height * width + i + (2 * k) * width; // [batch,head, h+k, w]
1653
+ int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width);
1654
+ int zero_index1 = batch_shift * height + h + 2*k; // [batch,head, w]
1655
+ int zero_index2 = batch_shift * height + h + 2*k+1; // [batch,head, w]
1656
+
1657
+ if (w_index1 >= weight_total || (2 * k + h)>=height) {
1658
+ weight[k]=__float2half2_rn(0);
1659
+ } else{
1660
+ //int zero_index = batch_shift + h; // [batch,head, w]
1661
+ //float scale_f1 = __half2float(scales[zero_index1]);
1662
+ //float zero_f1 = __half2float(zeros[zero_index1]);
1663
+ if (w_index2>=weight_total){
1664
+ w_tmp1 = __float2half((as_unsigned(mat[w_index1]) - __half2float(zeros[zero_index1]))* __half2float(scales[zero_index1]));
1665
+ w_tmp2 = __float2half(0);
1666
+ weight[k] = __halves2half2(w_tmp1,w_tmp2);
1667
+ //printf("zero_index is %d k is %d w is %d head is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,k,w,head,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k]));
1668
+ }else{
1669
+ w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1]));
1670
+ w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2]));
1671
+ half zero1=zeros[zero_index1];
1672
+ half zero2=zeros[zero_index2];
1673
+ half scale1=scales[zero_index1];
1674
+ half scale2=scales[zero_index2];
1675
+ weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero1,zero2)),__halves2half2(scale1,scale2));
1676
+ //weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f)));
1677
+ //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k]));
1678
+ }
1679
+ }
1680
+ }
1681
+
1682
+
1683
+ for (int vr = 0; vr < vec_row; ++vr){
1684
+ float res=0;
1685
+ int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid;
1686
+ int out_index = (batch_shift * vec_row + vr) * width + w;
1687
+
1688
+ if (vec_index < input_total) {
1689
+ //blockvec[tid] = __half2float(vec[vec_index]);
1690
+ blockvec[tid] = vec[vec_index];
1691
+ //printf("vec_index is %d out_index is %d vec_row is %d ,vr is %d tid is %d blockvec is %f\n",vec_index,out_index,vec_row,vr,tid,blockvec[tid]);
1692
+ } else {
1693
+ blockvec[tid] = __float2half(0);
1694
+ //blockvec[tid] = 0;
1695
+ }
1696
+ __syncthreads();
1697
+ if (out_index < out_total) {
1698
+ for (k = 0; k < BLOCKWIDTH_half; ++k){
1699
+ half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1]));
1700
+ res += __low2float(res2) + __high2float(res2);
1701
+ }
1702
+ atomicAdd(&mul[out_index], __float2half(res));
1703
+ }
1704
+ __syncthreads();
1705
+ }
1706
+ }
1707
+ }
1708
+ }