davidhd commited on
Commit
35101e6
·
verified ·
1 Parent(s): 3fe920c

Fix attention mask dtype issues.

Browse files
Files changed (1) hide show
  1. amplify.py +1 -1
amplify.py CHANGED
@@ -248,7 +248,7 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
248
  if attention_mask is not None and not torch.all(attention_mask == 0):
249
  assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, (
250
  "AMPLIFY expects an additive attention_mask.\n"
251
- "Modify the output of the tokenizer with attention_mask = torch.where(attention_mask, float(0.0), float("-inf"))"
252
  )
253
  attention_mask = (
254
  attention_mask.unsqueeze(1)
 
248
  if attention_mask is not None and not torch.all(attention_mask == 0):
249
  assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, (
250
  "AMPLIFY expects an additive attention_mask.\n"
251
+ "Modify the output of the tokenizer with attention_mask = torch.where(attention_mask, float(0.0), float('-inf'))"
252
  )
253
  attention_mask = (
254
  attention_mask.unsqueeze(1)