|
import importlib.util |
|
import logging |
|
import warnings |
|
|
|
import importlib_metadata |
|
from packaging import version |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
_xformers_available = importlib.util.find_spec("xformers") is not None |
|
try: |
|
if _xformers_available: |
|
_xformers_version = importlib_metadata.version("xformers") |
|
_torch_version = importlib_metadata.version("torch") |
|
if version.Version(_torch_version) < version.Version("1.12"): |
|
raise ValueError("xformers is installed but requires PyTorch >= 1.12") |
|
logger.debug(f"Successfully imported xformers version {_xformers_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_xformers_available = False |
|
|
|
_triton_modules_available = importlib.util.find_spec("triton") is not None |
|
try: |
|
if _triton_modules_available: |
|
_triton_version = importlib_metadata.version("triton") |
|
if version.Version(_triton_version) < version.Version("3.0.0"): |
|
raise ValueError("triton is installed but requires Triton >= 3.0.0") |
|
logger.debug(f"Successfully imported triton version {_triton_version}") |
|
except ImportError: |
|
_triton_modules_available = False |
|
warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") |
|
|
|
|
|
def is_xformers_available(): |
|
return _xformers_available |
|
|
|
|
|
def is_triton_module_available(): |
|
return _triton_modules_available |
|
|