from transformers.integrations import TransformersPlugin, replace_target_class from .llama_xformers_attention import LlamaXFormersAttention class LlamaXFormersPlugin(TransformersPlugin): def __init__(self, config): pass def process_model_pre_init(self, model): model_config = model.config replace_target_class(model, LlamaXFormersAttention, "LlamaAttention", kwargs={"config": model_config}) def process_model_post_init(self, model): pass