# Copyright (c) OpenMMLab. All rights reserved. import os from collections import defaultdict from contextlib import contextmanager from itertools import chain from typing import Dict, List, Optional, Union import torch import torch.nn as nn from mmpretrain.utils import require @require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/') @require('accelerate') def dispatch_model( model, device_map: Union[str, dict], max_memory: Optional[dict] = None, no_split_module_classes: Optional[List[str]] = None, offload_folder: str = None, offload_buffers: bool = False, preload_module_classes: Optional[List[str]] = None, ): """Split and dispatch a model across devices. The function depends on the `accelerate` package. Refers to https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling Args: model (torch.nn.Module): The model to dispatch. device_map (str | dict | None): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. You can use `device_map="auto"` to automatically generate the device map. Defaults to None. max_memory (dict | None): A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. Defaults to None. no_split_module_classes (List[str] | None): A list of layer class names that should never be split across device (for instance any layer that has a residual connection). If None, try to get the settings from the model class. Defaults to None. offload_folder (str | None): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_buffers (bool): In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. Defaults to False. preload_module_classes (List[str] | None): A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. Defaults to None. """ from accelerate import dispatch_model, infer_auto_device_map # Check valid device_map string. valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential'] if isinstance(device_map, str) and device_map not in valid_map_option: raise ValueError('If passing a string for `device_map`, please choose ' f'from {valid_map_option}.') # Generate device map automatically if isinstance(device_map, str): if no_split_module_classes is None: no_split_module_classes = getattr(model, '_no_split_modules', None) if no_split_module_classes is None: raise ValueError(f'{model.__class__.__name__} does not support ' f"`device_map='{device_map}'` yet.") if device_map != 'sequential': from accelerate.utils import get_balanced_memory max_memory = get_balanced_memory( model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=None, low_zero=(device_map == 'balanced_low_0'), ) max_memory[0] *= 0.9 device_map = infer_auto_device_map( model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=None, ) if 'disk' in device_map.values(): if offload_folder is None: raise ValueError( 'The current `device_map` had weights offloaded to the disk. ' 'Please provide an `offload_folder` for them.') os.makedirs(offload_folder, exist_ok=True) main_device = next( (d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu') model = dispatch_model( model, device_map=device_map, main_device=main_device, offload_dir=offload_folder, offload_buffers=offload_buffers, preload_module_classes=preload_module_classes, ) if hasattr(model, 'data_preprocessor'): model.data_preprocessor._device = torch.device(main_device) return model @contextmanager def init_empty_weights(include_buffers: bool = False): """A context manager under which models are initialized with all parameters on the meta device. With this context manager, we can create an empty model. Useful when just initializing the model would blow the available RAM. Besides move the parameters to meta device, this method will also avoid load checkpoint from `mmengine.runner.load_checkpoint` and `transformers.PreTrainedModel.from_pretrained`. Modified from https://github.com/huggingface/accelerate Args: include_buffers (bool): Whether put all buffers on the meta device during initialization. """ device = torch.device('meta') # move parameter and buffer to meta device old_register_parameter = nn.Module.register_parameter if include_buffers: old_register_buffer = nn.Module.register_buffer # See https://github.com/huggingface/accelerate/pull/699 tensor_constructors_to_patch = { torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full'] } def register_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ module._parameters[name] = param_cls( module._parameters[name].to(device), **kwargs) def register_buffer(module, name, buffer, *args, **kwargs): old_register_buffer(module, name, buffer, *args, **kwargs) if buffer is not None: module._buffers[name] = module._buffers[name].to(device) def patch_tensor_constructor(fn): def wrapper(*args, **kwargs): kwargs['device'] = device return fn(*args, **kwargs) return wrapper # Patch load_checkpoint import mmengine.runner.checkpoint as mmengine_load old_load_checkpoint = mmengine_load.load_checkpoint def patch_load_checkpoint(*args, **kwargs): return {} # Patch transformers from pretrained try: from transformers import PreTrainedModel from transformers.models.auto.auto_factory import (AutoConfig, _BaseAutoModelClass) with_transformers = True except ImportError: with_transformers = False @classmethod def patch_auto_model(cls, pretrained_model_name_or_path, *model_args, **kwargs): cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return cls.from_config(cfg) @classmethod def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args, **kwargs): cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return cls(cfg) if with_transformers: old_pretrained_model = PreTrainedModel.from_pretrained old_auto_model = _BaseAutoModelClass.from_pretrained try: nn.Module.register_parameter = register_parameter mmengine_load.load_checkpoint = patch_load_checkpoint if with_transformers: PreTrainedModel.from_pretrained = patch_pretrained_model _BaseAutoModelClass.from_pretrained = patch_auto_model if include_buffers: nn.Module.register_buffer = register_buffer for func in tensor_constructors_to_patch.keys(): tensor_constructor = patch_tensor_constructor( getattr(torch, func)) setattr(torch, func, tensor_constructor) yield finally: nn.Module.register_parameter = old_register_parameter mmengine_load.load_checkpoint = old_load_checkpoint if with_transformers: PreTrainedModel.from_pretrained = old_pretrained_model _BaseAutoModelClass.from_pretrained = old_auto_model if include_buffers: nn.Module.register_buffer = old_register_buffer for func, ori in tensor_constructors_to_patch.items(): setattr(torch, func, ori) def compute_module_sizes( model: nn.Module, dtype: Union[str, torch.dtype, None] = None, special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None): """Compute the size of each submodule of a given model.""" def get_dtype(dtype): if isinstance(dtype, str): dtype = getattr(torch, dtype) if dtype is not None: assert issubclass(dtype, torch.dtype) return dtype def dtype_bytes(dtype: torch.dtype): if dtype is torch.bool: return 1 if dtype.is_floating_point: return torch.finfo(dtype).bits / 8 else: return torch.iinfo(dtype).bits / 8 if dtype is not None: dtype = get_dtype(dtype) dtype_size = dtype_bytes(dtype) if special_dtypes is not None: special_dtypes = { key: dtype_bytes(dtype) for key, dtype in special_dtypes.items() } module_sizes = defaultdict(int) for name, tensor in chain( model.named_parameters(recurse=True), model.named_buffers(recurse=True)): if special_dtypes is not None and name in special_dtypes: size = tensor.numel() * special_dtypes[name] elif dtype is None: size = tensor.numel() * tensor.element_size() else: size = tensor.numel() * min(dtype_size, tensor.element_size()) name_parts = name.split('.') for idx in range(len(name_parts) + 1): module_sizes['.'.join(name_parts[:idx])] += size return module_sizes