ybelkada commited on
Commit
1141922
1 Parent(s): 4f95d51

Update llama_xformers_attention.py

Browse files
Files changed (1) hide show
  1. llama_xformers_attention.py +5 -1
llama_xformers_attention.py CHANGED
@@ -3,7 +3,11 @@ import torch.nn as nn
3
 
4
  from typing import Optional, Tuple
5
 
6
- from transformers.models.llama.modeling_llama import LlamaAttention
 
 
 
 
7
 
8
  class LlamaXFormersAttention(LlamaAttention):
9
  def forward(
 
3
 
4
  from typing import Optional, Tuple
5
 
6
+ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
7
+
8
+ from xformers.ops.fmha import (
9
+ memory_efficient_attention,
10
+ )
11
 
12
  class LlamaXFormersAttention(LlamaAttention):
13
  def forward(