# Copyright (c) 2024, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F from typing import Optional from torch import Tensor # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) def bias_dropout_add( x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float, training: bool ) -> Tensor: out = torch.nn.functional.dropout(x + bias, p=prob, training=training) if residual is not None: out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train( x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float ) -> Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference( x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float ) -> Tensor: return bias_dropout_add(x, bias, residual, prob, False)