Fix a couple of bugs and add tests from vLLM
Browse files- ext-torch/__init__.py +35 -29
- ext-torch/torch_binding.cpp +4 -0
- tests/__init__.py +0 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/allclose_default.py +14 -0
- tests/kernels/test_activation.py +139 -0
- tests/kernels/utils.py +73 -0
    	
        ext-torch/__init__.py
    CHANGED
    
    | @@ -6,36 +6,42 @@ except ImportError as e: | |
| 6 | 
             
                # Fallback for local development.
         | 
| 7 | 
             
                try:
         | 
| 8 | 
             
                    import _activation
         | 
|  | |
| 9 | 
             
                    ops = torch.ops._activition
         | 
| 10 | 
             
                except ImportError:
         | 
| 11 | 
             
                    raise e
         | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: | 
| 15 | 
            -
                ops.silu_and_mul(out, x) | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 41 | 
             
                ops.gelu_quick(out, x)
         | 
|  | 
|  | |
| 6 | 
             
                # Fallback for local development.
         | 
| 7 | 
             
                try:
         | 
| 8 | 
             
                    import _activation
         | 
| 9 | 
            +
             | 
| 10 | 
             
                    ops = torch.ops._activition
         | 
| 11 | 
             
                except ImportError:
         | 
| 12 | 
             
                    raise e
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
         | 
| 16 | 
            +
                ops.silu_and_mul(out, x)
         | 
| 17 | 
            +
                return out
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
         | 
| 21 | 
            +
                ops.gelu_and_mul(out, x)
         | 
| 22 | 
            +
                return out
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
         | 
| 26 | 
            +
                ops.gelu_tanh_and_mul(out, x)
         | 
| 27 | 
            +
                return out
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None:
         | 
| 31 | 
            +
                ops.fatrelu_and_mul(out, x, threshold)
         | 
| 32 | 
            +
                return out
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
         | 
| 36 | 
            +
                ops.gelu_fast(out, x)
         | 
| 37 | 
            +
                return out
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
         | 
| 41 | 
            +
                ops.gelu_new(out, x)
         | 
| 42 | 
            +
                return out
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
         | 
| 46 | 
             
                ops.gelu_quick(out, x)
         | 
| 47 | 
            +
                return out
         | 
    	
        ext-torch/torch_binding.cpp
    CHANGED
    
    | @@ -28,6 +28,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| 28 | 
             
              // Approximate GELU implementation.
         | 
| 29 | 
             
              ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
         | 
| 30 | 
             
              ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
         | 
|  | |
|  | |
|  | |
|  | |
| 31 | 
             
            }
         | 
| 32 |  | 
| 33 | 
             
            REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
         | 
|  | |
| 28 | 
             
              // Approximate GELU implementation.
         | 
| 29 | 
             
              ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
         | 
| 30 | 
             
              ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              // Quick GELU implementation.
         | 
| 33 | 
            +
              ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
         | 
| 34 | 
            +
              ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
         | 
| 35 | 
             
            }
         | 
| 36 |  | 
| 37 | 
             
            REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
         | 
    	
        tests/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        tests/kernels/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        tests/kernels/allclose_default.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Reference default values of atol and rtol are from
         | 
| 4 | 
            +
            # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
         | 
| 5 | 
            +
            default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
         | 
| 6 | 
            +
            default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def get_default_atol(output) -> float:
         | 
| 10 | 
            +
                return default_atol[output.dtype]
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def get_default_rtol(output) -> float:
         | 
| 14 | 
            +
                return default_rtol[output.dtype]
         | 
    	
        tests/kernels/test_activation.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from typing import Type
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import activation
         | 
| 6 | 
            +
            import pytest
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .utils import opcheck
         | 
| 11 | 
            +
            from .allclose_default import get_default_atol, get_default_rtol
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            DTYPES = [torch.half, torch.bfloat16, torch.float]
         | 
| 14 | 
            +
            NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
         | 
| 15 | 
            +
            D = [512, 13824]  # Arbitrary values for testing
         | 
| 16 | 
            +
            SEEDS = [0]
         | 
| 17 | 
            +
            CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def gelu_fast(x: torch.Tensor) -> torch.Tensor:
         | 
| 21 | 
            +
                return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def gelu_new(x: torch.Tensor) -> torch.Tensor:
         | 
| 25 | 
            +
                c = math.sqrt(2.0 / math.pi)
         | 
| 26 | 
            +
                return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def gelu_quick(x: torch.Tensor) -> torch.Tensor:
         | 
| 30 | 
            +
                return x * torch.sigmoid(1.702 * x)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def fatrelu_and_mul(x: torch.Tensor, threshold: float) -> torch.Tensor:
         | 
| 34 | 
            +
                d = x.shape[-1] // 2
         | 
| 35 | 
            +
                x1 = x[..., :d]
         | 
