|
from typing import Any, Iterable, Optional, Union |
|
|
|
import torch |
|
|
|
|
|
def B_to_GiB(bytes: Union[int, float]) -> float: |
|
return bytes / 2**30 |
|
|
|
|
|
def get_tensor_bytes(tensor: torch.Tensor) -> int: |
|
""" |
|
Returns the bytes of storage a given tensor takes up. If `tensor` is a view of a larger tensor, |
|
this function only returns the bytes associated with the view. |
|
""" |
|
tensor_bytes = tensor.numel() * tensor.element_size() |
|
return tensor_bytes |
|
|
|
|
|
class AllocatedMemContext: |
|
""" |
|
Context manager which captures the allocated GPU memory at context exit and the change between |
|
enter and exit. |
|
|
|
Only includes `allocated_bytes.all.`-prefixed keys in `memory_stats` with all readings converted |
|
to GiB. |
|
|
|
Example: |
|
|
|
```python |
|
|
|
``` |
|
""" |
|
|
|
def __init__(self) -> None: |
|
|
|
torch.cuda.current_blas_handle() |
|
|
|
self.before: dict[str, int] = {} |
|
self.after: dict[str, int] = {} |
|
self.delta: dict[str, int] = {} |
|
|
|
self._mem_key_prefix = "allocated_bytes.all." |
|
|
|
def _get_mem_dict(self) -> dict[str, int]: |
|
return { |
|
k.replace(self._mem_key_prefix, ""): v |
|
for k, v in torch.cuda.memory_stats().items() |
|
if self._mem_key_prefix in k |
|
} |
|
|
|
def __enter__(self) -> "AllocatedMemContext": |
|
self.before = self._get_mem_dict() |
|
return self |
|
|
|
def __exit__(self, *args: Any, **kwargs: Any) -> None: |
|
self.after = self._get_mem_dict() |
|
self.delta = {k: v - self.before[k] for k, v in self.after.items()} |
|
|
|
|
|
class SavedTensorContext: |
|
""" |
|
Context manager which captures all tensors which are registered as being saved for backwards |
|
within the context window. Does not work with `meta`-device tensors. |
|
|
|
All saved tensors are stored in the `saved_tensor_dict` attr, which is an instance of torch's |
|
WeakTensorKeyDictionary with tensor/data_ptr key/value pairs. Some of these tensors may be |
|
views of the same underlying storage. The total memory of all saved tensors in bytes, accounting |
|
for redundant views, can be accessed through `saved_tensor_mem`. |
|
|
|
Use: |
|
``` |
|
model = ... |
|
with SavedTensorContext(ignored_tensors=model.parameters()) as saved: |
|
# Do some computation with `model` and capture saved tensors which are not model weights |
|
|
|
``` |
|
saved.saved_tensor_dict # WeakTensorKeyDictionary of all saved tensors. |
|
saved.saved_tensor_mem # bytes from all saved tensors (activation memory). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
ignored_tensors: Optional[Iterable[torch.Tensor]] = None, |
|
) -> None: |
|
|
|
|
|
self._ignored_data_ptrs = ( |
|
set() |
|
if ignored_tensors is None |
|
else {t.untyped_storage().data_ptr() for t in ignored_tensors} |
|
) |
|
|
|
|
|
|
|
self.saved_tensor_dict = torch.utils.weak.WeakTensorKeyDictionary() |
|
|
|
def pack_hook(saved_tensor: torch.Tensor) -> torch.Tensor: |
|
data_ptr = saved_tensor.untyped_storage().data_ptr() |
|
if data_ptr not in self._ignored_data_ptrs: |
|
self.saved_tensor_dict[saved_tensor] = data_ptr |
|
return saved_tensor |
|
|
|
def unpack_hook(saved_tensor: torch.Tensor) -> torch.Tensor: |
|
return saved_tensor |
|
|
|
self._saved_tensors_hook = torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook) |
|
|
|
def __enter__(self) -> "SavedTensorContext": |
|
self._saved_tensors_hook.__enter__() |
|
return self |
|
|
|
def __exit__(self, *args: Any, **kwargs: Any) -> None: |
|
self._saved_tensors_hook.__exit__(*args, **kwargs) |
|
|
|
@property |
|
def saved_tensor_mem(self) -> int: |
|
""" |
|
The memory in bytes of all saved tensors, accounting for views into the same storage. |
|
""" |
|
accounted_for = self._ignored_data_ptrs.copy() |
|
total_bytes = 0 |
|
for t in self.saved_tensor_dict: |
|
data_ptr = t.untyped_storage().data_ptr() |
|
if data_ptr not in accounted_for: |
|
print(f"Tensor ptr: {t.untyped_storage().data_ptr()}, " |
|
f"shape: {t.shape}, " |
|
f"dtype: {t.dtype}, " |
|
f"device: {t.device}" |
|
) |
|
total_bytes += t.untyped_storage().nbytes() |
|
accounted_for.add(data_ptr) |
|
return total_bytes |
|
|