Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
class FP32LayerNorm(nn.LayerNorm): | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
origin_dtype = inputs.dtype | |
return F.layer_norm( | |
inputs.float(), | |
self.normalized_shape, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
).to(origin_dtype) | |