Spaces:
Runtime error
Runtime error
/** | |
* Copyright (c) Facebook, Inc. and its affiliates. | |
* | |
* This source code is licensed under the MIT license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
std::vector<at::Tensor> lightconv_cuda_forward( | |
at::Tensor input, | |
at::Tensor filters, | |
int padding_l); | |
std::vector<at::Tensor> lightconv_cuda_backward( | |
at::Tensor gradOutput, | |
int padding_l, | |
at::Tensor input, | |
at::Tensor filters); | |
std::vector<at::Tensor> lightconv_forward( | |
at::Tensor input, | |
at::Tensor filters, | |
int padding_l) { | |
CHECK_INPUT(input); | |
CHECK_INPUT(filters); | |
return lightconv_cuda_forward(input, filters, padding_l); | |
} | |
std::vector<at::Tensor> lightconv_backward( | |
at::Tensor gradOutput, | |
int padding_l, | |
at::Tensor input, | |
at::Tensor filters) { | |
CHECK_INPUT(gradOutput); | |
CHECK_INPUT(input); | |
CHECK_INPUT(filters); | |
return lightconv_cuda_backward(gradOutput, padding_l, input, filters); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); | |
m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); | |
} | |