llama-xformers / transformers_plugin.py
ybelkada's picture
Update transformers_plugin.py
b90a036
raw
history blame
274 Bytes
from transformers.integrations import TransformersPlugin
class LlamaXFormersPlugin(TransformersPlugin):
def __init__(self, config):
pass
def process_model_pre_init(self, model):
pass
def process_model_post_init(self, model):
pass