|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
|
|
|
|
def require_torch_gpu(test_case): |
|
""" |
|
Decorator marking a test that requires a GPU. Will be skipped when no GPU is available. |
|
""" |
|
if not torch.cuda.is_available(): |
|
return unittest.skip("test requires GPU")(test_case) |
|
else: |
|
return test_case |
|
|
|
|
|
def require_torch_multi_gpu(test_case): |
|
""" |
|
Decorator marking a test that requires multiple GPUs. Will be skipped when less than 2 GPUs are available. |
|
""" |
|
if not torch.cuda.is_available() or torch.cuda.device_count() < 2: |
|
return unittest.skip("test requires multiple GPUs")(test_case) |
|
else: |
|
return test_case |
|
|
|
|
|
def require_bitsandbytes(test_case): |
|
""" |
|
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library is not installed. |
|
""" |
|
try: |
|
import bitsandbytes |
|
except ImportError: |
|
return unittest.skip("test requires bitsandbytes")(test_case) |
|
else: |
|
return test_case |
|
|