File size: 225 Bytes
47fe089
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
import torch

def compute_gradient(x, degree):
    gradients = [x]
    for i in range(degree):
        x = torch.diff(x, dim=-1, prepend=x[..., 0:1])
        gradients.append(x)
    return torch.concatenate(gradients, dim=-1)