Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
class GradMultiply(torch.autograd.Function): | |
def forward(ctx, x, scale): | |
ctx.scale = scale | |
res = x.new(x) | |
return res | |
def backward(ctx, grad): | |
return grad * ctx.scale, None | |