SakuraD commited on
Commit
29a3d5a
1 Parent(s): 7b6030f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -4
  2. __pycache__/imagenet_class_index.cpython-310.pyc +0 -0
  3. __pycache__/kinetics_class_index.cpython-310.pyc +0 -0
  4. __pycache__/transforms.cpython-310.pyc +0 -0
  5. __pycache__/videomamba_image.cpython-310.pyc +0 -0
  6. __pycache__/videomamba_video.cpython-310.pyc +0 -0
  7. app.py +180 -0
  8. causal-conv1d/AUTHORS +1 -0
  9. causal-conv1d/LICENSE +29 -0
  10. causal-conv1d/README.md +1 -0
  11. causal-conv1d/causal_conv1d/__init__.py +3 -0
  12. causal-conv1d/causal_conv1d/causal_conv1d_interface.py +104 -0
  13. causal-conv1d/csrc/causal_conv1d.cpp +333 -0
  14. causal-conv1d/csrc/causal_conv1d.h +53 -0
  15. causal-conv1d/csrc/causal_conv1d_bwd.cu +525 -0
  16. causal-conv1d/csrc/causal_conv1d_common.h +64 -0
  17. causal-conv1d/csrc/causal_conv1d_fwd.cu +350 -0
  18. causal-conv1d/csrc/causal_conv1d_update.cu +96 -0
  19. causal-conv1d/csrc/static_switch.h +25 -0
  20. causal-conv1d/setup.py +264 -0
  21. causal-conv1d/tests/test_causal_conv1d.py +173 -0
  22. imagenet_class_index.py +1002 -0
  23. images/cat.png +0 -0
  24. images/dog.png +0 -0
  25. images/panda.png +0 -0
  26. install.sh +2 -0
  27. kinetics_class_index.py +402 -0
  28. mamba/.gitmodules +3 -0
  29. mamba/AUTHORS +2 -0
  30. mamba/LICENSE +201 -0
  31. mamba/README.md +149 -0
  32. mamba/assets/selection.png +0 -0
  33. mamba/benchmarks/benchmark_generation_mamba_simple.py +88 -0
  34. mamba/csrc/selective_scan/reverse_scan.cuh +401 -0
  35. mamba/csrc/selective_scan/selective_scan.cpp +497 -0
  36. mamba/csrc/selective_scan/selective_scan.h +101 -0
  37. mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +9 -0
  38. mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +9 -0
  39. mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +9 -0
  40. mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +9 -0
  41. mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +9 -0
  42. mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +9 -0
  43. mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh +531 -0
  44. mamba/csrc/selective_scan/selective_scan_common.h +221 -0
  45. mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu +10 -0
  46. mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu +10 -0
  47. mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu +10 -0
  48. mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh +345 -0
  49. mamba/csrc/selective_scan/static_switch.h +25 -0
  50. mamba/csrc/selective_scan/uninitialized_copy.cuh +69 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: VideoMamba
3
- emoji: 🌖
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: VideoMamba
3
+ emoji: 🐍
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
__pycache__/imagenet_class_index.cpython-310.pyc ADDED
Binary file (60.6 kB). View file
 
__pycache__/kinetics_class_index.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
__pycache__/transforms.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
__pycache__/videomamba_image.cpython-310.pyc ADDED
Binary file (9.64 kB). View file
 
