|
import sys |
|
from typing import Optional |
|
|
|
import pytest |
|
import torch |
|
from packaging.version import Version |
|
from pkg_resources import get_distribution |
|
|
|
""" |
|
Adapted from: |
|
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py |
|
""" |
|
|
|
from tests.helpers.module_available import ( |
|
_DEEPSPEED_AVAILABLE, |
|
_FAIRSCALE_AVAILABLE, |
|
_IS_WINDOWS, |
|
_RPC_AVAILABLE, |
|
) |
|
|
|
|
|
class RunIf: |
|
""" |
|
RunIf wrapper for conditional skipping of tests. |
|
Fully compatible with `@pytest.mark`. |
|
|
|
Example: |
|
|
|
@RunIf(min_torch="1.8") |
|
@pytest.mark.parametrize("arg1", [1.0, 2.0]) |
|
def test_wrapper(arg1): |
|
assert arg1 > 0 |
|
|
|
""" |
|
|
|
def __new__( |
|
self, |
|
min_gpus: int = 0, |
|
min_torch: Optional[str] = None, |
|
max_torch: Optional[str] = None, |
|
min_python: Optional[str] = None, |
|
skip_windows: bool = False, |
|
rpc: bool = False, |
|
fairscale: bool = False, |
|
deepspeed: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
min_gpus: min number of gpus required to run test |
|
min_torch: minimum pytorch version to run test |
|
max_torch: maximum pytorch version to run test |
|
min_python: minimum python version required to run test |
|
skip_windows: skip test for Windows platform |
|
rpc: requires Remote Procedure Call (RPC) |
|
fairscale: if `fairscale` module is required to run the test |
|
deepspeed: if `deepspeed` module is required to run the test |
|
kwargs: native pytest.mark.skipif keyword arguments |
|
""" |
|
conditions = [] |
|
reasons = [] |
|
|
|
if min_gpus: |
|
conditions.append(torch.cuda.device_count() < min_gpus) |
|
reasons.append(f"GPUs>={min_gpus}") |
|
|
|
if min_torch: |
|
torch_version = get_distribution("torch").version |
|
conditions.append(Version(torch_version) < Version(min_torch)) |
|
reasons.append(f"torch>={min_torch}") |
|
|
|
if max_torch: |
|
torch_version = get_distribution("torch").version |
|
conditions.append(Version(torch_version) >= Version(max_torch)) |
|
reasons.append(f"torch<{max_torch}") |
|
|
|
if min_python: |
|
py_version = ( |
|
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" |
|
) |
|
conditions.append(Version(py_version) < Version(min_python)) |
|
reasons.append(f"python>={min_python}") |
|
|
|
if skip_windows: |
|
conditions.append(_IS_WINDOWS) |
|
reasons.append("does not run on Windows") |
|
|
|
if rpc: |
|
conditions.append(not _RPC_AVAILABLE) |
|
reasons.append("RPC") |
|
|
|
if fairscale: |
|
conditions.append(not _FAIRSCALE_AVAILABLE) |
|
reasons.append("Fairscale") |
|
|
|
if deepspeed: |
|
conditions.append(not _DEEPSPEED_AVAILABLE) |
|
reasons.append("Deepspeed") |
|
|
|
reasons = [rs for cond, rs in zip(conditions, reasons) if cond] |
|
return pytest.mark.skipif( |
|
condition=any(conditions), |
|
reason=f"Requires: [{' + '.join(reasons)}]", |
|
**kwargs, |
|
) |
|
|