llama-xformers / transformers_plugin.py
ybelkada's picture
Update transformers_plugin.py
a5c24a9
raw
history blame
531 Bytes
from transformers.integrations import TransformersPlugin, replace_target_class
from .llama_xformers_attention import LlamaXFormersAttention
class LlamaXFormersPlugin(TransformersPlugin):
def __init__(self, config):
pass
def process_model_pre_init(self, model):
model_config = model.config
replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", init_kwargs={"config": model_config})
return model
def process_model_post_init(self, model):
return model