Fix For NaN Logits in HuggingFace Distribution of OpenELM
#3
by
jasonkrone
- opened
- modeling_openelm.py +1 -1
modeling_openelm.py
CHANGED
@@ -766,7 +766,7 @@ class OpenELMModel(OpenELMPreTrainedModel):
|
|
766 |
)
|
767 |
|
768 |
# We use the current dtype to avoid any overflows
|
769 |
-
min_dtype = torch.finfo(dtype).min
|
770 |
causal_mask = (
|
771 |
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
|
772 |
* min_dtype
|
|
|
766 |
)
|
767 |
|
768 |
# We use the current dtype to avoid any overflows
|
769 |
+
min_dtype = torch.finfo(dtype).min / 2
|
770 |
causal_mask = (
|
771 |
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
|
772 |
* min_dtype
|