import importlib.util import logging import warnings import importlib_metadata from packaging import version logger = logging.getLogger(__name__) _xformers_available = importlib.util.find_spec("xformers") is not None try: if _xformers_available: _xformers_version = importlib_metadata.version("xformers") _torch_version = importlib_metadata.version("torch") if version.Version(_torch_version) < version.Version("1.12"): raise ValueError("xformers is installed but requires PyTorch >= 1.12") logger.debug(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False _triton_modules_available = importlib.util.find_spec("triton") is not None try: if _triton_modules_available: _triton_version = importlib_metadata.version("triton") if version.Version(_triton_version) < version.Version("3.0.0"): raise ValueError("triton is installed but requires Triton >= 3.0.0") logger.debug(f"Successfully imported triton version {_triton_version}") except ImportError: _triton_modules_available = False warnings.warn("TritonLiteMLA and TritonMBConvPreGLU with `triton` is not available on your platform.") def is_xformers_available(): return _xformers_available def is_triton_module_available(): return _triton_modules_available