ybelkada commited on
Commit
2d1091a
1 Parent(s): 26dc06a

Update transformers_plugin.py

Browse files
Files changed (1) hide show
  1. transformers_plugin.py +1 -1
transformers_plugin.py CHANGED
@@ -8,7 +8,7 @@ class LlamaXFormersPlugin(TransformersPlugin):
8
 
9
  def process_model_pre_init(self, model):
10
  model_config = model.config
11
- replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", kwargs={"config": model_config})
12
 
13
 
14
  def process_model_post_init(self, model):
 
8
 
9
  def process_model_pre_init(self, model):
10
  model_config = model.config
11
+ replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", init_kwargs={"config": model_config})
12
 
13
 
14
  def process_model_post_init(self, model):