MTECBS / utils /transform.py
yan123yan
first version
47fe089
raw
history blame
225 Bytes
import torch
def compute_gradient(x, degree):
gradients = [x]
for i in range(degree):
x = torch.diff(x, dim=-1, prepend=x[..., 0:1])
gradients.append(x)
return torch.concatenate(gradients, dim=-1)