# Copyright Forge 2024 import time import torch import contextlib from backend import stream, memory_management, utils from backend.patcher.lora import merge_lora_to_weight stash = {} def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None): patches = getattr(layer, 'forge_online_loras', None) weight_patches, bias_patches = None, None if patches is not None: weight_patches = patches.get('weight', None) if patches is not None: bias_patches = patches.get('bias', None) weight = None if layer.weight is not None: weight = layer.weight if weight_fn is not None: if weight_args is not None: fn_device = weight_args.get('device', None) if fn_device is not None: weight = weight.to(device=fn_device) weight = weight_fn(weight) if weight_args is not None: weight = weight.to(**weight_args) if weight_patches is not None: weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype) bias = None if layer.bias is not None: bias = layer.bias if bias_fn is not None: if bias_args is not None: fn_device = bias_args.get('device', None) if fn_device is not None: bias = bias.to(device=fn_device) bias = bias_fn(bias) if bias_args is not None: bias = bias.to(**bias_args) if bias_patches is not None: bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype) return weight, bias def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None): weight, bias, signal = None, None, None non_blocking = True if getattr(x.device, 'type', None) == 'mps': non_blocking = False target_dtype = x.dtype target_device = x.device if skip_weight_dtype: weight_args = dict(device=target_device, non_blocking=non_blocking) else: weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) if skip_bias_dtype: bias_args = dict(device=target_device, non_blocking=non_blocking) else: bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) if stream.should_use_stream(): with stream.stream_context()(stream.mover_stream): weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) signal = stream.mover_stream.record_event() else: weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) return weight, bias, signal @contextlib.contextmanager def main_stream_worker(weight, bias, signal): if signal is None or not stream.should_use_stream(): yield return with stream.stream_context()(stream.current_stream): stream.current_stream.wait_event(signal) yield finished_signal = stream.current_stream.record_event() stash[id(finished_signal)] = (weight, bias, finished_signal) garbage = [] for k, (w, b, s) in stash.items(): if s.query(): garbage.append(k) for k in garbage: del stash[k] return def cleanup_cache(): if not stream.should_use_stream(): return stream.current_stream.synchronize() stream.mover_stream.synchronize() stash.clear() return current_device = None current_dtype = None current_manual_cast_enabled = False current_bnb_dtype = None class ForgeOperations: class Linear(torch.nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) self.weight = None self.bias = None self.parameters_manual_cast = current_manual_cast_enabled def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if hasattr(self, 'dummy'): if prefix + 'weight' in state_dict: self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy)) if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(x, weight, bias) else: weight, bias = get_weight_and_bias(self) return torch.nn.functional.linear(x, weight, bias) class Conv2d(torch.nn.Conv2d): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: weight, bias = get_weight_and_bias(self) return super()._conv_forward(x, weight, bias) class Conv3d(torch.nn.Conv3d): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: weight, bias = get_weight_and_bias(self) return super()._conv_forward(input, weight, bias) class Conv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: weight, bias = get_weight_and_bias(self) return super()._conv_forward(input, weight, bias) class ConvTranspose2d(torch.nn.ConvTranspose2d): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x, output_size=None): if self.parameters_manual_cast: num_spatial_dims = 2 output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: weight, bias = get_weight_and_bias(self) num_spatial_dims = 2 output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose1d(torch.nn.ConvTranspose1d): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x, output_size=None): if self.parameters_manual_cast: num_spatial_dims = 1 output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: weight, bias = get_weight_and_bias(self) num_spatial_dims = 1 output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose3d(torch.nn.ConvTranspose3d): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x, output_size=None): if self.parameters_manual_cast: num_spatial_dims = 3 output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: weight, bias = get_weight_and_bias(self) num_spatial_dims = 3 output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class GroupNorm(torch.nn.GroupNorm): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return torch.nn.functional.group_norm(x, self.num_groups, weight, bias, self.eps) else: return super().forward(x) class LayerNorm(torch.nn.LayerNorm): def __init__(self, *args, **kwargs): kwargs['device'] = current_device kwargs['dtype'] = current_dtype super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x) with main_stream_worker(weight, bias, signal): return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) else: return super().forward(x) class Embedding(torch.nn.Embedding): def __init__(self, *args, **kwargs): kwargs['device'] = current_device super().__init__(*args, **kwargs) self.parameters_manual_cast = current_manual_cast_enabled self.bias = None def reset_parameters(self): self.bias = None return None def forward(self, x): if self.parameters_manual_cast: weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) with main_stream_worker(weight, bias, signal): return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) else: return super().forward(x) try: from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit class ForgeOperationsBNB4bits(ForgeOperations): class Linear(ForgeLoader4Bit): def __init__(self, *args, **kwargs): super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype) self.parameters_manual_cast = current_manual_cast_enabled def forward(self, x): if self.bias is not None and self.bias.dtype != x.dtype: # Maybe this can also be set to all non-bnb ops since the cost is very low. # And it only invokes one time, and most linear does not have bias self.bias = utils.tensor2parameter(self.bias.to(x.dtype)) if hasattr(self, 'forge_online_loras'): weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True) with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(x, weight, bias) if not self.parameters_manual_cast: return functional_linear_4bits(x, self.weight, self.bias) elif not self.weight.bnb_quantized: assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' layer_original_device = self.weight.device self.weight = self.weight._quantize(x.device) bias = self.bias.to(x.device) if self.bias is not None else None out = functional_linear_4bits(x, self.weight, bias) self.weight = self.weight.to(layer_original_device) return out else: weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) with main_stream_worker(weight, bias, signal): return functional_linear_4bits(x, weight, bias) bnb_avaliable = True except: bnb_avaliable = False from backend.operations_gguf import dequantize_tensor class ForgeOperationsGGUF(ForgeOperations): class Linear(torch.nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) self.weight = None self.bias = None self.parameters_manual_cast = current_manual_cast_enabled def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if hasattr(self, 'dummy'): if prefix + 'weight' in state_dict: self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device) if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device) del self.dummy else: if prefix + 'weight' in state_dict: self.weight = state_dict[prefix + 'weight'] if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'] def _apply(self, fn, recurse=True): if self.weight is not None: self.weight = utils.tensor2parameter(fn(self.weight)) if self.bias is not None: self.bias = utils.tensor2parameter(fn(self.bias)) return self def forward(self, x): if self.bias is not None and self.bias.dtype != x.dtype: self.bias = utils.tensor2parameter(dequantize_tensor(self.bias).to(x.dtype)) if self.weight is not None and self.weight.dtype != x.dtype and getattr(self.weight, 'gguf_cls', None) is None: self.weight = utils.tensor2parameter(self.weight.to(x.dtype)) weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=None, skip_bias_dtype=True) with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(x, weight, bias) @contextlib.contextmanager def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None): global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype if operations is None: if bnb_dtype in ['gguf']: operations = ForgeOperationsGGUF elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']: operations = ForgeOperationsBNB4bits else: operations = ForgeOperations op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding'] backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} try: for op_name in op_names: setattr(torch.nn, op_name, getattr(operations, op_name)) yield finally: for op_name in op_names: setattr(torch.nn, op_name, backups[op_name]) return def shift_manual_cast(model, enabled): for m in model.modules(): if hasattr(m, 'parameters_manual_cast'): m.parameters_manual_cast = enabled return @contextlib.contextmanager def automatic_memory_management(): memory_management.free_memory( memory_required=3 * 1024 * 1024 * 1024, device=memory_management.get_torch_device() ) module_list = [] original_init = torch.nn.Module.__init__ original_to = torch.nn.Module.to def patched_init(self, *args, **kwargs): module_list.append(self) return original_init(self, *args, **kwargs) def patched_to(self, *args, **kwargs): module_list.append(self) return original_to(self, *args, **kwargs) try: torch.nn.Module.__init__ = patched_init torch.nn.Module.to = patched_to yield finally: torch.nn.Module.__init__ = original_init torch.nn.Module.to = original_to start = time.perf_counter() module_list = set(module_list) for module in module_list: module.cpu() memory_management.soft_empty_cache() end = time.perf_counter() print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') return class DynamicSwapInstaller: @staticmethod def _install_module(module: torch.nn.Module, target_device: torch.device): original_class = module.__class__ module.__dict__['forge_backup_original_class'] = original_class def hacked_get_attr(self, name: str): if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] if name in _parameters: p = _parameters[name] if p is None: return None if p.__class__ == torch.nn.Parameter: return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad) else: return p.to(target_device) if '_buffers' in self.__dict__: _buffers = self.__dict__['_buffers'] if name in _buffers: return _buffers[name].to(target_device) return super(original_class, self).__getattr__(name) module.__class__ = type('DynamicSwapInstance', (original_class,), { '__getattr__': hacked_get_attr, }) return @staticmethod def _uninstall_module(module: torch.nn.Module): if 'forge_backup_original_class' in module.__dict__: module.__class__ = module.__dict__.pop('forge_backup_original_class') return @staticmethod def install_model(model: torch.nn.Module, target_device: torch.device): for m in model.modules(): DynamicSwapInstaller._install_module(m, target_device) return @staticmethod def uninstall_model(model: torch.nn.Module): for m in model.modules(): DynamicSwapInstaller._uninstall_module(m) return