# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. def gen_forward(): kernels = [3, 5, 7, 15, 31, 63, 127, 255] seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] head = """ /** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include "lightconv_cuda.cuh" std::vector lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) { at::DeviceGuard g(input.device()); const auto minibatch = input.size(0); const auto numFeatures = input.size(1); const auto sequenceLength = input.size(2); const auto numHeads = filters.size(0); const auto filterSize = filters.size(1); const auto numFiltersInBlock = numFeatures / numHeads; const dim3 blocks(minibatch, numFeatures); auto output = at::zeros_like(input); auto stream = at::cuda::getCurrentCUDAStream(); """ sequence_if = """ if (sequenceLength <= {seq}) {{ switch(filterSize) {{ """ case_k = """ case {k}: """ main_block = """ if (padding_l == {pad}) {{ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{ lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> <<>>( input.data(), filters.data(), minibatch, sequenceLength, numFeatures, numFiltersInBlock, output.data()); }})); }} else """ bad_padding = """ { std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; } break; """ bad_filter = """ default: std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; } """ con_else = """ } else """ final_else = """ { switch(filterSize) { """ final_return = """ } return {output}; } """ with open("lightconv_cuda_forward.cu", "w") as forward: forward.write(head) for seq in seqs: forward.write(sequence_if.format(seq=seq)) for k in kernels: forward.write(case_k.format(k=k)) for pad in [k // 2, k - 1]: forward.write(main_block.format(k=k, b_size=seq, pad=pad)) forward.write(bad_padding) forward.write(bad_filter) forward.write(con_else) forward.write(final_else) for k in kernels: forward.write(case_k.format(k=k)) for pad in [k // 2, k - 1]: forward.write(main_block.format(k=k, b_size=seq, pad=pad)) forward.write(bad_padding) forward.write(bad_filter) forward.write(final_return) def gen_backward(): head = """ /** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include "lightconv_cuda.cuh" std::vector lightconv_cuda_backward( at::Tensor gradOutput, int padding_l, at::Tensor input, at::Tensor filters) { // gradWrtInput const int minibatch = input.size(0); const int numFeatures = input.size(1); const int sequenceLength = input.size(2); const int numHeads = filters.size(0); const int filterSize = filters.size(1); const dim3 gradBlocks(minibatch, numFeatures); const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads); const dim3 weightGradSecondpassBlocks(numHeads, filterSize); const int numFiltersInBlock = numFeatures / numHeads; auto gradInput = at::zeros_like(input); auto gradFilters = at::zeros_like(filters); at::DeviceGuard g(input.device()); auto stream = at::cuda::getCurrentCUDAStream(); switch(filterSize) { """ sequence_if = """ if (sequenceLength <= {seq}) {{ """ case_k = """ case {k}: """ main_block = """ if (padding_l == {p}) {{ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{ lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t> <<>>( gradOutput.data(), filters.data(), minibatch, sequenceLength, numFeatures, numFiltersInBlock, gradInput.data()); """ weight_grad_short = """ at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat)); lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t> <<>>( input.data(), gradOutput.data(), minibatch, sequenceLength, numFeatures, numFiltersInBlock, numHeads, tempSumGradFilters.data() ); lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t> <<>>( tempSumGradFilters.data(), minibatch, numFiltersInBlock, gradFilters.data() ); }})); }} else """ weight_grad = """ at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat)); lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t> <<>>( input.data(), gradOutput.data(), minibatch, sequenceLength, numFeatures, numFiltersInBlock, tempSumGradFilters.data() ); lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t> <<>>( tempSumGradFilters.data(), minibatch, numFiltersInBlock, gradFilters.data() ); }})); }} else """ bad_padding = """ { std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; } """ breakout = """ break; """ bad_filter = """ default: std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; """ con_else = """ } else """ final_else = """ { switch(filterSize) { """ last_return = """ } return {gradInput, gradFilters}; } """ kernels = [3, 5, 7, 15, 31, 63, 127, 255] seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] thresh = [32, 32, 64, 128, 256, -1, -1, -1] max_mem = [-1, -1, -1, -1, -1, 192, 96, 64] with open("lightconv_cuda_backward.cu", "w") as backward: backward.write(head) for (k, t, mem) in zip(kernels, thresh, max_mem): backward.write(case_k.format(k=k)) for seq in seqs: if (t == -1 or seq <= t) and (mem == -1 or seq < mem): backward.write(sequence_if.format(seq=seq)) for p in [k // 2, k - 1]: backward.write(main_block.format(k=k, b_size=seq, p=p)) backward.write(weight_grad_short.format(k=k, b_size=seq, p=p)) backward.write(bad_padding) else: for p in [k // 2, k - 1]: backward.write(main_block.format(k=k, b_size=32, p=p)) backward.write(weight_grad.format(k=k, b_size=32, p=p)) backward.write(bad_padding) backward.write(breakout) break backward.write(con_else) backward.write(bad_filter) backward.write(last_return) if __name__ == "__main__": gen_forward() gen_backward()