File size: 3,483 Bytes
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""

import random
import warnings
from typing import Union, Tuple, Any

import torch
from torch import autocast


class EmptyAutocast(object):
    """
    Empty class for disable any autocasting.
    """

    def __enter__(self):
        return None

    def __exit__(self, exc_type, exc_val, exc_tb):
        return

    def __call__(self, func):
        return


def get_precision_autocast(
    device="cpu", fp16=True, override_dtype=None
) -> Union[
    Tuple[EmptyAutocast, Union[torch.dtype, Any]],
    Tuple[autocast, Union[torch.dtype, Any]],
]:
    """
    Returns precision and autocast settings for given device and fp16 settings.
    Args:
        device: Device to get precision and autocast settings for.
        fp16: Whether to use fp16 precision.
        override_dtype: Override dtype for autocast.

    Returns:
        Autocast object, dtype
    """
    dtype = torch.float32
    cache_enabled = None

    if device == "cpu" and fp16:
        warnings.warn('FP16 is not supported on CPU. Using FP32 instead.')
        dtype = torch.float32

        # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment.
        # warnings.warn(
        #     "Accuracy BFP16 has experimental support on the CPU. "
        #     "This may result in an unexpected reduction in quality."
        # )
        # dtype = (
        #     torch.bfloat16
        # )  # Using bfloat16 for CPU, since autocast is not supported for float16


    if "cuda" in device and fp16:
        dtype = torch.float16
        cache_enabled = True

    if override_dtype is not None:
        dtype = override_dtype

    if dtype == torch.float32 and device == "cpu":
        return EmptyAutocast(), dtype

    return (
        torch.autocast(
            device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled
        ),
        dtype,
    )


def cast_network(network: torch.nn.Module, dtype: torch.dtype):
    """Cast network to given dtype

    Args:
        network: Network to be casted
        dtype: Dtype to cast network to
    """
    if dtype == torch.float16:
        network.half()
    elif dtype == torch.bfloat16:
        network.bfloat16()
    elif dtype == torch.float32:
        network.float()
    else:
        raise ValueError(f"Unknown dtype {dtype}")


def fix_seed(seed=42):
    """Sets fixed random seed

    Args:
        seed: Random seed to be set
    """
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # noinspection PyUnresolvedReferences
        torch.backends.cudnn.deterministic = True
        # noinspection PyUnresolvedReferences
        torch.backends.cudnn.benchmark = False
    return True


def suppress_warnings():
    # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer,
    # since source code is not affected by this issue and there aren't any other correct way to hide this message.
    warnings.filterwarnings(
        "ignore",
        category=UserWarning,
        message="Note that order of the arguments: ceil_mode and "
        "return_indices will changeto match the args list "
        "in nn.MaxPool2d in a future release.",
        module="torch",
    )