import sys from typing import Optional import pytest import torch from packaging.version import Version from pkg_resources import get_distribution """ Adapted from: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py """ from tests.helpers.module_available import ( _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _IS_WINDOWS, _RPC_AVAILABLE, ) class RunIf: """ RunIf wrapper for conditional skipping of tests. Fully compatible with `@pytest.mark`. Example: @RunIf(min_torch="1.8") @pytest.mark.parametrize("arg1", [1.0, 2.0]) def test_wrapper(arg1): assert arg1 > 0 """ def __new__( self, min_gpus: int = 0, min_torch: Optional[str] = None, max_torch: Optional[str] = None, min_python: Optional[str] = None, skip_windows: bool = False, rpc: bool = False, fairscale: bool = False, deepspeed: bool = False, **kwargs, ): """ Args: min_gpus: min number of gpus required to run test min_torch: minimum pytorch version to run test max_torch: maximum pytorch version to run test min_python: minimum python version required to run test skip_windows: skip test for Windows platform rpc: requires Remote Procedure Call (RPC) fairscale: if `fairscale` module is required to run the test deepspeed: if `deepspeed` module is required to run the test kwargs: native pytest.mark.skipif keyword arguments """ conditions = [] reasons = [] if min_gpus: conditions.append(torch.cuda.device_count() < min_gpus) reasons.append(f"GPUs>={min_gpus}") if min_torch: torch_version = get_distribution("torch").version conditions.append(Version(torch_version) < Version(min_torch)) reasons.append(f"torch>={min_torch}") if max_torch: torch_version = get_distribution("torch").version conditions.append(Version(torch_version) >= Version(max_torch)) reasons.append(f"torch<{max_torch}") if min_python: py_version = ( f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" ) conditions.append(Version(py_version) < Version(min_python)) reasons.append(f"python>={min_python}") if skip_windows: conditions.append(_IS_WINDOWS) reasons.append("does not run on Windows") if rpc: conditions.append(not _RPC_AVAILABLE) reasons.append("RPC") if fairscale: conditions.append(not _FAIRSCALE_AVAILABLE) reasons.append("Fairscale") if deepspeed: conditions.append(not _DEEPSPEED_AVAILABLE) reasons.append("Deepspeed") reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs, )