radna commited on
Commit
8f8935f
1 Parent(s): d4b4104

Update flash_attention.py

Browse files
Files changed (1) hide show
  1. flash_attention.py +3 -3
flash_attention.py CHANGED
@@ -66,7 +66,7 @@ class FlashAttention(nn.Module):
66
  max_s,
67
  self.dropout_p if self.training else 0.0,
68
  self.softmax_scale,
69
- causal,
70
  )
71
  output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
72
  else:
@@ -82,7 +82,7 @@ class FlashAttention(nn.Module):
82
  max_s,
83
  self.dropout_p if self.training else 0.0,
84
  self.softmax_scale,
85
- causal,
86
  )
87
  output = rearrange(
88
  pad_input(
@@ -102,7 +102,7 @@ class FlashAttention(nn.Module):
102
  max_s,
103
  self.dropout_p if self.training else 0.0,
104
  self.softmax_scale,
105
- causal,
106
  )
107
 
108
  return output, None
 
66
  max_s,
67
  self.dropout_p if self.training else 0.0,
68
  self.softmax_scale,
69
+ causal
70
  )
71
  output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
72
  else:
 
82
  max_s,
83
  self.dropout_p if self.training else 0.0,
84
  self.softmax_scale,
85
+ causal
86
  )
87
  output = rearrange(
88
  pad_input(
 
102
  max_s,
103
  self.dropout_p if self.training else 0.0,
104
  self.softmax_scale,
105
+ causal
106
  )
107
 
108
  return output, None