|
|
|
|
|
from functools import partial |
|
from optimum.quanto.tensor import QTensor |
|
import torch |
|
|
|
|
|
def hacked_state_dict(self, *args, **kwargs): |
|
orig_state_dict = self.orig_state_dict(*args, **kwargs) |
|
new_state_dict = {} |
|
for key, value in orig_state_dict.items(): |
|
if key.endswith("._scale"): |
|
continue |
|
if key.endswith(".input_scale"): |
|
continue |
|
if key.endswith(".output_scale"): |
|
continue |
|
if key.endswith("._data"): |
|
key = key[:-6] |
|
scale = orig_state_dict[key + "._scale"] |
|
|
|
dtype = scale.dtype |
|
scale = scale.float() |
|
value = value.float() |
|
dequantized = value * scale |
|
|
|
|
|
input_scale = orig_state_dict.get(key + ".input_scale") |
|
|
|
if input_scale is not None: |
|
|
|
if input_scale.item() != 1.0: |
|
raise ValueError("Input scale is not 1.0, cannot dequantize") |
|
|
|
output_scale = orig_state_dict.get(key + ".output_scale") |
|
|
|
if output_scale is not None: |
|
|
|
if output_scale.item() != 1.0: |
|
raise ValueError("Output scale is not 1.0, cannot dequantize") |
|
|
|
new_state_dict[key] = dequantized.to('cpu', dtype=dtype) |
|
else: |
|
new_state_dict[key] = value |
|
return new_state_dict |
|
|
|
|
|
def patch_dequantization_on_save(model): |
|
model.orig_state_dict = model.state_dict |
|
model.state_dict = partial(hacked_state_dict, model) |
|
|
|
|
|
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool: |
|
""" |
|
Convert a quantized parameter back to a regular Parameter with floating point values. |
|
|
|
Args: |
|
module: The module containing the parameter to unquantize |
|
param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias') |
|
|
|
Returns: |
|
bool: True if parameter was unquantized, False if it was already unquantized |
|
""" |
|
|
|
|
|
if not hasattr(module, param_name): |
|
raise AttributeError(f"Module has no parameter named '{param_name}'") |
|
|
|
param = getattr(module, param_name) |
|
|
|
|
|
if not isinstance(param, torch.nn.Parameter): |
|
raise TypeError(f"'{param_name}' is not a Parameter") |
|
if not isinstance(param, QTensor): |
|
return False |
|
|
|
|
|
with torch.no_grad(): |
|
float_tensor = param.float() |
|
new_param = torch.nn.Parameter( |
|
float_tensor, |
|
requires_grad=param.requires_grad |
|
) |
|
|
|
|
|
setattr(module, param_name, new_param) |
|
|
|
return True |