| 36 | 
            +
                x2 = x[..., d:]
         | 
| 37 | 
            +
                x1 = F.threshold(x1, threshold, 0.0)
         | 
| 38 | 
            +
                return x1 * x2
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
         | 
| 42 | 
            +
                d = x.shape[-1] // 2
         | 
| 43 | 
            +
                return F.silu(x[..., :d]) * x[..., d:]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
         | 
| 47 | 
            +
                d = x.shape[-1] // 2
         | 
| 48 | 
            +
                return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            @pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
         | 
| 52 | 
            +
            @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
         | 
| 53 | 
            +
            @pytest.mark.parametrize("d", D)
         | 
| 54 | 
            +
            @pytest.mark.parametrize("dtype", DTYPES)
         | 
| 55 | 
            +
            @pytest.mark.parametrize("seed", SEEDS)
         | 
| 56 | 
            +
            @pytest.mark.parametrize("device", CUDA_DEVICES)
         | 
| 57 | 
            +
            @torch.inference_mode()
         | 
| 58 | 
            +
            def test_act_and_mul(
         | 
| 59 | 
            +
                activation_name: str,
         | 
| 60 | 
            +
                num_tokens: int,
         | 
| 61 | 
            +
                d: int,
         | 
| 62 | 
            +
                dtype: torch.dtype,
         | 
| 63 | 
            +
                seed: int,
         | 
| 64 | 
            +
                device: str,
         | 
| 65 | 
            +
            ) -> None:
         | 
| 66 | 
            +
                random.seed(seed)
         | 
| 67 | 
            +
                torch.manual_seed(seed)
         | 
| 68 | 
            +
                torch.set_default_device(device)
         | 
| 69 | 
            +
                x = torch.randn(num_tokens, 2 * d, dtype=dtype)
         | 
| 70 | 
            +
                if activation_name == "silu":
         | 
| 71 | 
            +
                    torch_fn = silu_and_mul
         | 
| 72 | 
            +
                    fn = activation.silu_and_mul
         | 
| 73 | 
            +
                    op = activation.ops.silu_and_mul
         | 
| 74 | 
            +
                elif activation_name == "gelu":
         | 
| 75 | 
            +
                    torch_fn = lambda x: gelu_and_mul(x, "none")
         | 
| 76 | 
            +
                    fn = activation.gelu_and_mul
         | 
| 77 | 
            +
                    op = activation.ops.gelu_and_mul
         | 
| 78 | 
            +
                elif activation_name == "gelu_tanh":
         | 
| 79 | 
            +
                    torch_fn = lambda x: gelu_and_mul(x, "tanh")
         | 
| 80 | 
            +
                    fn = activation.gelu_tanh_and_mul
         | 
| 81 | 
            +
                    op = activation.ops.gelu_tanh_and_mul
         | 
| 82 | 
            +
                elif activation_name == "fatrelu":
         | 
| 83 | 
            +
                    threshold = random.uniform(0, 1)
         | 
| 84 | 
            +
                    torch_fn = lambda x: fatrelu_and_mul(x, threshold)
         | 
| 85 | 
            +
                    fn = lambda out, x: activation.fatrelu_and_mul(out, x, threshold)
         | 
| 86 | 
            +
                    op = activation.ops.fatrelu_and_mul
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
         | 
| 89 | 
            +
                out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
         | 
| 90 | 
            +
                out = fn(out, x)
         | 
| 91 | 
            +
                ref_out = torch_fn(x)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # The SiLU, GELU and FatReLU implementations are equivalent to the native
         | 
| 94 | 
            +
                # PyTorch implementations, so we can do exact comparison.
         | 
| 95 | 
            +
                torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                d = x.shape[-1] // 2
         | 
| 98 | 
            +
                output_shape = x.shape[:-1] + (d,)
         | 
| 99 | 
            +
                out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
         | 
| 100 | 
            +
                if activation_name == "fatrelu":
         | 
| 101 | 
            +
                    opcheck(op, (out, x, threshold))
         | 
| 102 | 
            +
                else:
         | 
| 103 | 
            +
                    opcheck(op, (out, x))
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            @pytest.mark.parametrize(
         | 
| 107 | 
            +
                "activation_fns",
         | 
| 108 | 
            +
                [
         | 
| 109 | 
            +
                    (gelu_fast, activation.gelu_fast, activation.ops.gelu_fast),
         | 
| 110 | 
            +
                    (gelu_new, activation.gelu_new, activation.ops.gelu_new),
         | 
| 111 | 
            +
                    (gelu_quick, activation.gelu_quick, activation.ops.gelu_quick),
         | 
| 112 | 
            +
                ],
         | 
| 113 | 
            +
            )
         | 
| 114 | 
            +
            @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
         | 
| 115 | 
            +
            @pytest.mark.parametrize("d", D)
         | 
