|
""" Model / Layer Config singleton state |
|
""" |
|
import os |
|
import warnings |
|
from typing import Any, Optional |
|
|
|
import torch |
|
|
|
__all__ = [ |
|
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', |
|
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' |
|
] |
|
|
|
|
|
_NO_JIT = False |
|
|
|
|
|
|
|
|
|
_NO_ACTIVATION_JIT = False |
|
|
|
|
|
_EXPORTABLE = False |
|
|
|
|
|
_SCRIPTABLE = False |
|
|
|
|
|
|
|
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
if 'TIMM_FUSED_ATTN' in os.environ: |
|
_USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) |
|
else: |
|
_USE_FUSED_ATTN = 1 |
|
|
|
|
|
def is_no_jit(): |
|
return _NO_JIT |
|
|
|
|
|
class set_no_jit: |
|
def __init__(self, mode: bool) -> None: |
|
global _NO_JIT |
|
self.prev = _NO_JIT |
|
_NO_JIT = mode |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, *args: Any) -> bool: |
|
global _NO_JIT |
|
_NO_JIT = self.prev |
|
return False |
|
|
|
|
|
def is_exportable(): |
|
return _EXPORTABLE |
|
|
|
|
|
class set_exportable: |
|
def __init__(self, mode: bool) -> None: |
|
global _EXPORTABLE |
|
self.prev = _EXPORTABLE |
|
_EXPORTABLE = mode |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, *args: Any) -> bool: |
|
global _EXPORTABLE |
|
_EXPORTABLE = self.prev |
|
return False |
|
|
|
|
|
def is_scriptable(): |
|
return _SCRIPTABLE |
|
|
|
|
|
class set_scriptable: |
|
def __init__(self, mode: bool) -> None: |
|
global _SCRIPTABLE |
|
self.prev = _SCRIPTABLE |
|
_SCRIPTABLE = mode |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, *args: Any) -> bool: |
|
global _SCRIPTABLE |
|
_SCRIPTABLE = self.prev |
|
return False |
|
|
|
|
|
class set_layer_config: |
|
""" Layer config context manager that allows setting all layer config flags at once. |
|
If a flag arg is None, it will not change the current value. |
|
""" |
|
def __init__( |
|
self, |
|
scriptable: Optional[bool] = None, |
|
exportable: Optional[bool] = None, |
|
no_jit: Optional[bool] = None, |
|
no_activation_jit: Optional[bool] = None): |
|
global _SCRIPTABLE |
|
global _EXPORTABLE |
|
global _NO_JIT |
|
global _NO_ACTIVATION_JIT |
|
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT |
|
if scriptable is not None: |
|
_SCRIPTABLE = scriptable |
|
if exportable is not None: |
|
_EXPORTABLE = exportable |
|
if no_jit is not None: |
|
_NO_JIT = no_jit |
|
if no_activation_jit is not None: |
|
_NO_ACTIVATION_JIT = no_activation_jit |
|
|
|
def __enter__(self) -> None: |
|
pass |
|
|
|
def __exit__(self, *args: Any) -> bool: |
|
global _SCRIPTABLE |
|
global _EXPORTABLE |
|
global _NO_JIT |
|
global _NO_ACTIVATION_JIT |
|
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev |
|
return False |
|
|
|
|
|
def use_fused_attn(experimental: bool = False) -> bool: |
|
|
|
if not _HAS_FUSED_ATTN or _EXPORTABLE: |
|
return False |
|
if experimental: |
|
return _USE_FUSED_ATTN > 1 |
|
return _USE_FUSED_ATTN > 0 |
|
|
|
|
|
def set_fused_attn(enable: bool = True, experimental: bool = False): |
|
global _USE_FUSED_ATTN |
|
if not _HAS_FUSED_ATTN: |
|
warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.') |
|
return |
|
if experimental and enable: |
|
_USE_FUSED_ATTN = 2 |
|
elif enable: |
|
_USE_FUSED_ATTN = 1 |
|
else: |
|
_USE_FUSED_ATTN = 0 |
|
|