import torch def compute_gradient(x, degree): gradients = [x] for i in range(degree): x = torch.diff(x, dim=-1, prepend=x[..., 0:1]) gradients.append(x) return torch.concatenate(gradients, dim=-1)