Fix precision error
Browse files- modeling_chatglm.py +9 -7
modeling_chatglm.py
CHANGED
|
@@ -3,9 +3,7 @@
|
|
| 3 |
import math
|
| 4 |
import copy
|
| 5 |
import warnings
|
| 6 |
-
import re
|
| 7 |
import sys
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
| 11 |
import torch.nn.functional as F
|
|
@@ -183,9 +181,14 @@ class RMSNorm(torch.nn.Module):
|
|
| 183 |
self.eps = eps
|
| 184 |
|
| 185 |
def forward(self, hidden_states: torch.Tensor):
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
return (self.weight * hidden_states).to(input_dtype)
|
| 191 |
|
|
@@ -517,8 +520,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 517 |
|
| 518 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 519 |
# Layernorm on the input data.
|
| 520 |
-
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 521 |
-
dtype=config.torch_dtype)
|
| 522 |
|
| 523 |
# Self attention.
|
| 524 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
|
| 3 |
import math
|
| 4 |
import copy
|
| 5 |
import warnings
|
|
|
|
| 6 |
import sys
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.utils.checkpoint
|
| 9 |
import torch.nn.functional as F
|
|
|
|
| 181 |
self.eps = eps
|
| 182 |
|
| 183 |
def forward(self, hidden_states: torch.Tensor):
|
| 184 |
+
if hidden_states == torch.bfloat16:
|
| 185 |
+
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
| 186 |
+
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
| 187 |
+
return self.weight * x_normed
|
| 188 |
+
else:
|
| 189 |
+
input_dtype = hidden_states.dtype
|
| 190 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 191 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 192 |
|
| 193 |
return (self.weight * hidden_states).to(input_dtype)
|
| 194 |
|
|
|
|
| 520 |
|
| 521 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 522 |
# Layernorm on the input data.
|
| 523 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype)
|
|
|
|
| 524 |
|
| 525 |
# Self attention.
|
| 526 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|