feat: updated .to() override to handle kwargs
Browse files- modeling_bert.py +3 -3
modeling_bert.py
CHANGED
@@ -422,9 +422,9 @@ class BertModel(BertPreTrainedModel):
|
|
422 |
pooler_output=pooled_output,
|
423 |
)
|
424 |
|
425 |
-
def to(self,
|
426 |
-
result = super().to(
|
427 |
-
if isinstance(
|
428 |
for layer in result.encoder.layers:
|
429 |
layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
|
430 |
layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
|
|
|
422 |
pooler_output=pooled_output,
|
423 |
)
|
424 |
|
425 |
+
def to(self, *args, **kwargs):
|
426 |
+
result = super().to(*args, **kwargs)
|
427 |
+
if (len(args) > 0 and isinstance(args[0], torch.dtype)) or "dtype" in kwargs:
|
428 |
for layer in result.encoder.layers:
|
429 |
layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
|
430 |
layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
|