Spaces:
Running
Running
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
import random | |
import warnings | |
from typing import Union, Tuple, Any | |
import torch | |
from torch import autocast | |
class EmptyAutocast(object): | |
""" | |
Empty class for disable any autocasting. | |
""" | |
def __enter__(self): | |
return None | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
return | |
def __call__(self, func): | |
return | |
def get_precision_autocast( | |
device="cpu", fp16=True, override_dtype=None | |
) -> Union[ | |
Tuple[EmptyAutocast, Union[torch.dtype, Any]], | |
Tuple[autocast, Union[torch.dtype, Any]], | |
]: | |
""" | |
Returns precision and autocast settings for given device and fp16 settings. | |
Args: | |
device: Device to get precision and autocast settings for. | |
fp16: Whether to use fp16 precision. | |
override_dtype: Override dtype for autocast. | |
Returns: | |
Autocast object, dtype | |
""" | |
dtype = torch.float32 | |
cache_enabled = None | |
if device == "cpu" and fp16: | |
warnings.warn('FP16 is not supported on CPU. Using FP32 instead.') | |
dtype = torch.float32 | |
# TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment. | |
# warnings.warn( | |
# "Accuracy BFP16 has experimental support on the CPU. " | |
# "This may result in an unexpected reduction in quality." | |
# ) | |
# dtype = ( | |
# torch.bfloat16 | |
# ) # Using bfloat16 for CPU, since autocast is not supported for float16 | |
if "cuda" in device and fp16: | |
dtype = torch.float16 | |
cache_enabled = True | |
if override_dtype is not None: | |
dtype = override_dtype | |
if dtype == torch.float32 and device == "cpu": | |
return EmptyAutocast(), dtype | |
return ( | |
torch.autocast( | |
device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled | |
), | |
dtype, | |
) | |
def cast_network(network: torch.nn.Module, dtype: torch.dtype): | |
"""Cast network to given dtype | |
Args: | |
network: Network to be casted | |
dtype: Dtype to cast network to | |
""" | |
if dtype == torch.float16: | |
network.half() | |
elif dtype == torch.bfloat16: | |
network.bfloat16() | |
elif dtype == torch.float32: | |
network.float() | |
else: | |
raise ValueError(f"Unknown dtype {dtype}") | |
def fix_seed(seed=42): | |
"""Sets fixed random seed | |
Args: | |
seed: Random seed to be set | |
""" | |
random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# noinspection PyUnresolvedReferences | |
torch.backends.cudnn.deterministic = True | |
# noinspection PyUnresolvedReferences | |
torch.backends.cudnn.benchmark = False | |
return True | |
def suppress_warnings(): | |
# Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer, | |
# since source code is not affected by this issue and there aren't any other correct way to hide this message. | |
warnings.filterwarnings( | |
"ignore", | |
category=UserWarning, | |
message="Note that order of the arguments: ceil_mode and " | |
"return_indices will changeto match the args list " | |
"in nn.MaxPool2d in a future release.", | |
module="torch", | |
) | |