Markus28 commited on
Commit
e86d612
·
1 Parent(s): a62c2ab

feat: updated .to() override to handle kwargs

Browse files
Files changed (1) hide show
  1. modeling_bert.py +3 -3
modeling_bert.py CHANGED
@@ -422,9 +422,9 @@ class BertModel(BertPreTrainedModel):
422
  pooler_output=pooled_output,
423
  )
424
 
425
- def to(self, target):
426
- result = super().to(target)
427
- if isinstance(target, torch.dtype):
428
  for layer in result.encoder.layers:
429
  layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
430
  layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
 
422
  pooler_output=pooled_output,
423
  )
424
 
425
+ def to(self, *args, **kwargs):
426
+ result = super().to(*args, **kwargs)
427
+ if (len(args) > 0 and isinstance(args[0], torch.dtype)) or "dtype" in kwargs:
428
  for layer in result.encoder.layers:
429
  layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
430
  layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)