| 116 | 
            +
            @pytest.mark.parametrize("dtype", DTYPES)
         | 
| 117 | 
            +
            @pytest.mark.parametrize("seed", SEEDS)
         | 
| 118 | 
            +
            @pytest.mark.parametrize("device", CUDA_DEVICES)
         | 
| 119 | 
            +
            @torch.inference_mode()
         | 
| 120 | 
            +
            def test_activation(
         | 
| 121 | 
            +
                activation_fns,
         | 
| 122 | 
            +
                num_tokens: int,
         | 
| 123 | 
            +
                d: int,
         | 
| 124 | 
            +
                dtype: torch.dtype,
         | 
| 125 | 
            +
                seed: int,
         | 
| 126 | 
            +
                device: str,
         | 
| 127 | 
            +
            ) -> None:
         | 
| 128 | 
            +
                torch.manual_seed(seed)
         | 
| 129 | 
            +
                torch.set_default_device(device)
         | 
| 130 | 
            +
                x = torch.randn(num_tokens, d, dtype=dtype)
         | 
| 131 | 
            +
                torch_fn, fn, op = activation_fns
         | 
| 132 | 
            +
                out = fn(torch.empty_like(x), x)
         | 
| 133 | 
            +
                ref_out = torch_fn(x)
         | 
| 134 | 
            +
                torch.testing.assert_close(
         | 
| 135 | 
            +
                    out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
         | 
| 136 | 
            +
                )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                out = torch.empty_like(x)
         | 
| 139 | 
            +
                opcheck(op, (out, x))
         | 
    	
        tests/kernels/utils.py
    ADDED
    
    | @@ -0,0 +1,73 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Kernel test utils"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import itertools
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import unittest
         | 
| 6 | 
            +
            from numbers import Number
         | 
| 7 | 
            +
            from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import pytest
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch._prims_common import TensorLikeType
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # For now, disable "test_aot_dispatch_dynamic" since there are some
         | 
| 14 | 
            +
            # bugs related to this test in PyTorch 2.4.
         | 
| 15 | 
            +
            DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
         | 
| 16 | 
            +
                "test_schema",
         | 
| 17 | 
            +
                "test_autograd_registration",
         | 
| 18 | 
            +
                "test_faketensor",
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
         | 
| 22 | 
            +
                "test_schema",
         | 
| 23 | 
            +
                "test_autograd_registration",
         | 
| 24 | 
            +
                "test_faketensor",
         | 
| 25 | 
            +
                "test_aot_dispatch_dynamic",
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            # Copied/modified from torch._refs.__init__.py
         | 
| 30 | 
            +
            def fp8_allclose(
         | 
| 31 | 
            +
                a: TensorLikeType,
         | 
| 32 | 
            +
                b: TensorLikeType,
         | 
| 33 | 
            +
                rtol: float = 1e-05,
         | 
| 34 | 
            +
                atol: float = 1e-08,
         | 
| 35 | 
            +
                equal_nan: bool = False,
         | 
| 36 | 
            +
            ) -> bool:
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                Reference implementation of torch.allclose
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                return bool(
         | 
| 43 | 
            +
                    torch.all(
         | 
| 44 | 
            +
                        torch.isclose(
         | 
| 45 | 
            +
                            a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
         | 
| 46 | 
            +
                        )
         | 
| 47 | 
            +
                    ).item()
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            # A special version of op check that has a restricted default set of test_utils
         | 
| 52 | 
            +
            # and a patched version of allclose that supports fp8 types.
         | 
| 53 | 
            +
            def opcheck(
         | 
| 54 | 
            +
                op: Union[
         | 
| 55 | 
            +
                    torch._ops.OpOverload,
         | 
| 56 | 
            +
                    torch._ops.OpOverloadPacket,
         | 
| 57 | 
            +
                    torch._library.custom_ops.CustomOpDef,
         | 
| 58 | 
            +
                ],
         | 
| 59 | 
            +
                args: Tuple[Any, ...],
         | 
| 60 | 
            +
                kwargs: Optional[Dict[str, Any]] = None,
         | 
| 61 | 
            +
                *,
         | 
| 62 | 
            +
                test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
         | 
| 63 | 
            +
                raise_exception: bool = True,
         | 
| 64 | 
            +
                cond: bool = True
         | 
| 65 | 
            +
            ) -> Dict[str, str]:
         | 
| 66 | 
            +
                with unittest.mock.patch("torch.allclose", new=fp8_allclose):
         | 
| 67 | 
            +
                    return (
         | 
| 68 | 
            +
                        torch.library.opcheck(
         | 
| 69 | 
            +
                            op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
         | 
| 70 | 
            +
                        )
         | 
| 71 | 
            +
                        if cond
         | 
| 72 | 
            +
                        else {}
         | 
| 73 | 
            +
                    )
         | 

