import torch import torch.nn as nn from .gptq import * from .modelutils import * from .quant import * from transformers import BloomForCausalLM as LM class SakuraForCausalLM(LM): def __init__(self,*args,**kwargs): def noop(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False torch.set_default_dtype(torch.half) super().__init__(*args,**kwargs) torch.set_default_dtype(torch.float) self.eval() layers = find_layers(self) for name in ['lm_head']: if name in layers: del layers[name] make_quant(self, layers, 4, -1)