Fix For NaN Logits in HuggingFace Distribution of OpenELM
#3
by
jasonkrone
- opened
I found that left padding of inputs led to NaN logits. The fix (credit to this thread), is to change the line min_dtype = torch.finfo(dtype).min
to min_dtype = torch.finfo(dtype).min / 2
in the function _update_causal_mask
.
I presume all other OpenELM model sizes and variations require the same fix.
Note: the if not is_tracing and torch.any(attention_mask != 1):
condition in the _update_causal_mask
function seems to be addressing the same issue; however, this mitigation only occurs when self.config._attn_implementation == "sdpa"
, whereas the issue is present even if self.config._attn_implementation == "eager"
.
P.S. thanks for your work on OpenELM!