Spaces:
Paused
Paused
File size: 1,377 Bytes
5238467 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
class TorchAutocast:
"""TorchAutocast utility class.
Allows you to enable and disable autocast. This is specially useful
when dealing with different architectures and clusters with different
levels of support.
Args:
enabled (bool): Whether to enable torch.autocast or not.
args: Additional args for torch.autocast.
kwargs: Additional kwargs for torch.autocast
"""
def __init__(self, enabled: bool, *args, **kwargs):
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
def __enter__(self):
if self.autocast is None:
return
try:
self.autocast.__enter__()
except RuntimeError:
device = self.autocast.device
dtype = self.autocast.fast_dtype
raise RuntimeError(
f"There was an error autocasting with dtype={dtype} device={device}\n"
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
)
def __exit__(self, *args, **kwargs):
if self.autocast is None:
return
self.autocast.__exit__(*args, **kwargs)
|