Update modeling_openelm.py
Browse files- modeling_openelm.py +5 -3
modeling_openelm.py
CHANGED
@@ -778,9 +778,11 @@ class OpenELMModel(OpenELMPreTrainedModel):
|
|
778 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
|
779 |
:, None, None, :
|
780 |
].eq(0.0)
|
781 |
-
causal_mask
|
782 |
-
|
783 |
-
|
|
|
|
|
784 |
|
785 |
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
|
786 |
# For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
|
|
778 |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
|
779 |
:, None, None, :
|
780 |
].eq(0.0)
|
781 |
+
causal_mask = causal_mask.clone()
|
782 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(...)
|
783 |
+
#causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
784 |
+
# padding_mask, min_dtype
|
785 |
+
#)
|
786 |
|
787 |
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
|
788 |
# For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|