Tests: actually gate FP8 use based on capability
Browse files
tests/test_triton_scaled_mm.py
CHANGED
|
@@ -37,9 +37,9 @@ def get_8bit_types():
|
|
| 37 |
capability = major * 10 + minor
|
| 38 |
supports_fp8 = capability >= 89
|
| 39 |
|
| 40 |
-
if torch.version.hip is not None:
|
| 41 |
types.append(torch.float8_e4m3fnuz)
|
| 42 |
-
elif torch.version.cuda is not None and torch.cuda.is_available():
|
| 43 |
types.append(torch.float8_e4m3fn)
|
| 44 |
return types
|
| 45 |
|
|
|
|
| 37 |
capability = major * 10 + minor
|
| 38 |
supports_fp8 = capability >= 89
|
| 39 |
|
| 40 |
+
if supports_fp8 and torch.version.hip is not None:
|
| 41 |
types.append(torch.float8_e4m3fnuz)
|
| 42 |
+
elif supports_fp8 and torch.version.cuda is not None and torch.cuda.is_available():
|
| 43 |
types.append(torch.float8_e4m3fn)
|
| 44 |
return types
|
| 45 |
|