|
""" |
|
Source url: https://github.com/OPHoperHPO/image-background-remove-tool |
|
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. |
|
License: Apache License 2.0 |
|
""" |
|
|
|
import random |
|
import warnings |
|
from typing import Union, Tuple, Any |
|
|
|
import torch |
|
from torch import autocast |
|
|
|
|
|
class EmptyAutocast(object): |
|
""" |
|
Empty class for disable any autocasting. |
|
""" |
|
|
|
def __enter__(self): |
|
return None |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
return |
|
|
|
def __call__(self, func): |
|
return |
|
|
|
|
|
def get_precision_autocast( |
|
device="cpu", fp16=True, override_dtype=None |
|
) -> Union[ |
|
Tuple[EmptyAutocast, Union[torch.dtype, Any]], |
|
Tuple[autocast, Union[torch.dtype, Any]], |
|
]: |
|
""" |
|
Returns precision and autocast settings for given device and fp16 settings. |
|
Args: |
|
device: Device to get precision and autocast settings for. |
|
fp16: Whether to use fp16 precision. |
|
override_dtype: Override dtype for autocast. |
|
|
|
Returns: |
|
Autocast object, dtype |
|
""" |
|
dtype = torch.float32 |
|
cache_enabled = None |
|
|
|
if device == "cpu" and fp16: |
|
warnings.warn('FP16 is not supported on CPU. Using FP32 instead.') |
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "cuda" in device and fp16: |
|
dtype = torch.float16 |
|
cache_enabled = True |
|
|
|
if override_dtype is not None: |
|
dtype = override_dtype |
|
|
|
if dtype == torch.float32 and device == "cpu": |
|
return EmptyAutocast(), dtype |
|
|
|
return ( |
|
torch.autocast( |
|
device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled |
|
), |
|
dtype, |
|
) |
|
|
|
|
|
def cast_network(network: torch.nn.Module, dtype: torch.dtype): |
|
"""Cast network to given dtype |
|
|
|
Args: |
|
network: Network to be casted |
|
dtype: Dtype to cast network to |
|
""" |
|
if dtype == torch.float16: |
|
network.half() |
|
elif dtype == torch.bfloat16: |
|
network.bfloat16() |
|
elif dtype == torch.float32: |
|
network.float() |
|
else: |
|
raise ValueError(f"Unknown dtype {dtype}") |
|
|
|
|
|
def fix_seed(seed=42): |
|
"""Sets fixed random seed |
|
|
|
Args: |
|
seed: Random seed to be set |
|
""" |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
torch.backends.cudnn.benchmark = False |
|
return True |
|
|
|
|
|
def suppress_warnings(): |
|
|
|
|
|
warnings.filterwarnings( |
|
"ignore", |
|
category=UserWarning, |
|
message="Note that order of the arguments: ceil_mode and " |
|
"return_indices will changeto match the args list " |
|
"in nn.MaxPool2d in a future release.", |
|
module="torch", |
|
) |
|
|