Crystalcareai
commited on
Delete single.py
Browse files
single.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import triton
|
3 |
-
import triton.language as tl
|
4 |
-
from torch.nn import functional as F
|
5 |
-
|
6 |
-
@triton.jit
|
7 |
-
def _single2scatter(
|
8 |
-
X_ptr, stride_xm, stride_xk,
|
9 |
-
W_ptr, stride_we, stride_wk, stride_wn,
|
10 |
-
Y_ptr, stride_ym, stride_yn,
|
11 |
-
expert_idxs_ptr,
|
12 |
-
FAN_OUT: tl.constexpr,
|
13 |
-
K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
14 |
-
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
15 |
-
ACC_TYPE: tl.constexpr,
|
16 |
-
):
|
17 |
-
pid0 = tl.program_id(axis=0)
|
18 |
-
pid1 = tl.program_id(axis=1)
|
19 |
-
|
20 |
-
N_block_id = pid0
|
21 |
-
if FAN_OUT == 1:
|
22 |
-
in_idx = pid1
|
23 |
-
else:
|
24 |
-
in_idx = 0
|
25 |
-
out_idx = pid1
|
26 |
-
|
27 |
-
K_block = tl.arange(0, BLOCK_K)
|
28 |
-
N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
|
29 |
-
E_idx = tl.load(expert_idxs_ptr + pid1)
|
30 |
-
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
|
31 |
-
W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
|
32 |
-
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
|
33 |
-
for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
|
34 |
-
x = tl.load(X_blk_ptrs)
|
35 |
-
w = tl.load(W_blk_ptrs)
|
36 |
-
acc += tl.sum(x * w, axis=0)[None, :]
|
37 |
-
X_blk_ptrs += BLOCK_K * stride_xk
|
38 |
-
W_blk_ptrs += BLOCK_K * stride_wk
|
39 |
-
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
|
40 |
-
tl.store(Y_blk_ptrs, acc)
|
41 |
-
|
42 |
-
def single2scatter(X, W, expert_idxs):
|
43 |
-
E, xdim, ydim = W.size()
|
44 |
-
k = expert_idxs.size(1)
|
45 |
-
assert X.size(0) == k or X.size(0) == 1
|
46 |
-
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
|
47 |
-
BLOCK_N = 128
|
48 |
-
BLOCK_K = 128
|
49 |
-
grid = ydim // BLOCK_N, k
|
50 |
-
_single2scatter[grid](
|
51 |
-
X, X.stride(0), X.stride(1),
|
52 |
-
W, W.stride(0), W.stride(1), W.stride(2),
|
53 |
-
Y, Y.stride(0), Y.stride(1),
|
54 |
-
expert_idxs,
|
55 |
-
FAN_OUT=Y.size(0) // X.size(0),
|
56 |
-
K=xdim, N=ydim, E=E,
|
57 |
-
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
|
58 |
-
ACC_TYPE=tl.float32
|
59 |
-
)
|
60 |
-
return Y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|