Update modeling_openelm.py
Browse files- modeling_openelm.py +3 -1
modeling_openelm.py
CHANGED
@@ -779,7 +779,9 @@ class OpenELMModel(OpenELMPreTrainedModel):
|
|
779 |
:, None, None, :
|
780 |
].eq(0.0)
|
781 |
causal_mask = causal_mask.clone()
|
782 |
-
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
|
|
|
|
783 |
#causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
784 |
# padding_mask, min_dtype
|
785 |
#)
|
|
|
779 |
:, None, None, :
|
780 |
].eq(0.0)
|
781 |
causal_mask = causal_mask.clone()
|
782 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
783 |
+
padding_mask, min_dtype
|
784 |
+
)
|
785 |
#causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
786 |
# padding_mask, min_dtype
|
787 |
#)
|