diff --git a/README.md b/README.md
index 4f7f9af39121839bb024750cfabb6c071a3c07f1..b09c625b97e427efca23e21e9646168ff37af838 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
---
title: VideoMamba
-emoji: 🌖
-colorFrom: gray
-colorTo: pink
+emoji: 🐍
+colorFrom: blue
+colorTo: green
sdk: gradio
-sdk_version: 4.21.0
+sdk_version: 3.29.0
app_file: app.py
pinned: false
license: apache-2.0
diff --git a/__pycache__/imagenet_class_index.cpython-310.pyc b/__pycache__/imagenet_class_index.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d64027ffc7eb3f1fa4a0dee54322c7bc9c0d8090
Binary files /dev/null and b/__pycache__/imagenet_class_index.cpython-310.pyc differ
diff --git a/__pycache__/kinetics_class_index.cpython-310.pyc b/__pycache__/kinetics_class_index.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a60bb523534163496896bee5e1a288480abdba6
Binary files /dev/null and b/__pycache__/kinetics_class_index.cpython-310.pyc differ
diff --git a/__pycache__/transforms.cpython-310.pyc b/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec70fceee33c7c1ed22a17f285e7eff49b10ad2e
Binary files /dev/null and b/__pycache__/transforms.cpython-310.pyc differ
diff --git a/__pycache__/videomamba_image.cpython-310.pyc b/__pycache__/videomamba_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63fb562f07d2d0db094219e15289d80e2f44919d
Binary files /dev/null and b/__pycache__/videomamba_image.cpython-310.pyc differ
diff --git a/__pycache__/videomamba_video.cpython-310.pyc b/__pycache__/videomamba_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b8abea20d021b0cb8f4f7bd195f99a4bf421e55
Binary files /dev/null and b/__pycache__/videomamba_video.cpython-310.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dde9587a53c230d455141a7e85e42ae260ba68a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,180 @@
+import os
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+import torchvision.transforms as T
+from PIL import Image
+from decord import VideoReader
+from decord import cpu
+from videomamba_image import videomamba_image_tiny
+from videomamba_video import videomamba_tiny
+from kinetics_class_index import kinetics_classnames
+from imagenet_class_index import imagenet_classnames
+from transforms import (
+ GroupNormalize, GroupScale, GroupCenterCrop,
+ Stack, ToTorchFormatTensor
+)
+
+import gradio as gr
+from huggingface_hub import hf_hub_download
+
+
+# install packages for mamba
+os.system("bash install.sh")
+
+
+# Device on which to run the model
+# Set to cuda to load on GPU
+device = "cuda"
+model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_k400_f16_res224.pth")
+model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
+# Pick a pretrained model
+model_video = videomamba_tiny(num_classes=400, num_frames=16)
+video_sd = torch.load(model_video_path, map_location='cpu')
+model_video.load_state_dict(video_sd)
+model_image = videomamba_image_tiny()
+image_sd = torch.load(model_image_path, map_location='cpu')
+model_image.load_state_dict(image_sd['model'])
+# Set to eval mode and move to desired device
+model_video = model_video.to(device).eval()
+model_image = model_image.to(device).eval()
+
+# Create an id to label name mapping
+kinetics_id_to_classname = {}
+for k, v in kinetics_classnames.items():
+ kinetics_id_to_classname[k] = v
+imagenet_id_to_classname = {}
+for k, v in imagenet_classnames.items():
+ imagenet_id_to_classname[k] = v[1]
+
+
+def get_index(num_frames, num_segments=8):
+ seg_size = float(num_frames - 1) / num_segments
+ start = int(seg_size / 2)
+ offsets = np.array([
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
+ ])
+ return offsets
+
+
+def load_video(video_path):
+ vr = VideoReader(video_path, ctx=cpu(0))
+ num_frames = len(vr)
+ frame_indices = get_index(num_frames, 16)
+
+ # transform
+ crop_size = 160
+ scale_size = 160
+ input_mean = [0.485, 0.456, 0.406]
+ input_std = [0.229, 0.224, 0.225]
+
+ transform = T.Compose([
+ GroupScale(int(scale_size)),
+ GroupCenterCrop(crop_size),
+ Stack(),
+ ToTorchFormatTensor(),
+ GroupNormalize(input_mean, input_std)
+ ])
+
+ images_group = list()
+ for frame_index in frame_indices:
+ img = Image.fromarray(vr[frame_index].asnumpy())
+ images_group.append(img)
+ torch_imgs = transform(images_group)
+ return torch_imgs
+
+
+def inference_video(video):
+ vid = load_video(video)
+
+ # The model expects inputs of shape: B x C x H x W
+ TC, H, W = vid.shape
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
+
+ with torch.no_grad():
+ prediction = model_video(inputs.to(device))
+ prediction = F.softmax(prediction, dim=1).flatten()
+
+ return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
+
+
+def set_example_video(example: list) -> dict:
+ return gr.Video.update(value=example[0])
+
+
+def inference_image(img):
+ image = img
+ image_transform = T.Compose(
+ [
+ T.Resize(224),
+ T.CenterCrop(224),
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ]
+ )
+ image = image_transform(image)
+
+ # The model expects inputs of shape: B x C x H x W
+ image = image.unsqueeze(0)
+
+ with torch.no_grad():
+ prediction = model_image(image.to(device))
+ prediction = F.softmax(prediction, dim=1).flatten()
+
+ return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
+
+
+def set_example_image(example: list) -> dict:
+ return gr.Image.update(value=example[0])
+
+
+demo = gr.Blocks()
+with demo:
+ gr.Markdown(
+ """
+ # VideoMamba-Ti
+ Gradio demo for VideoMamba: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.
+ """
+ )
+
+ with gr.Tab("Video"):
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ input_video = gr.Video(label='Input Video').style(height=360)
+ with gr.Row():
+ submit_video_button = gr.Button('Submit')
+ with gr.Column():
+ label_video = gr.Label(num_top_classes=5)
+ with gr.Row():
+ example_videos = gr.Dataset(components=[input_video], samples=[['./videos/hitting_baseball.mp4'], ['./videos/hoverboarding.mp4'], ['./videos/yoga.mp4']])
+
+ with gr.Tab("Image"):
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ input_image = gr.Image(label='Input Image', type='pil').style(height=360)
+ with gr.Row():
+ submit_image_button = gr.Button('Submit')
+ with gr.Column():
+ label_image = gr.Label(num_top_classes=5)
+ with gr.Row():
+ example_images = gr.Dataset(components=[input_image], samples=[['./images/cat.png'], ['./images/dog.png'], ['./images/panda.png']])
+
+ gr.Markdown(
+ """
+
VideoMamba: State Space Model for Efficient Video Understanding | Github Repo
+ """
+ )
+
+ submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video)
+ example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos.components)
+ submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image)
+ example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
+
+demo.launch(enable_queue=True)
+# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
\ No newline at end of file
diff --git a/causal-conv1d/AUTHORS b/causal-conv1d/AUTHORS
new file mode 100644
index 0000000000000000000000000000000000000000..88193855314bb723ced1860384e417954f559700
--- /dev/null
+++ b/causal-conv1d/AUTHORS
@@ -0,0 +1 @@
+Tri Dao, tri@tridao.me
diff --git a/causal-conv1d/LICENSE b/causal-conv1d/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5860e4b33f3d9d85fc636137c559331d51783a5b
--- /dev/null
+++ b/causal-conv1d/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/causal-conv1d/README.md b/causal-conv1d/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e905425a650d77c5c4854e4c4a261778c4d2690
--- /dev/null
+++ b/causal-conv1d/README.md
@@ -0,0 +1 @@
+# Causal depthwise conv1d in CUDA with a PyTorch interface
diff --git a/causal-conv1d/causal_conv1d/__init__.py b/causal-conv1d/causal_conv1d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4d610a1e557cabd723fb6e33438f03c5c4bf66
--- /dev/null
+++ b/causal-conv1d/causal_conv1d/__init__.py
@@ -0,0 +1,3 @@
+__version__ = "1.0.0"
+
+from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
diff --git a/causal-conv1d/causal_conv1d/causal_conv1d_interface.py b/causal-conv1d/causal_conv1d/causal_conv1d_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..f66143c39e767572ca12112811a384239b8beb63
--- /dev/null
+++ b/causal-conv1d/causal_conv1d/causal_conv1d_interface.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2023, Tri Dao.
+
+import torch
+import torch.nn.functional as F
+
+
+import causal_conv1d_cuda
+
+
+class CausalConv1dFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, weight, bias=None, activation=None):
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(2) != 1 and x.stride(1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+ ctx.save_for_backward(x, weight, bias)
+ ctx.activation = activation in ["silu", "swish"]
+ out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
+ return out
+
+ @staticmethod
+ def backward(ctx, dout):
+ x, weight, bias = ctx.saved_tensors
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ # Here we just pass in None and dx will be allocated in the C++ code.
+ dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
+ x, weight, bias, dout, None, ctx.activation
+ )
+ return dx, dweight, dbias if bias is not None else None, None
+
+
+def causal_conv1d_fn(x, weight, bias=None, activation=None):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+ activation: either None or "silu" or "swish"
+
+ out: (batch, dim, seqlen)
+ """
+ return CausalConv1dFn.apply(x, weight, bias, activation)
+
+
+def causal_conv1d_ref(x, weight, bias=None, activation=None):
+ """
+ x: (batch, dim, seqlen)
+ weight: (dim, width)
+ bias: (dim,)
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ x = x.to(weight.dtype)
+ seqlen = x.shape[-1]
+ dim, width = weight.shape
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
+ out = out[..., :seqlen]
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
+
+
+def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
+ """
+ x: (batch, dim)
+ conv_state: (batch, dim, width)
+ weight: (dim, width)
+ bias: (dim,)
+
+ out: (batch, dim)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ activation = activation in ["silu", "swish"]
+ return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
+
+
+def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
+ """
+ x: (batch, dim)
+ conv_state: (batch, dim, width)
+ weight: (dim, width)
+ bias: (dim,)
+
+ out: (batch, dim)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ dtype_in = x.dtype
+ batch, dim = x.shape
+ width = weight.shape[1]
+ assert conv_state.shape == (batch, dim, width)
+ assert weight.shape == (dim, width)
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
+ conv_state[:, :, -1] = x
+ out = torch.sum(conv_state * weight, dim=-1) # (B D)
+ if bias is not None:
+ out += bias
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/causal-conv1d/csrc/causal_conv1d.cpp b/causal-conv1d/csrc/causal_conv1d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1c80516ac8599d4d80910a1d4d85c4c435cf1e4f
--- /dev/null
+++ b/causal-conv1d/csrc/causal_conv1d.cpp
@@ -0,0 +1,333 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#include
+#include
+#include
+#include
+
+#include "causal_conv1d.h"
+
+#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+
+#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
+ if (ITYPE == at::ScalarType::Half) { \
+ using input_t = at::Half; \
+ __VA_ARGS__(); \
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
+ using input_t = at::BFloat16; \
+ __VA_ARGS__(); \
+ } else if (ITYPE == at::ScalarType::Float) { \
+ using input_t = float; \
+ __VA_ARGS__(); \
+ } else { \
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
+ }
+
+#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
+ if (WTYPE == at::ScalarType::Half) { \
+ using weight_t = at::Half; \
+ __VA_ARGS__(); \
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
+ using weight_t = at::BFloat16; \
+ __VA_ARGS__(); \
+ } else if (WTYPE == at::ScalarType::Float) { \
+ using weight_t = float; \
+ __VA_ARGS__(); \
+ } else { \
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
+ }
+
+template
+void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template
+void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+
+template
+void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template
+void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+
+template
+void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+
+void set_conv_params_fwd(ConvParamsBase ¶ms,
+ // sizes
+ const size_t batch,
+ const size_t dim,
+ const size_t seqlen,
+ const size_t width,
+ // device pointers
+ const at::Tensor x,
+ const at::Tensor weight,
+ const at::Tensor out,
+ void* bias_ptr,
+ bool silu_activation) {
+
+ // Reset the parameters
+ memset(¶ms, 0, sizeof(params));
+
+ params.batch = batch;
+ params.dim = dim;
+ params.seqlen = seqlen;
+ params.width = width;
+
+ params.silu_activation = silu_activation;
+
+ // Set the pointers and strides.
+ params.x_ptr = x.data_ptr();
+ params.weight_ptr = weight.data_ptr();
+ params.bias_ptr = bias_ptr;
+ params.out_ptr = out.data_ptr();
+ // All stride are in elements, not bytes.
+ params.x_batch_stride = x.stride(0);
+ params.x_c_stride = x.stride(1);
+ params.x_l_stride = x.stride(-1);
+ params.weight_c_stride = weight.stride(0);
+ params.weight_width_stride = weight.stride(1);
+ params.out_batch_stride = out.stride(0);
+ params.out_c_stride = out.stride(1);
+ params.out_l_stride = out.stride(-1);
+}
+
+
+void set_conv_params_bwd(ConvParamsBwd ¶ms,
+ // sizes
+ const size_t batch,
+ const size_t dim,
+ const size_t seqlen,
+ const size_t width,
+ // device pointers
+ const at::Tensor x,
+ const at::Tensor weight,
+ void* bias_ptr,
+ const at::Tensor dout,
+ const at::Tensor dx,
+ const at::Tensor dweight,
+ void* dbias_ptr,
+ bool silu_activation) {
+ // Pass in "dout" instead of "out", we're not gonna use "out" at all.
+ set_conv_params_fwd(params, batch, dim, seqlen, width,
+ x, weight, dout, bias_ptr, silu_activation);
+
+ // Set the pointers and strides.
+ params.dout_ptr = dout.data_ptr();
+ params.dx_ptr = dx.data_ptr();
+ params.dweight_ptr = dweight.data_ptr();
+ params.dbias_ptr = dbias_ptr;
+ // All stride are in elements, not bytes.
+ params.dout_batch_stride = dout.stride(0);
+ params.dout_c_stride = dout.stride(1);
+ params.dout_l_stride = dout.stride(2);
+ params.dweight_c_stride = dweight.stride(0);
+ params.dweight_width_stride = dweight.stride(1);
+ params.dx_batch_stride = dx.stride(0);
+ params.dx_c_stride = dx.stride(1);
+ params.dx_l_stride = dx.stride(2);
+}
+
+at::Tensor
+causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
+ const c10::optional &bias_,
+ bool silu_activation) {
+ auto input_type = x.scalar_type();
+ auto weight_type = weight.scalar_type();
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
+
+ TORCH_CHECK(x.is_cuda());
+ TORCH_CHECK(weight.is_cuda());
+
+ const auto sizes = x.sizes();
+ const int batch_size = sizes[0];
+ const int dim = sizes[1];
+ const int seqlen = sizes[2];
+ const int width = weight.size(-1);
+
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
+ CHECK_SHAPE(weight, dim, width);
+
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
+
+ if (is_channel_last) {
+ TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
+ }
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
+
+
+ if (bias_.has_value()) {
+ auto bias = bias_.value();
+ TORCH_CHECK(bias.scalar_type() == weight_type);
+ TORCH_CHECK(bias.is_cuda());
+ TORCH_CHECK(bias.stride(-1) == 1);
+ CHECK_SHAPE(bias, dim);
+ }
+
+ at::Tensor out = torch::empty_like(x);
+
+ ConvParamsBase params;
+ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
+ silu_activation);
+
+ // Otherwise the kernel will be launched from cuda:0 device
+ // Cast to char to avoid compiler warning about narrowing
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
+ if (!is_channel_last) {
+ causal_conv1d_fwd_cuda(params, stream);
+ } else {
+ causal_conv1d_channellast_fwd_cuda(params, stream);
+ }
+ });
+ });
+ return out;
+}
+
+std::vector
+causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
+ const c10::optional &bias_,
+ at::Tensor &dout,
+ c10::optional &dx_,
+ bool silu_activation) {
+ auto input_type = x.scalar_type();
+ auto weight_type = weight.scalar_type();
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
+
+ TORCH_CHECK(x.is_cuda());
+ TORCH_CHECK(weight.is_cuda());
+ TORCH_CHECK(dout.is_cuda());
+
+ const auto sizes = x.sizes();
+ const int batch_size = sizes[0];
+ const int dim = sizes[1];
+ const int seqlen = sizes[2];
+ const int width = weight.size(-1);
+
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
+
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
+ CHECK_SHAPE(weight, dim, width);
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
+
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
+ if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
+ if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
+
+ if (bias_.has_value()) {
+ auto bias = bias_.value();
+ TORCH_CHECK(bias.scalar_type() == weight_type);
+ TORCH_CHECK(bias.is_cuda());
+ TORCH_CHECK(bias.stride(-1) == 1);
+ CHECK_SHAPE(bias, dim);
+ }
+
+ at::Tensor dx;
+ if (dx_.has_value()) {
+ dx = dx_.value();
+ TORCH_CHECK(dx.scalar_type() == input_type);
+ TORCH_CHECK(dx.is_cuda());
+ CHECK_SHAPE(dx, batch_size, dim, seqlen);
+ if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
+ if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
+ } else {
+ dx = torch::empty_like(x);
+ }
+
+ // Otherwise the kernel will be launched from cuda:0 device
+ // Cast to char to avoid compiler warning about narrowing
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+
+ at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
+ at::Tensor dbias;
+ if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
+
+ ConvParamsBwd params;
+ set_conv_params_bwd(params, batch_size, dim, seqlen, width,
+ x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
+ dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
+ silu_activation);
+
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
+ if (!is_channel_last) {
+ causal_conv1d_bwd_cuda(params, stream);
+ } else {
+ causal_conv1d_channellast_bwd_cuda(params, stream);
+ }
+ });
+ });
+ return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias};
+}
+
+at::Tensor
+causal_conv1d_update(const at::Tensor &x,
+ const at::Tensor &conv_state,
+ const at::Tensor &weight,
+ const c10::optional &bias_,
+ bool silu_activation) {
+ auto input_type = x.scalar_type();
+ auto weight_type = weight.scalar_type();
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
+ TORCH_CHECK(conv_state.scalar_type() == input_type);
+
+ TORCH_CHECK(x.is_cuda());
+ TORCH_CHECK(conv_state.is_cuda());
+ TORCH_CHECK(weight.is_cuda());
+
+ const auto sizes = x.sizes();
+ const int batch_size = sizes[0];
+ const int dim = sizes[1];
+ const int width = weight.size(-1);
+
+ CHECK_SHAPE(x, batch_size, dim);
+ CHECK_SHAPE(conv_state, batch_size, dim, width);
+ CHECK_SHAPE(weight, dim, width);
+
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
+
+ if (bias_.has_value()) {
+ auto bias = bias_.value();
+ TORCH_CHECK(bias.scalar_type() == weight_type);
+ TORCH_CHECK(bias.is_cuda());
+ TORCH_CHECK(bias.stride(-1) == 1);
+ CHECK_SHAPE(bias, dim);
+ }
+
+ at::Tensor out = torch::empty_like(x);
+
+ ConvParamsBase params;
+ set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
+ silu_activation);
+ params.conv_state_ptr = conv_state.data_ptr();
+ // All stride are in elements, not bytes.
+ params.conv_state_batch_stride = conv_state.stride(0);
+ params.conv_state_c_stride = conv_state.stride(1);
+ params.conv_state_l_stride = conv_state.stride(2);
+
+ // Otherwise the kernel will be launched from cuda:0 device
+ // Cast to char to avoid compiler warning about narrowing
+ at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
+ causal_conv1d_update_cuda(params, stream);
+ });
+ });
+ return out;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
+ m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
+ m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
+}
diff --git a/causal-conv1d/csrc/causal_conv1d.h b/causal-conv1d/csrc/causal_conv1d.h
new file mode 100644
index 0000000000000000000000000000000000000000..844ed92cfc91a881e58fccfca001a13ebcc434cc
--- /dev/null
+++ b/causal-conv1d/csrc/causal_conv1d.h
@@ -0,0 +1,53 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct ConvParamsBase {
+ using index_t = uint32_t;
+
+ int batch, dim, seqlen, width;
+ bool silu_activation;
+
+ index_t x_batch_stride;
+ index_t x_c_stride;
+ index_t x_l_stride;
+ index_t weight_c_stride;
+ index_t weight_width_stride;
+ index_t out_batch_stride;
+ index_t out_c_stride;
+ index_t out_l_stride;
+
+ index_t conv_state_batch_stride;
+ index_t conv_state_c_stride;
+ index_t conv_state_l_stride;
+
+ // Common data pointers.
+ void *__restrict__ x_ptr;
+ void *__restrict__ weight_ptr;
+ void *__restrict__ bias_ptr;
+ void *__restrict__ out_ptr;
+
+ void *__restrict__ conv_state_ptr;
+};
+
+struct ConvParamsBwd: public ConvParamsBase {
+ index_t dx_batch_stride;
+ index_t dx_c_stride;
+ index_t dx_l_stride;
+ index_t dweight_c_stride;
+ index_t dweight_width_stride;
+ index_t dout_batch_stride;
+ index_t dout_c_stride;
+ index_t dout_l_stride;
+
+ // Common data pointers.
+ void *__restrict__ dx_ptr;
+ void *__restrict__ dweight_ptr;
+ void *__restrict__ dbias_ptr;
+ void *__restrict__ dout_ptr;
+};
+
diff --git a/causal-conv1d/csrc/causal_conv1d_bwd.cu b/causal-conv1d/csrc/causal_conv1d_bwd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..66609750a30a86a284451871ca163d79a0529047
--- /dev/null
+++ b/causal-conv1d/csrc/causal_conv1d_bwd.cu
@@ -0,0 +1,525 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#include
+#include
+#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
+
+#include
+#include
+#include
+
+#include "causal_conv1d.h"
+#include "causal_conv1d_common.h"
+#include "static_switch.h"
+
+template
+struct Causal_conv1d_bwd_kernel_traits {
+ using input_t = input_t_;
+ using weight_t = weight_t_;
+ static constexpr int kNThreads = kNThreads_;
+ static constexpr int kWidth = kWidth_;
+ static constexpr bool kSiluAct = kSiluAct_;
+ static constexpr int kNBytes = sizeof(input_t);
+ static_assert(kNBytes == 2 || kNBytes == 4);
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
+ static_assert(kWidth <= kNElts);
+ // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
+ // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
+ static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
+ using vec_t = typename BytesToType::Type;
+ using BlockLoadT = cub::BlockLoad;
+ using BlockLoadVecT = cub::BlockLoad;
+ using BlockStoreT = cub::BlockStore;
+ using BlockStoreVecT = cub::BlockStore;
+ using BlockReduceFloatT = cub::BlockReduce;
+ static constexpr int kSmemIOSize = kIsVecLoad
+ ? 0
+ : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
+ static constexpr int kSmemSize = std::max({kSmemExchangeSize,
+ int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
+};
+
+template
+__global__ __launch_bounds__(Ktraits::kNThreads)
+void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
+ constexpr int kWidth = Ktraits::kWidth;
+ constexpr int kNThreads = Ktraits::kNThreads;
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
+ constexpr int kNElts = Ktraits::kNElts;
+ constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
+ constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
+ using input_t = typename Ktraits::input_t;
+ using vec_t = typename Ktraits::vec_t;
+ using weight_t = typename Ktraits::weight_t;
+
+ // Shared memory.
+ extern __shared__ char smem_[];
+ auto& smem_load = reinterpret_cast(smem_);
+ auto& smem_load_vec = reinterpret_cast(smem_);
+ auto& smem_store = reinterpret_cast(smem_);
+ auto& smem_store_vec = reinterpret_cast(smem_);
+ vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize);
+ vec_t *smem_exchange_x = reinterpret_cast(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
+ auto& smem_reduce_float = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize);
+
+ const int tidx = threadIdx.x;
+ const int batch_id = blockIdx.x;
+ const int dim_id = blockIdx.y;
+ input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
+ + dim_id * params.x_c_stride;
+ weight_t *weight = reinterpret_cast(params.weight_ptr) + dim_id * params.weight_c_stride;
+ input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride
+ + dim_id * params.dout_c_stride;
+ input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride
+ + dim_id * params.dx_c_stride;
+ float *dweight = reinterpret_cast(params.dweight_ptr) + dim_id * params.dweight_c_stride;
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[dim_id]);
+
+ // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
+ if (tidx == 0) {
+ if constexpr (!kSiluAct) {
+ input_t zeros[kNElts] = {0};
+ smem_exchange[0] = reinterpret_cast(zeros)[0];
+ } else {
+ float zeros[kNElts] = {0};
+ #pragma unroll
+ for (int r = 0; r < kNExchangeRounds; ++r) {
+ smem_exchange[r * kNThreads] = reinterpret_cast(zeros)[r];
+ }
+ }
+ }
+
+ float weight_vals[kWidth];
+ #pragma unroll
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
+
+ float dweight_vals[kWidth] = {0};
+ float dbias_val = 0;
+
+ constexpr int kChunkSize = kNThreads * kNElts;
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
+ x += (n_chunks - 1) * kChunkSize;
+ dout += (n_chunks - 1) * kChunkSize;
+ dx += (n_chunks - 1) * kChunkSize;
+ for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
+ input_t x_vals_load[2 * kNElts] = {0};
+ input_t dout_vals_load[2 * kNElts] = {0};
+ if constexpr(kIsVecLoad) {
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(dout), *reinterpret_cast(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
+ } else {
+ __syncthreads();
+ Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
+ __syncthreads();
+ Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
+ }
+ float dout_vals[2 * kNElts], x_vals[2 * kNElts];
+ if constexpr (!kSiluAct) {
+ __syncthreads();
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
+ // the first elements of the next chunk.
+ if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; }
+ __syncthreads();
+ reinterpret_cast(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
+ __syncthreads();
+ // Now thread 0 can write the first elements of the current chunk.
+ if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; }
+ #pragma unroll
+ for (int i = 0; i < 2 * kNElts; ++i) {
+ dout_vals[i] = float(dout_vals_load[i]);
+ x_vals[i] = float(x_vals_load[i]);
+ }
+ } else {
+ if (tidx == 0 && chunk > 0) {
+ if constexpr(kIsVecLoad) {
+ reinterpret_cast(x_vals_load)[0] = reinterpret_cast(x)[-1];
+ } else {
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) {
+ if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
+ }
+ }
+ }
+ __syncthreads();
+ smem_exchange_x[tidx] = reinterpret_cast(x_vals_load)[1];
+ __syncthreads();
+ if (tidx > 0) { reinterpret_cast(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
+ #pragma unroll
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
+ // Recompute the output
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) {
+ float out_val = bias_val;
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
+ }
+ float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
+ dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
+ * (1.0f + out_val * (1.0f - out_sigmoid_val));
+ }
+ // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
+ // if input_t is 16 bits (since then we'd have 8 values of float)
+ __syncthreads();
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
+ // the first elements of the next chunk.
+ if (tidx > 0) {
+ #pragma unroll
+ for (int r = 0; r < kNExchangeRounds; ++r) {
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r];
+ }
+ }
+ __syncthreads();
+ #pragma unroll
+ for (int r = 0; r < kNExchangeRounds; ++r) {
+ reinterpret_cast(dout_vals)[kNExchangeRounds + r]
+ = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
+ }
+ __syncthreads();
+ // Now thread 0 can write the first elements of the current chunk.
+ if (tidx == 0) {
+ #pragma unroll
+ for (int r = 0; r < kNExchangeRounds; ++r) {
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r];
+ }
+ }
+ }
+ dout -= kChunkSize;
+ x -= kChunkSize;
+
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
+
+ float dx_vals[kNElts] = {0};
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) {
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
+ }
+ }
+
+ input_t dx_vals_store[kNElts];
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
+ if constexpr(kIsVecLoad) {
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(dx), reinterpret_cast(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
+ } else {
+ Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
+ }
+ dx -= kChunkSize;
+
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) {
+ dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
+ }
+ }
+ }
+
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ __syncthreads();
+ dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
+ if (tidx == 0) {
+ atomicAdd(&reinterpret_cast(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
+ }
+ }
+ if (params.bias_ptr != nullptr) {
+ __syncthreads();
+ dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
+ if (tidx == 0) {
+ atomicAdd(&reinterpret_cast(params.dbias_ptr)[dim_id], dbias_val);
+ }
+ }
+}
+
+template
+void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
+ using Ktraits = Causal_conv1d_bwd_kernel_traits;
+ constexpr int kSmemSize = Ktraits::kSmemSize;
+ dim3 grid(params.batch, params.dim);
+ auto kernel = &causal_conv1d_bwd_kernel;
+ if (kSmemSize >= 48 * 1024) {
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+ }
+ kernel<<>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ });
+}
+
+template
+void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
+ if (params.width == 2) {
+ causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
+ } else if (params.width == 3) {
+ causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
+ } else if (params.width == 4) {
+ causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
+ }
+}
+
+template
+struct Causal_conv1d_channellast_bwd_kernel_traits {
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
+ using input_t = input_t_;
+ using weight_t = weight_t_;
+ static constexpr bool kSiluAct = kSiluAct_;
+ static constexpr int kNThreads = kNThreads_;
+ static_assert(kNThreads % 32 == 0);
+ static constexpr int kNWarps = kNThreads / 32;
+ static constexpr int kWidth = kWidth_;
+ static constexpr int kChunkSizeL = kChunkSizeL_;
+ static constexpr int kNBytes = sizeof(input_t);
+ static_assert(kNBytes == 2 || kNBytes == 4);
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
+ using vec_t = typename BytesToType::Type;
+ // using BlockLoadT = cub::BlockLoad;
+ // using BlockStoreT = cub::BlockStore;
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
+ // sizeof(typename BlockStoreT::TempStorage)});
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
+};
+
+template
+__global__ __launch_bounds__(Ktraits::kNThreads)
+void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
+ constexpr int kWidth = Ktraits::kWidth;
+ constexpr int kNThreads = Ktraits::kNThreads;
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
+ constexpr int kNElts = Ktraits::kNElts;
+ constexpr int kNWarp = Ktraits::kNWarps;
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
+ using input_t = typename Ktraits::input_t;
+ using vec_t = typename Ktraits::vec_t;
+ using weight_t = typename Ktraits::weight_t;
+
+ // Shared memory.
+ __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
+
+ const int tid = threadIdx.x;
+ const int l_idx = tid / kNThreadsPerC;
+ const int c_idx = tid % kNThreadsPerC;
+ const int batch_id = blockIdx.x;
+ const int chunk_l_id = blockIdx.y;
+ const int chunk_c_id = blockIdx.z;
+ input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
+ weight_t *weight = reinterpret_cast(params.weight_ptr)
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
+ input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
+ input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
+ float *dweight = reinterpret_cast(params.dweight_ptr)
+ + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
+
+ #pragma unroll
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
+ input_t dout_vals_load[kNElts] = {0};
+ input_t x_vals_load[kNElts] = {0};
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + l * kLPerLoad * params.dout_l_stride);
+ reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride);
+ }
+ reinterpret_cast(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0];
+ reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0];
+ }
+ // Load the elements from the previous chunk or next chunk that are needed for convolution.
+ if (l_idx < kWidth - 1) {
+ input_t dout_vals_load[kNElts] = {0};
+ input_t x_vals_load[kNElts] = {0};
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + kChunkSizeL * params.dout_l_stride);
+ }
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride);
+ }
+ reinterpret_cast(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0];
+ reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0];
+ }
+ // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
+ if constexpr (kSiluAct) {
+ if (l_idx < kWidth - 1) {
+ input_t x_vals_load[kNElts] = {0};
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + kChunkSizeL * params.x_l_stride);
+ }
+ reinterpret_cast(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0];
+ }
+ }
+
+ __syncthreads();
+
+ constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
+ static_assert(kNThreadsPerRow <= 32);
+
+ const int row_idx = tid / kNThreadsPerRow;
+ const int col_idx = tid % kNThreadsPerRow;
+
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
+ float weight_vals[kWidth] = {0};
+ if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
+ }
+ }
+ float dout_vals[kLPerThread + kWidth - 1];
+ float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
+ #pragma unroll
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
+ dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
+ }
+
+ if constexpr (kSiluAct) { // Recompute the output
+ #pragma unroll
+ for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
+ }
+ #pragma unroll
+ for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
+ float out_val = bias_val;
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; }
+ float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
+ dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
+ }
+ }
+
+ float dweight_vals[kWidth] = {0};
+ SumOp sum_op;
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ #pragma unroll
+ for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; }
+ dweight_vals[w] = Allreduce::run(dweight_vals[w], sum_op);
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
+ atomicAdd(&reinterpret_cast(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
+ }
+ }
+
+ if (params.bias_ptr != nullptr) {
+ float dbias_val = 0.f;
+ for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
+ dbias_val = Allreduce::run(dbias_val, sum_op);
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
+ atomicAdd(&reinterpret_cast(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
+ }
+ }
+
+ float dx_vals[kLPerThread] = {0};
+ #pragma unroll
+ for (int i = 0; i < kLPerThread; ++i) {
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; }
+ }
+ // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
+ __syncwarp();
+ #pragma unroll
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
+ __syncthreads();
+
+ #pragma unroll
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
+ input_t dx_vals_store[kNElts];
+ reinterpret_cast(dx_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx];
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ *reinterpret_cast(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast(dx_vals_store)[0];
+ }
+ }
+
+}
+
+template
+void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
+ using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits;
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
+ dim3 block(Ktraits::kNThreads);
+ auto kernel = &causal_conv1d_channellast_bwd_kernel;
+ // if (kSmemSize >= 48 * 1024) {
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+ // }
+ // kernel<<>>(params);
+ kernel<<>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+}
+
+template
+void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
+ if (params.width == 2) {
+ causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
+ } else if (params.width == 3) {
+ causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
+ } else if (params.width == 4) {
+ causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
+ }
+}
+
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
\ No newline at end of file
diff --git a/causal-conv1d/csrc/causal_conv1d_common.h b/causal-conv1d/csrc/causal_conv1d_common.h
new file mode 100644
index 0000000000000000000000000000000000000000..8dd6a333b52163986c085f71475709706ce8f9c3
--- /dev/null
+++ b/causal-conv1d/csrc/causal_conv1d_common.h
@@ -0,0 +1,64 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include
+#include
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template struct BytesToType {};
+
+template<> struct BytesToType<16> {
+ using Type = uint4;
+ static_assert(sizeof(Type) == 16);
+};
+
+template<> struct BytesToType<8> {
+ using Type = uint64_t;
+ static_assert(sizeof(Type) == 8);
+};
+
+template<> struct BytesToType<4> {
+ using Type = uint32_t;
+ static_assert(sizeof(Type) == 4);
+};
+
+template<> struct BytesToType<2> {
+ using Type = uint16_t;
+ static_assert(sizeof(Type) == 2);
+};
+
+template<> struct BytesToType<1> {
+ using Type = uint8_t;
+ static_assert(sizeof(Type) == 1);
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct SumOp {
+__device__ inline T operator()(T const & x, T const & y) { return x + y; }
+};
+
+template
+struct Allreduce {
+ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
+ template
+ static __device__ inline T run(T x, Operator &op) {
+ constexpr int OFFSET = THREADS / 2;
+ x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
+ return Allreduce::run(x, op);
+ }
+};
+
+template<>
+struct Allreduce<2> {
+template
+static __device__ inline T run(T x, Operator &op) {
+ x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
+ return x;
+}
+};
diff --git a/causal-conv1d/csrc/causal_conv1d_fwd.cu b/causal-conv1d/csrc/causal_conv1d_fwd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..74a1459f88a87ef427075a25e5081899e382efc0
--- /dev/null
+++ b/causal-conv1d/csrc/causal_conv1d_fwd.cu
@@ -0,0 +1,350 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#include
+#include
+#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
+
+#include
+#include
+
+#include "causal_conv1d.h"
+#include "causal_conv1d_common.h"
+#include "static_switch.h"
+
+template
+struct Causal_conv1d_fwd_kernel_traits {
+ using input_t = input_t_;
+ using weight_t = weight_t_;
+ static constexpr int kNThreads = kNThreads_;
+ static constexpr int kWidth = kWidth_;
+ static constexpr int kNBytes = sizeof(input_t);
+ static_assert(kNBytes == 2 || kNBytes == 4);
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
+ static_assert(kWidth <= kNElts);
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
+ using vec_t = typename BytesToType::Type;
+ using BlockLoadT = cub::BlockLoad;
+ using BlockLoadVecT = cub::BlockLoad;
+ using BlockStoreT = cub::BlockStore;
+ using BlockStoreVecT = cub::BlockStore;
+ static constexpr int kSmemIOSize = kIsVecLoad
+ ? 0
+ : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
+};
+
+template
+__global__ __launch_bounds__(Ktraits::kNThreads)
+void causal_conv1d_fwd_kernel(ConvParamsBase params) {
+ constexpr int kWidth = Ktraits::kWidth;
+ constexpr int kNThreads = Ktraits::kNThreads;
+ constexpr int kNElts = Ktraits::kNElts;
+ constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
+ using input_t = typename Ktraits::input_t;
+ using vec_t = typename Ktraits::vec_t;
+ using weight_t = typename Ktraits::weight_t;
+
+ // Shared memory.
+ extern __shared__ char smem_[];
+ auto& smem_load = reinterpret_cast(smem_);
+ auto& smem_load_vec = reinterpret_cast(smem_);
+ auto& smem_store = reinterpret_cast(smem_);
+ auto& smem_store_vec = reinterpret_cast(smem_);
+ vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize);
+
+ const int tidx = threadIdx.x;
+ const int batch_id = blockIdx.x;
+ const int channel_id = blockIdx.y;
+ input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
+ + channel_id * params.x_c_stride;
+ weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride;
+ input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride
+ + channel_id * params.out_c_stride;
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]);
+
+ // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
+ if (tidx == 0) {
+ input_t zeros[kNElts] = {0};
+ smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0];
+ }
+
+ float weight_vals[kWidth];
+ #pragma unroll
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
+
+ constexpr int kChunkSize = kNThreads * kNElts;
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
+ for (int chunk = 0; chunk < n_chunks; ++chunk) {
+ input_t x_vals_load[2 * kNElts] = {0};
+ if constexpr(kIsVecLoad) {
+ Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
+ } else {
+ __syncthreads();
+ Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
+ }
+ x += kChunkSize;
+ __syncthreads();
+ // Thread kNThreads - 1 don't write yet, so that thread 0 can read
+ // the last elements of the previous chunk.
+ if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; }
+ __syncthreads();
+ reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
+ __syncthreads();
+ // Now thread kNThreads - 1 can write the last elements of the current chunk.
+ if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; }
+
+ float x_vals[2 * kNElts];
+ #pragma unroll
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
+
+ float out_vals[kNElts];
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) {
+ out_vals[i] = bias_val;
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
+ }
+ }
+
+ if (params.silu_activation) {
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) {
+ out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
+ }
+ }
+
+ input_t out_vals_store[kNElts];
+ #pragma unroll
+ for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
+ if constexpr(kIsVecLoad) {
+ Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
+ } else {
+ Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
+ }
+ out += kChunkSize;
+ }
+}
+
+template
+void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
+ using Ktraits = Causal_conv1d_fwd_kernel_traits;
+ constexpr int kSmemSize = Ktraits::kSmemSize;
+ dim3 grid(params.batch, params.dim);
+ auto kernel = &causal_conv1d_fwd_kernel;
+ if (kSmemSize >= 48 * 1024) {
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+ }
+ kernel<<>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+}
+
+template
+void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
+ if (params.width == 2) {
+ causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
+ } else if (params.width == 3) {
+ causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
+ } else if (params.width == 4) {
+ causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
+ }
+}
+
+template
+struct Causal_conv1d_channellast_fwd_kernel_traits {
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
+ using input_t = input_t_;
+ using weight_t = weight_t_;
+ static constexpr int kNThreads = kNThreads_;
+ static_assert(kNThreads % 32 == 0);
+ static constexpr int kNWarps = kNThreads / 32;
+ static constexpr int kWidth = kWidth_;
+ static constexpr int kChunkSizeL = kChunkSizeL_;
+ static constexpr int kNBytes = sizeof(input_t);
+ static_assert(kNBytes == 2 || kNBytes == 4);
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
+ using vec_t = typename BytesToType::Type;
+ // using BlockLoadT = cub::BlockLoad;
+ // using BlockStoreT = cub::BlockStore;
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
+ // sizeof(typename BlockStoreT::TempStorage)});
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
+};
+
+template
+__global__ __launch_bounds__(Ktraits::kNThreads)
+void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
+ constexpr int kWidth = Ktraits::kWidth;
+ constexpr int kNThreads = Ktraits::kNThreads;
+ constexpr int kNElts = Ktraits::kNElts;
+ constexpr int kNWarp = Ktraits::kNWarps;
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
+ using input_t = typename Ktraits::input_t;
+ using vec_t = typename Ktraits::vec_t;
+ using weight_t = typename Ktraits::weight_t;
+
+ // Shared memory.
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
+
+ const int tid = threadIdx.x;
+ const int l_idx = tid / kNThreadsPerC;
+ const int c_idx = tid % kNThreadsPerC;
+ const int batch_id = blockIdx.x;
+ const int chunk_l_id = blockIdx.y;
+ const int chunk_c_id = blockIdx.z;
+ input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
+ weight_t *weight = reinterpret_cast(params.weight_ptr)
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
+ input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
+
+ #pragma unroll
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
+ input_t x_vals_load[kNElts] = {0};
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride);
+ }
+ reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0];
+ }
+ // Load the elements from the previous chunk that are needed for convolution.
+ if (l_idx < kWidth - 1) {
+ input_t x_vals_load[kNElts] = {0};
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride);
+ }
+ reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0];
+ }
+
+ __syncthreads();
+
+ constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
+ static_assert(kNThreadsPerRow <= 32);
+
+ const int row_idx = tid / kNThreadsPerRow;
+ const int col_idx = tid % kNThreadsPerRow;
+
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
+ float weight_vals[kWidth] = {0};
+ if (chunk_c_id + kChunkSizeC + row_idx < params.dim) {
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) {
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
+ }
+ }
+ float x_vals[kWidth - 1 + kLPerThread];
+ #pragma unroll
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
+ }
+
+ float out_vals[kLPerThread];
+ #pragma unroll
+ for (int i = 0; i < kLPerThread; ++i) {
+ out_vals[i] = bias_val;
+ #pragma unroll
+ for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; }
+ if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
+ }
+
+ // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
+ __syncwarp();
+ #pragma unroll
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
+ __syncthreads();
+
+ #pragma unroll
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
+ input_t out_vals_store[kNElts];
+ reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx];
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
+ *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0];
+ }
+ }
+
+}
+
+template
+void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
+ using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits;
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
+ // printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C);
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
+ dim3 block(Ktraits::kNThreads);
+ auto kernel = &causal_conv1d_channellast_fwd_kernel;
+ // if (kSmemSize >= 48 * 1024) {
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+ // }
+ // kernel<<>>(params);
+ kernel<<>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template
+void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
+ if (params.width == 2) {
+ causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
+ } else if (params.width == 3) {
+ causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
+ } else if (params.width == 4) {
+ causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
+ }
+}
+
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
\ No newline at end of file
diff --git a/causal-conv1d/csrc/causal_conv1d_update.cu b/causal-conv1d/csrc/causal_conv1d_update.cu
new file mode 100644
index 0000000000000000000000000000000000000000..713e0ac883853491f9bdb0015b578657c228c1e7
--- /dev/null
+++ b/causal-conv1d/csrc/causal_conv1d_update.cu
@@ -0,0 +1,96 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#include
+#include
+#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
+
+#include
+#include
+
+#include "causal_conv1d.h"
+#include "causal_conv1d_common.h"
+#include "static_switch.h"
+
+template
+struct Causal_conv1d_update_kernel_traits {
+ using input_t = input_t_;
+ using weight_t = weight_t_;
+ static constexpr int kNThreads = kNThreads_;
+ static constexpr int kWidth = kWidth_;
+ static constexpr int kNBytes = sizeof(input_t);
+ static_assert(kNBytes == 2 || kNBytes == 4);
+};
+
+template
+__global__ __launch_bounds__(Ktraits::kNThreads)
+void causal_conv1d_update_kernel(ConvParamsBase params) {
+ constexpr int kWidth = Ktraits::kWidth;
+ constexpr int kNThreads = Ktraits::kNThreads;
+ using input_t = typename Ktraits::input_t;
+ using weight_t = typename Ktraits::weight_t;
+
+ const int tidx = threadIdx.x;
+ const int batch_id = blockIdx.x;
+ const int channel_id = blockIdx.y * kNThreads + tidx;
+ input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
+ + channel_id * params.x_c_stride;
+ input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
+ + channel_id * params.conv_state_c_stride;
+ weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride;
+ input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride
+ + channel_id * params.out_c_stride;
+ float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]);
+
+ float weight_vals[kWidth] = {0};
+ if (channel_id < params.dim) {
+ #pragma unroll
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
+ }
+
+ float x_vals[kWidth] = {0};
+ if (channel_id < params.dim) {
+ #pragma unroll
+ for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
+ x_vals[kWidth - 1] = float(x[0]);
+ #pragma unroll
+ for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
+ }
+
+ float out_val = bias_val;
+ #pragma unroll
+ for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
+ if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
+ if (channel_id < params.dim) { out[0] = input_t(out_val); }
+}
+
+template
+void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
+ using Ktraits = Causal_conv1d_update_kernel_traits;
+ dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
+ auto kernel = &causal_conv1d_update_kernel;
+ kernel<<>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template
+void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
+ if (params.width == 2) {
+ causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
+ } else if (params.width == 3) {
+ causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
+ } else if (params.width == 4) {
+ causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
+ }
+}
+
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
+template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
\ No newline at end of file
diff --git a/causal-conv1d/csrc/static_switch.h b/causal-conv1d/csrc/static_switch.h
new file mode 100644
index 0000000000000000000000000000000000000000..0f4ad3eb62235443d15c454b6691c2ec63645219
--- /dev/null
+++ b/causal-conv1d/csrc/static_switch.h
@@ -0,0 +1,25 @@
+// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
+// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
+
+#pragma once
+
+/// @param COND - a boolean expression to switch by
+/// @param CONST_NAME - a name given for the constexpr bool variable.
+/// @param ... - code to execute for true and false
+///
+/// Usage:
+/// ```
+/// BOOL_SWITCH(flag, BoolConst, [&] {
+/// some_function(...);
+/// });
+/// ```
+#define BOOL_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ if (COND) { \
+ static constexpr bool CONST_NAME = true; \
+ return __VA_ARGS__(); \
+ } else { \
+ static constexpr bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ } \
+ }()
diff --git a/causal-conv1d/setup.py b/causal-conv1d/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..12e36bf988215a4c536278026e6f4401e66534da
--- /dev/null
+++ b/causal-conv1d/setup.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2023, Tri Dao.
+import sys
+import warnings
+import os
+import re
+import ast
+from pathlib import Path
+from packaging.version import parse, Version
+import platform
+
+from setuptools import setup, find_packages
+import subprocess
+
+import urllib.request
+import urllib.error
+from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
+
+import torch
+from torch.utils.cpp_extension import (
+ BuildExtension,
+ CppExtension,
+ CUDAExtension,
+ CUDA_HOME,
+)
+
+
+with open("README.md", "r", encoding="utf-8") as fh:
+ long_description = fh.read()
+
+
+# ninja build does not work unless include_dirs are abs path
+this_dir = os.path.dirname(os.path.abspath(__file__))
+
+PACKAGE_NAME = "causal_conv1d"
+
+BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
+
+# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
+# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
+FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
+SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
+# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
+FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
+
+
+def get_platform():
+ """
+ Returns the platform name as used in wheel filenames.
+ """
+ if sys.platform.startswith("linux"):
+ return "linux_x86_64"
+ elif sys.platform == "darwin":
+ mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
+ return f"macosx_{mac_version}_x86_64"
+ elif sys.platform == "win32":
+ return "win_amd64"
+ else:
+ raise ValueError("Unsupported platform: {}".format(sys.platform))
+
+
+def get_cuda_bare_metal_version(cuda_dir):
+ raw_output = subprocess.check_output(
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
+ )
+ output = raw_output.split()
+ release_idx = output.index("release") + 1
+ bare_metal_version = parse(output[release_idx].split(",")[0])
+
+ return raw_output, bare_metal_version
+
+
+def check_if_cuda_home_none(global_option: str) -> None:
+ if CUDA_HOME is not None:
+ return
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
+ # in that case.
+ warnings.warn(
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
+ "only images whose names contain 'devel' will provide nvcc."
+ )
+
+
+def append_nvcc_threads(nvcc_extra_args):
+ return nvcc_extra_args + ["--threads", "4"]
+
+
+cmdclass = {}
+ext_modules = []
+
+if not SKIP_CUDA_BUILD:
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
+
+ check_if_cuda_home_none("causal_conv1d")
+ # Check, if CUDA11 is installed for compute capability 8.0
+ cc_flag = []
+ if CUDA_HOME is not None:
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
+ if bare_metal_version < Version("11.6"):
+ raise RuntimeError(
+ "causal_conv1d is only supported on CUDA 11.6 and above. "
+ "Note: make sure nvcc has a supported version by running nvcc -V."
+ )
+
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_70,code=sm_70")
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_80,code=sm_80")
+ if bare_metal_version >= Version("11.8"):
+ cc_flag.append("-gencode")
+ cc_flag.append("arch=compute_90,code=sm_90")
+
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
+ # torch._C._GLIBCXX_USE_CXX11_ABI
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
+ if FORCE_CXX11_ABI:
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
+
+ ext_modules.append(
+ CUDAExtension(
+ name="causal_conv1d_cuda",
+ sources=[
+ "csrc/causal_conv1d.cpp",
+ "csrc/causal_conv1d_fwd.cu",
+ "csrc/causal_conv1d_bwd.cu",
+ "csrc/causal_conv1d_update.cu",
+ ],
+ extra_compile_args={
+ "cxx": ["-O3"],
+ "nvcc": append_nvcc_threads(
+ [
+ "-O3",
+ "-U__CUDA_NO_HALF_OPERATORS__",
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
+ "--expt-relaxed-constexpr",
+ "--expt-extended-lambda",
+ "--use_fast_math",
+ "--ptxas-options=-v",
+ "-lineinfo",
+ ]
+ + cc_flag
+ ),
+ },
+ include_dirs=[this_dir],
+ )
+ )
+
+
+def get_package_version():
+ with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
+ public_version = ast.literal_eval(version_match.group(1))
+ local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
+ if local_version:
+ return f"{public_version}+{local_version}"
+ else:
+ return str(public_version)
+
+
+def get_wheel_url():
+ # Determine the version numbers that will be used to determine the correct wheel
+ # We're using the CUDA version used to build torch, not the one currently installed
+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
+ torch_cuda_version = parse(torch.version.cuda)
+ torch_version_raw = parse(torch.__version__)
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
+ # to save CI time. Minor versions should be compatible.
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
+ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
+ platform_name = get_platform()
+ causal_conv1d_version = get_package_version()
+ # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
+ cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
+ torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
+ cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
+
+ # Determine wheel URL based on CUDA version, torch version, python version and OS
+ wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
+ wheel_url = BASE_WHEEL_URL.format(
+ tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
+ )
+ return wheel_url, wheel_filename
+
+
+class CachedWheelsCommand(_bdist_wheel):
+ """
+ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
+ find an existing wheel (which is currently the case for all installs). We use
+ the environment parameters to detect whether there is already a pre-built version of a compatible
+ wheel available and short-circuits the standard full build pipeline.
+ """
+
+ def run(self):
+ if FORCE_BUILD:
+ return super().run()
+
+ wheel_url, wheel_filename = get_wheel_url()
+ print("Guessing wheel URL: ", wheel_url)
+ try:
+ urllib.request.urlretrieve(wheel_url, wheel_filename)
+
+ # Make the archive
+ # Lifted from the root wheel processing command
+ # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
+ if not os.path.exists(self.dist_dir):
+ os.makedirs(self.dist_dir)
+
+ impl_tag, abi_tag, plat_tag = self.get_tag()
+ archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
+
+ wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
+ print("Raw wheel path", wheel_path)
+ os.rename(wheel_filename, wheel_path)
+ except urllib.error.HTTPError:
+ print("Precompiled wheel not found. Building from source...")
+ # If the wheel could not be downloaded, build from source
+ super().run()
+
+
+setup(
+ name=PACKAGE_NAME,
+ version=get_package_version(),
+ packages=find_packages(
+ exclude=(
+ "build",
+ "csrc",
+ "include",
+ "tests",
+ "dist",
+ "docs",
+ "benchmarks",
+ "causal_conv1d.egg-info",
+ )
+ ),
+ author="Tri Dao",
+ author_email="tri@tridao.me",
+ description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url="https://github.com/Dao-AILab/causal-conv1d",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: BSD License",
+ "Operating System :: Unix",
+ ],
+ ext_modules=ext_modules,
+ cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
+ if ext_modules
+ else {
+ "bdist_wheel": CachedWheelsCommand,
+ },
+ python_requires=">=3.7",
+ install_requires=[
+ "torch",
+ "packaging",
+ "ninja",
+ ],
+)
diff --git a/causal-conv1d/tests/test_causal_conv1d.py b/causal-conv1d/tests/test_causal_conv1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e5985cfb0582e6656afb1d8b5c1de78f24f4276
--- /dev/null
+++ b/causal-conv1d/tests/test_causal_conv1d.py
@@ -0,0 +1,173 @@
+# Copyright (C) 2023, Tri Dao.
+
+import math
+
+import torch
+import pytest
+
+from einops import rearrange
+
+from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
+from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
+
+
+@pytest.mark.parametrize("channel_last", [False, True])
+# @pytest.mark.parametrize('channel_last', [True])
+@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
+# @pytest.mark.parametrize('itype', [torch.float16])
+@pytest.mark.parametrize("silu_activation", [False, True])
+# @pytest.mark.parametrize('silu_activation', [True])
+@pytest.mark.parametrize("has_bias", [False, True])
+# @pytest.mark.parametrize('has_bias', [True])
+@pytest.mark.parametrize("width", [2, 3, 4])
+# @pytest.mark.parametrize('width', [2])
+@pytest.mark.parametrize(
+ "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
+)
+# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
+# @pytest.mark.parametrize('seqlen', [128])
+def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
+ device = "cuda"
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
+ if itype == torch.bfloat16:
+ rtol, atol = 1e-2, 5e-2
+ rtolw, atolw = (1e-3, 1e-3)
+ # set seed
+ torch.random.manual_seed(0)
+ batch_size = 2
+ # batch_size = 1
+ dim = 4096 + 32 # Try dim not divisible by 64
+ # dim = 64
+ if not channel_last:
+ x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
+ else:
+ x = rearrange(
+ torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
+ ).requires_grad_()
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
+ if has_bias:
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
+ else:
+ bias = None
+ x_ref = x.detach().clone().requires_grad_()
+ weight_ref = weight.detach().clone().requires_grad_()
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
+ activation = None if not silu_activation else "silu"
+ out = causal_conv1d_fn(x, weight, bias, activation=activation)
+ out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
+
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
+
+ g = torch.randn_like(out)
+ out_ref.backward(g)
+ out.backward(g)
+
+ print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
+ print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
+ if has_bias:
+ print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
+
+ assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
+ assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
+ if has_bias:
+ assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
+
+
+@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
+# @pytest.mark.parametrize('itype', [torch.float16])
+@pytest.mark.parametrize("silu_activation", [False, True])
+# @pytest.mark.parametrize('silu_activation', [False])
+@pytest.mark.parametrize("has_bias", [False, True])
+# @pytest.mark.parametrize('has_bias', [True])
+@pytest.mark.parametrize("width", [2, 3, 4])
+# @pytest.mark.parametrize('width', [2])
+@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
+# @pytest.mark.parametrize("dim", [2048])
+def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
+ device = "cuda"
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
+ if itype == torch.bfloat16:
+ rtol, atol = 1e-2, 5e-2
+ rtolw, atolw = (1e-3, 1e-3)
+ # set seed
+ torch.random.manual_seed(0)
+ batch_size = 2
+ # batch_size = 1
+ # dim = 64
+ x = torch.randn(batch_size, dim, device=device, dtype=itype)
+ conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
+ if has_bias:
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
+ else:
+ bias = None
+ conv_state_ref = conv_state.detach().clone()
+ activation = None if not silu_activation else "silu"
+ out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
+ out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
+
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+ assert torch.equal(conv_state, conv_state_ref)
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
+
+
+# @pytest.mark.parametrize("channel_last", [False, True])
+@pytest.mark.parametrize('channel_last', [True])
+# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
+@pytest.mark.parametrize('itype', [torch.bfloat16])
+# @pytest.mark.parametrize("silu_activation", [False, True])
+@pytest.mark.parametrize('silu_activation', [True])
+# @pytest.mark.parametrize("has_bias", [False, True])
+@pytest.mark.parametrize('has_bias', [True])
+# @pytest.mark.parametrize("width", [2, 3, 4])
+@pytest.mark.parametrize('width', [4])
+@pytest.mark.parametrize(
+ # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
+ "seqlen", [2048]
+)
+# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
+# @pytest.mark.parametrize('seqlen', [128])
+def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
+ device = "cuda"
+ # set seed
+ torch.random.manual_seed(0)
+ batch_size = 2
+ # batch_size = 1
+ dim = 4096 + 32 # Try dim not divisible by 64
+ # dim = 64
+ if not channel_last:
+ x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
+ else:
+ x = rearrange(
+ torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
+ ).requires_grad_()
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
+ if has_bias:
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
+ else:
+ bias = None
+ activation = None if not silu_activation else "silu"
+ out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
+ g = torch.randn_like(out0)
+ dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
+ dw_atol = 1e-4
+ db_atol = 1e-4
+
+ for i in range(10000):
+ out = causal_conv1d_fn(x, weight, bias, activation=activation)
+ dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
+ dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
+ # if not dw_equal:
+ # breakpoint()
+ if has_bias:
+ db_equal = torch.allclose(db, db0, atol=db_atol)
+ # if not db_equal:
+ # breakpoint()
+ assert torch.equal(out, out0)
+ assert torch.equal(dx, dx0)
+ assert dw_equal
+ if has_bias:
+ assert dw_equal
diff --git a/imagenet_class_index.py b/imagenet_class_index.py
new file mode 100644
index 0000000000000000000000000000000000000000..5407d1471197fd9bfa466c47d8dfb6683cb9f551
--- /dev/null
+++ b/imagenet_class_index.py
@@ -0,0 +1,1002 @@
+imagenet_classnames = {
+ "0": ["n01440764", "tench"],
+ "1": ["n01443537", "goldfish"],
+ "2": ["n01484850", "great_white_shark"],
+ "3": ["n01491361", "tiger_shark"],
+ "4": ["n01494475", "hammerhead"],
+ "5": ["n01496331", "electric_ray"],
+ "6": ["n01498041", "stingray"],
+ "7": ["n01514668", "cock"],
+ "8": ["n01514859", "hen"],
+ "9": ["n01518878", "ostrich"],
+ "10": ["n01530575", "brambling"],
+ "11": ["n01531178", "goldfinch"],
+ "12": ["n01532829", "house_finch"],
+ "13": ["n01534433", "junco"],
+ "14": ["n01537544", "indigo_bunting"],
+ "15": ["n01558993", "robin"],
+ "16": ["n01560419", "bulbul"],
+ "17": ["n01580077", "jay"],
+ "18": ["n01582220", "magpie"],
+ "19": ["n01592084", "chickadee"],
+ "20": ["n01601694", "water_ouzel"],
+ "21": ["n01608432", "kite"],
+ "22": ["n01614925", "bald_eagle"],
+ "23": ["n01616318", "vulture"],
+ "24": ["n01622779", "great_grey_owl"],
+ "25": ["n01629819", "European_fire_salamander"],
+ "26": ["n01630670", "common_newt"],
+ "27": ["n01631663", "eft"],
+ "28": ["n01632458", "spotted_salamander"],
+ "29": ["n01632777", "axolotl"],
+ "30": ["n01641577", "bullfrog"],
+ "31": ["n01644373", "tree_frog"],
+ "32": ["n01644900", "tailed_frog"],
+ "33": ["n01664065", "loggerhead"],
+ "34": ["n01665541", "leatherback_turtle"],
+ "35": ["n01667114", "mud_turtle"],
+ "36": ["n01667778", "terrapin"],
+ "37": ["n01669191", "box_turtle"],
+ "38": ["n01675722", "banded_gecko"],
+ "39": ["n01677366", "common_iguana"],
+ "40": ["n01682714", "American_chameleon"],
+ "41": ["n01685808", "whiptail"],
+ "42": ["n01687978", "agama"],
+ "43": ["n01688243", "frilled_lizard"],
+ "44": ["n01689811", "alligator_lizard"],
+ "45": ["n01692333", "Gila_monster"],
+ "46": ["n01693334", "green_lizard"],
+ "47": ["n01694178", "African_chameleon"],
+ "48": ["n01695060", "Komodo_dragon"],
+ "49": ["n01697457", "African_crocodile"],
+ "50": ["n01698640", "American_alligator"],
+ "51": ["n01704323", "triceratops"],
+ "52": ["n01728572", "thunder_snake"],
+ "53": ["n01728920", "ringneck_snake"],
+ "54": ["n01729322", "hognose_snake"],
+ "55": ["n01729977", "green_snake"],
+ "56": ["n01734418", "king_snake"],
+ "57": ["n01735189", "garter_snake"],
+ "58": ["n01737021", "water_snake"],
+ "59": ["n01739381", "vine_snake"],
+ "60": ["n01740131", "night_snake"],
+ "61": ["n01742172", "boa_constrictor"],
+ "62": ["n01744401", "rock_python"],
+ "63": ["n01748264", "Indian_cobra"],
+ "64": ["n01749939", "green_mamba"],
+ "65": ["n01751748", "sea_snake"],
+ "66": ["n01753488", "horned_viper"],
+ "67": ["n01755581", "diamondback"],
+ "68": ["n01756291", "sidewinder"],
+ "69": ["n01768244", "trilobite"],
+ "70": ["n01770081", "harvestman"],
+ "71": ["n01770393", "scorpion"],
+ "72": ["n01773157", "black_and_gold_garden_spider"],
+ "73": ["n01773549", "barn_spider"],
+ "74": ["n01773797", "garden_spider"],
+ "75": ["n01774384", "black_widow"],
+ "76": ["n01774750", "tarantula"],
+ "77": ["n01775062", "wolf_spider"],
+ "78": ["n01776313", "tick"],
+ "79": ["n01784675", "centipede"],
+ "80": ["n01795545", "black_grouse"],
+ "81": ["n01796340", "ptarmigan"],
+ "82": ["n01797886", "ruffed_grouse"],
+ "83": ["n01798484", "prairie_chicken"],
+ "84": ["n01806143", "peacock"],
+ "85": ["n01806567", "quail"],
+ "86": ["n01807496", "partridge"],
+ "87": ["n01817953", "African_grey"],
+ "88": ["n01818515", "macaw"],
+ "89": ["n01819313", "sulphur-crested_cockatoo"],
+ "90": ["n01820546", "lorikeet"],
+ "91": ["n01824575", "coucal"],
+ "92": ["n01828970", "bee_eater"],
+ "93": ["n01829413", "hornbill"],
+ "94": ["n01833805", "hummingbird"],
+ "95": ["n01843065", "jacamar"],
+ "96": ["n01843383", "toucan"],
+ "97": ["n01847000", "drake"],
+ "98": ["n01855032", "red-breasted_merganser"],
+ "99": ["n01855672", "goose"],
+ "100": ["n01860187", "black_swan"],
+ "101": ["n01871265", "tusker"],
+ "102": ["n01872401", "echidna"],
+ "103": ["n01873310", "platypus"],
+ "104": ["n01877812", "wallaby"],
+ "105": ["n01882714", "koala"],
+ "106": ["n01883070", "wombat"],
+ "107": ["n01910747", "jellyfish"],
+ "108": ["n01914609", "sea_anemone"],
+ "109": ["n01917289", "brain_coral"],
+ "110": ["n01924916", "flatworm"],
+ "111": ["n01930112", "nematode"],
+ "112": ["n01943899", "conch"],
+ "113": ["n01944390", "snail"],
+ "114": ["n01945685", "slug"],
+ "115": ["n01950731", "sea_slug"],
+ "116": ["n01955084", "chiton"],
+ "117": ["n01968897", "chambered_nautilus"],
+ "118": ["n01978287", "Dungeness_crab"],
+ "119": ["n01978455", "rock_crab"],
+ "120": ["n01980166", "fiddler_crab"],
+ "121": ["n01981276", "king_crab"],
+ "122": ["n01983481", "American_lobster"],
+ "123": ["n01984695", "spiny_lobster"],
+ "124": ["n01985128", "crayfish"],
+ "125": ["n01986214", "hermit_crab"],
+ "126": ["n01990800", "isopod"],
+ "127": ["n02002556", "white_stork"],
+ "128": ["n02002724", "black_stork"],
+ "129": ["n02006656", "spoonbill"],
+ "130": ["n02007558", "flamingo"],
+ "131": ["n02009229", "little_blue_heron"],
+ "132": ["n02009912", "American_egret"],
+ "133": ["n02011460", "bittern"],
+ "134": ["n02012849", "crane"],
+ "135": ["n02013706", "limpkin"],
+ "136": ["n02017213", "European_gallinule"],
+ "137": ["n02018207", "American_coot"],
+ "138": ["n02018795", "bustard"],
+ "139": ["n02025239", "ruddy_turnstone"],
+ "140": ["n02027492", "red-backed_sandpiper"],
+ "141": ["n02028035", "redshank"],
+ "142": ["n02033041", "dowitcher"],
+ "143": ["n02037110", "oystercatcher"],
+ "144": ["n02051845", "pelican"],
+ "145": ["n02056570", "king_penguin"],
+ "146": ["n02058221", "albatross"],
+ "147": ["n02066245", "grey_whale"],
+ "148": ["n02071294", "killer_whale"],
+ "149": ["n02074367", "dugong"],
+ "150": ["n02077923", "sea_lion"],
+ "151": ["n02085620", "Chihuahua"],
+ "152": ["n02085782", "Japanese_spaniel"],
+ "153": ["n02085936", "Maltese_dog"],
+ "154": ["n02086079", "Pekinese"],
+ "155": ["n02086240", "Shih-Tzu"],
+ "156": ["n02086646", "Blenheim_spaniel"],
+ "157": ["n02086910", "papillon"],
+ "158": ["n02087046", "toy_terrier"],
+ "159": ["n02087394", "Rhodesian_ridgeback"],
+ "160": ["n02088094", "Afghan_hound"],
+ "161": ["n02088238", "basset"],
+ "162": ["n02088364", "beagle"],
+ "163": ["n02088466", "bloodhound"],
+ "164": ["n02088632", "bluetick"],
+ "165": ["n02089078", "black-and-tan_coonhound"],
+ "166": ["n02089867", "Walker_hound"],
+ "167": ["n02089973", "English_foxhound"],
+ "168": ["n02090379", "redbone"],
+ "169": ["n02090622", "borzoi"],
+ "170": ["n02090721", "Irish_wolfhound"],
+ "171": ["n02091032", "Italian_greyhound"],
+ "172": ["n02091134", "whippet"],
+ "173": ["n02091244", "Ibizan_hound"],
+ "174": ["n02091467", "Norwegian_elkhound"],
+ "175": ["n02091635", "otterhound"],
+ "176": ["n02091831", "Saluki"],
+ "177": ["n02092002", "Scottish_deerhound"],
+ "178": ["n02092339", "Weimaraner"],
+ "179": ["n02093256", "Staffordshire_bullterrier"],
+ "180": ["n02093428", "American_Staffordshire_terrier"],
+ "181": ["n02093647", "Bedlington_terrier"],
+ "182": ["n02093754", "Border_terrier"],
+ "183": ["n02093859", "Kerry_blue_terrier"],
+ "184": ["n02093991", "Irish_terrier"],
+ "185": ["n02094114", "Norfolk_terrier"],
+ "186": ["n02094258", "Norwich_terrier"],
+ "187": ["n02094433", "Yorkshire_terrier"],
+ "188": ["n02095314", "wire-haired_fox_terrier"],
+ "189": ["n02095570", "Lakeland_terrier"],
+ "190": ["n02095889", "Sealyham_terrier"],
+ "191": ["n02096051", "Airedale"],
+ "192": ["n02096177", "cairn"],
+ "193": ["n02096294", "Australian_terrier"],
+ "194": ["n02096437", "Dandie_Dinmont"],
+ "195": ["n02096585", "Boston_bull"],
+ "196": ["n02097047", "miniature_schnauzer"],
+ "197": ["n02097130", "giant_schnauzer"],
+ "198": ["n02097209", "standard_schnauzer"],
+ "199": ["n02097298", "Scotch_terrier"],
+ "200": ["n02097474", "Tibetan_terrier"],
+ "201": ["n02097658", "silky_terrier"],
+ "202": ["n02098105", "soft-coated_wheaten_terrier"],
+ "203": ["n02098286", "West_Highland_white_terrier"],
+ "204": ["n02098413", "Lhasa"],
+ "205": ["n02099267", "flat-coated_retriever"],
+ "206": ["n02099429", "curly-coated_retriever"],
+ "207": ["n02099601", "golden_retriever"],
+ "208": ["n02099712", "Labrador_retriever"],
+ "209": ["n02099849", "Chesapeake_Bay_retriever"],
+ "210": ["n02100236", "German_short-haired_pointer"],
+ "211": ["n02100583", "vizsla"],
+ "212": ["n02100735", "English_setter"],
+ "213": ["n02100877", "Irish_setter"],
+ "214": ["n02101006", "Gordon_setter"],
+ "215": ["n02101388", "Brittany_spaniel"],
+ "216": ["n02101556", "clumber"],
+ "217": ["n02102040", "English_springer"],
+ "218": ["n02102177", "Welsh_springer_spaniel"],
+ "219": ["n02102318", "cocker_spaniel"],
+ "220": ["n02102480", "Sussex_spaniel"],
+ "221": ["n02102973", "Irish_water_spaniel"],
+ "222": ["n02104029", "kuvasz"],
+ "223": ["n02104365", "schipperke"],
+ "224": ["n02105056", "groenendael"],
+ "225": ["n02105162", "malinois"],
+ "226": ["n02105251", "briard"],
+ "227": ["n02105412", "kelpie"],
+ "228": ["n02105505", "komondor"],
+ "229": ["n02105641", "Old_English_sheepdog"],
+ "230": ["n02105855", "Shetland_sheepdog"],
+ "231": ["n02106030", "collie"],
+ "232": ["n02106166", "Border_collie"],
+ "233": ["n02106382", "Bouvier_des_Flandres"],
+ "234": ["n02106550", "Rottweiler"],
+ "235": ["n02106662", "German_shepherd"],
+ "236": ["n02107142", "Doberman"],
+ "237": ["n02107312", "miniature_pinscher"],
+ "238": ["n02107574", "Greater_Swiss_Mountain_dog"],
+ "239": ["n02107683", "Bernese_mountain_dog"],
+ "240": ["n02107908", "Appenzeller"],
+ "241": ["n02108000", "EntleBucher"],
+ "242": ["n02108089", "boxer"],
+ "243": ["n02108422", "bull_mastiff"],
+ "244": ["n02108551", "Tibetan_mastiff"],
+ "245": ["n02108915", "French_bulldog"],
+ "246": ["n02109047", "Great_Dane"],
+ "247": ["n02109525", "Saint_Bernard"],
+ "248": ["n02109961", "Eskimo_dog"],
+ "249": ["n02110063", "malamute"],
+ "250": ["n02110185", "Siberian_husky"],
+ "251": ["n02110341", "dalmatian"],
+ "252": ["n02110627", "affenpinscher"],
+ "253": ["n02110806", "basenji"],
+ "254": ["n02110958", "pug"],
+ "255": ["n02111129", "Leonberg"],
+ "256": ["n02111277", "Newfoundland"],
+ "257": ["n02111500", "Great_Pyrenees"],
+ "258": ["n02111889", "Samoyed"],
+ "259": ["n02112018", "Pomeranian"],
+ "260": ["n02112137", "chow"],
+ "261": ["n02112350", "keeshond"],
+ "262": ["n02112706", "Brabancon_griffon"],
+ "263": ["n02113023", "Pembroke"],
+ "264": ["n02113186", "Cardigan"],
+ "265": ["n02113624", "toy_poodle"],
+ "266": ["n02113712", "miniature_poodle"],
+ "267": ["n02113799", "standard_poodle"],
+ "268": ["n02113978", "Mexican_hairless"],
+ "269": ["n02114367", "timber_wolf"],
+ "270": ["n02114548", "white_wolf"],
+ "271": ["n02114712", "red_wolf"],
+ "272": ["n02114855", "coyote"],
+ "273": ["n02115641", "dingo"],
+ "274": ["n02115913", "dhole"],
+ "275": ["n02116738", "African_hunting_dog"],
+ "276": ["n02117135", "hyena"],
+ "277": ["n02119022", "red_fox"],
+ "278": ["n02119789", "kit_fox"],
+ "279": ["n02120079", "Arctic_fox"],
+ "280": ["n02120505", "grey_fox"],
+ "281": ["n02123045", "tabby"],
+ "282": ["n02123159", "tiger_cat"],
+ "283": ["n02123394", "Persian_cat"],
+ "284": ["n02123597", "Siamese_cat"],
+ "285": ["n02124075", "Egyptian_cat"],
+ "286": ["n02125311", "cougar"],
+ "287": ["n02127052", "lynx"],
+ "288": ["n02128385", "leopard"],
+ "289": ["n02128757", "snow_leopard"],
+ "290": ["n02128925", "jaguar"],
+ "291": ["n02129165", "lion"],
+ "292": ["n02129604", "tiger"],
+ "293": ["n02130308", "cheetah"],
+ "294": ["n02132136", "brown_bear"],
+ "295": ["n02133161", "American_black_bear"],
+ "296": ["n02134084", "ice_bear"],
+ "297": ["n02134418", "sloth_bear"],
+ "298": ["n02137549", "mongoose"],
+ "299": ["n02138441", "meerkat"],
+ "300": ["n02165105", "tiger_beetle"],
+ "301": ["n02165456", "ladybug"],
+ "302": ["n02167151", "ground_beetle"],
+ "303": ["n02168699", "long-horned_beetle"],
+ "304": ["n02169497", "leaf_beetle"],
+ "305": ["n02172182", "dung_beetle"],
+ "306": ["n02174001", "rhinoceros_beetle"],
+ "307": ["n02177972", "weevil"],
+ "308": ["n02190166", "fly"],
+ "309": ["n02206856", "bee"],
+ "310": ["n02219486", "ant"],
+ "311": ["n02226429", "grasshopper"],
+ "312": ["n02229544", "cricket"],
+ "313": ["n02231487", "walking_stick"],
+ "314": ["n02233338", "cockroach"],
+ "315": ["n02236044", "mantis"],
+ "316": ["n02256656", "cicada"],
+ "317": ["n02259212", "leafhopper"],
+ "318": ["n02264363", "lacewing"],
+ "319": ["n02268443", "dragonfly"],
+ "320": ["n02268853", "damselfly"],
+ "321": ["n02276258", "admiral"],
+ "322": ["n02277742", "ringlet"],
+ "323": ["n02279972", "monarch"],
+ "324": ["n02280649", "cabbage_butterfly"],
+ "325": ["n02281406", "sulphur_butterfly"],
+ "326": ["n02281787", "lycaenid"],
+ "327": ["n02317335", "starfish"],
+ "328": ["n02319095", "sea_urchin"],
+ "329": ["n02321529", "sea_cucumber"],
+ "330": ["n02325366", "wood_rabbit"],
+ "331": ["n02326432", "hare"],
+ "332": ["n02328150", "Angora"],
+ "333": ["n02342885", "hamster"],
+ "334": ["n02346627", "porcupine"],
+ "335": ["n02356798", "fox_squirrel"],
+ "336": ["n02361337", "marmot"],
+ "337": ["n02363005", "beaver"],
+ "338": ["n02364673", "guinea_pig"],
+ "339": ["n02389026", "sorrel"],
+ "340": ["n02391049", "zebra"],
+ "341": ["n02395406", "hog"],
+ "342": ["n02396427", "wild_boar"],
+ "343": ["n02397096", "warthog"],
+ "344": ["n02398521", "hippopotamus"],
+ "345": ["n02403003", "ox"],
+ "346": ["n02408429", "water_buffalo"],
+ "347": ["n02410509", "bison"],
+ "348": ["n02412080", "ram"],
+ "349": ["n02415577", "bighorn"],
+ "350": ["n02417914", "ibex"],
+ "351": ["n02422106", "hartebeest"],
+ "352": ["n02422699", "impala"],
+ "353": ["n02423022", "gazelle"],
+ "354": ["n02437312", "Arabian_camel"],
+ "355": ["n02437616", "llama"],
+ "356": ["n02441942", "weasel"],
+ "357": ["n02442845", "mink"],
+ "358": ["n02443114", "polecat"],
+ "359": ["n02443484", "black-footed_ferret"],
+ "360": ["n02444819", "otter"],
+ "361": ["n02445715", "skunk"],
+ "362": ["n02447366", "badger"],
+ "363": ["n02454379", "armadillo"],
+ "364": ["n02457408", "three-toed_sloth"],
+ "365": ["n02480495", "orangutan"],
+ "366": ["n02480855", "gorilla"],
+ "367": ["n02481823", "chimpanzee"],
+ "368": ["n02483362", "gibbon"],
+ "369": ["n02483708", "siamang"],
+ "370": ["n02484975", "guenon"],
+ "371": ["n02486261", "patas"],
+ "372": ["n02486410", "baboon"],
+ "373": ["n02487347", "macaque"],
+ "374": ["n02488291", "langur"],
+ "375": ["n02488702", "colobus"],
+ "376": ["n02489166", "proboscis_monkey"],
+ "377": ["n02490219", "marmoset"],
+ "378": ["n02492035", "capuchin"],
+ "379": ["n02492660", "howler_monkey"],
+ "380": ["n02493509", "titi"],
+ "381": ["n02493793", "spider_monkey"],
+ "382": ["n02494079", "squirrel_monkey"],
+ "383": ["n02497673", "Madagascar_cat"],
+ "384": ["n02500267", "indri"],
+ "385": ["n02504013", "Indian_elephant"],
+ "386": ["n02504458", "African_elephant"],
+ "387": ["n02509815", "lesser_panda"],
+ "388": ["n02510455", "giant_panda"],
+ "389": ["n02514041", "barracouta"],
+ "390": ["n02526121", "eel"],
+ "391": ["n02536864", "coho"],
+ "392": ["n02606052", "rock_beauty"],
+ "393": ["n02607072", "anemone_fish"],
+ "394": ["n02640242", "sturgeon"],
+ "395": ["n02641379", "gar"],
+ "396": ["n02643566", "lionfish"],
+ "397": ["n02655020", "puffer"],
+ "398": ["n02666196", "abacus"],
+ "399": ["n02667093", "abaya"],
+ "400": ["n02669723", "academic_gown"],
+ "401": ["n02672831", "accordion"],
+ "402": ["n02676566", "acoustic_guitar"],
+ "403": ["n02687172", "aircraft_carrier"],
+ "404": ["n02690373", "airliner"],
+ "405": ["n02692877", "airship"],
+ "406": ["n02699494", "altar"],
+ "407": ["n02701002", "ambulance"],
+ "408": ["n02704792", "amphibian"],
+ "409": ["n02708093", "analog_clock"],
+ "410": ["n02727426", "apiary"],
+ "411": ["n02730930", "apron"],
+ "412": ["n02747177", "ashcan"],
+ "413": ["n02749479", "assault_rifle"],
+ "414": ["n02769748", "backpack"],
+ "415": ["n02776631", "bakery"],
+ "416": ["n02777292", "balance_beam"],
+ "417": ["n02782093", "balloon"],
+ "418": ["n02783161", "ballpoint"],
+ "419": ["n02786058", "Band_Aid"],
+ "420": ["n02787622", "banjo"],
+ "421": ["n02788148", "bannister"],
+ "422": ["n02790996", "barbell"],
+ "423": ["n02791124", "barber_chair"],
+ "424": ["n02791270", "barbershop"],
+ "425": ["n02793495", "barn"],
+ "426": ["n02794156", "barometer"],
+ "427": ["n02795169", "barrel"],
+ "428": ["n02797295", "barrow"],
+ "429": ["n02799071", "baseball"],
+ "430": ["n02802426", "basketball"],
+ "431": ["n02804414", "bassinet"],
+ "432": ["n02804610", "bassoon"],
+ "433": ["n02807133", "bathing_cap"],
+ "434": ["n02808304", "bath_towel"],
+ "435": ["n02808440", "bathtub"],
+ "436": ["n02814533", "beach_wagon"],
+ "437": ["n02814860", "beacon"],
+ "438": ["n02815834", "beaker"],
+ "439": ["n02817516", "bearskin"],
+ "440": ["n02823428", "beer_bottle"],
+ "441": ["n02823750", "beer_glass"],
+ "442": ["n02825657", "bell_cote"],
+ "443": ["n02834397", "bib"],
+ "444": ["n02835271", "bicycle-built-for-two"],
+ "445": ["n02837789", "bikini"],
+ "446": ["n02840245", "binder"],
+ "447": ["n02841315", "binoculars"],
+ "448": ["n02843684", "birdhouse"],
+ "449": ["n02859443", "boathouse"],
+ "450": ["n02860847", "bobsled"],
+ "451": ["n02865351", "bolo_tie"],
+ "452": ["n02869837", "bonnet"],
+ "453": ["n02870880", "bookcase"],
+ "454": ["n02871525", "bookshop"],
+ "455": ["n02877765", "bottlecap"],
+ "456": ["n02879718", "bow"],
+ "457": ["n02883205", "bow_tie"],
+ "458": ["n02892201", "brass"],
+ "459": ["n02892767", "brassiere"],
+ "460": ["n02894605", "breakwater"],
+ "461": ["n02895154", "breastplate"],
+ "462": ["n02906734", "broom"],
+ "463": ["n02909870", "bucket"],
+ "464": ["n02910353", "buckle"],
+ "465": ["n02916936", "bulletproof_vest"],
+ "466": ["n02917067", "bullet_train"],
+ "467": ["n02927161", "butcher_shop"],
+ "468": ["n02930766", "cab"],
+ "469": ["n02939185", "caldron"],
+ "470": ["n02948072", "candle"],
+ "471": ["n02950826", "cannon"],
+ "472": ["n02951358", "canoe"],
+ "473": ["n02951585", "can_opener"],
+ "474": ["n02963159", "cardigan"],
+ "475": ["n02965783", "car_mirror"],
+ "476": ["n02966193", "carousel"],
+ "477": ["n02966687", "carpenter's_kit"],
+ "478": ["n02971356", "carton"],
+ "479": ["n02974003", "car_wheel"],
+ "480": ["n02977058", "cash_machine"],
+ "481": ["n02978881", "cassette"],
+ "482": ["n02979186", "cassette_player"],
+ "483": ["n02980441", "castle"],
+ "484": ["n02981792", "catamaran"],
+ "485": ["n02988304", "CD_player"],
+ "486": ["n02992211", "cello"],
+ "487": ["n02992529", "cellular_telephone"],
+ "488": ["n02999410", "chain"],
+ "489": ["n03000134", "chainlink_fence"],
+ "490": ["n03000247", "chain_mail"],
+ "491": ["n03000684", "chain_saw"],
+ "492": ["n03014705", "chest"],
+ "493": ["n03016953", "chiffonier"],
+ "494": ["n03017168", "chime"],
+ "495": ["n03018349", "china_cabinet"],
+ "496": ["n03026506", "Christmas_stocking"],
+ "497": ["n03028079", "church"],
+ "498": ["n03032252", "cinema"],
+ "499": ["n03041632", "cleaver"],
+ "500": ["n03042490", "cliff_dwelling"],
+ "501": ["n03045698", "cloak"],
+ "502": ["n03047690", "clog"],
+ "503": ["n03062245", "cocktail_shaker"],
+ "504": ["n03063599", "coffee_mug"],
+ "505": ["n03063689", "coffeepot"],
+ "506": ["n03065424", "coil"],
+ "507": ["n03075370", "combination_lock"],
+ "508": ["n03085013", "computer_keyboard"],
+ "509": ["n03089624", "confectionery"],
+ "510": ["n03095699", "container_ship"],
+ "511": ["n03100240", "convertible"],
+ "512": ["n03109150", "corkscrew"],
+ "513": ["n03110669", "cornet"],
+ "514": ["n03124043", "cowboy_boot"],
+ "515": ["n03124170", "cowboy_hat"],
+ "516": ["n03125729", "cradle"],
+ "517": ["n03126707", "crane"],
+ "518": ["n03127747", "crash_helmet"],
+ "519": ["n03127925", "crate"],
+ "520": ["n03131574", "crib"],
+ "521": ["n03133878", "Crock_Pot"],
+ "522": ["n03134739", "croquet_ball"],
+ "523": ["n03141823", "crutch"],
+ "524": ["n03146219", "cuirass"],
+ "525": ["n03160309", "dam"],
+ "526": ["n03179701", "desk"],
+ "527": ["n03180011", "desktop_computer"],
+ "528": ["n03187595", "dial_telephone"],
+ "529": ["n03188531", "diaper"],
+ "530": ["n03196217", "digital_clock"],
+ "531": ["n03197337", "digital_watch"],
+ "532": ["n03201208", "dining_table"],
+ "533": ["n03207743", "dishrag"],
+ "534": ["n03207941", "dishwasher"],
+ "535": ["n03208938", "disk_brake"],
+ "536": ["n03216828", "dock"],
+ "537": ["n03218198", "dogsled"],
+ "538": ["n03220513", "dome"],
+ "539": ["n03223299", "doormat"],
+ "540": ["n03240683", "drilling_platform"],
+ "541": ["n03249569", "drum"],
+ "542": ["n03250847", "drumstick"],
+ "543": ["n03255030", "dumbbell"],
+ "544": ["n03259280", "Dutch_oven"],
+ "545": ["n03271574", "electric_fan"],
+ "546": ["n03272010", "electric_guitar"],
+ "547": ["n03272562", "electric_locomotive"],
+ "548": ["n03290653", "entertainment_center"],
+ "549": ["n03291819", "envelope"],
+ "550": ["n03297495", "espresso_maker"],
+ "551": ["n03314780", "face_powder"],
+ "552": ["n03325584", "feather_boa"],
+ "553": ["n03337140", "file"],
+ "554": ["n03344393", "fireboat"],
+ "555": ["n03345487", "fire_engine"],
+ "556": ["n03347037", "fire_screen"],
+ "557": ["n03355925", "flagpole"],
+ "558": ["n03372029", "flute"],
+ "559": ["n03376595", "folding_chair"],
+ "560": ["n03379051", "football_helmet"],
+ "561": ["n03384352", "forklift"],
+ "562": ["n03388043", "fountain"],
+ "563": ["n03388183", "fountain_pen"],
+ "564": ["n03388549", "four-poster"],
+ "565": ["n03393912", "freight_car"],
+ "566": ["n03394916", "French_horn"],
+ "567": ["n03400231", "frying_pan"],
+ "568": ["n03404251", "fur_coat"],
+ "569": ["n03417042", "garbage_truck"],
+ "570": ["n03424325", "gasmask"],
+ "571": ["n03425413", "gas_pump"],
+ "572": ["n03443371", "goblet"],
+ "573": ["n03444034", "go-kart"],
+ "574": ["n03445777", "golf_ball"],
+ "575": ["n03445924", "golfcart"],
+ "576": ["n03447447", "gondola"],
+ "577": ["n03447721", "gong"],
+ "578": ["n03450230", "gown"],
+ "579": ["n03452741", "grand_piano"],
+ "580": ["n03457902", "greenhouse"],
+ "581": ["n03459775", "grille"],
+ "582": ["n03461385", "grocery_store"],
+ "583": ["n03467068", "guillotine"],
+ "584": ["n03476684", "hair_slide"],
+ "585": ["n03476991", "hair_spray"],
+ "586": ["n03478589", "half_track"],
+ "587": ["n03481172", "hammer"],
+ "588": ["n03482405", "hamper"],
+ "589": ["n03483316", "hand_blower"],
+ "590": ["n03485407", "hand-held_computer"],
+ "591": ["n03485794", "handkerchief"],
+ "592": ["n03492542", "hard_disc"],
+ "593": ["n03494278", "harmonica"],
+ "594": ["n03495258", "harp"],
+ "595": ["n03496892", "harvester"],
+ "596": ["n03498962", "hatchet"],
+ "597": ["n03527444", "holster"],
+ "598": ["n03529860", "home_theater"],
+ "599": ["n03530642", "honeycomb"],
+ "600": ["n03532672", "hook"],
+ "601": ["n03534580", "hoopskirt"],
+ "602": ["n03535780", "horizontal_bar"],
+ "603": ["n03538406", "horse_cart"],
+ "604": ["n03544143", "hourglass"],
+ "605": ["n03584254", "iPod"],
+ "606": ["n03584829", "iron"],
+ "607": ["n03590841", "jack-o'-lantern"],
+ "608": ["n03594734", "jean"],
+ "609": ["n03594945", "jeep"],
+ "610": ["n03595614", "jersey"],
+ "611": ["n03598930", "jigsaw_puzzle"],
+ "612": ["n03599486", "jinrikisha"],
+ "613": ["n03602883", "joystick"],
+ "614": ["n03617480", "kimono"],
+ "615": ["n03623198", "knee_pad"],
+ "616": ["n03627232", "knot"],
+ "617": ["n03630383", "lab_coat"],
+ "618": ["n03633091", "ladle"],
+ "619": ["n03637318", "lampshade"],
+ "620": ["n03642806", "laptop"],
+ "621": ["n03649909", "lawn_mower"],
+ "622": ["n03657121", "lens_cap"],
+ "623": ["n03658185", "letter_opener"],
+ "624": ["n03661043", "library"],
+ "625": ["n03662601", "lifeboat"],
+ "626": ["n03666591", "lighter"],
+ "627": ["n03670208", "limousine"],
+ "628": ["n03673027", "liner"],
+ "629": ["n03676483", "lipstick"],
+ "630": ["n03680355", "Loafer"],
+ "631": ["n03690938", "lotion"],
+ "632": ["n03691459", "loudspeaker"],
+ "633": ["n03692522", "loupe"],
+ "634": ["n03697007", "lumbermill"],
+ "635": ["n03706229", "magnetic_compass"],
+ "636": ["n03709823", "mailbag"],
+ "637": ["n03710193", "mailbox"],
+ "638": ["n03710637", "maillot"],
+ "639": ["n03710721", "maillot"],
+ "640": ["n03717622", "manhole_cover"],
+ "641": ["n03720891", "maraca"],
+ "642": ["n03721384", "marimba"],
+ "643": ["n03724870", "mask"],
+ "644": ["n03729826", "matchstick"],
+ "645": ["n03733131", "maypole"],
+ "646": ["n03733281", "maze"],
+ "647": ["n03733805", "measuring_cup"],
+ "648": ["n03742115", "medicine_chest"],
+ "649": ["n03743016", "megalith"],
+ "650": ["n03759954", "microphone"],
+ "651": ["n03761084", "microwave"],
+ "652": ["n03763968", "military_uniform"],
+ "653": ["n03764736", "milk_can"],
+ "654": ["n03769881", "minibus"],
+ "655": ["n03770439", "miniskirt"],
+ "656": ["n03770679", "minivan"],
+ "657": ["n03773504", "missile"],
+ "658": ["n03775071", "mitten"],
+ "659": ["n03775546", "mixing_bowl"],
+ "660": ["n03776460", "mobile_home"],
+ "661": ["n03777568", "Model_T"],
+ "662": ["n03777754", "modem"],
+ "663": ["n03781244", "monastery"],
+ "664": ["n03782006", "monitor"],
+ "665": ["n03785016", "moped"],
+ "666": ["n03786901", "mortar"],
+ "667": ["n03787032", "mortarboard"],
+ "668": ["n03788195", "mosque"],
+ "669": ["n03788365", "mosquito_net"],
+ "670": ["n03791053", "motor_scooter"],
+ "671": ["n03792782", "mountain_bike"],
+ "672": ["n03792972", "mountain_tent"],
+ "673": ["n03793489", "mouse"],
+ "674": ["n03794056", "mousetrap"],
+ "675": ["n03796401", "moving_van"],
+ "676": ["n03803284", "muzzle"],
+ "677": ["n03804744", "nail"],
+ "678": ["n03814639", "neck_brace"],
+ "679": ["n03814906", "necklace"],
+ "680": ["n03825788", "nipple"],
+ "681": ["n03832673", "notebook"],
+ "682": ["n03837869", "obelisk"],
+ "683": ["n03838899", "oboe"],
+ "684": ["n03840681", "ocarina"],
+ "685": ["n03841143", "odometer"],
+ "686": ["n03843555", "oil_filter"],
+ "687": ["n03854065", "organ"],
+ "688": ["n03857828", "oscilloscope"],
+ "689": ["n03866082", "overskirt"],
+ "690": ["n03868242", "oxcart"],
+ "691": ["n03868863", "oxygen_mask"],
+ "692": ["n03871628", "packet"],
+ "693": ["n03873416", "paddle"],
+ "694": ["n03874293", "paddlewheel"],
+ "695": ["n03874599", "padlock"],
+ "696": ["n03876231", "paintbrush"],
+ "697": ["n03877472", "pajama"],
+ "698": ["n03877845", "palace"],
+ "699": ["n03884397", "panpipe"],
+ "700": ["n03887697", "paper_towel"],
+ "701": ["n03888257", "parachute"],
+ "702": ["n03888605", "parallel_bars"],
+ "703": ["n03891251", "park_bench"],
+ "704": ["n03891332", "parking_meter"],
+ "705": ["n03895866", "passenger_car"],
+ "706": ["n03899768", "patio"],
+ "707": ["n03902125", "pay-phone"],
+ "708": ["n03903868", "pedestal"],
+ "709": ["n03908618", "pencil_box"],
+ "710": ["n03908714", "pencil_sharpener"],
+ "711": ["n03916031", "perfume"],
+ "712": ["n03920288", "Petri_dish"],
+ "713": ["n03924679", "photocopier"],
+ "714": ["n03929660", "pick"],
+ "715": ["n03929855", "pickelhaube"],
+ "716": ["n03930313", "picket_fence"],
+ "717": ["n03930630", "pickup"],
+ "718": ["n03933933", "pier"],
+ "719": ["n03935335", "piggy_bank"],
+ "720": ["n03937543", "pill_bottle"],
+ "721": ["n03938244", "pillow"],
+ "722": ["n03942813", "ping-pong_ball"],
+ "723": ["n03944341", "pinwheel"],
+ "724": ["n03947888", "pirate"],
+ "725": ["n03950228", "pitcher"],
+ "726": ["n03954731", "plane"],
+ "727": ["n03956157", "planetarium"],
+ "728": ["n03958227", "plastic_bag"],
+ "729": ["n03961711", "plate_rack"],
+ "730": ["n03967562", "plow"],
+ "731": ["n03970156", "plunger"],
+ "732": ["n03976467", "Polaroid_camera"],
+ "733": ["n03976657", "pole"],
+ "734": ["n03977966", "police_van"],
+ "735": ["n03980874", "poncho"],
+ "736": ["n03982430", "pool_table"],
+ "737": ["n03983396", "pop_bottle"],
+ "738": ["n03991062", "pot"],
+ "739": ["n03992509", "potter's_wheel"],
+ "740": ["n03995372", "power_drill"],
+ "741": ["n03998194", "prayer_rug"],
+ "742": ["n04004767", "printer"],
+ "743": ["n04005630", "prison"],
+ "744": ["n04008634", "projectile"],
+ "745": ["n04009552", "projector"],
+ "746": ["n04019541", "puck"],
+ "747": ["n04023962", "punching_bag"],
+ "748": ["n04026417", "purse"],
+ "749": ["n04033901", "quill"],
+ "750": ["n04033995", "quilt"],
+ "751": ["n04037443", "racer"],
+ "752": ["n04039381", "racket"],
+ "753": ["n04040759", "radiator"],
+ "754": ["n04041544", "radio"],
+ "755": ["n04044716", "radio_telescope"],
+ "756": ["n04049303", "rain_barrel"],
+ "757": ["n04065272", "recreational_vehicle"],
+ "758": ["n04067472", "reel"],
+ "759": ["n04069434", "reflex_camera"],
+ "760": ["n04070727", "refrigerator"],
+ "761": ["n04074963", "remote_control"],
+ "762": ["n04081281", "restaurant"],
+ "763": ["n04086273", "revolver"],
+ "764": ["n04090263", "rifle"],
+ "765": ["n04099969", "rocking_chair"],
+ "766": ["n04111531", "rotisserie"],
+ "767": ["n04116512", "rubber_eraser"],
+ "768": ["n04118538", "rugby_ball"],
+ "769": ["n04118776", "rule"],
+ "770": ["n04120489", "running_shoe"],
+ "771": ["n04125021", "safe"],
+ "772": ["n04127249", "safety_pin"],
+ "773": ["n04131690", "saltshaker"],
+ "774": ["n04133789", "sandal"],
+ "775": ["n04136333", "sarong"],
+ "776": ["n04141076", "sax"],
+ "777": ["n04141327", "scabbard"],
+ "778": ["n04141975", "scale"],
+ "779": ["n04146614", "school_bus"],
+ "780": ["n04147183", "schooner"],
+ "781": ["n04149813", "scoreboard"],
+ "782": ["n04152593", "screen"],
+ "783": ["n04153751", "screw"],
+ "784": ["n04154565", "screwdriver"],
+ "785": ["n04162706", "seat_belt"],
+ "786": ["n04179913", "sewing_machine"],
+ "787": ["n04192698", "shield"],
+ "788": ["n04200800", "shoe_shop"],
+ "789": ["n04201297", "shoji"],
+ "790": ["n04204238", "shopping_basket"],
+ "791": ["n04204347", "shopping_cart"],
+ "792": ["n04208210", "shovel"],
+ "793": ["n04209133", "shower_cap"],
+ "794": ["n04209239", "shower_curtain"],
+ "795": ["n04228054", "ski"],
+ "796": ["n04229816", "ski_mask"],
+ "797": ["n04235860", "sleeping_bag"],
+ "798": ["n04238763", "slide_rule"],
+ "799": ["n04239074", "sliding_door"],
+ "800": ["n04243546", "slot"],
+ "801": ["n04251144", "snorkel"],
+ "802": ["n04252077", "snowmobile"],
+ "803": ["n04252225", "snowplow"],
+ "804": ["n04254120", "soap_dispenser"],
+ "805": ["n04254680", "soccer_ball"],
+ "806": ["n04254777", "sock"],
+ "807": ["n04258138", "solar_dish"],
+ "808": ["n04259630", "sombrero"],
+ "809": ["n04263257", "soup_bowl"],
+ "810": ["n04264628", "space_bar"],
+ "811": ["n04265275", "space_heater"],
+ "812": ["n04266014", "space_shuttle"],
+ "813": ["n04270147", "spatula"],
+ "814": ["n04273569", "speedboat"],
+ "815": ["n04275548", "spider_web"],
+ "816": ["n04277352", "spindle"],
+ "817": ["n04285008", "sports_car"],
+ "818": ["n04286575", "spotlight"],
+ "819": ["n04296562", "stage"],
+ "820": ["n04310018", "steam_locomotive"],
+ "821": ["n04311004", "steel_arch_bridge"],
+ "822": ["n04311174", "steel_drum"],
+ "823": ["n04317175", "stethoscope"],
+ "824": ["n04325704", "stole"],
+ "825": ["n04326547", "stone_wall"],
+ "826": ["n04328186", "stopwatch"],
+ "827": ["n04330267", "stove"],
+ "828": ["n04332243", "strainer"],
+ "829": ["n04335435", "streetcar"],
+ "830": ["n04336792", "stretcher"],
+ "831": ["n04344873", "studio_couch"],
+ "832": ["n04346328", "stupa"],
+ "833": ["n04347754", "submarine"],
+ "834": ["n04350905", "suit"],
+ "835": ["n04355338", "sundial"],
+ "836": ["n04355933", "sunglass"],
+ "837": ["n04356056", "sunglasses"],
+ "838": ["n04357314", "sunscreen"],
+ "839": ["n04366367", "suspension_bridge"],
+ "840": ["n04367480", "swab"],
+ "841": ["n04370456", "sweatshirt"],
+ "842": ["n04371430", "swimming_trunks"],
+ "843": ["n04371774", "swing"],
+ "844": ["n04372370", "switch"],
+ "845": ["n04376876", "syringe"],
+ "846": ["n04380533", "table_lamp"],
+ "847": ["n04389033", "tank"],
+ "848": ["n04392985", "tape_player"],
+ "849": ["n04398044", "teapot"],
+ "850": ["n04399382", "teddy"],
+ "851": ["n04404412", "television"],
+ "852": ["n04409515", "tennis_ball"],
+ "853": ["n04417672", "thatch"],
+ "854": ["n04418357", "theater_curtain"],
+ "855": ["n04423845", "thimble"],
+ "856": ["n04428191", "thresher"],
+ "857": ["n04429376", "throne"],
+ "858": ["n04435653", "tile_roof"],
+ "859": ["n04442312", "toaster"],
+ "860": ["n04443257", "tobacco_shop"],
+ "861": ["n04447861", "toilet_seat"],
+ "862": ["n04456115", "torch"],
+ "863": ["n04458633", "totem_pole"],
+ "864": ["n04461696", "tow_truck"],
+ "865": ["n04462240", "toyshop"],
+ "866": ["n04465501", "tractor"],
+ "867": ["n04467665", "trailer_truck"],
+ "868": ["n04476259", "tray"],
+ "869": ["n04479046", "trench_coat"],
+ "870": ["n04482393", "tricycle"],
+ "871": ["n04483307", "trimaran"],
+ "872": ["n04485082", "tripod"],
+ "873": ["n04486054", "triumphal_arch"],
+ "874": ["n04487081", "trolleybus"],
+ "875": ["n04487394", "trombone"],
+ "876": ["n04493381", "tub"],
+ "877": ["n04501370", "turnstile"],
+ "878": ["n04505470", "typewriter_keyboard"],
+ "879": ["n04507155", "umbrella"],
+ "880": ["n04509417", "unicycle"],
+ "881": ["n04515003", "upright"],
+ "882": ["n04517823", "vacuum"],
+ "883": ["n04522168", "vase"],
+ "884": ["n04523525", "vault"],
+ "885": ["n04525038", "velvet"],
+ "886": ["n04525305", "vending_machine"],
+ "887": ["n04532106", "vestment"],
+ "888": ["n04532670", "viaduct"],
+ "889": ["n04536866", "violin"],
+ "890": ["n04540053", "volleyball"],
+ "891": ["n04542943", "waffle_iron"],
+ "892": ["n04548280", "wall_clock"],
+ "893": ["n04548362", "wallet"],
+ "894": ["n04550184", "wardrobe"],
+ "895": ["n04552348", "warplane"],
+ "896": ["n04553703", "washbasin"],
+ "897": ["n04554684", "washer"],
+ "898": ["n04557648", "water_bottle"],
+ "899": ["n04560804", "water_jug"],
+ "900": ["n04562935", "water_tower"],
+ "901": ["n04579145", "whiskey_jug"],
+ "902": ["n04579432", "whistle"],
+ "903": ["n04584207", "wig"],
+ "904": ["n04589890", "window_screen"],
+ "905": ["n04590129", "window_shade"],
+ "906": ["n04591157", "Windsor_tie"],
+ "907": ["n04591713", "wine_bottle"],
+ "908": ["n04592741", "wing"],
+ "909": ["n04596742", "wok"],
+ "910": ["n04597913", "wooden_spoon"],
+ "911": ["n04599235", "wool"],
+ "912": ["n04604644", "worm_fence"],
+ "913": ["n04606251", "wreck"],
+ "914": ["n04612504", "yawl"],
+ "915": ["n04613696", "yurt"],
+ "916": ["n06359193", "web_site"],
+ "917": ["n06596364", "comic_book"],
+ "918": ["n06785654", "crossword_puzzle"],
+ "919": ["n06794110", "street_sign"],
+ "920": ["n06874185", "traffic_light"],
+ "921": ["n07248320", "book_jacket"],
+ "922": ["n07565083", "menu"],
+ "923": ["n07579787", "plate"],
+ "924": ["n07583066", "guacamole"],
+ "925": ["n07584110", "consomme"],
+ "926": ["n07590611", "hot_pot"],
+ "927": ["n07613480", "trifle"],
+ "928": ["n07614500", "ice_cream"],
+ "929": ["n07615774", "ice_lolly"],
+ "930": ["n07684084", "French_loaf"],
+ "931": ["n07693725", "bagel"],
+ "932": ["n07695742", "pretzel"],
+ "933": ["n07697313", "cheeseburger"],
+ "934": ["n07697537", "hotdog"],
+ "935": ["n07711569", "mashed_potato"],
+ "936": ["n07714571", "head_cabbage"],
+ "937": ["n07714990", "broccoli"],
+ "938": ["n07715103", "cauliflower"],
+ "939": ["n07716358", "zucchini"],
+ "940": ["n07716906", "spaghetti_squash"],
+ "941": ["n07717410", "acorn_squash"],
+ "942": ["n07717556", "butternut_squash"],
+ "943": ["n07718472", "cucumber"],
+ "944": ["n07718747", "artichoke"],
+ "945": ["n07720875", "bell_pepper"],
+ "946": ["n07730033", "cardoon"],
+ "947": ["n07734744", "mushroom"],
+ "948": ["n07742313", "Granny_Smith"],
+ "949": ["n07745940", "strawberry"],
+ "950": ["n07747607", "orange"],
+ "951": ["n07749582", "lemon"],
+ "952": ["n07753113", "fig"],
+ "953": ["n07753275", "pineapple"],
+ "954": ["n07753592", "banana"],
+ "955": ["n07754684", "jackfruit"],
+ "956": ["n07760859", "custard_apple"],
+ "957": ["n07768694", "pomegranate"],
+ "958": ["n07802026", "hay"],
+ "959": ["n07831146", "carbonara"],
+ "960": ["n07836838", "chocolate_sauce"],
+ "961": ["n07860988", "dough"],
+ "962": ["n07871810", "meat_loaf"],
+ "963": ["n07873807", "pizza"],
+ "964": ["n07875152", "potpie"],
+ "965": ["n07880968", "burrito"],
+ "966": ["n07892512", "red_wine"],
+ "967": ["n07920052", "espresso"],
+ "968": ["n07930864", "cup"],
+ "969": ["n07932039", "eggnog"],
+ "970": ["n09193705", "alp"],
+ "971": ["n09229709", "bubble"],
+ "972": ["n09246464", "cliff"],
+ "973": ["n09256479", "coral_reef"],
+ "974": ["n09288635", "geyser"],
+ "975": ["n09332890", "lakeside"],
+ "976": ["n09399592", "promontory"],
+ "977": ["n09421951", "sandbar"],
+ "978": ["n09428293", "seashore"],
+ "979": ["n09468604", "valley"],
+ "980": ["n09472597", "volcano"],
+ "981": ["n09835506", "ballplayer"],
+ "982": ["n10148035", "groom"],
+ "983": ["n10565667", "scuba_diver"],
+ "984": ["n11879895", "rapeseed"],
+ "985": ["n11939491", "daisy"],
+ "986": ["n12057211", "yellow_lady's_slipper"],
+ "987": ["n12144580", "corn"],
+ "988": ["n12267677", "acorn"],
+ "989": ["n12620546", "hip"],
+ "990": ["n12768682", "buckeye"],
+ "991": ["n12985857", "coral_fungus"],
+ "992": ["n12998815", "agaric"],
+ "993": ["n13037406", "gyromitra"],
+ "994": ["n13040303", "stinkhorn"],
+ "995": ["n13044778", "earthstar"],
+ "996": ["n13052670", "hen-of-the-woods"],
+ "997": ["n13054560", "bolete"],
+ "998": ["n13133613", "ear"],
+ "999": ["n15075141", "toilet_tissue"]
+}
\ No newline at end of file
diff --git a/images/cat.png b/images/cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c869201dbfa331abf49c21435186d96d8a9f18b
Binary files /dev/null and b/images/cat.png differ
diff --git a/images/dog.png b/images/dog.png
new file mode 100644
index 0000000000000000000000000000000000000000..a59f1ca272f13e2df0453871311fc84ab0d5da24
Binary files /dev/null and b/images/dog.png differ
diff --git a/images/panda.png b/images/panda.png
new file mode 100644
index 0000000000000000000000000000000000000000..b43f4b5c1cc135f45f4056ef8912fdc4ce5b2571
Binary files /dev/null and b/images/panda.png differ
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..034b4ecad14b2670f72e8b4fe52c6acf57686489
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,2 @@
+pip install -e causal-conv1d
+pip install -e mamba
\ No newline at end of file
diff --git a/kinetics_class_index.py b/kinetics_class_index.py
new file mode 100644
index 0000000000000000000000000000000000000000..597e23e72c690f2dce0525b24bdcc2a992c4d594
--- /dev/null
+++ b/kinetics_class_index.py
@@ -0,0 +1,402 @@
+kinetics_classnames = {
+ "0": "riding a bike",
+ "1": "marching",
+ "2": "dodgeball",
+ "3": "playing cymbals",
+ "4": "checking tires",
+ "5": "roller skating",
+ "6": "tasting beer",
+ "7": "clapping",
+ "8": "drawing",
+ "9": "juggling fire",
+ "10": "bobsledding",
+ "11": "petting animal (not cat)",
+ "12": "spray painting",
+ "13": "training dog",
+ "14": "eating watermelon",
+ "15": "building cabinet",
+ "16": "applauding",
+ "17": "playing harp",
+ "18": "balloon blowing",
+ "19": "sled dog racing",
+ "20": "wrestling",
+ "21": "pole vault",
+ "22": "hurling (sport)",
+ "23": "riding scooter",
+ "24": "shearing sheep",
+ "25": "sweeping floor",
+ "26": "eating carrots",
+ "27": "skateboarding",
+ "28": "dunking basketball",
+ "29": "disc golfing",
+ "30": "eating spaghetti",
+ "31": "playing flute",
+ "32": "riding mechanical bull",
+ "33": "making sushi",
+ "34": "trapezing",
+ "35": "picking fruit",
+ "36": "stretching leg",
+ "37": "playing ukulele",
+ "38": "tying tie",
+ "39": "skydiving",
+ "40": "playing cello",
+ "41": "jumping into pool",
+ "42": "shooting goal (soccer)",
+ "43": "trimming trees",
+ "44": "bookbinding",
+ "45": "ski jumping",
+ "46": "walking the dog",
+ "47": "riding unicycle",
+ "48": "shaving head",
+ "49": "hopscotch",
+ "50": "playing piano",
+ "51": "parasailing",
+ "52": "bartending",
+ "53": "kicking field goal",
+ "54": "finger snapping",
+ "55": "dining",
+ "56": "yawning",
+ "57": "peeling potatoes",
+ "58": "canoeing or kayaking",
+ "59": "front raises",
+ "60": "laughing",
+ "61": "dancing macarena",
+ "62": "digging",
+ "63": "reading newspaper",
+ "64": "hitting baseball",
+ "65": "clay pottery making",
+ "66": "exercising with an exercise ball",
+ "67": "playing saxophone",
+ "68": "shooting basketball",
+ "69": "washing hair",
+ "70": "lunge",
+ "71": "brushing hair",
+ "72": "curling hair",
+ "73": "kitesurfing",
+ "74": "tapping guitar",
+ "75": "bending back",
+ "76": "skipping rope",
+ "77": "situp",
+ "78": "folding paper",
+ "79": "cracking neck",
+ "80": "assembling computer",
+ "81": "cleaning gutters",
+ "82": "blowing out candles",
+ "83": "shaking hands",
+ "84": "dancing gangnam style",
+ "85": "windsurfing",
+ "86": "tap dancing",
+ "87": "skiing (not slalom or crosscountry)",
+ "88": "bandaging",
+ "89": "push up",
+ "90": "doing nails",
+ "91": "punching person (boxing)",
+ "92": "bouncing on trampoline",
+ "93": "scrambling eggs",
+ "94": "singing",
+ "95": "cleaning floor",
+ "96": "krumping",
+ "97": "drumming fingers",
+ "98": "snowmobiling",
+ "99": "gymnastics tumbling",
+ "100": "headbanging",
+ "101": "catching or throwing frisbee",
+ "102": "riding elephant",
+ "103": "bee keeping",
+ "104": "feeding birds",
+ "105": "snatch weight lifting",
+ "106": "mowing lawn",
+ "107": "fixing hair",
+ "108": "playing trumpet",
+ "109": "flying kite",
+ "110": "crossing river",
+ "111": "swinging legs",
+ "112": "sanding floor",
+ "113": "belly dancing",
+ "114": "sneezing",
+ "115": "clean and jerk",
+ "116": "side kick",
+ "117": "filling eyebrows",
+ "118": "shuffling cards",
+ "119": "recording music",
+ "120": "cartwheeling",
+ "121": "feeding fish",
+ "122": "folding clothes",
+ "123": "water skiing",
+ "124": "tobogganing",
+ "125": "blowing leaves",
+ "126": "smoking",
+ "127": "unboxing",
+ "128": "tai chi",
+ "129": "waxing legs",
+ "130": "riding camel",
+ "131": "slapping",
+ "132": "tossing salad",
+ "133": "capoeira",
+ "134": "playing cards",
+ "135": "playing organ",
+ "136": "playing violin",
+ "137": "playing drums",
+ "138": "tapping pen",
+ "139": "vault",
+ "140": "shoveling snow",
+ "141": "playing tennis",
+ "142": "getting a tattoo",
+ "143": "making a sandwich",
+ "144": "making tea",
+ "145": "grinding meat",
+ "146": "squat",
+ "147": "eating doughnuts",
+ "148": "ice fishing",
+ "149": "snowkiting",
+ "150": "kicking soccer ball",
+ "151": "playing controller",
+ "152": "giving or receiving award",
+ "153": "welding",
+ "154": "throwing discus",
+ "155": "throwing axe",
+ "156": "ripping paper",
+ "157": "swimming butterfly stroke",
+ "158": "air drumming",
+ "159": "blowing nose",
+ "160": "hockey stop",
+ "161": "taking a shower",
+ "162": "bench pressing",
+ "163": "planting trees",
+ "164": "pumping fist",
+ "165": "climbing tree",
+ "166": "tickling",
+ "167": "high kick",
+ "168": "waiting in line",
+ "169": "slacklining",
+ "170": "tango dancing",
+ "171": "hurdling",
+ "172": "carrying baby",
+ "173": "celebrating",
+ "174": "sharpening knives",
+ "175": "passing American football (in game)",
+ "176": "headbutting",
+ "177": "playing recorder",
+ "178": "brush painting",
+ "179": "garbage collecting",
+ "180": "robot dancing",
+ "181": "shredding paper",
+ "182": "pumping gas",
+ "183": "rock climbing",
+ "184": "hula hooping",
+ "185": "braiding hair",
+ "186": "opening present",
+ "187": "texting",
+ "188": "decorating the christmas tree",
+ "189": "answering questions",
+ "190": "playing keyboard",
+ "191": "writing",
+ "192": "bungee jumping",
+ "193": "sniffing",
+ "194": "eating burger",
+ "195": "playing accordion",
+ "196": "making pizza",
+ "197": "playing volleyball",
+ "198": "tasting food",
+ "199": "pushing cart",
+ "200": "spinning poi",
+ "201": "cleaning windows",
+ "202": "arm wrestling",
+ "203": "changing oil",
+ "204": "swimming breast stroke",
+ "205": "tossing coin",
+ "206": "deadlifting",
+ "207": "hoverboarding",
+ "208": "cutting watermelon",
+ "209": "cheerleading",
+ "210": "snorkeling",
+ "211": "washing hands",
+ "212": "eating cake",
+ "213": "pull ups",
+ "214": "surfing water",
+ "215": "eating hotdog",
+ "216": "holding snake",
+ "217": "playing harmonica",
+ "218": "ironing",
+ "219": "cutting nails",
+ "220": "golf chipping",
+ "221": "shot put",
+ "222": "hugging",
+ "223": "playing clarinet",
+ "224": "faceplanting",
+ "225": "trimming or shaving beard",
+ "226": "drinking shots",
+ "227": "riding mountain bike",
+ "228": "tying bow tie",
+ "229": "swinging on something",
+ "230": "skiing crosscountry",
+ "231": "unloading truck",
+ "232": "cleaning pool",
+ "233": "jogging",
+ "234": "ice climbing",
+ "235": "mopping floor",
+ "236": "making bed",
+ "237": "diving cliff",
+ "238": "washing dishes",
+ "239": "grooming dog",
+ "240": "weaving basket",
+ "241": "frying vegetables",
+ "242": "stomping grapes",
+ "243": "moving furniture",
+ "244": "cooking sausages",
+ "245": "doing laundry",
+ "246": "dying hair",
+ "247": "knitting",
+ "248": "reading book",
+ "249": "baby waking up",
+ "250": "punching bag",
+ "251": "surfing crowd",
+ "252": "cooking chicken",
+ "253": "pushing car",
+ "254": "springboard diving",
+ "255": "swing dancing",
+ "256": "massaging legs",
+ "257": "beatboxing",
+ "258": "breading or breadcrumbing",
+ "259": "somersaulting",
+ "260": "brushing teeth",
+ "261": "stretching arm",
+ "262": "juggling balls",
+ "263": "massaging person's head",
+ "264": "eating ice cream",
+ "265": "extinguishing fire",
+ "266": "hammer throw",
+ "267": "whistling",
+ "268": "crawling baby",
+ "269": "using remote controller (not gaming)",
+ "270": "playing cricket",
+ "271": "opening bottle",
+ "272": "playing xylophone",
+ "273": "motorcycling",
+ "274": "driving car",
+ "275": "exercising arm",
+ "276": "passing American football (not in game)",
+ "277": "playing kickball",
+ "278": "sticking tongue out",
+ "279": "flipping pancake",
+ "280": "catching fish",
+ "281": "eating chips",
+ "282": "shaking head",
+ "283": "sword fighting",
+ "284": "playing poker",
+ "285": "cooking on campfire",
+ "286": "doing aerobics",
+ "287": "paragliding",
+ "288": "using segway",
+ "289": "folding napkins",
+ "290": "playing bagpipes",
+ "291": "gargling",
+ "292": "skiing slalom",
+ "293": "strumming guitar",
+ "294": "javelin throw",
+ "295": "waxing back",
+ "296": "riding or walking with horse",
+ "297": "plastering",
+ "298": "long jump",
+ "299": "parkour",
+ "300": "wrapping present",
+ "301": "egg hunting",
+ "302": "archery",
+ "303": "cleaning toilet",
+ "304": "swimming backstroke",
+ "305": "snowboarding",
+ "306": "catching or throwing baseball",
+ "307": "massaging back",
+ "308": "blowing glass",
+ "309": "playing guitar",
+ "310": "playing chess",
+ "311": "golf driving",
+ "312": "presenting weather forecast",
+ "313": "rock scissors paper",
+ "314": "high jump",
+ "315": "baking cookies",
+ "316": "using computer",
+ "317": "washing feet",
+ "318": "arranging flowers",
+ "319": "playing bass guitar",
+ "320": "spraying",
+ "321": "cutting pineapple",
+ "322": "waxing chest",
+ "323": "auctioning",
+ "324": "jetskiing",
+ "325": "drinking",
+ "326": "busking",
+ "327": "playing monopoly",
+ "328": "salsa dancing",
+ "329": "waxing eyebrows",
+ "330": "watering plants",
+ "331": "zumba",
+ "332": "chopping wood",
+ "333": "pushing wheelchair",
+ "334": "carving pumpkin",
+ "335": "building shed",
+ "336": "making jewelry",
+ "337": "catching or throwing softball",
+ "338": "bending metal",
+ "339": "ice skating",
+ "340": "dancing charleston",
+ "341": "abseiling",
+ "342": "climbing a rope",
+ "343": "crying",
+ "344": "cleaning shoes",
+ "345": "dancing ballet",
+ "346": "driving tractor",
+ "347": "triple jump",
+ "348": "throwing ball",
+ "349": "getting a haircut",
+ "350": "running on treadmill",
+ "351": "climbing ladder",
+ "352": "blasting sand",
+ "353": "playing trombone",
+ "354": "drop kicking",
+ "355": "country line dancing",
+ "356": "changing wheel",
+ "357": "feeding goats",
+ "358": "tying knot (not on a tie)",
+ "359": "setting table",
+ "360": "shaving legs",
+ "361": "kissing",
+ "362": "riding mule",
+ "363": "counting money",
+ "364": "laying bricks",
+ "365": "barbequing",
+ "366": "news anchoring",
+ "367": "smoking hookah",
+ "368": "cooking egg",
+ "369": "peeling apples",
+ "370": "yoga",
+ "371": "sharpening pencil",
+ "372": "dribbling basketball",
+ "373": "petting cat",
+ "374": "playing ice hockey",
+ "375": "milking cow",
+ "376": "shining shoes",
+ "377": "juggling soccer ball",
+ "378": "scuba diving",
+ "379": "playing squash or racquetball",
+ "380": "drinking beer",
+ "381": "sign language interpreting",
+ "382": "playing basketball",
+ "383": "breakdancing",
+ "384": "testifying",
+ "385": "making snowman",
+ "386": "golf putting",
+ "387": "playing didgeridoo",
+ "388": "biking through snow",
+ "389": "sailing",
+ "390": "jumpstyle dancing",
+ "391": "water sliding",
+ "392": "grooming horse",
+ "393": "massaging feet",
+ "394": "playing paintball",
+ "395": "making a cake",
+ "396": "bowling",
+ "397": "contact juggling",
+ "398": "applying cream",
+ "399": "playing badminton"
+}
\ No newline at end of file
diff --git a/mamba/.gitmodules b/mamba/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..a7445800fb64f3ae664c0b994a54235105986d2e
--- /dev/null
+++ b/mamba/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "3rdparty/lm-evaluation-harness"]
+ path = 3rdparty/lm-evaluation-harness
+ url = https://github.com/EleutherAI/lm-evaluation-harness/
diff --git a/mamba/AUTHORS b/mamba/AUTHORS
new file mode 100644
index 0000000000000000000000000000000000000000..38557a872f8d603ed963a05c211de7032de5926b
--- /dev/null
+++ b/mamba/AUTHORS
@@ -0,0 +1,2 @@
+Tri Dao, tri@tridao.me
+Albert Gu, agu@andrew.cmu.edu
diff --git a/mamba/LICENSE b/mamba/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f4abe24eb520fbb077753ae4f34bfaa43cb3b83f
--- /dev/null
+++ b/mamba/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023 Tri Dao, Albert Gu
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/mamba/README.md b/mamba/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..754cefd7f862a90bad8fbdff71e3793a4e7849e3
--- /dev/null
+++ b/mamba/README.md
@@ -0,0 +1,149 @@
+# Mamba
+
+![Mamba](assets/selection.png "Selective State Space")
+> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
+> Albert Gu*, Tri Dao*\
+> Paper: https://arxiv.org/abs/2312.00752
+
+## About
+
+Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
+It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
+with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
+
+## Installation
+
+- `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
+- `pip install mamba-ssm`: the core Mamba package.
+
+It can also be built from source with `pip install .` from this repository.
+
+If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
+
+Other requirements:
+- Linux
+- NVIDIA GPU
+- PyTorch 1.12+
+- CUDA 11.6+
+
+## Usage
+
+We expose several levels of interface with the Mamba model.
+
+### Selective SSM
+
+Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
+
+Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
+
+### Mamba Block
+
+The main module of this repository is the Mamba architecture block wrapping the selective SSM.
+
+Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
+
+Usage:
+```
+from mamba_ssm import Mamba
+
+batch, length, dim = 2, 64, 16
+x = torch.randn(batch, length, dim).to("cuda")
+model = Mamba(
+ # This module uses roughly 3 * expand * d_model^2 parameters
+ d_model=dim, # Model dimension d_model
+ d_state=16, # SSM state expansion factor
+ d_conv=4, # Local convolution width
+ expand=2, # Block expansion factor
+).to("cuda")
+y = model(x)
+assert y.shape == x.shape
+```
+
+### Mamba Language Model
+
+Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
+
+Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
+
+This is an example of how to integrate Mamba into an end-to-end neural network.
+This example is used in the generation scripts below.
+
+
+
+## Pretrained Models
+
+Pretrained models are uploaded to
+[HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
+`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
+
+The models will be autodownloaded by the generation script below.
+
+These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
+
+| Parameters | Layers | Model dim. |
+|------------|--------|------------|
+| 130M | 12 | 768 |
+| 370M | 24 | 1024 |
+| 790M | 24 | 1536 |
+| 1.4B | 24 | 2048 |
+| 2.8B | 32 | 2560 |
+
+(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
+
+Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
+Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
+
+
+## Evaluations
+
+To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
+we use the
+[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
+library.
+
+1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
+ --recursive`. We use the `big-refactor` branch.
+2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
+3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
+```
+python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
+python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
+```
+
+Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
+
+## Inference
+
+The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
+1. autoloads a model from the HuggingFace Hub,
+2. generates completions of a user-specified prompt,
+3. benchmarks the inference speed of this generation.
+
+Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
+
+### Examples
+
+To test generation latency (e.g. batch size = 1) with different sampling strategies:
+
+```
+python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
+python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
+```
+
+To test generation throughput with random prompts (e.g. large batch size):
+```
+python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
+python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
+```
+
+## Citation
+
+If you use this codebase, or otherwise found our work valuable, please cite Mamba:
+```
+@article{mamba,
+ title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
+ author={Gu, Albert and Dao, Tri},
+ journal={arXiv preprint arXiv:2312.00752},
+ year={2023}
+}
+```
diff --git a/mamba/assets/selection.png b/mamba/assets/selection.png
new file mode 100644
index 0000000000000000000000000000000000000000..69b109a8eed4e3c7516b23e2b39d37e842a4464b
Binary files /dev/null and b/mamba/assets/selection.png differ
diff --git a/mamba/benchmarks/benchmark_generation_mamba_simple.py b/mamba/benchmarks/benchmark_generation_mamba_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f2943cb4bde6f25eddb82b7b999c5c5f8b39acc
--- /dev/null
+++ b/mamba/benchmarks/benchmark_generation_mamba_simple.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2023, Tri Dao, Albert Gu.
+
+import argparse
+import time
+import json
+
+import torch
+import torch.nn.functional as F
+
+from einops import rearrange
+
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
+
+
+parser = argparse.ArgumentParser(description="Generation benchmarking")
+parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
+parser.add_argument("--prompt", type=str, default=None)
+parser.add_argument("--promptlen", type=int, default=100)
+parser.add_argument("--genlen", type=int, default=100)
+parser.add_argument("--temperature", type=float, default=1.0)
+parser.add_argument("--topk", type=int, default=1)
+parser.add_argument("--topp", type=float, default=1.0)
+parser.add_argument("--batch", type=int, default=1)
+args = parser.parse_args()
+
+repeats = 3
+device = "cuda"
+dtype = torch.float16
+
+print(f"Loading model {args.model_name}")
+is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name
+
+if is_mamba:
+ tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer")
+ model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
+else:
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
+model.eval()
+print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
+
+torch.random.manual_seed(0)
+if args.prompt is None:
+ input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
+ attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
+else:
+ tokens = tokenizer(args.prompt, return_tensors="pt")
+ input_ids = tokens.input_ids.to(device=device)
+ attn_mask = tokens.attention_mask.to(device=device)
+max_length = input_ids.shape[1] + args.genlen
+
+if is_mamba:
+ fn = lambda: model.generate(
+ input_ids=input_ids,
+ max_length=max_length,
+ cg=True,
+ return_dict_in_generate=True,
+ output_scores=True,
+ enable_timing=False,
+ temperature=args.temperature,
+ top_k=args.topk,
+ top_p=args.topp,
+ )
+else:
+ fn = lambda: model.generate(
+ input_ids=input_ids,
+ attention_mask=attn_mask,
+ max_length=max_length,
+ return_dict_in_generate=True,
+ pad_token_id=tokenizer.eos_token_id,
+ do_sample=True,
+ temperature=args.temperature,
+ top_k=args.topk,
+ top_p=args.topp,
+ )
+out = fn()
+if args.prompt is not None:
+ print(tokenizer.batch_decode(out.sequences.tolist()))
+
+torch.cuda.synchronize()
+start = time.time()
+for _ in range(repeats):
+ fn()
+torch.cuda.synchronize()
+print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
+print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
diff --git a/mamba/csrc/selective_scan/reverse_scan.cuh b/mamba/csrc/selective_scan/reverse_scan.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d7e93174bb391d45271e6c77669a5e52d6c9cc78
--- /dev/null
+++ b/mamba/csrc/selective_scan/reverse_scan.cuh
@@ -0,0 +1,401 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include
+
+#include