Text Generation
Transformers
Safetensors
openelm
custom_code

Fix For NaN Logits in HuggingFace Distribution of OpenELM

#3
by jasonkrone - opened

I found that left padding of inputs led to NaN logits. The fix (credit to this thread), is to change the line min_dtype = torch.finfo(dtype).min to min_dtype = torch.finfo(dtype).min / 2 in the function _update_causal_mask.

I presume all other OpenELM model sizes and variations require the same fix.

Note: the if not is_tracing and torch.any(attention_mask != 1): condition in the _update_causal_mask function seems to be addressing the same issue; however, this mitigation only occurs when self.config._attn_implementation == "sdpa", whereas the issue is present even if self.config._attn_implementation == "eager".

P.S. thanks for your work on OpenELM!

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment