Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import torch | |
from typing import Callable, Iterable, Sequence, Union | |
def checkpoint( | |
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], | |
inputs: Sequence[torch.Tensor], | |
params: Iterable[torch.Tensor], | |
flag: bool, | |
use_deepspeed: bool = False | |
): | |
# Evaluate a function without caching intermediate activations, allowing for | |
# reduced memory at the expense of extra compute in the backward pass. | |
# :param func: the function to evaluate. | |
# :param inputs: the argument sequence to pass to `func`. | |
# :param params: a sequence of parameters `func` depends on but does not | |
# explicitly take as arguments. | |
# :param flag: if False, disable gradient checkpointing. | |
# :param use_deepspeed: if True, use deepspeed | |
if flag: | |
args = tuple(inputs) + tuple(params) | |
return CheckpointFunction.apply(func, len(inputs), *args) | |
else: | |
return func(*inputs) | |
class CheckpointFunction(torch.autograd.Function): | |
def forward(ctx, run_function, length, *args): | |
ctx.run_function = run_function | |
ctx.input_tensors = list(args[:length]) | |
ctx.input_params = list(args[length:]) | |
with torch.no_grad(): | |
output_tensors = ctx.run_function(*ctx.input_tensors) | |
return output_tensors | |
def backward(ctx, *output_grads): | |
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] | |
with torch.enable_grad(): | |
# Fixes a bug where the first op in run_function modifies the | |
# Tensor storage in place, which is not allowed for detach()'d | |
# Tensors. | |
shallow_copies = [x.view_as(x) for x in ctx.input_tensors] | |
output_tensors = ctx.run_function(*shallow_copies) | |
input_grads = torch.autograd.grad( | |
output_tensors, | |
ctx.input_tensors + ctx.input_params, | |
output_grads, | |
allow_unused=True, | |
) | |
del ctx.input_tensors | |
del ctx.input_params | |
del output_tensors | |
return (None, None) + input_grads | |