__pycache__/videomamba_video.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from decord import VideoReader
10
+ from decord import cpu
11
+ from videomamba_image import videomamba_image_tiny
12
+ from videomamba_video import videomamba_tiny
13
+ from kinetics_class_index import kinetics_classnames
14
+ from imagenet_class_index import imagenet_classnames
15
+ from transforms import (
16
+ GroupNormalize, GroupScale, GroupCenterCrop,
17
+ Stack, ToTorchFormatTensor
18
+ )
19
+
20
+ import gradio as gr
21
+ from huggingface_hub import hf_hub_download
22
+
23
+
24
+ # install packages for mamba
25
+ os.system("bash install.sh")
26
+
27
+
28
+ # Device on which to run the model
29
+ # Set to cuda to load on GPU
30
+ device = "cuda"
31
+ model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_k400_f16_res224.pth")
32
+ model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
33
+ # Pick a pretrained model
34
+ model_video = videomamba_tiny(num_classes=400, num_frames=16)
35
+ video_sd = torch.load(model_video_path, map_location='cpu')
36
+ model_video.load_state_dict(video_sd)
37
+ model_image = videomamba_image_tiny()
38
+ image_sd = torch.load(model_image_path, map_location='cpu')
39
+ model_image.load_state_dict(image_sd['model'])
40
+ # Set to eval mode and move to desired device
41
+ model_video = model_video.to(device).eval()
42
+ model_image = model_image.to(device).eval()
43
+
44
+ # Create an id to label name mapping
45
+ kinetics_id_to_classname = {}
46
+ for k, v in kinetics_classnames.items():
47
+ kinetics_id_to_classname[k] = v
48
+ imagenet_id_to_classname = {}
49
+ for k, v in imagenet_classnames.items():
50
+ imagenet_id_to_classname[k] = v[1]
51
+
52
+
53
+ def get_index(num_frames, num_segments=8):
54
+ seg_size = float(num_frames - 1) / num_segments
55
+ start = int(seg_size / 2)
56
+ offsets = np.array([
57
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
58
+ ])
59
+ return offsets
60
+
61
+
62
+ def load_video(video_path):
63
+ vr = VideoReader(video_path, ctx=cpu(0))
64
+ num_frames = len(vr)
65
+ frame_indices = get_index(num_frames, 16)
66
+
67
+ # transform
68
+ crop_size = 160
69
+ scale_size = 160
70
+ input_mean = [0.485, 0.456, 0.406]
71
+ input_std = [0.229, 0.224, 0.225]
72
+
73
+ transform = T.Compose([
74
+ GroupScale(int(scale_size)),
75
+ GroupCenterCrop(crop_size),
76
+ Stack(),
77
+ ToTorchFormatTensor(),
78
+ GroupNormalize(input_mean, input_std)
79
+ ])
80
+
81
+ images_group = list()
82
+ for frame_index in frame_indices:
83
+ img = Image.fromarray(vr[frame_index].asnumpy())
84
+ images_group.append(img)
85
+ torch_imgs = transform(images_group)
86
+ return torch_imgs
87
+
88
+
89
+ def inference_video(video):
90
+ vid = load_video(video)
91
+
92
+ # The model expects inputs of shape: B x C x H x W
93
+ TC, H, W = vid.shape
94
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
95
+
96
+ with torch.no_grad():
97
+ prediction = model_video(inputs.to(device))
98
+ prediction = F.softmax(prediction, dim=1).flatten()
99
+
100
+ return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
101
+
102
+
103
+ def set_example_video(example: list) -> dict:
104
+ return gr.Video.update(value=example[0])
105
+
106
+
107
+ def inference_image(img):
108
+ image = img
109
+ image_transform = T.Compose(
110
+ [
111
+ T.Resize(224),
112
+ T.CenterCrop(224),
113
+ T.ToTensor(),
114
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
115
+ ]
116
+ )
117
+ image = image_transform(image)
118
+
119
+ # The model expects inputs of shape: B x C x H x W
120
+ image = image.unsqueeze(0)
121
+
122
+ with torch.no_grad():
123
+ prediction = model_image(image.to(device))
124
+ prediction = F.softmax(prediction, dim=1).flatten()
125
+
126
+ return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
127
+
128
+
129
+ def set_example_image(example: list) -> dict:
130
+ return gr.Image.update(value=example[0])
131
+
132
+
133
+ demo = gr.Blocks()
134
+ with demo:
135
+ gr.Markdown(
136
+ """
137
+ # VideoMamba-Ti
138
+ Gradio demo for <a href='https://github.com/OpenGVLab/VideoMamba' target='_blank'>VideoMamba</a>: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.
139
+ """
140
+ )
141
+
142
+ with gr.Tab("Video"):
143
+ with gr.Box():
144
+ with gr.Row():
145
+ with gr.Column():
146
+ with gr.Row():
147
+ input_video = gr.Video(label='Input Video').style(height=360)
148
+ with gr.Row():
149
+ submit_video_button = gr.Button('Submit')
150
+ with gr.Column():
151
+ label_video = gr.Label(num_top_classes=5)
152
+ with gr.Row():
153
+ example_videos = gr.Dataset(components=[input_video], samples=[['./videos/hitting_baseball.mp4'], ['./videos/hoverboarding.mp4'], ['./videos/yoga.mp4']])
154
+
155
+ with gr.Tab("Image"):
156
+ with gr.Box():
157
+ with gr.Row():
158
+ with gr.Column():
159
+ with gr.Row():
160
+ input_image = gr.Image(label='Input Image', type='pil').style(height=360)
161
+ with gr.Row():
162
+ submit_image_button = gr.Button('Submit')
163
+ with gr.Column():
164
+ label_image = gr.Label(num_top_classes=5)
165
+ with gr.Row():
166
+ example_images = gr.Dataset(components=[input_image], samples=[['./images/cat.png'], ['./images/dog.png'], ['./images/panda.png']])
167
+
168
+ gr.Markdown(
169
+ """
170
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2403.06977' target='_blank'>VideoMamba: State Space Model for Efficient Video Understanding</a> | <a href='https://github.com/OpenGVLab/VideoMamba' target='_blank'>Github Repo</a></p>
171
+ """
172
+ )
173
+
174
+ submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video)
175
+ example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos.components)
176
+ submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image)
177
+ example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
178
+
179
+ demo.launch(enable_queue=True)
180
+ # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
causal-conv1d/AUTHORS ADDED
@@ -0,0 +1 @@
 
 
1
+ Tri Dao, [email protected]
causal-conv1d/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
causal-conv1d/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Causal depthwise conv1d in CUDA with a PyTorch interface
causal-conv1d/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
causal-conv1d/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ import causal_conv1d_cuda
8
+
9
+
10
+ class CausalConv1dFn(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, x, weight, bias=None, activation=None):
13
+ if activation not in [None, "silu", "swish"]:
14
+ raise NotImplementedError("activation must be None, silu, or swish")
15
+ if x.stride(2) != 1 and x.stride(1) != 1:
16
+ x = x.contiguous()
17
+ bias = bias.contiguous() if bias is not None else None
18
+ ctx.save_for_backward(x, weight, bias)
19
+ ctx.activation = activation in ["silu", "swish"]
20
+ out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
21
+ return out
22
+
23
+ @staticmethod
24
+ def backward(ctx, dout):
25
+ x, weight, bias = ctx.saved_tensors
26
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
27
+ dout = dout.contiguous()
28
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
29
+ # backward of conv1d with the backward of chunk).
30
+ # Here we just pass in None and dx will be allocated in the C++ code.
31
+ dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
32
+ x, weight, bias, dout, None, ctx.activation
33
+ )
34
+ return dx, dweight, dbias if bias is not None else None, None
35
+
36
+
37
+ def causal_conv1d_fn(x, weight, bias=None, activation=None):
38
+ """
39
+ x: (batch, dim, seqlen)
40
+ weight: (dim, width)
41
+ bias: (dim,)
42
+ activation: either None or "silu" or "swish"
43
+
44
+ out: (batch, dim, seqlen)
45
+ """
46
+ return CausalConv1dFn.apply(x, weight, bias, activation)
47
+
48
+
49
+ def causal_conv1d_ref(x, weight, bias=None, activation=None):
50
+ """
51
+ x: (batch, dim, seqlen)
52
+ weight: (dim, width)
53
+ bias: (dim,)
54
+
55
+ out: (batch, dim, seqlen)
56
+ """
57
+ if activation not in [None, "silu", "swish"]:
58
+ raise NotImplementedError("activation must be None, silu, or swish")
59
+ dtype_in = x.dtype
60
+ x = x.to(weight.dtype)
61
+ seqlen = x.shape[-1]
62
+ dim, width = weight.shape
63
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
64
+ out = out[..., :seqlen]
65
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
66
+
67
+
68
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
69
+ """
70
+ x: (batch, dim)
71
+ conv_state: (batch, dim, width)
72
+ weight: (dim, width)
73
+ bias: (dim,)
74
+
75
+ out: (batch, dim)
76
+ """
77
+ if activation not in [None, "silu", "swish"]:
78
+ raise NotImplementedError("activation must be None, silu, or swish")
79
+ activation = activation in ["silu", "swish"]
80
+ return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
81
+
82
+
83
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
84
+ """
85
+ x: (batch, dim)
86
+ conv_state: (batch, dim, width)
87
+ weight: (dim, width)
88
+ bias: (dim,)
89
+
90
+ out: (batch, dim)
91
+ """
92
+ if activation not in [None, "silu", "swish"]:
93
+ raise NotImplementedError("activation must be None, silu, or swish")
94
+ dtype_in = x.dtype
95
+ batch, dim = x.shape
96
+ width = weight.shape[1]
97
+ assert conv_state.shape == (batch, dim, width)
98
+ assert weight.shape == (dim, width)
99
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
100
+ conv_state[:, :, -1] = x
101
+ out = torch.sum(conv_state * weight, dim=-1) # (B D)
102
+ if bias is not None:
103
+ out += bias
104
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
causal-conv1d/csrc/causal_conv1d.cpp ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "causal_conv1d.h"
11
+
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+
14
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
+ if (ITYPE == at::ScalarType::Half) { \
16
+ using input_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
19
+ using input_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (ITYPE == at::ScalarType::Float) { \
22
+ using input_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
+ }
27
+
28
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
+ if (WTYPE == at::ScalarType::Half) { \
30
+ using weight_t = at::Half; \
31
+ __VA_ARGS__(); \
32
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
33
+ using weight_t = at::BFloat16; \
34
+ __VA_ARGS__(); \
35
+ } else if (WTYPE == at::ScalarType::Float) { \
36
+ using weight_t = float; \
37
+ __VA_ARGS__(); \
38
+ } else { \
39
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
+ }
41
+
42
+ template<typename input_t, typename weight_t>
43
+ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
44
+ template <typename input_t, typename weight_t>
45
+ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
46
+
47
+ template<typename input_t, typename weight_t>
48
+ void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
49
+ template<typename input_t, typename weight_t>
50
+ void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
51
+
52
+ template<typename input_t, typename weight_t>
53
+ void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
54
+
55
+ void set_conv_params_fwd(ConvParamsBase &params,
56
+ // sizes
57
+ const size_t batch,
58
+ const size_t dim,
59
+ const size_t seqlen,
60
+ const size_t width,
61
+ // device pointers
62
+ const at::Tensor x,
63
+ const at::Tensor weight,
64
+ const at::Tensor out,
65
+ void* bias_ptr,
66
+ bool silu_activation) {
67
+
68
+ // Reset the parameters
69
+ memset(&params, 0, sizeof(params));
70
+
71
+ params.batch = batch;
72
+ params.dim = dim;
73
+ params.seqlen = seqlen;
74
+ params.width = width;
75
+
76
+ params.silu_activation = silu_activation;
77
+
78
+ // Set the pointers and strides.
79
+ params.x_ptr = x.data_ptr();
80
+ params.weight_ptr = weight.data_ptr();
81
+ params.bias_ptr = bias_ptr;
82
+ params.out_ptr = out.data_ptr();
83
+ // All stride are in elements, not bytes.
84
+ params.x_batch_stride = x.stride(0);
85
+ params.x_c_stride = x.stride(1);
86
+ params.x_l_stride = x.stride(-1);
87
+ params.weight_c_stride = weight.stride(0);
88
+ params.weight_width_stride = weight.stride(1);
89
+ params.out_batch_stride = out.stride(0);
90
+ params.out_c_stride = out.stride(1);
91
+ params.out_l_stride = out.stride(-1);
92
+ }
93
+
94
+
95
+ void set_conv_params_bwd(ConvParamsBwd &params,
96
+ // sizes
97
+ const size_t batch,
98
+ const size_t dim,
99
+ const size_t seqlen,
100
+ const size_t width,
101
+ // device pointers
102
+ const at::Tensor x,
103
+ const at::Tensor weight,
104
+ void* bias_ptr,
105
+ const at::Tensor dout,
106
+ const at::Tensor dx,
107
+ const at::Tensor dweight,
108
+ void* dbias_ptr,
109
+ bool silu_activation) {
110
+ // Pass in "dout" instead of "out", we're not gonna use "out" at all.
111
+ set_conv_params_fwd(params, batch, dim, seqlen, width,
112
+ x, weight, dout, bias_ptr, silu_activation);
113
+
114
+ // Set the pointers and strides.
115
+ params.dout_ptr = dout.data_ptr();
116
+ params.dx_ptr = dx.data_ptr();
117
+ params.dweight_ptr = dweight.data_ptr();
118
+ params.dbias_ptr = dbias_ptr;
119
+ // All stride are in elements, not bytes.
120
+ params.dout_batch_stride = dout.stride(0);
121
+ params.dout_c_stride = dout.stride(1);
122
+ params.dout_l_stride = dout.stride(2);
123
+ params.dweight_c_stride = dweight.stride(0);
124
+ params.dweight_width_stride = dweight.stride(1);
125
+ params.dx_batch_stride = dx.stride(0);
126
+ params.dx_c_stride = dx.stride(1);
127
+ params.dx_l_stride = dx.stride(2);
128
+ }
129
+
130
+ at::Tensor
131
+ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
132
+ const c10::optional<at::Tensor> &bias_,
133
+ bool silu_activation) {
134
+ auto input_type = x.scalar_type();
135
+ auto weight_type = weight.scalar_type();
136
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
137
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
138
+
139
+ TORCH_CHECK(x.is_cuda());
140
+ TORCH_CHECK(weight.is_cuda());
141
+
142
+ const auto sizes = x.sizes();
143
+ const int batch_size = sizes[0];
144
+ const int dim = sizes[1];
145
+ const int seqlen = sizes[2];
146
+ const int width = weight.size(-1);
147
+
148
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
149
+ CHECK_SHAPE(weight, dim, width);
150
+
151
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
152
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
153
+
154
+ if (is_channel_last) {
155
+ TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
156
+ }
157
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
158
+
159
+
160
+ if (bias_.has_value()) {
161
+ auto bias = bias_.value();
162
+ TORCH_CHECK(bias.scalar_type() == weight_type);
163
+ TORCH_CHECK(bias.is_cuda());
164
+ TORCH_CHECK(bias.stride(-1) == 1);
165
+ CHECK_SHAPE(bias, dim);
166
+ }
167
+
168
+ at::Tensor out = torch::empty_like(x);
169
+
170
+ ConvParamsBase params;
171
+ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
172
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
173
+ silu_activation);
174
+
175
+ // Otherwise the kernel will be launched from cuda:0 device
176
+ // Cast to char to avoid compiler warning about narrowing
177
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
178
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
179
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
180
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
181
+ if (!is_channel_last) {
182
+ causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
183
+ } else {
184
+ causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
185
+ }
186
+ });
187
+ });
188
+ return out;
189
+ }
190
+
191
+ std::vector<at::Tensor>
192
+ causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
193
+ const c10::optional<at::Tensor> &bias_,
194
+ at::Tensor &dout,
195
+ c10::optional<at::Tensor> &dx_,
196
+ bool silu_activation) {
197
+ auto input_type = x.scalar_type();
198
+ auto weight_type = weight.scalar_type();
199
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
200
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
201
+
202
+ TORCH_CHECK(x.is_cuda());
203
+ TORCH_CHECK(weight.is_cuda());
204
+ TORCH_CHECK(dout.is_cuda());
205
+
206
+ const auto sizes = x.sizes();
207
+ const int batch_size = sizes[0];
208
+ const int dim = sizes[1];
209
+ const int seqlen = sizes[2];
210
+ const int width = weight.size(-1);
211
+
212
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
213
+
214
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
215
+ CHECK_SHAPE(weight, dim, width);
216
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
217
+
218
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
219
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
220
+ if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
221
+ if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
222
+
223
+ if (bias_.has_value()) {
224
+ auto bias = bias_.value();
225
+ TORCH_CHECK(bias.scalar_type() == weight_type);
226
+ TORCH_CHECK(bias.is_cuda());
227
+ TORCH_CHECK(bias.stride(-1) == 1);
228
+ CHECK_SHAPE(bias, dim);
229
+ }
230
+
231
+ at::Tensor dx;
232
+ if (dx_.has_value()) {
233
+ dx = dx_.value();
234
+ TORCH_CHECK(dx.scalar_type() == input_type);
235
+ TORCH_CHECK(dx.is_cuda());
236
+ CHECK_SHAPE(dx, batch_size, dim, seqlen);
237
+ if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
238
+ if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
239
+ } else {
240
+ dx = torch::empty_like(x);
241
+ }
242
+
243
+ // Otherwise the kernel will be launched from cuda:0 device
244
+ // Cast to char to avoid compiler warning about narrowing
245
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
246
+
247
+ at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
248
+ at::Tensor dbias;
249
+ if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
250
+
251
+ ConvParamsBwd params;
252
+ set_conv_params_bwd(params, batch_size, dim, seqlen, width,
253
+ x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
254
+ dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
255
+ silu_activation);
256
+
257
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
258
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
259
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
260
+ if (!is_channel_last) {
261
+ causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
262
+ } else {
263
+ causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
264
+ }
265
+ });
266
+ });
267
+ return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias};
268
+ }
269
+
270
+ at::Tensor
271
+ causal_conv1d_update(const at::Tensor &x,
272
+ const at::Tensor &conv_state,
273
+ const at::Tensor &weight,
274
+ const c10::optional<at::Tensor> &bias_,
275
+ bool silu_activation) {
276
+ auto input_type = x.scalar_type();
277
+ auto weight_type = weight.scalar_type();
278
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
279
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
280
+ TORCH_CHECK(conv_state.scalar_type() == input_type);
281
+
282
+ TORCH_CHECK(x.is_cuda());
283
+ TORCH_CHECK(conv_state.is_cuda());
284
+ TORCH_CHECK(weight.is_cuda());
285
+
286
+ const auto sizes = x.sizes();
287
+ const int batch_size = sizes[0];
288
+ const int dim = sizes[1];
289
+ const int width = weight.size(-1);
290
+
291
+ CHECK_SHAPE(x, batch_size, dim);
292
+ CHECK_SHAPE(conv_state, batch_size, dim, width);
293
+ CHECK_SHAPE(weight, dim, width);
294
+
295
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
296
+
297
+ if (bias_.has_value()) {
298
+ auto bias = bias_.value();
299
+ TORCH_CHECK(bias.scalar_type() == weight_type);
300
+ TORCH_CHECK(bias.is_cuda());
301
+ TORCH_CHECK(bias.stride(-1) == 1);
302
+ CHECK_SHAPE(bias, dim);
303
+ }
304
+
305
+ at::Tensor out = torch::empty_like(x);
306
+
307
+ ConvParamsBase params;
308
+ set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
309
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
310
+ silu_activation);
311
+ params.conv_state_ptr = conv_state.data_ptr();
312
+ // All stride are in elements, not bytes.
313
+ params.conv_state_batch_stride = conv_state.stride(0);
314
+ params.conv_state_c_stride = conv_state.stride(1);
315
+ params.conv_state_l_stride = conv_state.stride(2);
316
+
317
+ // Otherwise the kernel will be launched from cuda:0 device
318
+ // Cast to char to avoid compiler warning about narrowing
319
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
320
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
321
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
322
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
323
+ causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
324
+ });
325
+ });
326
+ return out;
327
+ }
328
+
329
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
330
+ m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
331
+ m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
332
+ m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
333
+ }
causal-conv1d/csrc/causal_conv1d.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct ConvParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, dim, seqlen, width;
13
+ bool silu_activation;
14
+
15
+ index_t x_batch_stride;
16
+ index_t x_c_stride;
17
+ index_t x_l_stride;
18
+ index_t weight_c_stride;
19
+ index_t weight_width_stride;
20
+ index_t out_batch_stride;
21
+ index_t out_c_stride;
22
+ index_t out_l_stride;
23
+
24
+ index_t conv_state_batch_stride;
25
+ index_t conv_state_c_stride;
26
+ index_t conv_state_l_stride;
27
+
28
+ // Common data pointers.
29
+ void *__restrict__ x_ptr;
30
+ void *__restrict__ weight_ptr;
31
+ void *__restrict__ bias_ptr;
32
+ void *__restrict__ out_ptr;
33
+
34
+ void *__restrict__ conv_state_ptr;
35
+ };
36
+
37
+ struct ConvParamsBwd: public ConvParamsBase {
38
+ index_t dx_batch_stride;
39
+ index_t dx_c_stride;
40
+ index_t dx_l_stride;
41
+ index_t dweight_c_stride;
42
+ index_t dweight_width_stride;
43
+ index_t dout_batch_stride;
44
+ index_t dout_c_stride;
45
+ index_t dout_l_stride;
46
+
47
+ // Common data pointers.
48
+ void *__restrict__ dx_ptr;
49
+ void *__restrict__ dweight_ptr;
50
+ void *__restrict__ dbias_ptr;
51
+ void *__restrict__ dout_ptr;
52
+ };
53
+
causal-conv1d/csrc/causal_conv1d_bwd.cu ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #include <cub/block/block_load.cuh>
10
+ #include <cub/block/block_store.cuh>
11
+ #include <cub/block/block_reduce.cuh>
12
+
13
+ #include "causal_conv1d.h"
14
+ #include "causal_conv1d_common.h"
15
+ #include "static_switch.h"
16
+
17
+ template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
18
+ struct Causal_conv1d_bwd_kernel_traits {
19
+ using input_t = input_t_;
20
+ using weight_t = weight_t_;
21
+ static constexpr int kNThreads = kNThreads_;
22
+ static constexpr int kWidth = kWidth_;
23
+ static constexpr bool kSiluAct = kSiluAct_;
24
+ static constexpr int kNBytes = sizeof(input_t);
25
+ static_assert(kNBytes == 2 || kNBytes == 4);
26
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
27
+ static_assert(kWidth <= kNElts);
28
+ // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
29
+ // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
30
+ static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
31
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
32
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
33
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
34
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
35
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
36
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
37
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
38
+ static constexpr int kSmemIOSize = kIsVecLoad
39
+ ? 0
40
+ : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
41
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
42
+ static constexpr int kSmemSize = std::max({kSmemExchangeSize,
43
+ int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
44
+ };
45
+
46
+ template<typename Ktraits>
47
+ __global__ __launch_bounds__(Ktraits::kNThreads)
48
+ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
49
+ constexpr int kWidth = Ktraits::kWidth;
50
+ constexpr int kNThreads = Ktraits::kNThreads;
51
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
52
+ constexpr int kNElts = Ktraits::kNElts;
53
+ constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
54
+ constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
55
+ using input_t = typename Ktraits::input_t;
56
+ using vec_t = typename Ktraits::vec_t;
57
+ using weight_t = typename Ktraits::weight_t;
58
+
59
+ // Shared memory.
60
+ extern __shared__ char smem_[];
61
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
62
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
63
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
64
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
65
+ vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
66
+ vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
67
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
68
+
69
+ const int tidx = threadIdx.x;
70
+ const int batch_id = blockIdx.x;
71
+ const int dim_id = blockIdx.y;
72
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
73
+ + dim_id * params.x_c_stride;
74
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
75
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
76
+ + dim_id * params.dout_c_stride;
77
+ input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
78
+ + dim_id * params.dx_c_stride;
79
+ float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
80
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
81
+
82
+ // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
83
+ if (tidx == 0) {
84
+ if constexpr (!kSiluAct) {
85
+ input_t zeros[kNElts] = {0};
86
+ smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
87
+ } else {
88
+ float zeros[kNElts] = {0};
89
+ #pragma unroll
90
+ for (int r = 0; r < kNExchangeRounds; ++r) {
91
+ smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
92
+ }
93
+ }
94
+ }
95
+
96
+ float weight_vals[kWidth];
97
+ #pragma unroll
98
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
99
+
100
+ float dweight_vals[kWidth] = {0};
101
+ float dbias_val = 0;
102
+
103
+ constexpr int kChunkSize = kNThreads * kNElts;
104
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
105
+ x += (n_chunks - 1) * kChunkSize;
106
+ dout += (n_chunks - 1) * kChunkSize;
107
+ dx += (n_chunks - 1) * kChunkSize;
108
+ for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
109
+ input_t x_vals_load[2 * kNElts] = {0};
110
+ input_t dout_vals_load[2 * kNElts] = {0};
111
+ if constexpr(kIsVecLoad) {
112
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
113
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
114
+ } else {
115
+ __syncthreads();
116
+ Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
117
+ __syncthreads();
118
+ Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
119
+ }
120
+ float dout_vals[2 * kNElts], x_vals[2 * kNElts];
121
+ if constexpr (!kSiluAct) {
122
+ __syncthreads();
123
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
124
+ // the first elements of the next chunk.
125
+ if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
126
+ __syncthreads();
127
+ reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
128
+ __syncthreads();
129
+ // Now thread 0 can write the first elements of the current chunk.
130
+ if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
131
+ #pragma unroll
132
+ for (int i = 0; i < 2 * kNElts; ++i) {
133
+ dout_vals[i] = float(dout_vals_load[i]);
134
+ x_vals[i] = float(x_vals_load[i]);
135
+ }
136
+ } else {
137
+ if (tidx == 0 && chunk > 0) {
138
+ if constexpr(kIsVecLoad) {
139
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
140
+ } else {
141
+ #pragma unroll
142
+ for (int i = 0; i < kNElts; ++i) {
143
+ if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
144
+ }
145
+ }
146
+ }
147
+ __syncthreads();
148
+ smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
149
+ __syncthreads();
150
+ if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
151
+ #pragma unroll
152
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
153
+ // Recompute the output
154
+ #pragma unroll
155
+ for (int i = 0; i < kNElts; ++i) {
156
+ float out_val = bias_val;
157
+ #pragma unroll
158
+ for (int w = 0; w < kWidth; ++w) {
159
+ out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
160
+ }
161
+ float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
162
+ dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
163
+ * (1.0f + out_val * (1.0f - out_sigmoid_val));
164
+ }
165
+ // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
166
+ // if input_t is 16 bits (since then we'd have 8 values of float)
167
+ __syncthreads();
168
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
169
+ // the first elements of the next chunk.
170
+ if (tidx > 0) {
171
+ #pragma unroll
172
+ for (int r = 0; r < kNExchangeRounds; ++r) {
173
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
174
+ }
175
+ }
176
+ __syncthreads();
177
+ #pragma unroll
178
+ for (int r = 0; r < kNExchangeRounds; ++r) {
179
+ reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
180
+ = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
181
+ }
182
+ __syncthreads();
183
+ // Now thread 0 can write the first elements of the current chunk.
184
+ if (tidx == 0) {
185
+ #pragma unroll
186
+ for (int r = 0; r < kNExchangeRounds; ++r) {
187
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
188
+ }
189
+ }
190
+ }
191
+ dout -= kChunkSize;
192
+ x -= kChunkSize;
193
+
194
+ #pragma unroll
195
+ for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
196
+
197
+ float dx_vals[kNElts] = {0};
198
+ #pragma unroll
199
+ for (int i = 0; i < kNElts; ++i) {
200
+ #pragma unroll
201
+ for (int w = 0; w < kWidth; ++w) {
202
+ dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
203
+ }
204
+ }
205
+
206
+ input_t dx_vals_store[kNElts];
207
+ #pragma unroll
208
+ for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
209
+ if constexpr(kIsVecLoad) {
210
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
211
+ } else {
212
+ Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
213
+ }
214
+ dx -= kChunkSize;
215
+
216
+ #pragma unroll
217
+ for (int w = 0; w < kWidth; ++w) {
218
+ #pragma unroll
219
+ for (int i = 0; i < kNElts; ++i) {
220
+ dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
221
+ }
222
+ }
223
+ }
224
+
225
+ #pragma unroll
226
+ for (int w = 0; w < kWidth; ++w) {
227
+ __syncthreads();
228
+ dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
229
+ if (tidx == 0) {
230
+ atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
231
+ }
232
+ }
233
+ if (params.bias_ptr != nullptr) {
234
+ __syncthreads();
235
+ dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
236
+ if (tidx == 0) {
237
+ atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
238
+ }
239
+ }
240
+ }
241
+
242
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
243
+ void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
244
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
245
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
246
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
247
+ using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
248
+ constexpr int kSmemSize = Ktraits::kSmemSize;
249
+ dim3 grid(params.batch, params.dim);
250
+ auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
251
+ if (kSmemSize >= 48 * 1024) {
252
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
253
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
254
+ }
255
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
256
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
257
+ });
258
+ });
259
+ }
260
+
261
+ template<typename input_t, typename weight_t>
262
+ void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
263
+ if (params.width == 2) {
264
+ causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
265
+ } else if (params.width == 3) {
266
+ causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
267
+ } else if (params.width == 4) {
268
+ causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
269
+ }
270
+ }
271
+
272
+ template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
273
+ struct Causal_conv1d_channellast_bwd_kernel_traits {
274
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
275
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
276
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
277
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
278
+ using input_t = input_t_;
279
+ using weight_t = weight_t_;
280
+ static constexpr bool kSiluAct = kSiluAct_;
281
+ static constexpr int kNThreads = kNThreads_;
282
+ static_assert(kNThreads % 32 == 0);
283
+ static constexpr int kNWarps = kNThreads / 32;
284
+ static constexpr int kWidth = kWidth_;
285
+ static constexpr int kChunkSizeL = kChunkSizeL_;
286
+ static constexpr int kNBytes = sizeof(input_t);
287
+ static_assert(kNBytes == 2 || kNBytes == 4);
288
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
289
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
290
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
291
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
292
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
293
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
294
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
295
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
296
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
297
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
298
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
299
+ // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
300
+ // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
301
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
302
+ // sizeof(typename BlockStoreT::TempStorage)});
303
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
304
+ };
305
+
306
+ template<typename Ktraits>
307
+ __global__ __launch_bounds__(Ktraits::kNThreads)
308
+ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
309
+ constexpr int kWidth = Ktraits::kWidth;
310
+ constexpr int kNThreads = Ktraits::kNThreads;
311
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
312
+ constexpr int kNElts = Ktraits::kNElts;
313
+ constexpr int kNWarp = Ktraits::kNWarps;
314
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
315
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
316
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
317
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
318
+ using input_t = typename Ktraits::input_t;
319
+ using vec_t = typename Ktraits::vec_t;
320
+ using weight_t = typename Ktraits::weight_t;
321
+
322
+ // Shared memory.
323
+ __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
324
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
325
+
326
+ const int tid = threadIdx.x;
327
+ const int l_idx = tid / kNThreadsPerC;
328
+ const int c_idx = tid % kNThreadsPerC;
329
+ const int batch_id = blockIdx.x;
330
+ const int chunk_l_id = blockIdx.y;
331
+ const int chunk_c_id = blockIdx.z;
332
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
333
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
334
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
335
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
336
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
337
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
338
+ input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
339
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
340
+ float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
341
+ + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
342
+
343
+ #pragma unroll
344
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
345
+ input_t dout_vals_load[kNElts] = {0};
346
+ input_t x_vals_load[kNElts] = {0};
347
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
348
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
349
+ reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
350
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
351
+ }
352
+ reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
353
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
354
+ }
355
+ // Load the elements from the previous chunk or next chunk that are needed for convolution.
356
+ if (l_idx < kWidth - 1) {
357
+ input_t dout_vals_load[kNElts] = {0};
358
+ input_t x_vals_load[kNElts] = {0};
359
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
360
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
361
+ reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
362
+ }
363
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
364
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
365
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
366
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
367
+ }
368
+ reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
369
+ reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
370
+ }
371
+ // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
372
+ if constexpr (kSiluAct) {
373
+ if (l_idx < kWidth - 1) {
374
+ input_t x_vals_load[kNElts] = {0};
375
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
376
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
377
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
378
+ }
379
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
380
+ }
381
+ }
382
+
383
+ __syncthreads();
384
+
385
+ constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
386
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
387
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
388
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
389
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
390
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
391
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
392
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
393
+ static_assert(kNThreadsPerRow <= 32);
394
+
395
+ const int row_idx = tid / kNThreadsPerRow;
396
+ const int col_idx = tid % kNThreadsPerRow;
397
+
398
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
399
+ float weight_vals[kWidth] = {0};
400
+ if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
401
+ #pragma unroll
402
+ for (int w = 0; w < kWidth; ++w) {
403
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
404
+ }
405
+ }
406
+ float dout_vals[kLPerThread + kWidth - 1];
407
+ float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
408
+ #pragma unroll
409
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
410
+ dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
411
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
412
+ }
413
+
414
+ if constexpr (kSiluAct) { // Recompute the output
415
+ #pragma unroll
416
+ for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
417
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
418
+ }
419
+ #pragma unroll
420
+ for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
421
+ float out_val = bias_val;
422
+ #pragma unroll
423
+ for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; }
424
+ float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
425
+ dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
426
+ }
427
+ }
428
+
429
+ float dweight_vals[kWidth] = {0};
430
+ SumOp<float> sum_op;
431
+ #pragma unroll
432
+ for (int w = 0; w < kWidth; ++w) {
433
+ #pragma unroll
434
+ for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; }
435
+ dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
436
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
437
+ atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
438
+ }
439
+ }
440
+
441
+ if (params.bias_ptr != nullptr) {
442
+ float dbias_val = 0.f;
443
+ for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
444
+ dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
445
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
446
+ atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
447
+ }
448
+ }
449
+
450
+ float dx_vals[kLPerThread] = {0};
451
+ #pragma unroll
452
+ for (int i = 0; i < kLPerThread; ++i) {
453
+ #pragma unroll
454
+ for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; }
455
+ }
456
+ // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
457
+ __syncwarp();
458
+ #pragma unroll
459
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
460
+ __syncthreads();
461
+
462
+ #pragma unroll
463
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
464
+ input_t dx_vals_store[kNElts];
465
+ reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
466
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
467
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
468
+ *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
469
+ }
470
+ }
471
+
472
+ }
473
+
474
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
475
+ void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
476
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
477
+ using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, 64, kSiluAct, true, input_t, weight_t>;
478
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
479
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
480
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
481
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
482
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
483
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
484
+ dim3 block(Ktraits::kNThreads);
485
+ auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits>;
486
+ // if (kSmemSize >= 48 * 1024) {
487
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
488
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
489
+ // }
490
+ // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
491
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
492
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
493
+ });
494
+ }
495
+
496
+ template<typename input_t, typename weight_t>
497
+ void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
498
+ if (params.width == 2) {
499
+ causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
500
+ } else if (params.width == 3) {
501
+ causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
502
+ } else if (params.width == 4) {
503
+ causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
504
+ }
505
+ }
506
+
507
+ template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
508
+ template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
509
+ template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
510
+ template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
511
+ template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
512
+ template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
513
+ template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
514
+ template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
515
+ template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
516
+
517
+ template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
518
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
519
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
520
+ template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
521
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
522
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
523
+ template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
524
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
525
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
causal-conv1d/csrc/causal_conv1d_common.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cuda_bf16.h>
8
+ #include <cuda_fp16.h>
9
+
10
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
11
+
12
+ template<int BYTES> struct BytesToType {};
13
+
14
+ template<> struct BytesToType<16> {
15
+ using Type = uint4;
16
+ static_assert(sizeof(Type) == 16);
17
+ };
18
+
19
+ template<> struct BytesToType<8> {
20
+ using Type = uint64_t;
21
+ static_assert(sizeof(Type) == 8);
22
+ };
23
+
24
+ template<> struct BytesToType<4> {
25
+ using Type = uint32_t;
26
+ static_assert(sizeof(Type) == 4);
27
+ };
28
+
29
+ template<> struct BytesToType<2> {
30
+ using Type = uint16_t;
31
+ static_assert(sizeof(Type) == 2);
32
+ };
33
+
34
+ template<> struct BytesToType<1> {
35
+ using Type = uint8_t;
36
+ static_assert(sizeof(Type) == 1);
37
+ };
38
+
39
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ template<typename T>
42
+ struct SumOp {
43
+ __device__ inline T operator()(T const & x, T const & y) { return x + y; }
44
+ };
45
+
46
+ template<int THREADS>
47
+ struct Allreduce {
48
+ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
49
+ template<typename T, typename Operator>
50
+ static __device__ inline T run(T x, Operator &op) {
51
+ constexpr int OFFSET = THREADS / 2;
52
+ x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
53
+ return Allreduce<OFFSET>::run(x, op);
54
+ }
55
+ };
56
+
57
+ template<>
58
+ struct Allreduce<2> {
59
+ template<typename T, typename Operator>
60
+ static __device__ inline T run(T x, Operator &op) {
61
+ x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
62
+ return x;
63
+ }
64
+ };
causal-conv1d/csrc/causal_conv1d_fwd.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #include <cub/block/block_load.cuh>
10
+ #include <cub/block/block_store.cuh>
11
+
12
+ #include "causal_conv1d.h"
13
+ #include "causal_conv1d_common.h"
14
+ #include "static_switch.h"
15
+
16
+ template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
17
+ struct Causal_conv1d_fwd_kernel_traits {
18
+ using input_t = input_t_;
19
+ using weight_t = weight_t_;
20
+ static constexpr int kNThreads = kNThreads_;
21
+ static constexpr int kWidth = kWidth_;
22
+ static constexpr int kNBytes = sizeof(input_t);
23
+ static_assert(kNBytes == 2 || kNBytes == 4);
24
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
25
+ static_assert(kWidth <= kNElts);
26
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
27
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
28
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
29
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
30
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
31
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
32
+ static constexpr int kSmemIOSize = kIsVecLoad
33
+ ? 0
34
+ : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
35
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
36
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
37
+ };
38
+
39
+ template<typename Ktraits>
40
+ __global__ __launch_bounds__(Ktraits::kNThreads)
41
+ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
42
+ constexpr int kWidth = Ktraits::kWidth;
43
+ constexpr int kNThreads = Ktraits::kNThreads;
44
+ constexpr int kNElts = Ktraits::kNElts;
45
+ constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
46
+ using input_t = typename Ktraits::input_t;
47
+ using vec_t = typename Ktraits::vec_t;
48
+ using weight_t = typename Ktraits::weight_t;
49
+
50
+ // Shared memory.
51
+ extern __shared__ char smem_[];
52
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
53
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
54
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
55
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
56
+ vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
57
+
58
+ const int tidx = threadIdx.x;
59
+ const int batch_id = blockIdx.x;
60
+ const int channel_id = blockIdx.y;
61
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
62
+ + channel_id * params.x_c_stride;
63
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
64
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
65
+ + channel_id * params.out_c_stride;
66
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
67
+
68
+ // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
69
+ if (tidx == 0) {
70
+ input_t zeros[kNElts] = {0};
71
+ smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
72
+ }
73
+
74
+ float weight_vals[kWidth];
75
+ #pragma unroll
76
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
77
+
78
+ constexpr int kChunkSize = kNThreads * kNElts;
79
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
80
+ for (int chunk = 0; chunk < n_chunks; ++chunk) {
81
+ input_t x_vals_load[2 * kNElts] = {0};
82
+ if constexpr(kIsVecLoad) {
83
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
84
+ } else {
85
+ __syncthreads();
86
+ Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
87
+ }
88
+ x += kChunkSize;
89
+ __syncthreads();
90
+ // Thread kNThreads - 1 don't write yet, so that thread 0 can read
91
+ // the last elements of the previous chunk.
92
+ if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
93
+ __syncthreads();
94
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
95
+ __syncthreads();
96
+ // Now thread kNThreads - 1 can write the last elements of the current chunk.
97
+ if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
98
+
99
+ float x_vals[2 * kNElts];
100
+ #pragma unroll
101
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
102
+
103
+ float out_vals[kNElts];
104
+ #pragma unroll
105
+ for (int i = 0; i < kNElts; ++i) {
106
+ out_vals[i] = bias_val;
107
+ #pragma unroll
108
+ for (int w = 0; w < kWidth; ++w) {
109
+ out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
110
+ }
111
+ }
112
+
113
+ if (params.silu_activation) {
114
+ #pragma unroll
115
+ for (int i = 0; i < kNElts; ++i) {
116
+ out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
117
+ }
118
+ }
119
+
120
+ input_t out_vals_store[kNElts];
121
+ #pragma unroll
122
+ for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
123
+ if constexpr(kIsVecLoad) {
124
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
125
+ } else {
126
+ Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
127
+ }
128
+ out += kChunkSize;
129
+ }
130
+ }
131
+
132
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
133
+ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
134
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
135
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
136
+ using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
137
+ constexpr int kSmemSize = Ktraits::kSmemSize;
138
+ dim3 grid(params.batch, params.dim);
139
+ auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
140
+ if (kSmemSize >= 48 * 1024) {
141
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
142
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
143
+ }
144
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
145
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
146
+ });
147
+ }
148
+
149
+ template<typename input_t, typename weight_t>
150
+ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
151
+ if (params.width == 2) {
152
+ causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
153
+ } else if (params.width == 3) {
154
+ causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
155
+ } else if (params.width == 4) {
156
+ causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
157
+ }
158
+ }
159
+
160
+ template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
161
+ struct Causal_conv1d_channellast_fwd_kernel_traits {
162
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
163
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
164
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
165
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
166
+ using input_t = input_t_;
167
+ using weight_t = weight_t_;
168
+ static constexpr int kNThreads = kNThreads_;
169
+ static_assert(kNThreads % 32 == 0);
170
+ static constexpr int kNWarps = kNThreads / 32;
171
+ static constexpr int kWidth = kWidth_;
172
+ static constexpr int kChunkSizeL = kChunkSizeL_;
173
+ static constexpr int kNBytes = sizeof(input_t);
174
+ static_assert(kNBytes == 2 || kNBytes == 4);
175
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
176
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
177
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
178
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
179
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
180
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
181
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
182
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
183
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
184
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
185
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
186
+ // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
187
+ // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
188
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
189
+ // sizeof(typename BlockStoreT::TempStorage)});
190
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
191
+ };
192
+
193
+ template<typename Ktraits>
194
+ __global__ __launch_bounds__(Ktraits::kNThreads)
195
+ void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
196
+ constexpr int kWidth = Ktraits::kWidth;
197
+ constexpr int kNThreads = Ktraits::kNThreads;
198
+ constexpr int kNElts = Ktraits::kNElts;
199
+ constexpr int kNWarp = Ktraits::kNWarps;
200
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
201
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
202
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
203
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
204
+ using input_t = typename Ktraits::input_t;
205
+ using vec_t = typename Ktraits::vec_t;
206
+ using weight_t = typename Ktraits::weight_t;
207
+
208
+ // Shared memory.
209
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
210
+
211
+ const int tid = threadIdx.x;
212
+ const int l_idx = tid / kNThreadsPerC;
213
+ const int c_idx = tid % kNThreadsPerC;
214
+ const int batch_id = blockIdx.x;
215
+ const int chunk_l_id = blockIdx.y;
216
+ const int chunk_c_id = blockIdx.z;
217
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
218
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
219
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
220
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
221
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
222
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
223
+
224
+ #pragma unroll
225
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
226
+ input_t x_vals_load[kNElts] = {0};
227
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
228
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
229
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
230
+ }
231
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
232
+ }
233
+ // Load the elements from the previous chunk that are needed for convolution.
234
+ if (l_idx < kWidth - 1) {
235
+ input_t x_vals_load[kNElts] = {0};
236
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
237
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
238
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
239
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
240
+ }
241
+ reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
242
+ }
243
+
244
+ __syncthreads();
245
+
246
+ constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
247
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
248
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
249
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
250
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
251
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
252
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
253
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
254
+ static_assert(kNThreadsPerRow <= 32);
255
+
256
+ const int row_idx = tid / kNThreadsPerRow;
257
+ const int col_idx = tid % kNThreadsPerRow;
258
+
259
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
260
+ float weight_vals[kWidth] = {0};
261
+ if (chunk_c_id + kChunkSizeC + row_idx < params.dim) {
262
+ #pragma unroll
263
+ for (int w = 0; w < kWidth; ++w) {
264
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
265
+ }
266
+ }
267
+ float x_vals[kWidth - 1 + kLPerThread];
268
+ #pragma unroll
269
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
270
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
271
+ }
272
+
273
+ float out_vals[kLPerThread];
274
+ #pragma unroll
275
+ for (int i = 0; i < kLPerThread; ++i) {
276
+ out_vals[i] = bias_val;
277
+ #pragma unroll
278
+ for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; }
279
+ if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
280
+ }
281
+
282
+ // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
283
+ __syncwarp();
284
+ #pragma unroll
285
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
286
+ __syncthreads();
287
+
288
+ #pragma unroll
289
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
290
+ input_t out_vals_store[kNElts];
291
+ reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
292
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
293
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
294
+ *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
295
+ }
296
+ }
297
+
298
+ }
299
+
300
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
301
+ void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
302
+ using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
303
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
304
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
305
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
306
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
307
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
308
+ // printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C);
309
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
310
+ dim3 block(Ktraits::kNThreads);
311
+ auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits>;
312
+ // if (kSmemSize >= 48 * 1024) {
313
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
314
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
315
+ // }
316
+ // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
317
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
318
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
319
+ }
320
+
321
+ template<typename input_t, typename weight_t>
322
+ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
323
+ if (params.width == 2) {
324
+ causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
325
+ } else if (params.width == 3) {
326
+ causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
327
+ } else if (params.width == 4) {
328
+ causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
329
+ }
330
+ }
331
+
332
+ template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
333
+ template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
334
+ template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
335
+ template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
336
+ template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
337
+ template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
338
+ template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
339
+ template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
340
+ template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
341
+
342
+ template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
343
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
344
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
345
+ template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
346
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
347
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
348
+ template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
349
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
350
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
causal-conv1d/csrc/causal_conv1d_update.cu ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #include <cub/block/block_load.cuh>
10
+ #include <cub/block/block_store.cuh>
11
+
12
+ #include "causal_conv1d.h"
13
+ #include "causal_conv1d_common.h"
14
+ #include "static_switch.h"
15
+
16
+ template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
17
+ struct Causal_conv1d_update_kernel_traits {
18
+ using input_t = input_t_;
19
+ using weight_t = weight_t_;
20
+ static constexpr int kNThreads = kNThreads_;
21
+ static constexpr int kWidth = kWidth_;
22
+ static constexpr int kNBytes = sizeof(input_t);
23
+ static_assert(kNBytes == 2 || kNBytes == 4);
24
+ };
25
+
26
+ template<typename Ktraits>
27
+ __global__ __launch_bounds__(Ktraits::kNThreads)
28
+ void causal_conv1d_update_kernel(ConvParamsBase params) {
29
+ constexpr int kWidth = Ktraits::kWidth;
30
+ constexpr int kNThreads = Ktraits::kNThreads;
31
+ using input_t = typename Ktraits::input_t;
32
+ using weight_t = typename Ktraits::weight_t;
33
+
34
+ const int tidx = threadIdx.x;
35
+ const int batch_id = blockIdx.x;
36
+ const int channel_id = blockIdx.y * kNThreads + tidx;
37
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
38
+ + channel_id * params.x_c_stride;
39
+ input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
40
+ + channel_id * params.conv_state_c_stride;
41
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
42
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
43
+ + channel_id * params.out_c_stride;
44
+ float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
45
+
46
+ float weight_vals[kWidth] = {0};
47
+ if (channel_id < params.dim) {
48
+ #pragma unroll
49
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
50
+ }
51
+
52
+ float x_vals[kWidth] = {0};
53
+ if (channel_id < params.dim) {
54
+ #pragma unroll
55
+ for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
56
+ x_vals[kWidth - 1] = float(x[0]);
57
+ #pragma unroll
58
+ for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
59
+ }
60
+
61
+ float out_val = bias_val;
62
+ #pragma unroll
63
+ for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
64
+ if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
65
+ if (channel_id < params.dim) { out[0] = input_t(out_val); }
66
+ }
67
+
68
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
69
+ void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
70
+ using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
71
+ dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
72
+ auto kernel = &causal_conv1d_update_kernel<Ktraits>;
73
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
74
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
75
+ }
76
+
77
+ template<typename input_t, typename weight_t>
78
+ void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
79
+ if (params.width == 2) {
80
+ causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
81
+ } else if (params.width == 3) {
82
+ causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
83
+ } else if (params.width == 4) {
84
+ causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
85
+ }
86
+ }
87
+
88
+ template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
89
+ template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
90
+ template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
91
+ template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
92
+ template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
93
+ template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
94
+ template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
95
+ template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
96
+ template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
causal-conv1d/csrc/static_switch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
+ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
+
4
+ #pragma once
5
+
6
+ /// @param COND - a boolean expression to switch by
7
+ /// @param CONST_NAME - a name given for the constexpr bool variable.
8
+ /// @param ... - code to execute for true and false
9
+ ///
10
+ /// Usage:
11
+ /// ```
12
+ /// BOOL_SWITCH(flag, BoolConst, [&] {
13
+ /// some_function<BoolConst>(...);
14
+ /// });
15
+ /// ```
16
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
+ [&] { \
18
+ if (COND) { \
19
+ static constexpr bool CONST_NAME = true; \
20
+ return __VA_ARGS__(); \
21
+ } else { \
22
+ static constexpr bool CONST_NAME = false; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ }()
causal-conv1d/setup.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ import sys
3
+ import warnings
4
+ import os
5
+ import re
6
+ import ast
7
+ from pathlib import Path
8
+ from packaging.version import parse, Version
9
+ import platform
10
+
11
+ from setuptools import setup, find_packages
12
+ import subprocess
13
+
14
+ import urllib.request
15
+ import urllib.error
16
+ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
17
+
18
+ import torch
19
+ from torch.utils.cpp_extension import (
20
+ BuildExtension,
21
+ CppExtension,
22
+ CUDAExtension,
23
+ CUDA_HOME,
24
+ )
25
+
26
+
27
+ with open("README.md", "r", encoding="utf-8") as fh:
28
+ long_description = fh.read()
29
+
30
+
31
+ # ninja build does not work unless include_dirs are abs path
32
+ this_dir = os.path.dirname(os.path.abspath(__file__))
33
+
34
+ PACKAGE_NAME = "causal_conv1d"
35
+
36
+ BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
37
+
38
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
39
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
40
+ FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
41
+ SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
42
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
43
+ FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
44
+
45
+
46
+ def get_platform():
47
+ """
48
+ Returns the platform name as used in wheel filenames.
49
+ """
50
+ if sys.platform.startswith("linux"):
51
+ return "linux_x86_64"
52
+ elif sys.platform == "darwin":
53
+ mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
54
+ return f"macosx_{mac_version}_x86_64"
55
+ elif sys.platform == "win32":
56
+ return "win_amd64"
57
+ else:
58
+ raise ValueError("Unsupported platform: {}".format(sys.platform))
59
+
60
+
61
+ def get_cuda_bare_metal_version(cuda_dir):
62
+ raw_output = subprocess.check_output(
63
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
64
+ )
65
+ output = raw_output.split()
66
+ release_idx = output.index("release") + 1
67
+ bare_metal_version = parse(output[release_idx].split(",")[0])
68
+
69
+ return raw_output, bare_metal_version
70
+
71
+
72
+ def check_if_cuda_home_none(global_option: str) -> None:
73
+ if CUDA_HOME is not None:
74
+ return
75
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
76
+ # in that case.
77
+ warnings.warn(
78
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
79
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
80
+ "only images whose names contain 'devel' will provide nvcc."
81
+ )
82
+
83
+
84
+ def append_nvcc_threads(nvcc_extra_args):
85
+ return nvcc_extra_args + ["--threads", "4"]
86
+
87
+
88
+ cmdclass = {}
89
+ ext_modules = []
90
+
91
+ if not SKIP_CUDA_BUILD:
92
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
93
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
94
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
95
+
96
+ check_if_cuda_home_none("causal_conv1d")
97
+ # Check, if CUDA11 is installed for compute capability 8.0
98
+ cc_flag = []
99
+ if CUDA_HOME is not None:
100
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
101
+ if bare_metal_version < Version("11.6"):
102
+ raise RuntimeError(
103
+ "causal_conv1d is only supported on CUDA 11.6 and above. "
104
+ "Note: make sure nvcc has a supported version by running nvcc -V."
105
+ )
106
+
107
+ cc_flag.append("-gencode")
108
+ cc_flag.append("arch=compute_70,code=sm_70")
109
+ cc_flag.append("-gencode")
110
+ cc_flag.append("arch=compute_80,code=sm_80")
111
+ if bare_metal_version >= Version("11.8"):
112
+ cc_flag.append("-gencode")
113
+ cc_flag.append("arch=compute_90,code=sm_90")
114
+
115
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
116
+ # torch._C._GLIBCXX_USE_CXX11_ABI
117
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
118
+ if FORCE_CXX11_ABI:
119
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
120
+
121
+ ext_modules.append(
122
+ CUDAExtension(
123
+ name="causal_conv1d_cuda",
124
+ sources=[
125
+ "csrc/causal_conv1d.cpp",
126
+ "csrc/causal_conv1d_fwd.cu",
127
+ "csrc/causal_conv1d_bwd.cu",
128
+ "csrc/causal_conv1d_update.cu",
129
+ ],
130
+ extra_compile_args={
131
+ "cxx": ["-O3"],
132
+ "nvcc": append_nvcc_threads(
133
+ [
134
+ "-O3",
135
+ "-U__CUDA_NO_HALF_OPERATORS__",
136
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
137
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
138
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
139
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
140
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
141
+ "--expt-relaxed-constexpr",
142
+ "--expt-extended-lambda",
143
+ "--use_fast_math",
144
+ "--ptxas-options=-v",
145
+ "-lineinfo",
146
+ ]
147
+ + cc_flag
148
+ ),
149
+ },
150
+ include_dirs=[this_dir],
151
+ )
152
+ )
153
+
154
+
155
+ def get_package_version():
156
+ with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
157
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
158
+ public_version = ast.literal_eval(version_match.group(1))
159
+ local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
160
+ if local_version:
161
+ return f"{public_version}+{local_version}"
162
+ else:
163
+ return str(public_version)
164
+
165
+
166
+ def get_wheel_url():
167
+ # Determine the version numbers that will be used to determine the correct wheel
168
+ # We're using the CUDA version used to build torch, not the one currently installed
169
+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
170
+ torch_cuda_version = parse(torch.version.cuda)
171
+ torch_version_raw = parse(torch.__version__)
172
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
173
+ # to save CI time. Minor versions should be compatible.
174
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
175
+ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
176
+ platform_name = get_platform()
177
+ causal_conv1d_version = get_package_version()
178
+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
179
+ cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
180
+ torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
181
+ cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
182
+
183
+ # Determine wheel URL based on CUDA version, torch version, python version and OS
184
+ wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
185
+ wheel_url = BASE_WHEEL_URL.format(
186
+ tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
187
+ )
188
+ return wheel_url, wheel_filename
189
+
190
+
191
+ class CachedWheelsCommand(_bdist_wheel):
192
+ """
193
+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
194
+ find an existing wheel (which is currently the case for all installs). We use
195
+ the environment parameters to detect whether there is already a pre-built version of a compatible
196
+ wheel available and short-circuits the standard full build pipeline.
197
+ """
198
+
199
+ def run(self):
200
+ if FORCE_BUILD:
201
+ return super().run()
202
+
203
+ wheel_url, wheel_filename = get_wheel_url()
204
+ print("Guessing wheel URL: ", wheel_url)
205
+ try:
206
+ urllib.request.urlretrieve(wheel_url, wheel_filename)
207
+
208
+ # Make the archive
209
+ # Lifted from the root wheel processing command
210
+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
211
+ if not os.path.exists(self.dist_dir):
212
+ os.makedirs(self.dist_dir)
213
+
214
+ impl_tag, abi_tag, plat_tag = self.get_tag()
215
+ archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
216
+
217
+ wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
218
+ print("Raw wheel path", wheel_path)
219
+ os.rename(wheel_filename, wheel_path)
220
+ except urllib.error.HTTPError:
221
+ print("Precompiled wheel not found. Building from source...")
222
+ # If the wheel could not be downloaded, build from source
223
+ super().run()
224
+
225
+
226
+ setup(
227
+ name=PACKAGE_NAME,
228
+ version=get_package_version(),
229
+ packages=find_packages(
230
+ exclude=(
231
+ "build",
232
+ "csrc",
233
+ "include",
234
+ "tests",
235
+ "dist",
236
+ "docs",
237
+ "benchmarks",
238
+ "causal_conv1d.egg-info",
239
+ )
240
+ ),
241
+ author="Tri Dao",
242
+ author_email="[email protected]",
243
+ description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
244
+ long_description=long_description,
245
+ long_description_content_type="text/markdown",
246
+ url="https://github.com/Dao-AILab/causal-conv1d",
247
+ classifiers=[
248
+ "Programming Language :: Python :: 3",
249
+ "License :: OSI Approved :: BSD License",
250
+ "Operating System :: Unix",
251
+ ],
252
+ ext_modules=ext_modules,
253
+ cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
254
+ if ext_modules
255
+ else {
256
+ "bdist_wheel": CachedWheelsCommand,
257
+ },
258
+ python_requires=">=3.7",
259
+ install_requires=[
260
+ "torch",
261
+ "packaging",
262
+ "ninja",
263
+ ],
264
+ )
causal-conv1d/tests/test_causal_conv1d.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import pytest
7
+
8
+ from einops import rearrange
9
+
10
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
11
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
12
+
13
+
14
+ @pytest.mark.parametrize("channel_last", [False, True])
15
+ # @pytest.mark.parametrize('channel_last', [True])
16
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
17
+ # @pytest.mark.parametrize('itype', [torch.float16])
18
+ @pytest.mark.parametrize("silu_activation", [False, True])
19
+ # @pytest.mark.parametrize('silu_activation', [True])
20
+ @pytest.mark.parametrize("has_bias", [False, True])
21
+ # @pytest.mark.parametrize('has_bias', [True])
22
+ @pytest.mark.parametrize("width", [2, 3, 4])
23
+ # @pytest.mark.parametrize('width', [2])
24
+ @pytest.mark.parametrize(
25
+ "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
26
+ )
27
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
28
+ # @pytest.mark.parametrize('seqlen', [128])
29
+ def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
30
+ device = "cuda"
31
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
32
+ if itype == torch.bfloat16:
33
+ rtol, atol = 1e-2, 5e-2
34
+ rtolw, atolw = (1e-3, 1e-3)
35
+ # set seed
36
+ torch.random.manual_seed(0)
37
+ batch_size = 2
38
+ # batch_size = 1
39
+ dim = 4096 + 32 # Try dim not divisible by 64
40
+ # dim = 64
41
+ if not channel_last:
42
+ x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
43
+ else:
44
+ x = rearrange(
45
+ torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
46
+ ).requires_grad_()
47
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
48
+ if has_bias:
49
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
50
+ else:
51
+ bias = None
52
+ x_ref = x.detach().clone().requires_grad_()
53
+ weight_ref = weight.detach().clone().requires_grad_()
54
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
55
+ activation = None if not silu_activation else "silu"
56
+ out = causal_conv1d_fn(x, weight, bias, activation=activation)
57
+ out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
58
+
59
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
60
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
61
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
62
+
63
+ g = torch.randn_like(out)
64
+ out_ref.backward(g)
65
+ out.backward(g)
66
+
67
+ print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
68
+ print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
69
+ if has_bias:
70
+ print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
71
+
72
+ assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
73
+ assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
74
+ if has_bias:
75
+ assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
76
+
77
+
78
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
79
+ # @pytest.mark.parametrize('itype', [torch.float16])
80
+ @pytest.mark.parametrize("silu_activation", [False, True])
81
+ # @pytest.mark.parametrize('silu_activation', [False])
82
+ @pytest.mark.parametrize("has_bias", [False, True])
83
+ # @pytest.mark.parametrize('has_bias', [True])
84
+ @pytest.mark.parametrize("width", [2, 3, 4])
85
+ # @pytest.mark.parametrize('width', [2])
86
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
87
+ # @pytest.mark.parametrize("dim", [2048])
88
+ def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
89
+ device = "cuda"
90
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
91
+ if itype == torch.bfloat16:
92
+ rtol, atol = 1e-2, 5e-2
93
+ rtolw, atolw = (1e-3, 1e-3)
94
+ # set seed
95
+ torch.random.manual_seed(0)
96
+ batch_size = 2
97
+ # batch_size = 1
98
+ # dim = 64
99
+ x = torch.randn(batch_size, dim, device=device, dtype=itype)
100
+ conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
101
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
102
+ if has_bias:
103
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
104
+ else:
105
+ bias = None
106
+ conv_state_ref = conv_state.detach().clone()
107
+ activation = None if not silu_activation else "silu"
108
+ out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
109
+ out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
110
+
111
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
112
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
113
+ assert torch.equal(conv_state, conv_state_ref)
114
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115
+
116
+
117
+ # @pytest.mark.parametrize("channel_last", [False, True])
118
+ @pytest.mark.parametrize('channel_last', [True])
119
+ # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
120
+ @pytest.mark.parametrize('itype', [torch.bfloat16])
121
+ # @pytest.mark.parametrize("silu_activation", [False, True])
122
+ @pytest.mark.parametrize('silu_activation', [True])
123
+ # @pytest.mark.parametrize("has_bias", [False, True])
124
+ @pytest.mark.parametrize('has_bias', [True])
125
+ # @pytest.mark.parametrize("width", [2, 3, 4])
126
+ @pytest.mark.parametrize('width', [4])
127
+ @pytest.mark.parametrize(
128
+ # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
129
+ "seqlen", [2048]
130
+ )
131
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
132
+ # @pytest.mark.parametrize('seqlen', [128])
133
+ def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
134
+ device = "cuda"
135
+ # set seed
136
+ torch.random.manual_seed(0)
137
+ batch_size = 2
138
+ # batch_size = 1
139
+ dim = 4096 + 32 # Try dim not divisible by 64
140
+ # dim = 64
141
+ if not channel_last:
142
+ x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
143
+ else:
144
+ x = rearrange(
145
+ torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
146
+ ).requires_grad_()
147
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
148
+ if has_bias:
149
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
150
+ else:
151
+ bias = None
152
+ activation = None if not silu_activation else "silu"
153
+ out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
154
+ g = torch.randn_like(out0)
155
+ dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
156
+ dw_atol = 1e-4
157
+ db_atol = 1e-4
158
+
159
+ for i in range(10000):
160
+ out = causal_conv1d_fn(x, weight, bias, activation=activation)
161
+ dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
162
+ dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
163
+ # if not dw_equal:
164
+ # breakpoint()
165
+ if has_bias:
166
+ db_equal = torch.allclose(db, db0, atol=db_atol)
167
+ # if not db_equal:
168
+ # breakpoint()
169
+ assert torch.equal(out, out0)
170
+ assert torch.equal(dx, dx0)
171
+ assert dw_equal
172
+ if has_bias:
173
+ assert dw_equal
imagenet_class_index.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_classnames = {
2
+ "0": ["n01440764", "tench"],
3
+ "1": ["n01443537", "goldfish"],
4
+ "2": ["n01484850", "great_white_shark"],
5
+ "3": ["n01491361", "tiger_shark"],
6
+ "4": ["n01494475", "hammerhead"],
7
+ "5": ["n01496331", "electric_ray"],
8
+ "6": ["n01498041", "stingray"],
9
+ "7": ["n01514668", "cock"],
10
+ "8": ["n01514859", "hen"],
11
+ "9": ["n01518878", "ostrich"],
12
+ "10": ["n01530575", "brambling"],
13
+ "11": ["n01531178", "goldfinch"],
14
+ "12": ["n01532829", "house_finch"],
15
+ "13": ["n01534433", "junco"],
16
+ "14": ["n01537544", "indigo_bunting"],
17
+ "15": ["n01558993", "robin"],
18
+ "16": ["n01560419", "bulbul"],
19
+ "17": ["n01580077", "jay"],
20
+ "18": ["n01582220", "magpie"],
21
+ "19": ["n01592084", "chickadee"],
22
+ "20": ["n01601694", "water_ouzel"],
23
+ "21": ["n01608432", "kite"],
24
+ "22": ["n01614925", "bald_eagle"],
25
+ "23": ["n01616318", "vulture"],
26
+ "24": ["n01622779", "great_grey_owl"],
27
+ "25": ["n01629819", "European_fire_salamander"],
28
+ "26": ["n01630670", "common_newt"],
29
+ "27": ["n01631663", "eft"],
30
+ "28": ["n01632458", "spotted_salamander"],
31
+ "29": ["n01632777", "axolotl"],
32
+ "30": ["n01641577", "bullfrog"],
33
+ "31": ["n01644373", "tree_frog"],
34
+ "32": ["n01644900", "tailed_frog"],
35
+ "33": ["n01664065", "loggerhead"],
36
+ "34": ["n01665541", "leatherback_turtle"],
37
+ "35": ["n01667114", "mud_turtle"],
38
+ "36": ["n01667778", "terrapin"],
39
+ "37": ["n01669191", "box_turtle"],
40
+ "38": ["n01675722", "banded_gecko"],
41
+ "39": ["n01677366", "common_iguana"],
42
+ "40": ["n01682714", "American_chameleon"],
43
+ "41": ["n01685808", "whiptail"],
44
+ "42": ["n01687978", "agama"],
45
+ "43": ["n01688243", "frilled_lizard"],
46
+ "44": ["n01689811", "alligator_lizard"],
47
+ "45": ["n01692333", "Gila_monster"],
48
+ "46": ["n01693334", "green_lizard"],
49
+ "47": ["n01694178", "African_chameleon"],
50
+ "48": ["n01695060", "Komodo_dragon"],
51
+ "49": ["n01697457", "African_crocodile"],
52
+ "50": ["n01698640", "American_alligator"],
53
+ "51": ["n01704323", "triceratops"],
54
+ "52": ["n01728572", "thunder_snake"],
55
+ "53": ["n01728920", "ringneck_snake"],
56
+ "54": ["n01729322", "hognose_snake"],
57
+ "55": ["n01729977", "green_snake"],
58
+ "56": ["n01734418", "king_snake"],
59
+ "57": ["n01735189", "garter_snake"],
60
+ "58": ["n01737021", "water_snake"],
61
+ "59": ["n01739381", "vine_snake"],
62
+ "60": ["n01740131", "night_snake"],
63
+ "61": ["n01742172", "boa_constrictor"],
64
+ "62": ["n01744401", "rock_python"],
65
+ "63": ["n01748264", "Indian_cobra"],
66
+ "64": ["n01749939", "green_mamba"],
67
+ "65": ["n01751748", "sea_snake"],
68
+ "66": ["n01753488", "horned_viper"],
69
+ "67": ["n01755581", "diamondback"],
70
+ "68": ["n01756291", "sidewinder"],
71
+ "69": ["n01768244", "trilobite"],
72
+ "70": ["n01770081", "harvestman"],
73
+ "71": ["n01770393", "scorpion"],
74
+ "72": ["n01773157", "black_and_gold_garden_spider"],
75
+ "73": ["n01773549", "barn_spider"],
76
+ "74": ["n01773797", "garden_spider"],
77
+ "75": ["n01774384", "black_widow"],
78
+ "76": ["n01774750", "tarantula"],
79
+ "77": ["n01775062", "wolf_spider"],
80
+ "78": ["n01776313", "tick"],
81
+ "79": ["n01784675", "centipede"],
82
+ "80": ["n01795545", "black_grouse"],
83
+ "81": ["n01796340", "ptarmigan"],
84
+ "82": ["n01797886", "ruffed_grouse"],
85
+ "83": ["n01798484", "prairie_chicken"],
86
+ "84": ["n01806143", "peacock"],
87
+ "85": ["n01806567", "quail"],
88
+ "86": ["n01807496", "partridge"],
89
+ "87": ["n01817953", "African_grey"],
90
+ "88": ["n01818515", "macaw"],
91
+ "89": ["n01819313", "sulphur-crested_cockatoo"],
92
+ "90": ["n01820546", "lorikeet"],
93
+ "91": ["n01824575", "coucal"],
94
+ "92": ["n01828970", "bee_eater"],
95
+ "93": ["n01829413", "hornbill"],
96
+ "94": ["n01833805", "hummingbird"],
97
+ "95": ["n01843065", "jacamar"],
98
+ "96": ["n01843383", "toucan"],
99
+ "97": ["n01847000", "drake"],
100
+ "98": ["n01855032", "red-breasted_merganser"],
101
+ "99": ["n01855672", "goose"],
102
+ "100": ["n01860187", "black_swan"],
103
+ "101": ["n01871265", "tusker"],
104
+ "102": ["n01872401", "echidna"],
105
+ "103": ["n01873310", "platypus"],
106
+ "104": ["n01877812", "wallaby"],
107
+ "105": ["n01882714", "koala"],
108
+ "106": ["n01883070", "wombat"],
109
+ "107": ["n01910747", "jellyfish"],
110
+ "108": ["n01914609", "sea_anemone"],
111
+ "109": ["n01917289", "brain_coral"],
112
+ "110": ["n01924916", "flatworm"],
113
+ "111": ["n01930112", "nematode"],
114
+ "112": ["n01943899", "conch"],
115
+ "113": ["n01944390", "snail"],
116
+ "114": ["n01945685", "slug"],
117
+ "115": ["n01950731", "sea_slug"],
118
+ "116": ["n01955084", "chiton"],
119
+ "117": ["n01968897", "chambered_nautilus"],
120
+ "118": ["n01978287", "Dungeness_crab"],
121
+ "119": ["n01978455", "rock_crab"],
122
+ "120": ["n01980166", "fiddler_crab"],
123
+ "121": ["n01981276", "king_crab"],
124
+ "122": ["n01983481", "American_lobster"],
125
+ "123": ["n01984695", "spiny_lobster"],
126
+ "124": ["n01985128", "crayfish"],
127
+ "125": ["n01986214", "hermit_crab"],
128
+ "126": ["n01990800", "isopod"],
129
+ "127": ["n02002556", "white_stork"],
130
+ "128": ["n02002724", "black_stork"],
131
+ "129": ["n02006656", "spoonbill"],
132
+ "130": ["n02007558", "flamingo"],
133
+ "131": ["n02009229", "little_blue_heron"],
134
+ "132": ["n02009912", "American_egret"],
135
+ "133": ["n02011460", "bittern"],
136
+ "134": ["n02012849", "crane"],
137
+ "135": ["n02013706", "limpkin"],
138
+ "136": ["n02017213", "European_gallinule"],
139
+ "137": ["n02018207", "American_coot"],
140
+ "138": ["n02018795", "bustard"],
141
+ "139": ["n02025239", "ruddy_turnstone"],
142
+ "140": ["n02027492", "red-backed_sandpiper"],
143
+ "141": ["n02028035", "redshank"],
144
+ "142": ["n02033041", "dowitcher"],
145
+ "143": ["n02037110", "oystercatcher"],
146
+ "144": ["n02051845", "pelican"],
147
+ "145": ["n02056570", "king_penguin"],
148
+ "146": ["n02058221", "albatross"],
149
+ "147": ["n02066245", "grey_whale"],
150
+ "148": ["n02071294", "killer_whale"],
151
+ "149": ["n02074367", "dugong"],
152
+ "150": ["n02077923", "sea_lion"],
153
+ "151": ["n02085620", "Chihuahua"],
154
+ "152": ["n02085782", "Japanese_spaniel"],
155
+ "153": ["n02085936", "Maltese_dog"],
156
+ "154": ["n02086079", "Pekinese"],
157
+ "155": ["n02086240", "Shih-Tzu"],
158
+ "156": ["n02086646", "Blenheim_spaniel"],
159
+ "157": ["n02086910", "papillon"],
160
+ "158": ["n02087046", "toy_terrier"],
161
+ "159": ["n02087394", "Rhodesian_ridgeback"],
162
+ "160": ["n02088094", "Afghan_hound"],
163
+ "161": ["n02088238", "basset"],
164
+ "162": ["n02088364", "beagle"],
165
+ "163": ["n02088466", "bloodhound"],
166
+ "164": ["n02088632", "bluetick"],
167
+ "165": ["n02089078", "black-and-tan_coonhound"],
168
+ "166": ["n02089867", "Walker_hound"],
169
+ "167": ["n02089973", "English_foxhound"],
170
+ "168": ["n02090379", "redbone"],
171
+ "169": ["n02090622", "borzoi"],
172
+ "170": ["n02090721", "Irish_wolfhound"],
173
+ "171": ["n02091032", "Italian_greyhound"],
174
+ "172": ["n02091134", "whippet"],
175
+ "173": ["n02091244", "Ibizan_hound"],
176
+ "174": ["n02091467", "Norwegian_elkhound"],
177
+ "175": ["n02091635", "otterhound"],
178
+ "176": ["n02091831", "Saluki"],
179
+ "177": ["n02092002", "Scottish_deerhound"],
180
+ "178": ["n02092339", "Weimaraner"],
181
+ "179": ["n02093256", "Staffordshire_bullterrier"],
182
+ "180": ["n02093428", "American_Staffordshire_terrier"],
183
+ "181": ["n02093647", "Bedlington_terrier"],
184
+ "182": ["n02093754", "Border_terrier"],
185
+ "183": ["n02093859", "Kerry_blue_terrier"],
186
+ "184": ["n02093991", "Irish_terrier"],
187
+ "185": ["n02094114", "Norfolk_terrier"],
188
+ "186": ["n02094258", "Norwich_terrier"],
189
+ "187": ["n02094433", "Yorkshire_terrier"],
190
+ "188": ["n02095314", "wire-haired_fox_terrier"],
191
+ "189": ["n02095570", "Lakeland_terrier"],
192
+ "190": ["n02095889", "Sealyham_terrier"],
193
+ "191": ["n02096051", "Airedale"],
194
+ "192": ["n02096177", "cairn"],
195
+ "193": ["n02096294", "Australian_terrier"],
196
+ "194": ["n02096437", "Dandie_Dinmont"],
197
+ "195": ["n02096585", "Boston_bull"],
198
+ "196": ["n02097047", "miniature_schnauzer"],
199
+ "197": ["n02097130", "giant_schnauzer"],
200
+ "198": ["n02097209", "standard_schnauzer"],
201
+ "199": ["n02097298", "Scotch_terrier"],
202
+ "200": ["n02097474", "Tibetan_terrier"],
203
+ "201": ["n02097658", "silky_terrier"],
204
+ "202": ["n02098105", "soft-coated_wheaten_terrier"],
205
+ "203": ["n02098286", "West_Highland_white_terrier"],
206
+ "204": ["n02098413", "Lhasa"],
207
+ "205": ["n02099267", "flat-coated_retriever"],
208
+ "206": ["n02099429", "curly-coated_retriever"],
209
+ "207": ["n02099601", "golden_retriever"],
210
+ "208": ["n02099712", "Labrador_retriever"],
211
+ "209": ["n02099849", "Chesapeake_Bay_retriever"],
212
+ "210": ["n02100236", "German_short-haired_pointer"],
213
+ "211": ["n02100583", "vizsla"],
214
+ "212": ["n02100735", "English_setter"],
215
+ "213": ["n02100877", "Irish_setter"],
216
+ "214": ["n02101006", "Gordon_setter"],
217
+ "215": ["n02101388", "Brittany_spaniel"],
218
+ "216": ["n02101556", "clumber"],
219
+ "217": ["n02102040", "English_springer"],
220
+ "218": ["n02102177", "Welsh_springer_spaniel"],
221
+ "219": ["n02102318", "cocker_spaniel"],
222
+ "220": ["n02102480", "Sussex_spaniel"],
223
+ "221": ["n02102973", "Irish_water_spaniel"],
224
+ "222": ["n02104029", "kuvasz"],
225
+ "223": ["n02104365", "schipperke"],
226
+ "224": ["n02105056", "groenendael"],
227
+ "225": ["n02105162", "malinois"],
228
+ "226": ["n02105251", "briard"],
229
+ "227": ["n02105412", "kelpie"],
230
+ "228": ["n02105505", "komondor"],
231
+ "229": ["n02105641", "Old_English_sheepdog"],
232
+ "230": ["n02105855", "Shetland_sheepdog"],
233
+ "231": ["n02106030", "collie"],
234
+ "232": ["n02106166", "Border_collie"],
235
+ "233": ["n02106382", "Bouvier_des_Flandres"],
236
+ "234": ["n02106550", "Rottweiler"],
237
+ "235": ["n02106662", "German_shepherd"],
238
+ "236": ["n02107142", "Doberman"],
239
+ "237": ["n02107312", "miniature_pinscher"],
240
+ "238": ["n02107574", "Greater_Swiss_Mountain_dog"],
241
+ "239": ["n02107683", "Bernese_mountain_dog"],
242
+ "240": ["n02107908", "Appenzeller"],
243
+ "241": ["n02108000", "EntleBucher"],
244
+ "242": ["n02108089", "boxer"],
245
+ "243": ["n02108422", "bull_mastiff"],
246
+ "244": ["n02108551", "Tibetan_mastiff"],
247
+ "245": ["n02108915", "French_bulldog"],
248
+ "246": ["n02109047", "Great_Dane"],
249
+ "247": ["n02109525", "Saint_Bernard"],
250
+ "248": ["n02109961", "Eskimo_dog"],
251
+ "249": ["n02110063", "malamute"],
252
+ "250": ["n02110185", "Siberian_husky"],
253
+ "251": ["n02110341", "dalmatian"],
254
+ "252": ["n02110627", "affenpinscher"],
255
+ "253": ["n02110806", "basenji"],
256
+ "254": ["n02110958", "pug"],
257
+ "255": ["n02111129", "Leonberg"],
258
+ "256": ["n02111277", "Newfoundland"],
259
+ "257": ["n02111500", "Great_Pyrenees"],
260
+ "258": ["n02111889", "Samoyed"],
261
+ "259": ["n02112018", "Pomeranian"],
262
+ "260": ["n02112137", "chow"],
263
+ "261": ["n02112350", "keeshond"],
264
+ "262": ["n02112706", "Brabancon_griffon"],
265
+ "263": ["n02113023", "Pembroke"],
266
+ "264": ["n02113186", "Cardigan"],
267
+ "265": ["n02113624", "toy_poodle"],
268
+ "266": ["n02113712", "miniature_poodle"],
269
+ "267": ["n02113799", "standard_poodle"],
270
+ "268": ["n02113978", "Mexican_hairless"],
271
+ "269": ["n02114367", "timber_wolf"],
272
+ "270": ["n02114548", "white_wolf"],
273
+ "271": ["n02114712", "red_wolf"],
274
+ "272": ["n02114855", "coyote"],
275
+ "273": ["n02115641", "dingo"],
276
+ "274": ["n02115913", "dhole"],
277
+ "275": ["n02116738", "African_hunting_dog"],
278
+ "276": ["n02117135", "hyena"],
279
+ "277": ["n02119022", "red_fox"],
280
+ "278": ["n02119789", "kit_fox"],
281
+ "279": ["n02120079", "Arctic_fox"],
282
+ "280": ["n02120505", "grey_fox"],
283
+ "281": ["n02123045", "tabby"],
284
+ "282": ["n02123159", "tiger_cat"],
285
+ "283": ["n02123394", "Persian_cat"],
286
+ "284": ["n02123597", "Siamese_cat"],
287
+ "285": ["n02124075", "Egyptian_cat"],
288
+ "286": ["n02125311", "cougar"],
289
+ "287": ["n02127052", "lynx"],
290
+ "288": ["n02128385", "leopard"],
291
+ "289": ["n02128757", "snow_leopard"],
292
+ "290": ["n02128925", "jaguar"],
293
+ "291": ["n02129165", "lion"],
294
+ "292": ["n02129604", "tiger"],
295
+ "293": ["n02130308", "cheetah"],
296
+ "294": ["n02132136", "brown_bear"],
297
+ "295": ["n02133161", "American_black_bear"],
298
+ "296": ["n02134084", "ice_bear"],
299
+ "297": ["n02134418", "sloth_bear"],
300
+ "298": ["n02137549", "mongoose"],
301
+ "299": ["n02138441", "meerkat"],
302
+ "300": ["n02165105", "tiger_beetle"],
303
+ "301": ["n02165456", "ladybug"],
304
+ "302": ["n02167151", "ground_beetle"],
305
+ "303": ["n02168699", "long-horned_beetle"],
306
+ "304": ["n02169497", "leaf_beetle"],
307
+ "305": ["n02172182", "dung_beetle"],
308
+ "306": ["n02174001", "rhinoceros_beetle"],
309
+ "307": ["n02177972", "weevil"],
310
+ "308": ["n02190166", "fly"],
311
+ "309": ["n02206856", "bee"],
312
+ "310": ["n02219486", "ant"],
313
+ "311": ["n02226429", "grasshopper"],
314
+ "312": ["n02229544", "cricket"],
315
+ "313": ["n02231487", "walking_stick"],
316
+ "314": ["n02233338", "cockroach"],
317
+ "315": ["n02236044", "mantis"],
318
+ "316": ["n02256656", "cicada"],
319
+ "317": ["n02259212", "leafhopper"],
320
+ "318": ["n02264363", "lacewing"],
321
+ "319": ["n02268443", "dragonfly"],
322
+ "320": ["n02268853", "damselfly"],
323
+ "321": ["n02276258", "admiral"],
324
+ "322": ["n02277742", "ringlet"],
325
+ "323": ["n02279972", "monarch"],
326
+ "324": ["n02280649", "cabbage_butterfly"],
327
+ "325": ["n02281406", "sulphur_butterfly"],
328
+ "326": ["n02281787", "lycaenid"],
329
+ "327": ["n02317335", "starfish"],
330
+ "328": ["n02319095", "sea_urchin"],
331
+ "329": ["n02321529", "sea_cucumber"],
332
+ "330": ["n02325366", "wood_rabbit"],
333
+ "331": ["n02326432", "hare"],
334
+ "332": ["n02328150", "Angora"],
335
+ "333": ["n02342885", "hamster"],
336
+ "334": ["n02346627", "porcupine"],
337
+ "335": ["n02356798", "fox_squirrel"],
338
+ "336": ["n02361337", "marmot"],
339
+ "337": ["n02363005", "beaver"],
340
+ "338": ["n02364673", "guinea_pig"],
341
+ "339": ["n02389026", "sorrel"],
342
+ "340": ["n02391049", "zebra"],
343
+ "341": ["n02395406", "hog"],
344
+ "342": ["n02396427", "wild_boar"],
345
+ "343": ["n02397096", "warthog"],
346
+ "344": ["n02398521", "hippopotamus"],
347
+ "345": ["n02403003", "ox"],
348
+ "346": ["n02408429", "water_buffalo"],
349
+ "347": ["n02410509", "bison"],
350
+ "348": ["n02412080", "ram"],
351
+ "349": ["n02415577", "bighorn"],
352
+ "350": ["n02417914", "ibex"],
353
+ "351": ["n02422106", "hartebeest"],
354
+ "352": ["n02422699", "impala"],
355
+ "353": ["n02423022", "gazelle"],
356
+ "354": ["n02437312", "Arabian_camel"],
357
+ "355": ["n02437616", "llama"],
358
+ "356": ["n02441942", "weasel"],
359
+ "357": ["n02442845", "mink"],
360
+ "358": ["n02443114", "polecat"],
361
+ "359": ["n02443484", "black-footed_ferret"],
362
+ "360": ["n02444819", "otter"],
363
+ "361": ["n02445715", "skunk"],
364
+ "362": ["n02447366", "badger"],
365
+ "363": ["n02454379", "armadillo"],
366
+ "364": ["n02457408", "three-toed_sloth"],
367
+ "365": ["n02480495", "orangutan"],
368
+ "366": ["n02480855", "gorilla"],
369
+ "367": ["n02481823", "chimpanzee"],
370
+ "368": ["n02483362", "gibbon"],
371
+ "369": ["n02483708", "siamang"],
372
+ "370": ["n02484975", "guenon"],
373
+ "371": ["n02486261", "patas"],
374
+ "372": ["n02486410", "baboon"],
375
+ "373": ["n02487347", "macaque"],
376
+ "374": ["n02488291", "langur"],
377
+ "375": ["n02488702", "colobus"],
378
+ "376": ["n02489166", "proboscis_monkey"],
379
+ "377": ["n02490219", "marmoset"],
380
+ "378": ["n02492035", "capuchin"],
381
+ "379": ["n02492660", "howler_monkey"],
382
+ "380": ["n02493509", "titi"],
383
+ "381": ["n02493793", "spider_monkey"],
384
+ "382": ["n02494079", "squirrel_monkey"],
385
+ "383": ["n02497673", "Madagascar_cat"],
386
+ "384": ["n02500267", "indri"],
387
+ "385": ["n02504013", "Indian_elephant"],
388
+ "386": ["n02504458", "African_elephant"],
389
+ "387": ["n02509815", "lesser_panda"],
390
+ "388": ["n02510455", "giant_panda"],
391
+ "389": ["n02514041", "barracouta"],
392
+ "390": ["n02526121", "eel"],
393
+ "391": ["n02536864", "coho"],
394
+ "392": ["n02606052", "rock_beauty"],
395
+ "393": ["n02607072", "anemone_fish"],
396
+ "394": ["n02640242", "sturgeon"],
397
+ "395": ["n02641379", "gar"],
398
+ "396": ["n02643566", "lionfish"],
399
+ "397": ["n02655020", "puffer"],
400
+ "398": ["n02666196", "abacus"],
401
+ "399": ["n02667093", "abaya"],
402
+ "400": ["n02669723", "academic_gown"],
403
+ "401": ["n02672831", "accordion"],
404
+ "402": ["n02676566", "acoustic_guitar"],
405
+ "403": ["n02687172", "aircraft_carrier"],
406
+ "404": ["n02690373", "airliner"],
407
+ "405": ["n02692877", "airship"],
408
+ "406": ["n02699494", "altar"],
409
+ "407": ["n02701002", "ambulance"],
410
+ "408": ["n02704792", "amphibian"],
411
+ "409": ["n02708093", "analog_clock"],
412
+ "410": ["n02727426", "apiary"],
413
+ "411": ["n02730930", "apron"],
414
+ "412": ["n02747177", "ashcan"],
415
+ "413": ["n02749479", "assault_rifle"],
416
+ "414": ["n02769748", "backpack"],
417
+ "415": ["n02776631", "bakery"],
418
+ "416": ["n02777292", "balance_beam"],
419
+ "417": ["n02782093", "balloon"],
420
+ "418": ["n02783161", "ballpoint"],
421
+ "419": ["n02786058", "Band_Aid"],
422
+ "420": ["n02787622", "banjo"],
423
+ "421": ["n02788148", "bannister"],
424
+ "422": ["n02790996", "barbell"],
425
+ "423": ["n02791124", "barber_chair"],
426
+ "424": ["n02791270", "barbershop"],
427
+ "425": ["n02793495", "barn"],
428
+ "426": ["n02794156", "barometer"],
429
+ "427": ["n02795169", "barrel"],
430
+ "428": ["n02797295", "barrow"],
431
+ "429": ["n02799071", "baseball"],
432
+ "430": ["n02802426", "basketball"],
433
+ "431": ["n02804414", "bassinet"],
434
+ "432": ["n02804610", "bassoon"],
435
+ "433": ["n02807133", "bathing_cap"],
436
+ "434": ["n02808304", "bath_towel"],
437
+ "435": ["n02808440", "bathtub"],
438
+ "436": ["n02814533", "beach_wagon"],
439
+ "437": ["n02814860", "beacon"],
440
+ "438": ["n02815834", "beaker"],
441
+ "439": ["n02817516", "bearskin"],
442
+ "440": ["n02823428", "beer_bottle"],
443
+ "441": ["n02823750", "beer_glass"],
444
+ "442": ["n02825657", "bell_cote"],
445
+ "443": ["n02834397", "bib"],
446
+ "444": ["n02835271", "bicycle-built-for-two"],
447
+ "445": ["n02837789", "bikini"],
448
+ "446": ["n02840245", "binder"],
449
+ "447": ["n02841315", "binoculars"],
450
+ "448": ["n02843684", "birdhouse"],
451
+ "449": ["n02859443", "boathouse"],
452
+ "450": ["n02860847", "bobsled"],
453
+ "451": ["n02865351", "bolo_tie"],
454
+ "452": ["n02869837", "bonnet"],
455
+ "453": ["n02870880", "bookcase"],
456
+ "454": ["n02871525", "bookshop"],
457
+ "455": ["n02877765", "bottlecap"],
458
+ "456": ["n02879718", "bow"],
459
+ "457": ["n02883205", "bow_tie"],
460
+ "458": ["n02892201", "brass"],
461
+ "459": ["n02892767", "brassiere"],
462
+ "460": ["n02894605", "breakwater"],
463
+ "461": ["n02895154", "breastplate"],
464
+ "462": ["n02906734", "broom"],
465
+ "463": ["n02909870", "bucket"],
466
+ "464": ["n02910353", "buckle"],
467
+ "465": ["n02916936", "bulletproof_vest"],
468
+ "466": ["n02917067", "bullet_train"],
469
+ "467": ["n02927161", "butcher_shop"],
470
+ "468": ["n02930766", "cab"],
471
+ "469": ["n02939185", "caldron"],
472
+ "470": ["n02948072", "candle"],
473
+ "471": ["n02950826", "cannon"],
474
+ "472": ["n02951358", "canoe"],
475
+ "473": ["n02951585", "can_opener"],
476
+ "474": ["n02963159", "cardigan"],
477
+ "475": ["n02965783", "car_mirror"],
478
+ "476": ["n02966193", "carousel"],
479
+ "477": ["n02966687", "carpenter's_kit"],
480
+ "478": ["n02971356", "carton"],
481
+ "479": ["n02974003", "car_wheel"],
482
+ "480": ["n02977058", "cash_machine"],
483
+ "481": ["n02978881", "cassette"],
484
+ "482": ["n02979186", "cassette_player"],
485
+ "483": ["n02980441", "castle"],
486
+ "484": ["n02981792", "catamaran"],
487
+ "485": ["n02988304", "CD_player"],
488
+ "486": ["n02992211", "cello"],
489
+ "487": ["n02992529", "cellular_telephone"],
490
+ "488": ["n02999410", "chain"],
491
+ "489": ["n03000134", "chainlink_fence"],
492
+ "490": ["n03000247", "chain_mail"],
493
+ "491": ["n03000684", "chain_saw"],
494
+ "492": ["n03014705", "chest"],
495
+ "493": ["n03016953", "chiffonier"],
496
+ "494": ["n03017168", "chime"],
497
+ "495": ["n03018349", "china_cabinet"],
498
+ "496": ["n03026506", "Christmas_stocking"],
499
+ "497": ["n03028079", "church"],
500
+ "498": ["n03032252", "cinema"],
501
+ "499": ["n03041632", "cleaver"],
502
+ "500": ["n03042490", "cliff_dwelling"],
503
+ "501": ["n03045698", "cloak"],
504
+ "502": ["n03047690", "clog"],
505
+ "503": ["n03062245", "cocktail_shaker"],
506
+ "504": ["n03063599", "coffee_mug"],
507
+ "505": ["n03063689", "coffeepot"],
508
+ "506": ["n03065424", "coil"],
509
+ "507": ["n03075370", "combination_lock"],
510
+ "508": ["n03085013", "computer_keyboard"],
511
+ "509": ["n03089624", "confectionery"],
512
+ "510": ["n03095699", "container_ship"],
513
+ "511": ["n03100240", "convertible"],
514
+ "512": ["n03109150", "corkscrew"],
515
+ "513": ["n03110669", "cornet"],
516
+ "514": ["n03124043", "cowboy_boot"],
517
+ "515": ["n03124170", "cowboy_hat"],
518
+ "516": ["n03125729", "cradle"],
519
+ "517": ["n03126707", "crane"],
520
+ "518": ["n03127747", "crash_helmet"],
521
+ "519": ["n03127925", "crate"],
522
+ "520": ["n03131574", "crib"],
523
+ "521": ["n03133878", "Crock_Pot"],
524
+ "522": ["n03134739", "croquet_ball"],
525
+ "523": ["n03141823", "crutch"],
526
+ "524": ["n03146219", "cuirass"],
527
+ "525": ["n03160309", "dam"],
528
+ "526": ["n03179701", "desk"],
529
+ "527": ["n03180011", "desktop_computer"],
530
+ "528": ["n03187595", "dial_telephone"],
531
+ "529": ["n03188531", "diaper"],
532
+ "530": ["n03196217", "digital_clock"],
533
+ "531": ["n03197337", "digital_watch"],
534
+ "532": ["n03201208", "dining_table"],
535
+ "533": ["n03207743", "dishrag"],
536
+ "534": ["n03207941", "dishwasher"],
537
+ "535": ["n03208938", "disk_brake"],
538
+ "536": ["n03216828", "dock"],
539
+ "537": ["n03218198", "dogsled"],
540
+ "538": ["n03220513", "dome"],
541
+ "539": ["n03223299", "doormat"],
542
+ "540": ["n03240683", "drilling_platform"],
543
+ "541": ["n03249569", "drum"],
544
+ "542": ["n03250847", "drumstick"],
545
+ "543": ["n03255030", "dumbbell"],
546
+ "544": ["n03259280", "Dutch_oven"],
547
+ "545": ["n03271574", "electric_fan"],
548
+ "546": ["n03272010", "electric_guitar"],
549
+ "547": ["n03272562", "electric_locomotive"],
550
+ "548": ["n03290653", "entertainment_center"],
551
+ "549": ["n03291819", "envelope"],
552
+ "550": ["n03297495", "espresso_maker"],
553
+ "551": ["n03314780", "face_powder"],
554
+ "552": ["n03325584", "feather_boa"],
555
+ "553": ["n03337140", "file"],
556
+ "554": ["n03344393", "fireboat"],
557
+ "555": ["n03345487", "fire_engine"],
558
+ "556": ["n03347037", "fire_screen"],
559
+ "557": ["n03355925", "flagpole"],
560
+ "558": ["n03372029", "flute"],
561
+ "559": ["n03376595", "folding_chair"],
562
+ "560": ["n03379051", "football_helmet"],
563
+ "561": ["n03384352", "forklift"],
564
+ "562": ["n03388043", "fountain"],
565
+ "563": ["n03388183", "fountain_pen"],
566
+ "564": ["n03388549", "four-poster"],
567
+ "565": ["n03393912", "freight_car"],
568
+ "566": ["n03394916", "French_horn"],
569
+ "567": ["n03400231", "frying_pan"],
570
+ "568": ["n03404251", "fur_coat"],
571
+ "569": ["n03417042", "garbage_truck"],
572
+ "570": ["n03424325", "gasmask"],
573
+ "571": ["n03425413", "gas_pump"],
574
+ "572": ["n03443371", "goblet"],
575
+ "573": ["n03444034", "go-kart"],
576
+ "574": ["n03445777", "golf_ball"],
577
+ "575": ["n03445924", "golfcart"],
578
+ "576": ["n03447447", "gondola"],
579
+ "577": ["n03447721", "gong"],
580
+ "578": ["n03450230", "gown"],
581
+ "579": ["n03452741", "grand_piano"],
582
+ "580": ["n03457902", "greenhouse"],
583
+ "581": ["n03459775", "grille"],
584
+ "582": ["n03461385", "grocery_store"],
585
+ "583": ["n03467068", "guillotine"],
586
+ "584": ["n03476684", "hair_slide"],
587
+ "585": ["n03476991", "hair_spray"],
588
+ "586": ["n03478589", "half_track"],
589
+ "587": ["n03481172", "hammer"],
590
+ "588": ["n03482405", "hamper"],
591
+ "589": ["n03483316", "hand_blower"],
592
+ "590": ["n03485407", "hand-held_computer"],
593
+ "591": ["n03485794", "handkerchief"],
594
+ "592": ["n03492542", "hard_disc"],
595
+ "593": ["n03494278", "harmonica"],
596
+ "594": ["n03495258", "harp"],
597
+ "595": ["n03496892", "harvester"],
598
+ "596": ["n03498962", "hatchet"],
599
+ "597": ["n03527444", "holster"],
600
+ "598": ["n03529860", "home_theater"],
601
+ "599": ["n03530642", "honeycomb"],
602
+ "600": ["n03532672", "hook"],
603
+ "601": ["n03534580", "hoopskirt"],
604
+ "602": ["n03535780", "horizontal_bar"],
605
+ "603": ["n03538406", "horse_cart"],
606
+ "604": ["n03544143", "hourglass"],
607
+ "605": ["n03584254", "iPod"],
608
+ "606": ["n03584829", "iron"],
609
+ "607": ["n03590841", "jack-o'-lantern"],
610
+ "608": ["n03594734", "jean"],
611
+ "609": ["n03594945", "jeep"],
612
+ "610": ["n03595614", "jersey"],
613
+ "611": ["n03598930", "jigsaw_puzzle"],
614
+ "612": ["n03599486", "jinrikisha"],
615
+ "613": ["n03602883", "joystick"],
616
+ "614": ["n03617480", "kimono"],
617
+ "615": ["n03623198", "knee_pad"],
618
+ "616": ["n03627232", "knot"],
619
+ "617": ["n03630383", "lab_coat"],
620
+ "618": ["n03633091", "ladle"],
621
+ "619": ["n03637318", "lampshade"],
622
+ "620": ["n03642806", "laptop"],
623
+ "621": ["n03649909", "lawn_mower"],
624
+ "622": ["n03657121", "lens_cap"],
625
+ "623": ["n03658185", "letter_opener"],
626
+ "624": ["n03661043", "library"],
627
+ "625": ["n03662601", "lifeboat"],
628
+ "626": ["n03666591", "lighter"],
629
+ "627": ["n03670208", "limousine"],
630
+ "628": ["n03673027", "liner"],
631
+ "629": ["n03676483", "lipstick"],
632
+ "630": ["n03680355", "Loafer"],
633
+ "631": ["n03690938", "lotion"],
634
+ "632": ["n03691459", "loudspeaker"],
635
+ "633": ["n03692522", "loupe"],
636
+ "634": ["n03697007", "lumbermill"],
637
+ "635": ["n03706229", "magnetic_compass"],
638
+ "636": ["n03709823", "mailbag"],
639
+ "637": ["n03710193", "mailbox"],
640
+ "638": ["n03710637", "maillot"],
641
+ "639": ["n03710721", "maillot"],
642
+ "640": ["n03717622", "manhole_cover"],
643
+ "641": ["n03720891", "maraca"],
644
+ "642": ["n03721384", "marimba"],
645
+ "643": ["n03724870", "mask"],
646
+ "644": ["n03729826", "matchstick"],
647
+ "645": ["n03733131", "maypole"],
648
+ "646": ["n03733281", "maze"],
649
+ "647": ["n03733805", "measuring_cup"],
650
+ "648": ["n03742115", "medicine_chest"],
651
+ "649": ["n03743016", "megalith"],
652
+ "650": ["n03759954", "microphone"],
653
+ "651": ["n03761084", "microwave"],
654
+ "652": ["n03763968", "military_uniform"],
655
+ "653": ["n03764736", "milk_can"],
656
+ "654": ["n03769881", "minibus"],
657
+ "655": ["n03770439", "miniskirt"],
658
+ "656": ["n03770679", "minivan"],
659
+ "657": ["n03773504", "missile"],
660
+ "658": ["n03775071", "mitten"],
661
+ "659": ["n03775546", "mixing_bowl"],
662
+ "660": ["n03776460", "mobile_home"],
663
+ "661": ["n03777568", "Model_T"],
664
+ "662": ["n03777754", "modem"],
665
+ "663": ["n03781244", "monastery"],
666
+ "664": ["n03782006", "monitor"],
667
+ "665": ["n03785016", "moped"],
668
+ "666": ["n03786901", "mortar"],
669
+ "667": ["n03787032", "mortarboard"],
670
+ "668": ["n03788195", "mosque"],
671
+ "669": ["n03788365", "mosquito_net"],
672
+ "670": ["n03791053", "motor_scooter"],
673
+ "671": ["n03792782", "mountain_bike"],
674
+ "672": ["n03792972", "mountain_tent"],
675
+ "673": ["n03793489", "mouse"],
676
+ "674": ["n03794056", "mousetrap"],
677
+ "675": ["n03796401", "moving_van"],
678
+ "676": ["n03803284", "muzzle"],
679
+ "677": ["n03804744", "nail"],
680
+ "678": ["n03814639", "neck_brace"],
681
+ "679": ["n03814906", "necklace"],
682
+ "680": ["n03825788", "nipple"],
683
+ "681": ["n03832673", "notebook"],
684
+ "682": ["n03837869", "obelisk"],
685
+ "683": ["n03838899", "oboe"],
686
+ "684": ["n03840681", "ocarina"],
687
+ "685": ["n03841143", "odometer"],
688
+ "686": ["n03843555", "oil_filter"],
689
+ "687": ["n03854065", "organ"],
690
+ "688": ["n03857828", "oscilloscope"],
691
+ "689": ["n03866082", "overskirt"],
692
+ "690": ["n03868242", "oxcart"],
693
+ "691": ["n03868863", "oxygen_mask"],
694
+ "692": ["n03871628", "packet"],
695
+ "693": ["n03873416", "paddle"],
696
+ "694": ["n03874293", "paddlewheel"],
697
+ "695": ["n03874599", "padlock"],
698
+ "696": ["n03876231", "paintbrush"],
699
+ "697": ["n03877472", "pajama"],
700
+ "698": ["n03877845", "palace"],
701
+ "699": ["n03884397", "panpipe"],
702
+ "700": ["n03887697", "paper_towel"],
703
+ "701": ["n03888257", "parachute"],
704
+ "702": ["n03888605", "parallel_bars"],
705
+ "703": ["n03891251", "park_bench"],
706
+ "704": ["n03891332", "parking_meter"],
707
+ "705": ["n03895866", "passenger_car"],
708
+ "706": ["n03899768", "patio"],
709
+ "707": ["n03902125", "pay-phone"],
710
+ "708": ["n03903868", "pedestal"],
711
+ "709": ["n03908618", "pencil_box"],
712
+ "710": ["n03908714", "pencil_sharpener"],
713
+ "711": ["n03916031", "perfume"],
714
+ "712": ["n03920288", "Petri_dish"],
715
+ "713": ["n03924679", "photocopier"],
716
+ "714": ["n03929660", "pick"],
717
+ "715": ["n03929855", "pickelhaube"],
718
+ "716": ["n03930313", "picket_fence"],
719
+ "717": ["n03930630", "pickup"],
720
+ "718": ["n03933933", "pier"],
721
+ "719": ["n03935335", "piggy_bank"],
722
+ "720": ["n03937543", "pill_bottle"],
723
+ "721": ["n03938244", "pillow"],
724
+ "722": ["n03942813", "ping-pong_ball"],
725
+ "723": ["n03944341", "pinwheel"],
726
+ "724": ["n03947888", "pirate"],
727
+ "725": ["n03950228", "pitcher"],
728
+ "726": ["n03954731", "plane"],
729
+ "727": ["n03956157", "planetarium"],
730
+ "728": ["n03958227", "plastic_bag"],
731
+ "729": ["n03961711", "plate_rack"],
732
+ "730": ["n03967562", "plow"],
733
+ "731": ["n03970156", "plunger"],
734
+ "732": ["n03976467", "Polaroid_camera"],
735
+ "733": ["n03976657", "pole"],
736
+ "734": ["n03977966", "police_van"],
737
+ "735": ["n03980874", "poncho"],
738
+ "736": ["n03982430", "pool_table"],
739
+ "737": ["n03983396", "pop_bottle"],
740
+ "738": ["n03991062", "pot"],
741
+ "739": ["n03992509", "potter's_wheel"],
742
+ "740": ["n03995372", "power_drill"],
743
+ "741": ["n03998194", "prayer_rug"],
744
+ "742": ["n04004767", "printer"],
745
+ "743": ["n04005630", "prison"],
746
+ "744": ["n04008634", "projectile"],
747
+ "745": ["n04009552", "projector"],
748
+ "746": ["n04019541", "puck"],
749
+ "747": ["n04023962", "punching_bag"],
750
+ "748": ["n04026417", "purse"],
751
+ "749": ["n04033901", "quill"],
752
+ "750": ["n04033995", "quilt"],
753
+ "751": ["n04037443", "racer"],
754
+ "752": ["n04039381", "racket"],
755
+ "753": ["n04040759", "radiator"],
756
+ "754": ["n04041544", "radio"],
757
+ "755": ["n04044716", "radio_telescope"],
758
+ "756": ["n04049303", "rain_barrel"],
759
+ "757": ["n04065272", "recreational_vehicle"],
760
+ "758": ["n04067472", "reel"],
761
+ "759": ["n04069434", "reflex_camera"],
762
+ "760": ["n04070727", "refrigerator"],
763
+ "761": ["n04074963", "remote_control"],
764
+ "762": ["n04081281", "restaurant"],
765
+ "763": ["n04086273", "revolver"],
766
+ "764": ["n04090263", "rifle"],
767
+ "765": ["n04099969", "rocking_chair"],
768
+ "766": ["n04111531", "rotisserie"],
769
+ "767": ["n04116512", "rubber_eraser"],
770
+ "768": ["n04118538", "rugby_ball"],
771
+ "769": ["n04118776", "rule"],
772
+ "770": ["n04120489", "running_shoe"],
773
+ "771": ["n04125021", "safe"],
774
+ "772": ["n04127249", "safety_pin"],
775
+ "773": ["n04131690", "saltshaker"],
776
+ "774": ["n04133789", "sandal"],
777
+ "775": ["n04136333", "sarong"],
778
+ "776": ["n04141076", "sax"],
779
+ "777": ["n04141327", "scabbard"],
780
+ "778": ["n04141975", "scale"],
781
+ "779": ["n04146614", "school_bus"],
782
+ "780": ["n04147183", "schooner"],
783
+ "781": ["n04149813", "scoreboard"],
784
+ "782": ["n04152593", "screen"],
785
+ "783": ["n04153751", "screw"],
786
+ "784": ["n04154565", "screwdriver"],
787
+ "785": ["n04162706", "seat_belt"],
788
+ "786": ["n04179913", "sewing_machine"],
789
+ "787": ["n04192698", "shield"],
790
+ "788": ["n04200800", "shoe_shop"],
791
+ "789": ["n04201297", "shoji"],
792
+ "790": ["n04204238", "shopping_basket"],
793
+ "791": ["n04204347", "shopping_cart"],
794
+ "792": ["n04208210", "shovel"],
795
+ "793": ["n04209133", "shower_cap"],
796
+ "794": ["n04209239", "shower_curtain"],
797
+ "795": ["n04228054", "ski"],
798
+ "796": ["n04229816", "ski_mask"],
799
+ "797": ["n04235860", "sleeping_bag"],
800
+ "798": ["n04238763", "slide_rule"],
801
+ "799": ["n04239074", "sliding_door"],
802
+ "800": ["n04243546", "slot"],
803
+ "801": ["n04251144", "snorkel"],
804
+ "802": ["n04252077", "snowmobile"],
805
+ "803": ["n04252225", "snowplow"],
806
+ "804": ["n04254120", "soap_dispenser"],
807
+ "805": ["n04254680", "soccer_ball"],
808
+ "806": ["n04254777", "sock"],
809
+ "807": ["n04258138", "solar_dish"],
810
+ "808": ["n04259630", "sombrero"],
811
+ "809": ["n04263257", "soup_bowl"],
812
+ "810": ["n04264628", "space_bar"],
813
+ "811": ["n04265275", "space_heater"],
814
+ "812": ["n04266014", "space_shuttle"],
815
+ "813": ["n04270147", "spatula"],
816
+ "814": ["n04273569", "speedboat"],
817
+ "815": ["n04275548", "spider_web"],
818
+ "816": ["n04277352", "spindle"],
819
+ "817": ["n04285008", "sports_car"],
820
+ "818": ["n04286575", "spotlight"],
821
+ "819": ["n04296562", "stage"],
822
+ "820": ["n04310018", "steam_locomotive"],
823
+ "821": ["n04311004", "steel_arch_bridge"],
824
+ "822": ["n04311174", "steel_drum"],
825
+ "823": ["n04317175", "stethoscope"],
826
+ "824": ["n04325704", "stole"],
827
+ "825": ["n04326547", "stone_wall"],
828
+ "826": ["n04328186", "stopwatch"],
829
+ "827": ["n04330267", "stove"],
830
+ "828": ["n04332243", "strainer"],
831
+ "829": ["n04335435", "streetcar"],
832
+ "830": ["n04336792", "stretcher"],
833
+ "831": ["n04344873", "studio_couch"],
834
+ "832": ["n04346328", "stupa"],
835
+ "833": ["n04347754", "submarine"],
836
+ "834": ["n04350905", "suit"],
837
+ "835": ["n04355338", "sundial"],
838
+ "836": ["n04355933", "sunglass"],
839
+ "837": ["n04356056", "sunglasses"],
840
+ "838": ["n04357314", "sunscreen"],
841
+ "839": ["n04366367", "suspension_bridge"],
842
+ "840": ["n04367480", "swab"],
843
+ "841": ["n04370456", "sweatshirt"],
844
+ "842": ["n04371430", "swimming_trunks"],
845
+ "843": ["n04371774", "swing"],
846
+ "844": ["n04372370", "switch"],
847
+ "845": ["n04376876", "syringe"],
848
+ "846": ["n04380533", "table_lamp"],
849
+ "847": ["n04389033", "tank"],
850
+ "848": ["n04392985", "tape_player"],
851
+ "849": ["n04398044", "teapot"],
852
+ "850": ["n04399382", "teddy"],
853
+ "851": ["n04404412", "television"],
854
+ "852": ["n04409515", "tennis_ball"],
855
+ "853": ["n04417672", "thatch"],
856
+ "854": ["n04418357", "theater_curtain"],
857
+ "855": ["n04423845", "thimble"],
858
+ "856": ["n04428191", "thresher"],
859
+ "857": ["n04429376", "throne"],
860
+ "858": ["n04435653", "tile_roof"],
861
+ "859": ["n04442312", "toaster"],
862
+ "860": ["n04443257", "tobacco_shop"],
863
+ "861": ["n04447861", "toilet_seat"],
864
+ "862": ["n04456115", "torch"],
865
+ "863": ["n04458633", "totem_pole"],
866
+ "864": ["n04461696", "tow_truck"],
867
+ "865": ["n04462240", "toyshop"],
868
+ "866": ["n04465501", "tractor"],
869
+ "867": ["n04467665", "trailer_truck"],
870
+ "868": ["n04476259", "tray"],
871
+ "869": ["n04479046", "trench_coat"],
872
+ "870": ["n04482393", "tricycle"],
873
+ "871": ["n04483307", "trimaran"],
874
+ "872": ["n04485082", "tripod"],
875
+ "873": ["n04486054", "triumphal_arch"],
876
+ "874": ["n04487081", "trolleybus"],
877
+ "875": ["n04487394", "trombone"],
878
+ "876": ["n04493381", "tub"],
879
+ "877": ["n04501370", "turnstile"],
880
+ "878": ["n04505470", "typewriter_keyboard"],
881
+ "879": ["n04507155", "umbrella"],
882
+ "880": ["n04509417", "unicycle"],
883
+ "881": ["n04515003", "upright"],
884
+ "882": ["n04517823", "vacuum"],
885
+ "883": ["n04522168", "vase"],
886
+ "884": ["n04523525", "vault"],
887
+ "885": ["n04525038", "velvet"],
888
+ "886": ["n04525305", "vending_machine"],
889
+ "887": ["n04532106", "vestment"],
890
+ "888": ["n04532670", "viaduct"],
891
+ "889": ["n04536866", "violin"],
892
+ "890": ["n04540053", "volleyball"],
893
+ "891": ["n04542943", "waffle_iron"],
894
+ "892": ["n04548280", "wall_clock"],
895
+ "893": ["n04548362", "wallet"],
896
+ "894": ["n04550184", "wardrobe"],
897
+ "895": ["n04552348", "warplane"],
898
+ "896": ["n04553703", "washbasin"],
899
+ "897": ["n04554684", "washer"],
900
+ "898": ["n04557648", "water_bottle"],
901
+ "899": ["n04560804", "water_jug"],
902
+ "900": ["n04562935", "water_tower"],
903
+ "901": ["n04579145", "whiskey_jug"],
904
+ "902": ["n04579432", "whistle"],
905
+ "903": ["n04584207", "wig"],
906
+ "904": ["n04589890", "window_screen"],
907
+ "905": ["n04590129", "window_shade"],
908
+ "906": ["n04591157", "Windsor_tie"],
909
+ "907": ["n04591713", "wine_bottle"],
910
+ "908": ["n04592741", "wing"],
911
+ "909": ["n04596742", "wok"],
912
+ "910": ["n04597913", "wooden_spoon"],
913
+ "911": ["n04599235", "wool"],
914
+ "912": ["n04604644", "worm_fence"],
915
+ "913": ["n04606251", "wreck"],
916
+ "914": ["n04612504", "yawl"],
917
+ "915": ["n04613696", "yurt"],
918
+ "916": ["n06359193", "web_site"],
919
+ "917": ["n06596364", "comic_book"],
920
+ "918": ["n06785654", "crossword_puzzle"],
921
+ "919": ["n06794110", "street_sign"],
922
+ "920": ["n06874185", "traffic_light"],
923
+ "921": ["n07248320", "book_jacket"],
924
+ "922": ["n07565083", "menu"],
925
+ "923": ["n07579787", "plate"],
926
+ "924": ["n07583066", "guacamole"],
927
+ "925": ["n07584110", "consomme"],
928
+ "926": ["n07590611", "hot_pot"],
929
+ "927": ["n07613480", "trifle"],
930
+ "928": ["n07614500", "ice_cream"],
931
+ "929": ["n07615774", "ice_lolly"],
932
+ "930": ["n07684084", "French_loaf"],
933
+ "931": ["n07693725", "bagel"],
934
+ "932": ["n07695742", "pretzel"],
935
+ "933": ["n07697313", "cheeseburger"],
936
+ "934": ["n07697537", "hotdog"],
937
+ "935": ["n07711569", "mashed_potato"],
938
+ "936": ["n07714571", "head_cabbage"],
939
+ "937": ["n07714990", "broccoli"],
940
+ "938": ["n07715103", "cauliflower"],
941
+ "939": ["n07716358", "zucchini"],
942
+ "940": ["n07716906", "spaghetti_squash"],
943
+ "941": ["n07717410", "acorn_squash"],
944
+ "942": ["n07717556", "butternut_squash"],
945
+ "943": ["n07718472", "cucumber"],
946
+ "944": ["n07718747", "artichoke"],
947
+ "945": ["n07720875", "bell_pepper"],
948
+ "946": ["n07730033", "cardoon"],
949
+ "947": ["n07734744", "mushroom"],
950
+ "948": ["n07742313", "Granny_Smith"],
951
+ "949": ["n07745940", "strawberry"],
952
+ "950": ["n07747607", "orange"],
953
+ "951": ["n07749582", "lemon"],
954
+ "952": ["n07753113", "fig"],
955
+ "953": ["n07753275", "pineapple"],
956
+ "954": ["n07753592", "banana"],
957
+ "955": ["n07754684", "jackfruit"],
958
+ "956": ["n07760859", "custard_apple"],
959
+ "957": ["n07768694", "pomegranate"],
960
+ "958": ["n07802026", "hay"],
961
+ "959": ["n07831146", "carbonara"],
962
+ "960": ["n07836838", "chocolate_sauce"],
963
+ "961": ["n07860988", "dough"],
964
+ "962": ["n07871810", "meat_loaf"],
965
+ "963": ["n07873807", "pizza"],
966
+ "964": ["n07875152", "potpie"],
967
+ "965": ["n07880968", "burrito"],
968
+ "966": ["n07892512", "red_wine"],
969
+ "967": ["n07920052", "espresso"],
970
+ "968": ["n07930864", "cup"],
971
+ "969": ["n07932039", "eggnog"],
972
+ "970": ["n09193705", "alp"],
973
+ "971": ["n09229709", "bubble"],
974
+ "972": ["n09246464", "cliff"],
975
+ "973": ["n09256479", "coral_reef"],
976
+ "974": ["n09288635", "geyser"],
977
+ "975": ["n09332890", "lakeside"],
978
+ "976": ["n09399592", "promontory"],
979
+ "977": ["n09421951", "sandbar"],
980
+ "978": ["n09428293", "seashore"],
981
+ "979": ["n09468604", "valley"],
982
+ "980": ["n09472597", "volcano"],
983
+ "981": ["n09835506", "ballplayer"],
984
+ "982": ["n10148035", "groom"],
985
+ "983": ["n10565667", "scuba_diver"],
986
+ "984": ["n11879895", "rapeseed"],
987
+ "985": ["n11939491", "daisy"],
988
+ "986": ["n12057211", "yellow_lady's_slipper"],
989
+ "987": ["n12144580", "corn"],
990
+ "988": ["n12267677", "acorn"],
991
+ "989": ["n12620546", "hip"],
992
+ "990": ["n12768682", "buckeye"],
993
+ "991": ["n12985857", "coral_fungus"],
994
+ "992": ["n12998815", "agaric"],
995
+ "993": ["n13037406", "gyromitra"],
996
+ "994": ["n13040303", "stinkhorn"],
997
+ "995": ["n13044778", "earthstar"],
998
+ "996": ["n13052670", "hen-of-the-woods"],
999
+ "997": ["n13054560", "bolete"],
1000
+ "998": ["n13133613", "ear"],
1001
+ "999": ["n15075141", "toilet_tissue"]
1002
+ }
images/cat.png ADDED
images/dog.png ADDED
images/panda.png ADDED
install.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip install -e causal-conv1d
2
+ pip install -e mamba
kinetics_class_index.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kinetics_classnames = {
2
+ "0": "riding a bike",
3
+ "1": "marching",
4
+ "2": "dodgeball",
5
+ "3": "playing cymbals",
6
+ "4": "checking tires",
7
+ "5": "roller skating",
8
+ "6": "tasting beer",
9
+ "7": "clapping",
10
+ "8": "drawing",
11
+ "9": "juggling fire",
12
+ "10": "bobsledding",
13
+ "11": "petting animal (not cat)",
14
+ "12": "spray painting",
15
+ "13": "training dog",
16
+ "14": "eating watermelon",
17
+ "15": "building cabinet",
18
+ "16": "applauding",
19
+ "17": "playing harp",
20
+ "18": "balloon blowing",
21
+ "19": "sled dog racing",
22
+ "20": "wrestling",
23
+ "21": "pole vault",
24
+ "22": "hurling (sport)",
25
+ "23": "riding scooter",
26
+ "24": "shearing sheep",
27
+ "25": "sweeping floor",
28
+ "26": "eating carrots",
29
+ "27": "skateboarding",
30
+ "28": "dunking basketball",
31
+ "29": "disc golfing",
32
+ "30": "eating spaghetti",
33
+ "31": "playing flute",
34
+ "32": "riding mechanical bull",
35
+ "33": "making sushi",
36
+ "34": "trapezing",
37
+ "35": "picking fruit",
38
+ "36": "stretching leg",
39
+ "37": "playing ukulele",
40
+ "38": "tying tie",
41
+ "39": "skydiving",
42
+ "40": "playing cello",
43
+ "41": "jumping into pool",
44
+ "42": "shooting goal (soccer)",
45
+ "43": "trimming trees",
46
+ "44": "bookbinding",
47
+ "45": "ski jumping",
48
+ "46": "walking the dog",
49
+ "47": "riding unicycle",
50
+ "48": "shaving head",
51
+ "49": "hopscotch",
52
+ "50": "playing piano",
53
+ "51": "parasailing",
54
+ "52": "bartending",
55
+ "53": "kicking field goal",
56
+ "54": "finger snapping",
57
+ "55": "dining",
58
+ "56": "yawning",
59
+ "57": "peeling potatoes",
60
+ "58": "canoeing or kayaking",
61
+ "59": "front raises",
62
+ "60": "laughing",
63
+ "61": "dancing macarena",
64
+ "62": "digging",
65
+ "63": "reading newspaper",
66
+ "64": "hitting baseball",
67
+ "65": "clay pottery making",
68
+ "66": "exercising with an exercise ball",
69
+ "67": "playing saxophone",
70
+ "68": "shooting basketball",
71
+ "69": "washing hair",
72
+ "70": "lunge",
73
+ "71": "brushing hair",
74
+ "72": "curling hair",
75
+ "73": "kitesurfing",
76
+ "74": "tapping guitar",
77
+ "75": "bending back",
78
+ "76": "skipping rope",
79
+ "77": "situp",
80
+ "78": "folding paper",
81
+ "79": "cracking neck",
82
+ "80": "assembling computer",
83
+ "81": "cleaning gutters",
84
+ "82": "blowing out candles",
85
+ "83": "shaking hands",
86
+ "84": "dancing gangnam style",
87
+ "85": "windsurfing",
88
+ "86": "tap dancing",
89
+ "87": "skiing (not slalom or crosscountry)",
90
+ "88": "bandaging",
91
+ "89": "push up",
92
+ "90": "doing nails",
93
+ "91": "punching person (boxing)",
94
+ "92": "bouncing on trampoline",
95
+ "93": "scrambling eggs",
96
+ "94": "singing",
97
+ "95": "cleaning floor",
98
+ "96": "krumping",
99
+ "97": "drumming fingers",
100
+ "98": "snowmobiling",
101
+ "99": "gymnastics tumbling",
102
+ "100": "headbanging",
103
+ "101": "catching or throwing frisbee",
104
+ "102": "riding elephant",
105
+ "103": "bee keeping",
106
+ "104": "feeding birds",
107
+ "105": "snatch weight lifting",
108
+ "106": "mowing lawn",
109
+ "107": "fixing hair",
110
+ "108": "playing trumpet",
111
+ "109": "flying kite",
112
+ "110": "crossing river",
113
+ "111": "swinging legs",
114
+ "112": "sanding floor",
115
+ "113": "belly dancing",
116
+ "114": "sneezing",
117
+ "115": "clean and jerk",
118
+ "116": "side kick",
119
+ "117": "filling eyebrows",
120
+ "118": "shuffling cards",
121
+ "119": "recording music",
122
+ "120": "cartwheeling",
123
+ "121": "feeding fish",
124
+ "122": "folding clothes",
125
+ "123": "water skiing",
126
+ "124": "tobogganing",
127
+ "125": "blowing leaves",
128
+ "126": "smoking",
129
+ "127": "unboxing",
130
+ "128": "tai chi",
131
+ "129": "waxing legs",
132
+ "130": "riding camel",
133
+ "131": "slapping",
134
+ "132": "tossing salad",
135
+ "133": "capoeira",
136
+ "134": "playing cards",
137
+ "135": "playing organ",
138
+ "136": "playing violin",
139
+ "137": "playing drums",
140
+ "138": "tapping pen",
141
+ "139": "vault",
142
+ "140": "shoveling snow",
143
+ "141": "playing tennis",
144
+ "142": "getting a tattoo",
145
+ "143": "making a sandwich",
146
+ "144": "making tea",
147
+ "145": "grinding meat",
148
+ "146": "squat",
149
+ "147": "eating doughnuts",
150
+ "148": "ice fishing",
151
+ "149": "snowkiting",
152
+ "150": "kicking soccer ball",
153
+ "151": "playing controller",
154
+ "152": "giving or receiving award",
155
+ "153": "welding",
156
+ "154": "throwing discus",
157
+ "155": "throwing axe",
158
+ "156": "ripping paper",
159
+ "157": "swimming butterfly stroke",
160
+ "158": "air drumming",
161
+ "159": "blowing nose",
162
+ "160": "hockey stop",
163
+ "161": "taking a shower",
164
+ "162": "bench pressing",
165
+ "163": "planting trees",
166
+ "164": "pumping fist",
167
+ "165": "climbing tree",
168
+ "166": "tickling",
169
+ "167": "high kick",
170
+ "168": "waiting in line",
171
+ "169": "slacklining",
172
+ "170": "tango dancing",
173
+ "171": "hurdling",
174
+ "172": "carrying baby",
175
+ "173": "celebrating",
176
+ "174": "sharpening knives",
177
+ "175": "passing American football (in game)",
178
+ "176": "headbutting",
179
+ "177": "playing recorder",
180
+ "178": "brush painting",
181
+ "179": "garbage collecting",
182
+ "180": "robot dancing",
183
+ "181": "shredding paper",
184
+ "182": "pumping gas",
185
+ "183": "rock climbing",
186
+ "184": "hula hooping",
187
+ "185": "braiding hair",
188
+ "186": "opening present",
189
+ "187": "texting",
190
+ "188": "decorating the christmas tree",
191
+ "189": "answering questions",
192
+ "190": "playing keyboard",
193
+ "191": "writing",
194
+ "192": "bungee jumping",
195
+ "193": "sniffing",
196
+ "194": "eating burger",
197
+ "195": "playing accordion",
198
+ "196": "making pizza",
199
+ "197": "playing volleyball",
200
+ "198": "tasting food",
201
+ "199": "pushing cart",
202
+ "200": "spinning poi",
203
+ "201": "cleaning windows",
204
+ "202": "arm wrestling",
205
+ "203": "changing oil",
206
+ "204": "swimming breast stroke",
207
+ "205": "tossing coin",
208
+ "206": "deadlifting",
209
+ "207": "hoverboarding",
210
+ "208": "cutting watermelon",
211
+ "209": "cheerleading",
212
+ "210": "snorkeling",
213
+ "211": "washing hands",
214
+ "212": "eating cake",
215
+ "213": "pull ups",
216
+ "214": "surfing water",
217
+ "215": "eating hotdog",
218
+ "216": "holding snake",
219
+ "217": "playing harmonica",
220
+ "218": "ironing",
221
+ "219": "cutting nails",
222
+ "220": "golf chipping",
223
+ "221": "shot put",
224
+ "222": "hugging",
225
+ "223": "playing clarinet",
226
+ "224": "faceplanting",
227
+ "225": "trimming or shaving beard",
228
+ "226": "drinking shots",
229
+ "227": "riding mountain bike",
230
+ "228": "tying bow tie",
231
+ "229": "swinging on something",
232
+ "230": "skiing crosscountry",
233
+ "231": "unloading truck",
234
+ "232": "cleaning pool",
235
+ "233": "jogging",
236
+ "234": "ice climbing",
237
+ "235": "mopping floor",
238
+ "236": "making bed",
239
+ "237": "diving cliff",
240
+ "238": "washing dishes",
241
+ "239": "grooming dog",
242
+ "240": "weaving basket",
243
+ "241": "frying vegetables",
244
+ "242": "stomping grapes",
245
+ "243": "moving furniture",
246
+ "244": "cooking sausages",
247
+ "245": "doing laundry",
248
+ "246": "dying hair",
249
+ "247": "knitting",
250
+ "248": "reading book",
251
+ "249": "baby waking up",
252
+ "250": "punching bag",
253
+ "251": "surfing crowd",
254
+ "252": "cooking chicken",
255
+ "253": "pushing car",
256
+ "254": "springboard diving",
257
+ "255": "swing dancing",
258
+ "256": "massaging legs",
259
+ "257": "beatboxing",
260
+ "258": "breading or breadcrumbing",
261
+ "259": "somersaulting",
262
+ "260": "brushing teeth",
263
+ "261": "stretching arm",
264
+ "262": "juggling balls",
265
+ "263": "massaging person's head",
266
+ "264": "eating ice cream",
267
+ "265": "extinguishing fire",
268
+ "266": "hammer throw",
269
+ "267": "whistling",
270
+ "268": "crawling baby",
271
+ "269": "using remote controller (not gaming)",
272
+ "270": "playing cricket",
273
+ "271": "opening bottle",
274
+ "272": "playing xylophone",
275
+ "273": "motorcycling",
276
+ "274": "driving car",
277
+ "275": "exercising arm",
278
+ "276": "passing American football (not in game)",
279
+ "277": "playing kickball",
280
+ "278": "sticking tongue out",
281
+ "279": "flipping pancake",
282
+ "280": "catching fish",
283
+ "281": "eating chips",
284
+ "282": "shaking head",
285
+ "283": "sword fighting",
286
+ "284": "playing poker",
287
+ "285": "cooking on campfire",
288
+ "286": "doing aerobics",
289
+ "287": "paragliding",
290
+ "288": "using segway",
291
+ "289": "folding napkins",
292
+ "290": "playing bagpipes",
293
+ "291": "gargling",
294
+ "292": "skiing slalom",
295
+ "293": "strumming guitar",
296
+ "294": "javelin throw",
297
+ "295": "waxing back",
298
+ "296": "riding or walking with horse",
299
+ "297": "plastering",
300
+ "298": "long jump",
301
+ "299": "parkour",
302
+ "300": "wrapping present",
303
+ "301": "egg hunting",
304
+ "302": "archery",
305
+ "303": "cleaning toilet",
306
+ "304": "swimming backstroke",
307
+ "305": "snowboarding",
308
+ "306": "catching or throwing baseball",
309
+ "307": "massaging back",
310
+ "308": "blowing glass",
311
+ "309": "playing guitar",
312
+ "310": "playing chess",
313
+ "311": "golf driving",
314
+ "312": "presenting weather forecast",
315
+ "313": "rock scissors paper",
316
+ "314": "high jump",
317
+ "315": "baking cookies",
318
+ "316": "using computer",
319
+ "317": "washing feet",
320
+ "318": "arranging flowers",
321
+ "319": "playing bass guitar",
322
+ "320": "spraying",
323
+ "321": "cutting pineapple",
324
+ "322": "waxing chest",
325
+ "323": "auctioning",
326
+ "324": "jetskiing",
327
+ "325": "drinking",
328
+ "326": "busking",
329
+ "327": "playing monopoly",
330
+ "328": "salsa dancing",
331
+ "329": "waxing eyebrows",
332
+ "330": "watering plants",
333
+ "331": "zumba",
334
+ "332": "chopping wood",
335
+ "333": "pushing wheelchair",
336
+ "334": "carving pumpkin",
337
+ "335": "building shed",
338
+ "336": "making jewelry",
339
+ "337": "catching or throwing softball",
340
+ "338": "bending metal",
341
+ "339": "ice skating",
342
+ "340": "dancing charleston",
343
+ "341": "abseiling",
344
+ "342": "climbing a rope",
345
+ "343": "crying",
346
+ "344": "cleaning shoes",
347
+ "345": "dancing ballet",
348
+ "346": "driving tractor",
349
+ "347": "triple jump",
350
+ "348": "throwing ball",
351
+ "349": "getting a haircut",
352
+ "350": "running on treadmill",
353
+ "351": "climbing ladder",
354
+ "352": "blasting sand",
355
+ "353": "playing trombone",
356
+ "354": "drop kicking",
357
+ "355": "country line dancing",
358
+ "356": "changing wheel",
359
+ "357": "feeding goats",
360
+ "358": "tying knot (not on a tie)",
361
+ "359": "setting table",
362
+ "360": "shaving legs",
363
+ "361": "kissing",
364
+ "362": "riding mule",
365
+ "363": "counting money",
366
+ "364": "laying bricks",
367
+ "365": "barbequing",
368
+ "366": "news anchoring",
369
+ "367": "smoking hookah",
370
+ "368": "cooking egg",
371
+ "369": "peeling apples",
372
+ "370": "yoga",
373
+ "371": "sharpening pencil",
374
+ "372": "dribbling basketball",
375
+ "373": "petting cat",
376
+ "374": "playing ice hockey",
377
+ "375": "milking cow",
378
+ "376": "shining shoes",
379
+ "377": "juggling soccer ball",
380
+ "378": "scuba diving",
381
+ "379": "playing squash or racquetball",
382
+ "380": "drinking beer",
383
+ "381": "sign language interpreting",
384
+ "382": "playing basketball",
385
+ "383": "breakdancing",
386
+ "384": "testifying",
387
+ "385": "making snowman",
388
+ "386": "golf putting",
389
+ "387": "playing didgeridoo",
390
+ "388": "biking through snow",
391
+ "389": "sailing",
392
+ "390": "jumpstyle dancing",
393
+ "391": "water sliding",
394
+ "392": "grooming horse",
395
+ "393": "massaging feet",
396
+ "394": "playing paintball",
397
+ "395": "making a cake",
398
+ "396": "bowling",
399
+ "397": "contact juggling",
400
+ "398": "applying cream",
401
+ "399": "playing badminton"
402
+ }
mamba/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "3rdparty/lm-evaluation-harness"]
2
+ path = 3rdparty/lm-evaluation-harness
3
+ url = https://github.com/EleutherAI/lm-evaluation-harness/
mamba/AUTHORS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Tri Dao, [email protected]
2
+ Albert Gu, [email protected]
mamba/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Tri Dao, Albert Gu
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
mamba/README.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mamba
2
+
3
+ ![Mamba](assets/selection.png "Selective State Space")
4
+ > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5
+ > Albert Gu*, Tri Dao*\
6
+ > Paper: https://arxiv.org/abs/2312.00752
7
+
8
+ ## About
9
+
10
+ Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
11
+ It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
12
+ with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
13
+
14
+ ## Installation
15
+
16
+ - `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
17
+ - `pip install mamba-ssm`: the core Mamba package.
18
+
19
+ It can also be built from source with `pip install .` from this repository.
20
+
21
+ If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
22
+
23
+ Other requirements:
24
+ - Linux
25
+ - NVIDIA GPU
26
+ - PyTorch 1.12+
27
+ - CUDA 11.6+
28
+
29
+ ## Usage
30
+
31
+ We expose several levels of interface with the Mamba model.
32
+
33
+ ### Selective SSM
34
+
35
+ Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
36
+
37
+ Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
38
+
39
+ ### Mamba Block
40
+
41
+ The main module of this repository is the Mamba architecture block wrapping the selective SSM.
42
+
43
+ Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
44
+
45
+ Usage:
46
+ ```
47
+ from mamba_ssm import Mamba
48
+
49
+ batch, length, dim = 2, 64, 16
50
+ x = torch.randn(batch, length, dim).to("cuda")
51
+ model = Mamba(
52
+ # This module uses roughly 3 * expand * d_model^2 parameters
53
+ d_model=dim, # Model dimension d_model
54
+ d_state=16, # SSM state expansion factor
55
+ d_conv=4, # Local convolution width
56
+ expand=2, # Block expansion factor
57
+ ).to("cuda")
58
+ y = model(x)
59
+ assert y.shape == x.shape
60
+ ```
61
+
62
+ ### Mamba Language Model
63
+
64
+ Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
65
+
66
+ Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
67
+
68
+ This is an example of how to integrate Mamba into an end-to-end neural network.
69
+ This example is used in the generation scripts below.
70
+
71
+
72
+
73
+ ## Pretrained Models
74
+
75
+ Pretrained models are uploaded to
76
+ [HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
77
+ `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
78
+
79
+ The models will be autodownloaded by the generation script below.
80
+
81
+ These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
82
+
83
+ | Parameters | Layers | Model dim. |
84
+ |------------|--------|------------|
85
+ | 130M | 12 | 768 |
86
+ | 370M | 24 | 1024 |
87
+ | 790M | 24 | 1536 |
88
+ | 1.4B | 24 | 2048 |
89
+ | 2.8B | 32 | 2560 |
90
+
91
+ (The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
92
+
93
+ Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
94
+ Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
95
+
96
+
97
+ ## Evaluations
98
+
99
+ To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
100
+ we use the
101
+ [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
102
+ library.
103
+
104
+ 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
105
+ --recursive`. We use the `big-refactor` branch.
106
+ 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
107
+ 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
108
+ ```
109
+ python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
110
+ python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
111
+ ```
112
+
113
+ Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
114
+
115
+ ## Inference
116
+
117
+ The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
118
+ 1. autoloads a model from the HuggingFace Hub,
119
+ 2. generates completions of a user-specified prompt,
120
+ 3. benchmarks the inference speed of this generation.
121
+
122
+ Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
123
+
124
+ ### Examples
125
+
126
+ To test generation latency (e.g. batch size = 1) with different sampling strategies:
127
+
128
+ ```
129
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
130
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
131
+ ```
132
+
133
+ To test generation throughput with random prompts (e.g. large batch size):
134
+ ```
135
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
136
+ python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
137
+ ```
138
+
139
+ ## Citation
140
+
141
+ If you use this codebase, or otherwise found our work valuable, please cite Mamba:
142
+ ```
143
+ @article{mamba,
144
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
145
+ author={Gu, Albert and Dao, Tri},
146
+ journal={arXiv preprint arXiv:2312.00752},
147
+ year={2023}
148
+ }
149
+ ```
mamba/assets/selection.png ADDED
mamba/benchmarks/benchmark_generation_mamba_simple.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import argparse
4
+ import time
5
+ import json
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15
+
16
+
17
+ parser = argparse.ArgumentParser(description="Generation benchmarking")
18
+ parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19
+ parser.add_argument("--prompt", type=str, default=None)
20
+ parser.add_argument("--promptlen", type=int, default=100)
21
+ parser.add_argument("--genlen", type=int, default=100)
22
+ parser.add_argument("--temperature", type=float, default=1.0)
23
+ parser.add_argument("--topk", type=int, default=1)
24
+ parser.add_argument("--topp", type=float, default=1.0)
25
+ parser.add_argument("--batch", type=int, default=1)
26
+ args = parser.parse_args()
27
+
28
+ repeats = 3
29
+ device = "cuda"
30
+ dtype = torch.float16
31
+
32
+ print(f"Loading model {args.model_name}")
33
+ is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name
34
+
35
+ if is_mamba:
36
+ tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer")
37
+ model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
38
+ else:
39
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
41
+ model.eval()
42
+ print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
43
+
44
+ torch.random.manual_seed(0)
45
+ if args.prompt is None:
46
+ input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
47
+ attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
48
+ else:
49
+ tokens = tokenizer(args.prompt, return_tensors="pt")
50
+ input_ids = tokens.input_ids.to(device=device)
51
+ attn_mask = tokens.attention_mask.to(device=device)
52
+ max_length = input_ids.shape[1] + args.genlen
53
+
54
+ if is_mamba:
55
+ fn = lambda: model.generate(
56
+ input_ids=input_ids,
57
+ max_length=max_length,
58
+ cg=True,
59
+ return_dict_in_generate=True,
60
+ output_scores=True,
61
+ enable_timing=False,
62
+ temperature=args.temperature,
63
+ top_k=args.topk,
64
+ top_p=args.topp,
65
+ )
66
+ else:
67
+ fn = lambda: model.generate(
68
+ input_ids=input_ids,
69
+ attention_mask=attn_mask,
70
+ max_length=max_length,
71
+ return_dict_in_generate=True,
72
+ pad_token_id=tokenizer.eos_token_id,
73
+ do_sample=True,
74
+ temperature=args.temperature,
75
+ top_k=args.topk,
76
+ top_p=args.topp,
77
+ )
78
+ out = fn()
79
+ if args.prompt is not None:
80
+ print(tokenizer.batch_decode(out.sequences.tolist()))
81
+
82
+ torch.cuda.synchronize()
83
+ start = time.time()
84
+ for _ in range(repeats):
85
+ fn()
86
+ torch.cuda.synchronize()
87
+ print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
88
+ print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
mamba/csrc/selective_scan/reverse_scan.cuh ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cub/config.cuh>
8
+
9
+ #include <cub/util_ptx.cuh>
10
+ #include <cub/util_type.cuh>
11
+ #include <cub/block/block_raking_layout.cuh>
12
+ // #include <cub/detail/uninitialized_copy.cuh>
13
+ #include "uninitialized_copy.cuh"
14
+
15
+ /**
16
+ * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
17
+ */
18
+ template <
19
+ int LENGTH,
20
+ typename T,
21
+ typename ReductionOp>
22
+ __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
23
+ static_assert(LENGTH > 0);
24
+ T retval = input[LENGTH - 1];
25
+ #pragma unroll
26
+ for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
27
+ return retval;
28
+ }
29
+
30
+ /**
31
+ * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
32
+ */
33
+ template <
34
+ int LENGTH,
35
+ typename T,
36
+ typename ScanOp>
37
+ __device__ __forceinline__ T ThreadReverseScanInclusive(
38
+ const T (&input)[LENGTH],
39
+ T (&output)[LENGTH],
40
+ ScanOp scan_op,
41
+ const T postfix)
42
+ {
43
+ T inclusive = postfix;
44
+ #pragma unroll
45
+ for (int i = LENGTH - 1; i >= 0; --i) {
46
+ inclusive = scan_op(inclusive, input[i]);
47
+ output[i] = inclusive;
48
+ }
49
+ }
50
+
51
+ /**
52
+ * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
53
+ */
54
+ template <
55
+ int LENGTH,
56
+ typename T,
57
+ typename ScanOp>
58
+ __device__ __forceinline__ T ThreadReverseScanExclusive(
59
+ const T (&input)[LENGTH],
60
+ T (&output)[LENGTH],
61
+ ScanOp scan_op,
62
+ const T postfix)
63
+ {
64
+ // Careful, output maybe be aliased to input
65
+ T exclusive = postfix;
66
+ T inclusive;
67
+ #pragma unroll
68
+ for (int i = LENGTH - 1; i >= 0; --i) {
69
+ inclusive = scan_op(exclusive, input[i]);
70
+ output[i] = exclusive;
71
+ exclusive = inclusive;
72
+ }
73
+ return inclusive;
74
+ }
75
+
76
+
77
+ /**
78
+ * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
79
+ *
80
+ * LOGICAL_WARP_THREADS must be a power-of-two
81
+ */
82
+ template <
83
+ typename T, ///< Data type being scanned
84
+ int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
85
+ >
86
+ struct WarpReverseScan {
87
+ //---------------------------------------------------------------------
88
+ // Constants and type definitions
89
+ //---------------------------------------------------------------------
90
+
91
+ /// Whether the logical warp size and the PTX warp size coincide
92
+ static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
93
+ /// The number of warp scan steps
94
+ static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
95
+ static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
96
+
97
+
98
+ //---------------------------------------------------------------------
99
+ // Thread fields
100
+ //---------------------------------------------------------------------
101
+
102
+ /// Lane index in logical warp
103
+ unsigned int lane_id;
104
+
105
+ /// Logical warp index in 32-thread physical warp
106
+ unsigned int warp_id;
107
+
108
+ /// 32-thread physical warp member mask of logical warp
109
+ unsigned int member_mask;
110
+
111
+ //---------------------------------------------------------------------
112
+ // Construction
113
+ //---------------------------------------------------------------------
114
+
115
+ /// Constructor
116
+ explicit __device__ __forceinline__
117
+ WarpReverseScan()
118
+ : lane_id(cub::LaneId())
119
+ , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
120
+ , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
121
+ {
122
+ if (!IS_ARCH_WARP) {
123
+ lane_id = lane_id % LOGICAL_WARP_THREADS;
124
+ }
125
+ }
126
+
127
+
128
+ /// Broadcast
129
+ __device__ __forceinline__ T Broadcast(
130
+ T input, ///< [in] The value to broadcast
131
+ int src_lane) ///< [in] Which warp lane is to do the broadcasting
132
+ {
133
+ return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
134
+ }
135
+
136
+
137
+ /// Inclusive scan
138
+ template <typename ScanOpT>
139
+ __device__ __forceinline__ void InclusiveReverseScan(
140
+ T input, ///< [in] Calling thread's input item.
141
+ T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
142
+ ScanOpT scan_op) ///< [in] Binary scan operator
143
+ {
144
+ inclusive_output = input;
145
+ #pragma unroll
146
+ for (int STEP = 0; STEP < STEPS; STEP++) {
147
+ int offset = 1 << STEP;
148
+ T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
149
+ inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
150
+ );
151
+ // Perform scan op if from a valid peer
152
+ inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
153
+ ? inclusive_output : scan_op(temp, inclusive_output);
154
+ }
155
+ }
156
+
157
+ /// Exclusive scan
158
+ // Get exclusive from inclusive
159
+ template <typename ScanOpT>
160
+ __device__ __forceinline__ void ExclusiveReverseScan(
161
+ T input, ///< [in] Calling thread's input item.
162
+ T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
163
+ ScanOpT scan_op, ///< [in] Binary scan operator
164
+ T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
165
+ {
166
+ T inclusive_output;
167
+ InclusiveReverseScan(input, inclusive_output, scan_op);
168
+ warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
169
+ // initial value unknown
170
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
171
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
172
+ );
173
+ }
174
+
175
+ /**
176
+ * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
177
+ */
178
+ template <typename ScanOpT>
179
+ __device__ __forceinline__ void ReverseScan(
180
+ T input, ///< [in] Calling thread's input item.
181
+ T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
182
+ T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
183
+ ScanOpT scan_op) ///< [in] Binary scan operator
184
+ {
185
+ InclusiveReverseScan(input, inclusive_output, scan_op);
186
+ // initial value unknown
187
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
188
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
189
+ );
190
+ }
191
+
192
+ };
193
+
194
+ /**
195
+ * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
196
+ */
197
+ template <
198
+ typename T, ///< Data type being scanned
199
+ int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
200
+ bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
201
+ >
202
+ struct BlockReverseScan {
203
+ //---------------------------------------------------------------------
204
+ // Types and constants
205
+ //---------------------------------------------------------------------
206
+
207
+ /// Constants
208
+ /// The thread block size in threads
209
+ static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
210
+
211
+ /// Layout type for padded thread block raking grid
212
+ using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
213
+ // The number of reduction elements is not a multiple of the number of raking threads for now
214
+ static_assert(BlockRakingLayout::UNGUARDED);
215
+
216
+ /// Number of raking threads
217
+ static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
218
+ /// Number of raking elements per warp synchronous raking thread
219
+ static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
220
+ /// Cooperative work can be entirely warp synchronous
221
+ static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
222
+
223
+ /// WarpReverseScan utility type
224
+ using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
225
+
226
+ /// Shared memory storage layout type
227
+ struct _TempStorage {
228
+ typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
229
+ };
230
+
231
+
232
+ /// Alias wrapper allowing storage to be unioned
233
+ struct TempStorage : cub::Uninitialized<_TempStorage> {};
234
+
235
+
236
+ //---------------------------------------------------------------------
237
+ // Per-thread fields
238
+ //---------------------------------------------------------------------
239
+
240
+ // Thread fields
241
+ _TempStorage &temp_storage;
242
+ unsigned int linear_tid;
243
+ T cached_segment[SEGMENT_LENGTH];
244
+
245
+
246
+ //---------------------------------------------------------------------
247
+ // Utility methods
248
+ //---------------------------------------------------------------------
249
+
250
+ /// Performs upsweep raking reduction, returning the aggregate
251
+ template <typename ScanOp>
252
+ __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
253
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
254
+ // Read data into registers
255
+ #pragma unroll
256
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
257
+ T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
258
+ #pragma unroll
259
+ for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
260
+ raking_partial = scan_op(raking_partial, cached_segment[i]);
261
+ }
262
+ return raking_partial;
263
+ }
264
+
265
+
266
+ /// Performs exclusive downsweep raking scan
267
+ template <typename ScanOp>
268
+ __device__ __forceinline__ void ExclusiveDownsweep(
269
+ ScanOp scan_op,
270
+ T raking_partial)
271
+ {
272
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
273
+ // Read data back into registers
274
+ if (!MEMOIZE) {
275
+ #pragma unroll
276
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
277
+ }
278
+ ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
279
+ // Write data back to smem
280
+ #pragma unroll
281
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
282
+ }
283
+
284
+
285
+ //---------------------------------------------------------------------
286
+ // Constructors
287
+ //---------------------------------------------------------------------
288
+
289
+ /// Constructor
290
+ __device__ __forceinline__ BlockReverseScan(
291
+ TempStorage &temp_storage)
292
+ :
293
+ temp_storage(temp_storage.Alias()),
294
+ linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
295
+ {}
296
+
297
+
298
+ /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
299
+ template <
300
+ typename ScanOp,
301
+ typename BlockPostfixCallbackOp>
302
+ __device__ __forceinline__ void ExclusiveReverseScan(
303
+ T input, ///< [in] Calling thread's input item
304
+ T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
305
+ ScanOp scan_op, ///< [in] Binary scan operator
306
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
307
+ {
308
+ if (WARP_SYNCHRONOUS) {
309
+ // Short-circuit directly to warp-synchronous scan
310
+ T block_aggregate;
311
+ WarpReverseScan warp_scan;
312
+ warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
313
+ // Obtain warp-wide postfix in lane0, then broadcast to other lanes
314
+ T block_postfix = block_postfix_callback_op(block_aggregate);
315
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
316
+ exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
317
+ } else {
318
+ // Place thread partial into shared memory raking grid
319
+ T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
320
+ detail::uninitialized_copy(placement_ptr, input);
321
+ cub::CTA_SYNC();
322
+ // Reduce parallelism down to just raking threads
323
+ if (linear_tid < RAKING_THREADS) {
324
+ WarpReverseScan warp_scan;
325
+ // Raking upsweep reduction across shared partials
326
+ T upsweep_partial = Upsweep(scan_op);
327
+ // Warp-synchronous scan
328
+ T exclusive_partial, block_aggregate;
329
+ warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
330
+ // Obtain block-wide postfix in lane0, then broadcast to other lanes
331
+ T block_postfix = block_postfix_callback_op(block_aggregate);
332
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
333
+ // Update postfix with warpscan exclusive partial
334
+ T downsweep_postfix = linear_tid == RAKING_THREADS - 1
335
+ ? block_postfix : scan_op(block_postfix, exclusive_partial);
336
+ // Exclusive raking downsweep scan
337
+ ExclusiveDownsweep(scan_op, downsweep_postfix);
338
+ }
339
+ cub::CTA_SYNC();
340
+ // Grab thread postfix from shared memory
341
+ exclusive_output = *placement_ptr;
342
+
343
+ // // Compute warp scan in each warp.
344
+ // // The exclusive output from the last lane in each warp is invalid.
345
+ // T inclusive_output;
346
+ // WarpReverseScan warp_scan;
347
+ // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
348
+
349
+ // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
350
+ // T block_aggregate;
351
+ // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
352
+
353
+ // // Apply warp postfix to our lane's partial
354
+ // if (warp_id != 0) {
355
+ // exclusive_output = scan_op(warp_postfix, exclusive_output);
356
+ // if (lane_id == 0) { exclusive_output = warp_postfix; }
357
+ // }
358
+
359
+ // // Use the first warp to determine the thread block postfix, returning the result in lane0
360
+ // if (warp_id == 0) {
361
+ // T block_postfix = block_postfix_callback_op(block_aggregate);
362
+ // if (lane_id == 0) {
363
+ // // Share the postfix with all threads
364
+ // detail::uninitialized_copy(&temp_storage.block_postfix,
365
+ // block_postfix);
366
+
367
+ // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
368
+ // }
369
+ // }
370
+
371
+ // cub::CTA_SYNC();
372
+
373
+ // // Incorporate thread block postfix into outputs
374
+ // T block_postfix = temp_storage.block_postfix;
375
+ // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
376
+ }
377
+ }
378
+
379
+
380
+ /**
381
+ * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
382
+ */
383
+ template <
384
+ int ITEMS_PER_THREAD,
385
+ typename ScanOp,
386
+ typename BlockPostfixCallbackOp>
387
+ __device__ __forceinline__ void InclusiveReverseScan(
388
+ T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
389
+ T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
390
+ ScanOp scan_op, ///< [in] Binary scan functor
391
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
392
+ {
393
+ // Reduce consecutive thread items in registers
394
+ T thread_postfix = ThreadReverseReduce(input, scan_op);
395
+ // Exclusive thread block-scan
396
+ ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
397
+ // Inclusive scan in registers with postfix as seed
398
+ ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
399
+ }
400
+
401
+ };
mamba/csrc/selective_scan/selective_scan.cpp ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+
14
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
+ if (ITYPE == at::ScalarType::Half) { \
16
+ using input_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
19
+ using input_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (ITYPE == at::ScalarType::Float) { \
22
+ using input_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
+ }
27
+
28
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
+ if (WTYPE == at::ScalarType::Half) { \
30
+ using weight_t = at::Half; \
31
+ __VA_ARGS__(); \
32
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
33
+ using weight_t = at::BFloat16; \
34
+ __VA_ARGS__(); \
35
+ } else if (WTYPE == at::ScalarType::Float) { \
36
+ using weight_t = float; \
37
+ __VA_ARGS__(); \
38
+ } else { \
39
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
+ }
41
+
42
+ #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
43
+ if (WTYPE == at::ScalarType::Float) { \
44
+ using weight_t = float; \
45
+ __VA_ARGS__(); \
46
+ } else if (WTYPE == at::ScalarType::ComplexFloat) { \
47
+ using weight_t = c10::complex<float>; \
48
+ __VA_ARGS__(); \
49
+ } else { \
50
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
51
+ }
52
+
53
+ template<typename input_t, typename weight_t>
54
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
55
+
56
+ template <typename input_t, typename weight_t>
57
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
58
+
59
+ void set_ssm_params_fwd(SSMParamsBase &params,
60
+ // sizes
61
+ const size_t batch,
62
+ const size_t dim,
63
+ const size_t seqlen,
64
+ const size_t dstate,
65
+ const size_t n_groups,
66
+ const size_t n_chunks,
67
+ const bool is_variable_B,
68
+ const bool is_variable_C,
69
+ // device pointers
70
+ const at::Tensor u,
71
+ const at::Tensor delta,
72
+ const at::Tensor A,
73
+ const at::Tensor B,
74
+ const at::Tensor C,
75
+ const at::Tensor out,
76
+ const at::Tensor z,
77
+ const at::Tensor out_z,
78
+ void* D_ptr,
79
+ void* delta_bias_ptr,
80
+ void* x_ptr,
81
+ bool has_z,
82
+ bool delta_softplus) {
83
+
84
+ // Reset the parameters
85
+ memset(&params, 0, sizeof(params));
86
+
87
+ params.batch = batch;
88
+ params.dim = dim;
89
+ params.seqlen = seqlen;
90
+ params.dstate = dstate;
91
+ params.n_groups = n_groups;
92
+ params.n_chunks = n_chunks;
93
+ params.dim_ngroups_ratio = dim / n_groups;
94
+
95
+ params.delta_softplus = delta_softplus;
96
+
97
+ params.is_variable_B = is_variable_B;
98
+ params.is_variable_C = is_variable_C;
99
+
100
+ // Set the pointers and strides.
101
+ params.u_ptr = u.data_ptr();
102
+ params.delta_ptr = delta.data_ptr();
103
+ params.A_ptr = A.data_ptr();
104
+ params.B_ptr = B.data_ptr();
105
+ params.C_ptr = C.data_ptr();
106
+ params.D_ptr = D_ptr;
107
+ params.delta_bias_ptr = delta_bias_ptr;
108
+ params.out_ptr = out.data_ptr();
109
+ params.x_ptr = x_ptr;
110
+ params.z_ptr = has_z ? z.data_ptr() : nullptr;
111
+ params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
112
+ // All stride are in elements, not bytes.
113
+ params.A_d_stride = A.stride(0);
114
+ params.A_dstate_stride = A.stride(1);
115
+ if (!is_variable_B) {
116
+ params.B_d_stride = B.stride(0);
117
+ } else {
118
+ params.B_batch_stride = B.stride(0);
119
+ params.B_group_stride = B.stride(1);
120
+ }
121
+ params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
122
+ if (!is_variable_C) {
123
+ params.C_d_stride = C.stride(0);
124
+ } else {
125
+ params.C_batch_stride = C.stride(0);
126
+ params.C_group_stride = C.stride(1);
127
+ }
128
+ params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
129
+ params.u_batch_stride = u.stride(0);
130
+ params.u_d_stride = u.stride(1);
131
+ params.delta_batch_stride = delta.stride(0);
132
+ params.delta_d_stride = delta.stride(1);
133
+ if (has_z) {
134
+ params.z_batch_stride = z.stride(0);
135
+ params.z_d_stride = z.stride(1);
136
+ params.out_z_batch_stride = out_z.stride(0);
137
+ params.out_z_d_stride = out_z.stride(1);
138
+ }
139
+ params.out_batch_stride = out.stride(0);
140
+ params.out_d_stride = out.stride(1);
141
+ }
142
+
143
+ void set_ssm_params_bwd(SSMParamsBwd &params,
144
+ // sizes
145
+ const size_t batch,
146
+ const size_t dim,
147
+ const size_t seqlen,
148
+ const size_t dstate,
149
+ const size_t n_groups,
150
+ const size_t n_chunks,
151
+ const bool is_variable_B,
152
+ const bool is_variable_C,
153
+ // device pointers
154
+ const at::Tensor u,
155
+ const at::Tensor delta,
156
+ const at::Tensor A,
157
+ const at::Tensor B,
158
+ const at::Tensor C,
159
+ const at::Tensor z,
160
+ const at::Tensor out,
161
+ const at::Tensor out_z,
162
+ void* D_ptr,
163
+ void* delta_bias_ptr,
164
+ void* x_ptr,
165
+ const at::Tensor dout,
166
+ const at::Tensor du,
167
+ const at::Tensor ddelta,
168
+ const at::Tensor dA,
169
+ const at::Tensor dB,
170
+ const at::Tensor dC,
171
+ const at::Tensor dz,
172
+ void* dD_ptr,
173
+ void* ddelta_bias_ptr,
174
+ bool has_z,
175
+ bool delta_softplus,
176
+ bool recompute_out_z) {
177
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179
+ u, delta, A, B, C, has_z ? out : dout,
180
+ has_z ? z : dout,
181
+ // If not recompute_out_z, pass dout instead of out_z.
182
+ // This won't be used by the bwd kernel
183
+ recompute_out_z ? out_z : dout,
184
+ D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
185
+ if (!recompute_out_z) { params.out_z_ptr = nullptr; }
186
+
187
+ // Set the pointers and strides.
188
+ params.dout_ptr = dout.data_ptr();
189
+ params.du_ptr = du.data_ptr();
190
+ params.dA_ptr = dA.data_ptr();
191
+ params.dB_ptr = dB.data_ptr();
192
+ params.dC_ptr = dC.data_ptr();
193
+ params.dD_ptr = dD_ptr;
194
+ params.ddelta_ptr = ddelta.data_ptr();
195
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
196
+ params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
197
+ // All stride are in elements, not bytes.
198
+ params.dout_batch_stride = dout.stride(0);
199
+ params.dout_d_stride = dout.stride(1);
200
+ params.dA_d_stride = dA.stride(0);
201
+ params.dA_dstate_stride = dA.stride(1);
202
+ if (!is_variable_B) {
203
+ params.dB_d_stride = dB.stride(0);
204
+ } else {
205
+ params.dB_batch_stride = dB.stride(0);
206
+ params.dB_group_stride = dB.stride(1);
207
+ }
208
+ params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
209
+ if (!is_variable_C) {
210
+ params.dC_d_stride = dC.stride(0);
211
+ } else {
212
+ params.dC_batch_stride = dC.stride(0);
213
+ params.dC_group_stride = dC.stride(1);
214
+ }
215
+ params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
216
+ params.du_batch_stride = du.stride(0);
217
+ params.du_d_stride = du.stride(1);
218
+ params.ddelta_batch_stride = ddelta.stride(0);
219
+ params.ddelta_d_stride = ddelta.stride(1);
220
+ if (has_z) {
221
+ params.dz_batch_stride = dz.stride(0);
222
+ params.dz_d_stride = dz.stride(1);
223
+ }
224
+ }
225
+
226
+ std::vector<at::Tensor>
227
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
228
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
229
+ const c10::optional<at::Tensor> &D_,
230
+ const c10::optional<at::Tensor> &z_,
231
+ const c10::optional<at::Tensor> &delta_bias_,
232
+ bool delta_softplus) {
233
+ auto input_type = u.scalar_type();
234
+ auto weight_type = A.scalar_type();
235
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
236
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
237
+
238
+ const bool is_variable_B = B.dim() >= 3;
239
+ const bool is_variable_C = C.dim() >= 3;
240
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
241
+
242
+ TORCH_CHECK(delta.scalar_type() == input_type);
243
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
244
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
245
+
246
+ TORCH_CHECK(u.is_cuda());
247
+ TORCH_CHECK(delta.is_cuda());
248
+ TORCH_CHECK(A.is_cuda());
249
+ TORCH_CHECK(B.is_cuda());
250
+ TORCH_CHECK(C.is_cuda());
251
+
252
+ TORCH_CHECK(u.stride(-1) == 1);
253
+ TORCH_CHECK(delta.stride(-1) == 1);
254
+
255
+ const auto sizes = u.sizes();
256
+ const int batch_size = sizes[0];
257
+ const int dim = sizes[1];
258
+ const int seqlen = sizes[2];
259
+ const int dstate = A.size(1);
260
+ const int n_groups = is_variable_B ? B.size(1) : 1;
261
+
262
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
263
+
264
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
265
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
266
+ CHECK_SHAPE(A, dim, dstate);
267
+ if (!is_variable_B) {
268
+ CHECK_SHAPE(B, dim, dstate);
269
+ } else {
270
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
271
+ TORCH_CHECK(B.stride(-1) == 1);
272
+ }
273
+ if (!is_variable_C) {
274
+ CHECK_SHAPE(C, dim, dstate);
275
+ } else {
276
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
277
+ TORCH_CHECK(C.stride(-1) == 1);
278
+ }
279
+
280
+ if (D_.has_value()) {
281
+ auto D = D_.value();
282
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
283
+ TORCH_CHECK(D.is_cuda());
284
+ TORCH_CHECK(D.stride(-1) == 1);
285
+ CHECK_SHAPE(D, dim);
286
+ }
287
+
288
+ if (delta_bias_.has_value()) {
289
+ auto delta_bias = delta_bias_.value();
290
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
291
+ TORCH_CHECK(delta_bias.is_cuda());
292
+ TORCH_CHECK(delta_bias.stride(-1) == 1);
293
+ CHECK_SHAPE(delta_bias, dim);
294
+ }
295
+
296
+ at::Tensor z, out_z;
297
+ const bool has_z = z_.has_value();
298
+ if (has_z) {
299
+ z = z_.value();
300
+ TORCH_CHECK(z.scalar_type() == input_type);
301
+ TORCH_CHECK(z.is_cuda());
302
+ TORCH_CHECK(z.stride(-1) == 1);
303
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
304
+ out_z = torch::empty_like(z);
305
+ }
306
+
307
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
308
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
309
+ // at::Tensor out = torch::empty_like(u);
310
+ // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
311
+ at::Tensor out = torch::empty_like(delta);
312
+ at::Tensor x;
313
+ x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
314
+
315
+ SSMParamsBase params;
316
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
317
+ u, delta, A, B, C, out, z, out_z,
318
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
319
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
320
+ x.data_ptr(),
321
+ has_z,
322
+ delta_softplus);
323
+
324
+ // Otherwise the kernel will be launched from cuda:0 device
325
+ // Cast to char to avoid compiler warning about narrowing
326
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
327
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
328
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
329
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
330
+ selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
331
+ });
332
+ });
333
+ std::vector<at::Tensor> result = {out, x};
334
+ if (has_z) { result.push_back(out_z); }
335
+ return result;
336
+ }
337
+
338
+ std::vector<at::Tensor>
339
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
340
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
341
+ const c10::optional<at::Tensor> &D_,
342
+ const c10::optional<at::Tensor> &z_,
343
+ const c10::optional<at::Tensor> &delta_bias_,
344
+ const at::Tensor &dout,
345
+ const c10::optional<at::Tensor> &x_,
346
+ const c10::optional<at::Tensor> &out_,
347
+ c10::optional<at::Tensor> &dz_,
348
+ bool delta_softplus,
349
+ bool recompute_out_z) {
350
+ auto input_type = u.scalar_type();
351
+ auto weight_type = A.scalar_type();
352
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
353
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
354
+
355
+ const bool is_variable_B = B.dim() >= 3;
356
+ const bool is_variable_C = C.dim() >= 3;
357
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
358
+
359
+ TORCH_CHECK(delta.scalar_type() == input_type);
360
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
361
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
362
+ TORCH_CHECK(dout.scalar_type() == input_type);
363
+
364
+ TORCH_CHECK(u.is_cuda());
365
+ TORCH_CHECK(delta.is_cuda());
366
+ TORCH_CHECK(A.is_cuda());
367
+ TORCH_CHECK(B.is_cuda());
368
+ TORCH_CHECK(C.is_cuda());
369
+ TORCH_CHECK(dout.is_cuda());
370
+
371
+ TORCH_CHECK(u.stride(-1) == 1);
372
+ TORCH_CHECK(delta.stride(-1) == 1);
373
+ TORCH_CHECK(dout.stride(-1) == 1);
374
+
375
+ const auto sizes = u.sizes();
376
+ const int batch_size = sizes[0];
377
+ const int dim = sizes[1];
378
+ const int seqlen = sizes[2];
379
+ const int dstate = A.size(1);
380
+ const int n_groups = is_variable_B ? B.size(1) : 1;
381
+
382
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
383
+
384
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
385
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
386
+ CHECK_SHAPE(A, dim, dstate);
387
+ if (!is_variable_B) {
388
+ CHECK_SHAPE(B, dim, dstate);
389
+ } else {
390
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
391
+ TORCH_CHECK(B.stride(-1) == 1);
392
+ }
393
+ if (!is_variable_C) {
394
+ CHECK_SHAPE(C, dim, dstate);
395
+ } else {
396
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
397
+ TORCH_CHECK(C.stride(-1) == 1);
398
+ }
399
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
400
+
401
+ if (D_.has_value()) {
402
+ auto D = D_.value();
403
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
404
+ TORCH_CHECK(D.is_cuda());
405
+ TORCH_CHECK(D.stride(-1) == 1);
406
+ CHECK_SHAPE(D, dim);
407
+ }
408
+
409
+ if (delta_bias_.has_value()) {
410
+ auto delta_bias = delta_bias_.value();
411
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
412
+ TORCH_CHECK(delta_bias.is_cuda());
413
+ TORCH_CHECK(delta_bias.stride(-1) == 1);
414
+ CHECK_SHAPE(delta_bias, dim);
415
+ }
416
+
417
+ at::Tensor z, out, dz, out_z;
418
+ const bool has_z = z_.has_value();
419
+ if (has_z) {
420
+ z = z_.value();
421
+ TORCH_CHECK(z.scalar_type() == input_type);
422
+ TORCH_CHECK(z.is_cuda());
423
+ TORCH_CHECK(z.stride(-1) == 1);
424
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
425
+
426
+ TORCH_CHECK(out_.has_value());
427
+ out = out_.value();
428
+ TORCH_CHECK(out.scalar_type() == input_type);
429
+ TORCH_CHECK(out.is_cuda());
430
+ TORCH_CHECK(out.stride(-1) == 1);
431
+ CHECK_SHAPE(out, batch_size, dim, seqlen);
432
+
433
+ if (dz_.has_value()) {
434
+ dz = dz_.value();
435
+ TORCH_CHECK(dz.scalar_type() == input_type);
436
+ TORCH_CHECK(dz.is_cuda());
437
+ TORCH_CHECK(dz.stride(-1) == 1);
438
+ CHECK_SHAPE(dz, batch_size, dim, seqlen);
439
+ } else {
440
+ dz = torch::empty_like(z);
441
+ }
442
+ if (recompute_out_z) {
443
+ out_z = torch::empty_like(out);
444
+ }
445
+ }
446
+
447
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
448
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
449
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
450
+ if (x_.has_value()) {
451
+ auto x = x_.value();
452
+ TORCH_CHECK(x.scalar_type() == weight_type);
453
+ TORCH_CHECK(x.is_cuda());
454
+ TORCH_CHECK(x.is_contiguous());
455
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
456
+ }
457
+
458
+ at::Tensor du = torch::empty_like(u);
459
+ at::Tensor ddelta = torch::empty_like(delta);
460
+ at::Tensor dA = torch::zeros_like(A);
461
+ at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
462
+ at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
463
+ at::Tensor dD;
464
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
465
+ at::Tensor ddelta_bias;
466
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
467
+
468
+ SSMParamsBwd params;
469
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
470
+ u, delta, A, B, C, z, out, out_z,
471
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
472
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
473
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
474
+ dout, du, ddelta, dA, dB, dC, dz,
475
+ D_.has_value() ? dD.data_ptr() : nullptr,
476
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
477
+ has_z, delta_softplus, recompute_out_z);
478
+
479
+ // Otherwise the kernel will be launched from cuda:0 device
480
+ // Cast to char to avoid compiler warning about narrowing
481
+ at::cuda::CUDAGuard device_guard{(char)u.get_device()};
482
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
483
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
484
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
485
+ selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
486
+ });
487
+ });
488
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
489
+ if (has_z) { result.push_back(dz); }
490
+ if (recompute_out_z) { result.push_back(out_z); }
491
+ return result;
492
+ }
493
+
494
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
495
+ m.def("fwd", &selective_scan_fwd, "Selective scan forward");
496
+ m.def("bwd", &selective_scan_bwd, "Selective scan backward");
497
+ }
mamba/csrc/selective_scan/selective_scan.h ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct SSMScanParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, seqlen, n_chunks;
13
+ index_t a_batch_stride;
14
+ index_t b_batch_stride;
15
+ index_t out_batch_stride;
16
+
17
+ // Common data pointers.
18
+ void *__restrict__ a_ptr;
19
+ void *__restrict__ b_ptr;
20
+ void *__restrict__ out_ptr;
21
+ void *__restrict__ x_ptr;
22
+ };
23
+
24
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
25
+
26
+ struct SSMParamsBase {
27
+ using index_t = uint32_t;
28
+
29
+ int batch, dim, seqlen, dstate, n_groups, n_chunks;
30
+ int dim_ngroups_ratio;
31
+ bool is_variable_B;
32
+ bool is_variable_C;
33
+
34
+ bool delta_softplus;
35
+
36
+ index_t A_d_stride;
37
+ index_t A_dstate_stride;
38
+ index_t B_batch_stride;
39
+ index_t B_d_stride;
40
+ index_t B_dstate_stride;
41
+ index_t B_group_stride;
42
+ index_t C_batch_stride;
43
+ index_t C_d_stride;
44
+ index_t C_dstate_stride;
45
+ index_t C_group_stride;
46
+ index_t u_batch_stride;
47
+ index_t u_d_stride;
48
+ index_t delta_batch_stride;
49
+ index_t delta_d_stride;
50
+ index_t z_batch_stride;
51
+ index_t z_d_stride;
52
+ index_t out_batch_stride;
53
+ index_t out_d_stride;
54
+ index_t out_z_batch_stride;
55
+ index_t out_z_d_stride;
56
+
57
+ // Common data pointers.
58
+ void *__restrict__ A_ptr;
59
+ void *__restrict__ B_ptr;
60
+ void *__restrict__ C_ptr;
61
+ void *__restrict__ D_ptr;
62
+ void *__restrict__ u_ptr;
63
+ void *__restrict__ delta_ptr;
64
+ void *__restrict__ delta_bias_ptr;
65
+ void *__restrict__ out_ptr;
66
+ void *__restrict__ x_ptr;
67
+ void *__restrict__ z_ptr;
68
+ void *__restrict__ out_z_ptr;
69
+ };
70
+
71
+ struct SSMParamsBwd: public SSMParamsBase {
72
+ index_t dout_batch_stride;
73
+ index_t dout_d_stride;
74
+ index_t dA_d_stride;
75
+ index_t dA_dstate_stride;
76
+ index_t dB_batch_stride;
77
+ index_t dB_group_stride;
78
+ index_t dB_d_stride;
79
+ index_t dB_dstate_stride;
80
+ index_t dC_batch_stride;
81
+ index_t dC_group_stride;
82
+ index_t dC_d_stride;
83
+ index_t dC_dstate_stride;
84
+ index_t du_batch_stride;
85
+ index_t du_d_stride;
86
+ index_t dz_batch_stride;
87
+ index_t dz_d_stride;
88
+ index_t ddelta_batch_stride;
89
+ index_t ddelta_d_stride;
90
+
91
+ // Common data pointers.
92
+ void *__restrict__ dout_ptr;
93
+ void *__restrict__ dA_ptr;
94
+ void *__restrict__ dB_ptr;
95
+ void *__restrict__ dC_ptr;
96
+ void *__restrict__ dD_ptr;
97
+ void *__restrict__ du_ptr;
98
+ void *__restrict__ dz_ptr;
99
+ void *__restrict__ ddelta_ptr;
100
+ void *__restrict__ ddelta_bias_ptr;
101
+ };
mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #include <cub/block/block_reduce.cuh>
16
+
17
+ #include "selective_scan.h"
18
+ #include "selective_scan_common.h"
19
+ #include "reverse_scan.cuh"
20
+ #include "static_switch.h"
21
+
22
+ template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
23
+ template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
24
+ template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
25
+
26
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
27
+ bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
28
+ struct Selective_Scan_bwd_kernel_traits {
29
+ static_assert(kNItems_ % 4 == 0);
30
+ using input_t = input_t_;
31
+ using weight_t = weight_t_;
32
+ static constexpr int kNThreads = kNThreads_;
33
+ static constexpr int kNItems = kNItems_;
34
+ static constexpr int kNBytes = sizeof(input_t);
35
+ static_assert(kNBytes == 2 || kNBytes == 4);
36
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
37
+ static_assert(kNItems % kNElts == 0);
38
+ static constexpr int kNLoads = kNItems / kNElts;
39
+ static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
40
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
41
+ static constexpr bool kIsVariableB = kIsVariableB_;
42
+ static constexpr bool kIsVariableC = kIsVariableC_;
43
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
44
+ static constexpr bool kHasZ = kHasZ_;
45
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
46
+ // For complex this would lead to massive register spilling, so we keep it at 2.
47
+ static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
48
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
49
+ using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
50
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
51
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
53
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
54
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
55
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
56
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
57
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
58
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
59
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
60
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
61
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
62
+ using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
63
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
64
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
65
+ sizeof(typename BlockLoadVecT::TempStorage),
66
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
67
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
68
+ sizeof(typename BlockStoreT::TempStorage),
69
+ sizeof(typename BlockStoreVecT::TempStorage)});
70
+ static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
71
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
72
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
73
+ };
74
+
75
+ template<typename Ktraits>
76
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
77
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
78
+ constexpr bool kIsComplex = Ktraits::kIsComplex;
79
+ constexpr bool kIsVariableB = Ktraits::kIsVariableB;
80
+ constexpr bool kIsVariableC = Ktraits::kIsVariableC;
81
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
82
+ constexpr bool kHasZ = Ktraits::kHasZ;
83
+ constexpr int kNThreads = Ktraits::kNThreads;
84
+ constexpr int kNItems = Ktraits::kNItems;
85
+ using input_t = typename Ktraits::input_t;
86
+ using weight_t = typename Ktraits::weight_t;
87
+ using scan_t = typename Ktraits::scan_t;
88
+
89
+ // Shared memory.
90
+ extern __shared__ char smem_[];
91
+ // cast to lvalue reference of expected type
92
+ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
93
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
94
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
95
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
96
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
97
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
98
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
99
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
100
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
101
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
102
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
103
+ auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
104
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
105
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
106
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
107
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
108
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
109
+ weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
110
+
111
+ const int batch_id = blockIdx.x;
112
+ const int dim_id = blockIdx.y;
113
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
114
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
115
+ + dim_id * params.u_d_stride;
116
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
117
+ + dim_id * params.delta_d_stride;
118
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
119
+ + dim_id * params.dout_d_stride;
120
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
121
+ weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
122
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
123
+ weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
124
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
125
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
126
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
127
+ + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
128
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
129
+ + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
130
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
131
+ float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
132
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
133
+ float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
134
+ scan_t *x = params.x_ptr == nullptr
135
+ ? nullptr
136
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
137
+ float dD_val = 0;
138
+ float ddelta_bias_val = 0;
139
+
140
+ constexpr int kChunkSize = kNThreads * kNItems;
141
+ u += (params.n_chunks - 1) * kChunkSize;
142
+ delta += (params.n_chunks - 1) * kChunkSize;
143
+ dout += (params.n_chunks - 1) * kChunkSize;
144
+ Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
145
+ Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
146
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
147
+ input_t u_vals[kNItems];
148
+ input_t delta_vals_load[kNItems];
149
+ input_t dout_vals_load[kNItems];
150
+ __syncthreads();
151
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
152
+ u -= kChunkSize;
153
+ __syncthreads();
154
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
155
+ // Will reload delta at the same location if kDeltaSoftplus
156
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
157
+ __syncthreads();
158
+ load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
159
+ dout -= kChunkSize;
160
+
161
+ float dout_vals[kNItems], delta_vals[kNItems];
162
+ #pragma unroll
163
+ for (int i = 0; i < kNItems; ++i) {
164
+ dout_vals[i] = float(dout_vals_load[i]);
165
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
166
+ if constexpr (kDeltaSoftplus) {
167
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
168
+ }
169
+ }
170
+
171
+ if constexpr (kHasZ) {
172
+ input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
173
+ + dim_id * params.z_d_stride + chunk * kChunkSize;
174
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
175
+ + dim_id * params.out_d_stride + chunk * kChunkSize;
176
+ input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
177
+ + dim_id * params.dz_d_stride + chunk * kChunkSize;
178
+ input_t z_vals[kNItems], out_vals[kNItems];
179
+ __syncthreads();
180
+ load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
181
+ __syncthreads();
182
+ load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
183
+ float dz_vals[kNItems], z_silu_vals[kNItems];
184
+ #pragma unroll
185
+ for (int i = 0; i < kNItems; ++i) {
186
+ float z_val = z_vals[i];
187
+ float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
188
+ z_silu_vals[i] = z_val * z_sigmoid_val;
189
+ dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
190
+ * (1.0f + z_val * (1.0f - z_sigmoid_val));
191
+ dout_vals[i] *= z_silu_vals[i];
192
+ }
193
+ __syncthreads();
194
+ store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
195
+ if (params.out_z_ptr != nullptr) { // Recompute and store out_z
196
+ float out_z_vals[kNItems];
197
+ #pragma unroll
198
+ for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
199
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
200
+ // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
201
+ // }
202
+ input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
203
+ + dim_id * params.out_z_d_stride + chunk * kChunkSize;
204
+ __syncthreads();
205
+ store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
206
+ }
207
+ }
208
+
209
+ float du_vals[kNItems];
210
+ #pragma unroll
211
+ for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
212
+ #pragma unroll
213
+ for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
214
+
215
+ float ddelta_vals[kNItems] = {0};
216
+ __syncthreads();
217
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
218
+ const weight_t A_val = A[state_idx * params.A_dstate_stride];
219
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
220
+ weight_t A_scaled;
221
+ constexpr float kLog2e = M_LOG2E;
222
+ if constexpr (!kIsComplex) {
223
+ A_scaled = A_val * kLog2e;
224
+ } else {
225
+ A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
226
+ }
227
+ weight_t B_val, C_val;
228
+ weight_t B_vals[kNItems], C_vals[kNItems];
229
+ if constexpr (!kIsVariableB) {
230
+ B_val = B[state_idx * params.B_dstate_stride];
231
+ } else {
232
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
233
+ smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
234
+ }
235
+ if constexpr (!kIsVariableC) {
236
+ C_val = C[state_idx * params.C_dstate_stride];
237
+ } else {
238
+ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
239
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
240
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
241
+ }
242
+ // const weight_t A_val = smem_a[state_idx];
243
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
244
+ if constexpr (!kIsComplex) {
245
+ #pragma unroll
246
+ for (int i = 0; i < kNItems; ++i) {
247
+ const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
248
+ thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
249
+ if (i == 0) {
250
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
251
+ } else {
252
+ thread_reverse_data[i - 1].x = delta_a_exp;
253
+ }
254
+ thread_reverse_data[i].y = dout_vals[i] *
255
+ (!kIsVariableC
256
+ ? (!kIsVariableB ? B_val * C_val : C_val)
257
+ : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
258
+ }
259
+ __syncthreads();
260
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
261
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
262
+ : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
263
+ // Initialize running total
264
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
265
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
266
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
267
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
268
+ );
269
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
270
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
271
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
272
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
273
+ );
274
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
275
+ weight_t dA_val = 0, dBC_val = 0;
276
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
277
+ #pragma unroll
278
+ for (int i = 0; i < kNItems; ++i) {
279
+ const float dx = thread_reverse_data[i].y;
280
+ const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
281
+ du_vals[i] += ddelta_u * delta_vals[i];
282
+ const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
283
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
284
+ dA_val += dx * delta_vals[i] * a;
285
+ if constexpr (!kIsVariableB || !kIsVariableC) {
286
+ if constexpr (!kIsVariableB) { // dBC_val is dB_val
287
+ dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
288
+ } else { // dBC_val is dC_val
289
+ dBC_val += dout_vals[i] * thread_data[i].y;
290
+ }
291
+ }
292
+ if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
293
+ if constexpr (kIsVariableC) {
294
+ dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
295
+ }
296
+ }
297
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
298
+ if constexpr (kIsVariableB || kIsVariableC) {
299
+ if constexpr (kIsVariableB) {
300
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
301
+ }
302
+ if constexpr (kIsVariableC) {
303
+ auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
304
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
305
+ }
306
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
307
+ weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
308
+ weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
309
+ #pragma unroll
310
+ for (int i = 0; i < kNItems; ++i) {
311
+ if (i * kNThreads < seqlen_remaining) {
312
+ if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
313
+ if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
314
+ }
315
+ }
316
+ }
317
+ if constexpr (!kIsVariableB || !kIsVariableC) {
318
+ float2 dA_dBC_val = make_float2(dA_val, dBC_val);
319
+ dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
320
+ dA_val = dA_dBC_val.x;
321
+ if (threadIdx.x == 0) {
322
+ smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
323
+ }
324
+ } else {
325
+ dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
326
+ }
327
+ if (threadIdx.x == 0) {
328
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
329
+ }
330
+ } else {
331
+ #pragma unroll
332
+ for (int i = 0; i < kNItems; ++i) {
333
+ // Pytorch's implementation of complex exp (which calls thrust) is very slow
334
+ complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
335
+ weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
336
+ thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
337
+ if (i == 0) {
338
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
339
+ } else {
340
+ thread_reverse_data[i - 1].x = delta_a_exp.real_;
341
+ thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
342
+ }
343
+ complex_t dout_BC = 2 * dout_vals[i]
344
+ * conj(!kIsVariableC
345
+ ? (!kIsVariableB ? B_val * C_val : C_val)
346
+ : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
347
+ thread_reverse_data[i].z = dout_BC.real_;
348
+ thread_reverse_data[i].w = dout_BC.imag_;
349
+ }
350
+ __syncthreads();
351
+ complex_t delta_a_exp = threadIdx.x == kNThreads - 1
352
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
353
+ : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
354
+ thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
355
+ thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
356
+ // Initialize running total
357
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
358
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
359
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
360
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
361
+ );
362
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
363
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
364
+ Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
365
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
366
+ );
367
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
368
+ weight_t dA_val = 0, dBC_val = 0;
369
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
370
+ #pragma unroll
371
+ for (int i = 0; i < kNItems; ++i) {
372
+ complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
373
+ complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
374
+ float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
375
+ if constexpr (!kIsVariableB || !kIsVariableC) {
376
+ if constexpr (!kIsVariableB) { // dBC_val is dB_val
377
+ dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
378
+ } else { // dBC_val is dC_val
379
+ dBC_val += (2 * dout_vals[i]) * conj(x);
380
+ }
381
+ }
382
+ const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
383
+ du_vals[i] += ddelta_u * delta_vals[i];
384
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
385
+ dA_val += delta_vals[i] * dx * a_conj;
386
+ if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
387
+ if constexpr (kIsVariableC) {
388
+ dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
389
+ }
390
+ }
391
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
392
+ if constexpr (kIsVariableB || kIsVariableC) {
393
+ float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
394
+ if constexpr (kIsVariableB) {
395
+ #pragma unroll
396
+ for (int i = 0; i < kNItems; ++i) {
397
+ dB_vals_f[i * 2] = dB_vals[i].real_;
398
+ dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
399
+ }
400
+ Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
401
+ }
402
+ if constexpr (kIsVariableC) {
403
+ #pragma unroll
404
+ for (int i = 0; i < kNItems; ++i) {
405
+ dC_vals_f[i * 2] = dC_vals[i].real_;
406
+ dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
407
+ }
408
+ auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
409
+ Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
410
+ }
411
+ const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
412
+ float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
413
+ float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
414
+ #pragma unroll
415
+ for (int i = 0; i < kNItems * 2; ++i) {
416
+ if (i * kNThreads < seqlen_remaining) {
417
+ if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
418
+ if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
419
+ }
420
+ }
421
+ }
422
+ if constexpr (!kIsVariableB || !kIsVariableC) {
423
+ float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
424
+ dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
425
+ dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
426
+ dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
427
+ if (threadIdx.x == 0) {
428
+ smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
429
+ }
430
+ } else {
431
+ dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
432
+ }
433
+ if (threadIdx.x == 0) {
434
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
435
+ }
436
+ }
437
+ }
438
+
439
+ if constexpr (kDeltaSoftplus) {
440
+ __syncthreads();
441
+ input_t delta_vals_load[kNItems];
442
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
443
+ delta -= kChunkSize;
444
+ #pragma unroll
445
+ for (int i = 0; i < kNItems; ++i) {
446
+ float delta_val = float(delta_vals_load[i]) + delta_bias;
447
+ float delta_val_neg_exp = expf(-delta_val);
448
+ ddelta_vals[i] = delta_val <= 20.f
449
+ ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
450
+ : ddelta_vals[i];
451
+ }
452
+ }
453
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
454
+
455
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
456
+ + dim_id * params.du_d_stride + chunk * kChunkSize;
457
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
458
+ + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
459
+ __syncthreads();
460
+ store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
461
+ __syncthreads();
462
+ store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
463
+
464
+ Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
465
+ Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
466
+ }
467
+ if (params.dD_ptr != nullptr) {
468
+ dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
469
+ if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
470
+ }
471
+ if (params.ddelta_bias_ptr != nullptr) {
472
+ __syncthreads();
473
+ ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
474
+ if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
475
+ }
476
+ for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
477
+ gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
478
+ weight_t dBC_val;
479
+ if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
480
+ if constexpr (!kIsVariableB) {
481
+ gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
482
+ !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
483
+ }
484
+ if constexpr (!kIsVariableC) {
485
+ gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
486
+ !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
487
+ }
488
+ }
489
+ }
490
+
491
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
492
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
493
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
494
+ BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
495
+ BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
496
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
497
+ BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
498
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
499
+ // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
500
+ // TODO: check this
501
+ constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
502
+ // printf("smem_size = %d\n", kSmemSize);
503
+ dim3 grid(params.batch, params.dim);
504
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
505
+ if (kSmemSize >= 48 * 1024) {
506
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
507
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
508
+ }
509
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
510
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
511
+ });
512
+ });
513
+ });
514
+ });
515
+ });
516
+ }
517
+
518
+ template<typename input_t, typename weight_t>
519
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
520
+ if (params.seqlen <= 128) {
521
+ selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
522
+ } else if (params.seqlen <= 256) {
523
+ selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
524
+ } else if (params.seqlen <= 512) {
525
+ selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
526
+ } else if (params.seqlen <= 1024) {
527
+ selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
528
+ } else {
529
+ selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
530
+ }
531
+ }
mamba/csrc/selective_scan/selective_scan_common.h ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cuda_bf16.h>
8
+ #include <cuda_fp16.h>
9
+ #include <c10/util/complex.h> // For scalar_value_type
10
+
11
+ #define MAX_DSTATE 256
12
+
13
+ using complex_t = c10::complex<float>;
14
+
15
+ inline __device__ float2 operator+(const float2 & a, const float2 & b){
16
+ return {a.x + b.x, a.y + b.y};
17
+ }
18
+
19
+ inline __device__ float3 operator+(const float3 &a, const float3 &b) {
20
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
21
+ }
22
+
23
+ inline __device__ float4 operator+(const float4 & a, const float4 & b){
24
+ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
25
+ }
26
+
27
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
28
+
29
+ template<int BYTES> struct BytesToType {};
30
+
31
+ template<> struct BytesToType<16> {
32
+ using Type = uint4;
33
+ static_assert(sizeof(Type) == 16);
34
+ };
35
+
36
+ template<> struct BytesToType<8> {
37
+ using Type = uint64_t;
38
+ static_assert(sizeof(Type) == 8);
39
+ };
40
+
41
+ template<> struct BytesToType<4> {
42
+ using Type = uint32_t;
43
+ static_assert(sizeof(Type) == 4);
44
+ };
45
+
46
+ template<> struct BytesToType<2> {
47
+ using Type = uint16_t;
48
+ static_assert(sizeof(Type) == 2);
49
+ };
50
+
51
+ template<> struct BytesToType<1> {
52
+ using Type = uint8_t;
53
+ static_assert(sizeof(Type) == 1);
54
+ };
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
57
+
58
+ template<typename scalar_t, int N>
59
+ struct Converter{
60
+ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
61
+ #pragma unroll
62
+ for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
63
+ }
64
+ };
65
+
66
+ template<int N>
67
+ struct Converter<at::Half, N>{
68
+ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
69
+ static_assert(N % 2 == 0);
70
+ auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
71
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
72
+ #pragma unroll
73
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
74
+ }
75
+ };
76
+
77
+ #if __CUDA_ARCH__ >= 800
78
+ template<int N>
79
+ struct Converter<at::BFloat16, N>{
80
+ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
81
+ static_assert(N % 2 == 0);
82
+ auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
83
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
84
+ #pragma unroll
85
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
86
+ }
87
+ };
88
+ #endif
89
+
90
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
91
+
92
+ // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
93
+ // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
94
+ __device__ __forceinline__ complex_t cexp2f(complex_t z) {
95
+ float t = exp2f(z.real_);
96
+ float c, s;
97
+ sincosf(z.imag_, &s, &c);
98
+ return complex_t(c * t, s * t);
99
+ }
100
+
101
+ __device__ __forceinline__ complex_t cexpf(complex_t z) {
102
+ float t = expf(z.real_);
103
+ float c, s;
104
+ sincosf(z.imag_, &s, &c);
105
+ return complex_t(c * t, s * t);
106
+ }
107
+
108
+ template<typename scalar_t> struct SSMScanOp;
109
+
110
+ template<>
111
+ struct SSMScanOp<float> {
112
+ __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
113
+ return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
114
+ }
115
+ };
116
+
117
+ template<>
118
+ struct SSMScanOp<complex_t> {
119
+ __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
120
+ complex_t a0 = complex_t(ab0.x, ab0.y);
121
+ complex_t b0 = complex_t(ab0.z, ab0.w);
122
+ complex_t a1 = complex_t(ab1.x, ab1.y);
123
+ complex_t b1 = complex_t(ab1.z, ab1.w);
124
+ complex_t out_a = a1 * a0;
125
+ complex_t out_b = a1 * b0 + b1;
126
+ return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
127
+ }
128
+ };
129
+
130
+ // A stateful callback functor that maintains a running prefix to be applied
131
+ // during consecutive scan operations.
132
+ template <typename scalar_t> struct SSMScanPrefixCallbackOp {
133
+ using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
134
+ scan_t running_prefix;
135
+ // Constructor
136
+ __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
137
+ // Callback operator to be entered by the first warp of threads in the block.
138
+ // Thread-0 is responsible for returning a value for seeding the block-wide scan.
139
+ __device__ scan_t operator()(scan_t block_aggregate) {
140
+ scan_t old_prefix = running_prefix;
141
+ running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
142
+ return old_prefix;
143
+ }
144
+ };
145
+
146
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
147
+
148
+ template<typename Ktraits>
149
+ inline __device__ void load_input(typename Ktraits::input_t *u,
150
+ typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
151
+ typename Ktraits::BlockLoadT::TempStorage &smem_load,
152
+ int seqlen) {
153
+ if constexpr (Ktraits::kIsEvenLen) {
154
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
155
+ using vec_t = typename Ktraits::vec_t;
156
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(
157
+ reinterpret_cast<vec_t*>(u),
158
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
159
+ );
160
+ } else {
161
+ Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
162
+ }
163
+ }
164
+
165
+ template<typename Ktraits>
166
+ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
167
+ typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
168
+ typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
169
+ int seqlen) {
170
+ constexpr int kNItems = Ktraits::kNItems;
171
+ if constexpr (!Ktraits::kIsComplex) {
172
+ typename Ktraits::input_t B_vals_load[kNItems];
173
+ if constexpr (Ktraits::kIsEvenLen) {
174
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
175
+ using vec_t = typename Ktraits::vec_t;
176
+ Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
177
+ reinterpret_cast<vec_t*>(Bvar),
178
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
179
+ );
180
+ } else {
181
+ Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
182
+ }
183
+ // #pragma unroll
184
+ // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
185
+ Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
186
+ } else {
187
+ typename Ktraits::input_t B_vals_load[kNItems * 2];
188
+ if constexpr (Ktraits::kIsEvenLen) {
189
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
190
+ using vec_t = typename Ktraits::vec_t;
191
+ Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
192
+ reinterpret_cast<vec_t*>(Bvar),
193
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
194
+ );
195
+ } else {
196
+ Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
197
+ }
198
+ #pragma unroll
199
+ for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
200
+ }
201
+ }
202
+
203
+ template<typename Ktraits>
204
+ inline __device__ void store_output(typename Ktraits::input_t *out,
205
+ const float (&out_vals)[Ktraits::kNItems],
206
+ typename Ktraits::BlockStoreT::TempStorage &smem_store,
207
+ int seqlen) {
208
+ typename Ktraits::input_t write_vals[Ktraits::kNItems];
209
+ #pragma unroll
210
+ for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
211
+ if constexpr (Ktraits::kIsEvenLen) {
212
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
213
+ using vec_t = typename Ktraits::vec_t;
214
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(
215
+ reinterpret_cast<vec_t*>(out),
216
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
217
+ );
218
+ } else {
219
+ Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
220
+ }
221
+ }
mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);
mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #include <cub/block/block_load.cuh>
12
+ #include <cub/block/block_store.cuh>
13
+ #include <cub/block/block_scan.cuh>
14
+
15
+ #include "selective_scan.h"
16
+ #include "selective_scan_common.h"
17
+ #include "static_switch.h"
18
+
19
+ template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
20
+ bool kIsVariableB_, bool kIsVariableC_,
21
+ bool kHasZ_, typename input_t_, typename weight_t_>
22
+ struct Selective_Scan_fwd_kernel_traits {
23
+ static_assert(kNItems_ % 4 == 0);
24
+ using input_t = input_t_;
25
+ using weight_t = weight_t_;
26
+ static constexpr int kNThreads = kNThreads_;
27
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
28
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
29
+ static constexpr int kNItems = kNItems_;
30
+ static constexpr int kNRows = kNRows_;
31
+ static constexpr int kNBytes = sizeof(input_t);
32
+ static_assert(kNBytes == 2 || kNBytes == 4);
33
+ static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
34
+ static_assert(kNItems % kNElts == 0);
35
+ static constexpr int kNLoads = kNItems / kNElts;
36
+ static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
37
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
38
+ static constexpr bool kIsVariableB = kIsVariableB_;
39
+ static constexpr bool kIsVariableC = kIsVariableC_;
40
+ static constexpr bool kHasZ = kHasZ_;
41
+
42
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
43
+
44
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
45
+ using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
46
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
47
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
48
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
49
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
50
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
51
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
52
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
54
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
55
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
56
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
57
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
58
+ static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
59
+ sizeof(typename BlockLoadVecT::TempStorage),
60
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
61
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
62
+ sizeof(typename BlockStoreT::TempStorage),
63
+ sizeof(typename BlockStoreVecT::TempStorage)});
64
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
65
+ };
66
+
67
+ template<typename Ktraits>
68
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
69
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
70
+ constexpr bool kIsComplex = Ktraits::kIsComplex;
71
+ constexpr bool kIsVariableB = Ktraits::kIsVariableB;
72
+ constexpr bool kIsVariableC = Ktraits::kIsVariableC;
73
+ constexpr bool kHasZ = Ktraits::kHasZ;
74
+ constexpr int kNThreads = Ktraits::kNThreads;
75
+ constexpr int kNItems = Ktraits::kNItems;
76
+ constexpr int kNRows = Ktraits::kNRows;
77
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
78
+ using input_t = typename Ktraits::input_t;
79
+ using weight_t = typename Ktraits::weight_t;
80
+ using scan_t = typename Ktraits::scan_t;
81
+
82
+ // Shared memory.
83
+ extern __shared__ char smem_[];
84
+ // cast to lvalue reference of expected type
85
+ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
86
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
87
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
88
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
89
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
90
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
91
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
92
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
93
+ // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
94
+ // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
95
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
96
+
97
+ const int batch_id = blockIdx.x;
98
+ const int dim_id = blockIdx.y;
99
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
100
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
101
+ + dim_id * kNRows * params.u_d_stride;
102
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
103
+ + dim_id * kNRows * params.delta_d_stride;
104
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
105
+ weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
106
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
107
+ weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
108
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
109
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
110
+
111
+ float D_val[kNRows] = {0};
112
+ if (params.D_ptr != nullptr) {
113
+ #pragma unroll
114
+ for (int r = 0; r < kNRows; ++r) {
115
+ D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
116
+ }
117
+ }
118
+ float delta_bias[kNRows] = {0};
119
+ if (params.delta_bias_ptr != nullptr) {
120
+ #pragma unroll
121
+ for (int r = 0; r < kNRows; ++r) {
122
+ delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
123
+ }
124
+ }
125
+
126
+ // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
127
+ // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
128
+ // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
129
+ // }
130
+
131
+ constexpr int kChunkSize = kNThreads * kNItems;
132
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
133
+ input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
134
+ __syncthreads();
135
+ #pragma unroll
136
+ for (int r = 0; r < kNRows; ++r) {
137
+ if constexpr (!kDirectIO) {
138
+ if (r > 0) { __syncthreads(); }
139
+ }
140
+ load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
141
+ if constexpr (!kDirectIO) { __syncthreads(); }
142
+ load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
143
+ }
144
+ u += kChunkSize;
145
+ delta += kChunkSize;
146
+
147
+ float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
148
+ #pragma unroll
149
+ for (int r = 0; r < kNRows; ++r) {
150
+ #pragma unroll
151
+ for (int i = 0; i < kNItems; ++i) {
152
+ float u_val = float(u_vals[r][i]);
153
+ delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
154
+ if (params.delta_softplus) {
155
+ delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
156
+ }
157
+ delta_u_vals[r][i] = delta_vals[r][i] * u_val;
158
+ out_vals[r][i] = D_val[r] * u_val;
159
+ }
160
+ }
161
+
162
+ __syncthreads();
163
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
164
+ weight_t A_val[kNRows];
165
+ #pragma unroll
166
+ for (int r = 0; r < kNRows; ++r) {
167
+ A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
168
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
169
+ constexpr float kLog2e = M_LOG2E;
170
+ if constexpr (!kIsComplex) {
171
+ A_val[r] *= kLog2e;
172
+ } else {
173
+ A_val[r].real_ *= kLog2e;
174
+ }
175
+ }
176
+ // This variable holds B * C if both B and C are constant across seqlen. If only B varies
177
+ // across seqlen, this holds C. If only C varies across seqlen, this holds B.
178
+ // If both B and C vary, this is unused.
179
+ weight_t BC_val[kNRows];
180
+ weight_t B_vals[kNItems], C_vals[kNItems];
181
+ if constexpr (kIsVariableB) {
182
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
183
+ smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
184
+ if constexpr (!kIsVariableC) {
185
+ #pragma unroll
186
+ for (int r = 0; r < kNRows; ++r) {
187
+ BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
188
+ }
189
+ }
190
+ }
191
+ if constexpr (kIsVariableC) {
192
+ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
193
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
194
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
195
+ if constexpr (!kIsVariableB) {
196
+ #pragma unroll
197
+ for (int r = 0; r < kNRows; ++r) {
198
+ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
199
+ }
200
+ }
201
+ }
202
+ if constexpr (!kIsVariableB && !kIsVariableC) {
203
+ #pragma unroll
204
+ for (int r = 0; r < kNRows; ++r) {
205
+ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
206
+ }
207
+ }
208
+
209
+ #pragma unroll
210
+ for (int r = 0; r < kNRows; ++r) {
211
+ if (r > 0) { __syncthreads(); } // Scan could be using the same smem
212
+ scan_t thread_data[kNItems];
213
+ #pragma unroll
214
+ for (int i = 0; i < kNItems; ++i) {
215
+ if constexpr (!kIsComplex) {
216
+ thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
217
+ !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
218
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
219
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
220
+ thread_data[i] = make_float2(1.f, 0.f);
221
+ }
222
+ }
223
+ } else {
224
+ // Pytorch's implementation of complex exp (which calls thrust) is very slow
225
+ complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
226
+ weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
227
+ thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
228
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
229
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
230
+ thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
231
+ }
232
+ }
233
+ }
234
+ }
235
+ // Initialize running total
236
+ scan_t running_prefix;
237
+ if constexpr (!kIsComplex) {
238
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
239
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
240
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
241
+ } else {
242
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
243
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
244
+ }
245
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
246
+ Ktraits::BlockScanT(smem_scan).InclusiveScan(
247
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
248
+ );
249
+ // There's a syncthreads in the scan op, so we don't need to sync here.
250
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
251
+ if (threadIdx.x == 0) {
252
+ smem_running_prefix[state_idx] = prefix_op.running_prefix;
253
+ x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
254
+ }
255
+ #pragma unroll
256
+ for (int i = 0; i < kNItems; ++i) {
257
+ const weight_t C_val = !kIsVariableC
258
+ ? BC_val[r]
259
+ : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
260
+ if constexpr (!kIsComplex) {
261
+ out_vals[r][i] += thread_data[i].y * C_val;
262
+ } else {
263
+ out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
264
+ }
265
+ }
266
+ }
267
+ }
268
+
269
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
270
+ + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
271
+ __syncthreads();
272
+ #pragma unroll
273
+ for (int r = 0; r < kNRows; ++r) {
274
+ if constexpr (!kDirectIO) {
275
+ if (r > 0) { __syncthreads(); }
276
+ }
277
+ store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
278
+ }
279
+
280
+ if constexpr (kHasZ) {
281
+ input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
282
+ + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
283
+ input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
284
+ + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
285
+ #pragma unroll
286
+ for (int r = 0; r < kNRows; ++r) {
287
+ input_t z_vals[kNItems];
288
+ __syncthreads();
289
+ load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
290
+ #pragma unroll
291
+ for (int i = 0; i < kNItems; ++i) {
292
+ float z_val = z_vals[i];
293
+ out_vals[r][i] *= z_val / (1 + expf(-z_val));
294
+ }
295
+ __syncthreads();
296
+ store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
297
+ }
298
+ }
299
+
300
+ Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
301
+ Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
302
+ }
303
+ }
304
+
305
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
306
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
307
+ // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
308
+ // processing 1 row.
309
+ constexpr int kNRows = 1;
310
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
311
+ BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
312
+ BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
313
+ BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
314
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
315
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
316
+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
317
+ // printf("smem_size = %d\n", kSmemSize);
318
+ dim3 grid(params.batch, params.dim / kNRows);
319
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
320
+ if (kSmemSize >= 48 * 1024) {
321
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
322
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
323
+ }
324
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
325
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
326
+ });
327
+ });
328
+ });
329
+ });
330
+ }
331
+
332
+ template<typename input_t, typename weight_t>
333
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
334
+ if (params.seqlen <= 128) {
335
+ selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
336
+ } else if (params.seqlen <= 256) {
337
+ selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
338
+ } else if (params.seqlen <= 512) {
339
+ selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
340
+ } else if (params.seqlen <= 1024) {
341
+ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
342
+ } else {
343
+ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
344
+ }
345
+ }
mamba/csrc/selective_scan/static_switch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
+ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
+
4
+ #pragma once
5
+
6
+ /// @param COND - a boolean expression to switch by
7
+ /// @param CONST_NAME - a name given for the constexpr bool variable.
8
+ /// @param ... - code to execute for true and false
9
+ ///
10
+ /// Usage:
11
+ /// ```
12
+ /// BOOL_SWITCH(flag, BoolConst, [&] {
13
+ /// some_function<BoolConst>(...);
14
+ /// });
15
+ /// ```
16
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
+ [&] { \
18
+ if (COND) { \
19
+ constexpr bool CONST_NAME = true; \
20
+ return __VA_ARGS__(); \
21
+ } else { \
22
+ constexpr bool CONST_NAME = false; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ }()
mamba/csrc/selective_scan/uninitialized_copy.cuh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Redistribution and use in source and binary forms, with or without
5
+ * modification, are permitted provided that the following conditions are met:
6
+ * * Redistributions of source code must retain the above copyright
7
+ * notice, this list of conditions and the following disclaimer.
8
+ * * Redistributions in binary form must reproduce the above copyright
9
+ * notice, this list of conditions and the following disclaimer in the
10
+ * documentation and/or other materials provided with the distribution.
11
+ * * Neither the name of the NVIDIA CORPORATION nor the
12
+ * names of its contributors may be used to endorse or promote products
13
+ * derived from this software without specific prior written permission.
14
+ *
15
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18
+ * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ *
26
+ ******************************************************************************/
27
+
28
+ #pragma once
29
+
30
+ #include <cub/config.cuh>
31
+
32
+ #include <cuda/std/type_traits>
33
+
34
+
35
+ namespace detail
36
+ {
37
+
38
+ #if defined(_NVHPC_CUDA)
39
+ template <typename T, typename U>
40
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41
+ {
42
+ // NVBug 3384810
43
+ new (ptr) T(::cuda::std::forward<U>(val));
44
+ }
45
+ #else
46
+ template <typename T,
47
+ typename U,
48
+ typename ::cuda::std::enable_if<
49
+ ::cuda::std::is_trivially_copyable<T>::value,
50
+ int
51
+ >::type = 0>
52
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53
+ {
54
+ *ptr = ::cuda::std::forward<U>(val);
55
+ }
56
+
57
+ template <typename T,
58
+ typename U,
59
+ typename ::cuda::std::enable_if<
60
+ !::cuda::std::is_trivially_copyable<T>::value,
61
+ int
62
+ >::type = 0>
63
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64
+ {
65
+ new (ptr) T(::cuda::std::forward<U>(val));
66
+ }
67
+ #endif
68
+
69
+ } // namespace detail