ImageConductor / peft /tuners /vera /buffer_dict.py
Yw22's picture
init demo
d711508
raw
history blame
5.53 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html
# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands
from __future__ import annotations
import collections
from collections import OrderedDict
import torch
from torch.nn import Module
class BufferDict(Module):
r"""
Holds buffers in a dictionary.
BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and
will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the
argument to `torch.nn.BufferDict.update`).
Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not
preserve the order of the merged mapping.
Args:
buffers (iterable, optional):
a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string,
`torch.Tensor`)
```python
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)})
def forward(self, x, choice):
x = self.buffers[choice].mm(x)
return x
```
"""
def __init__(self, buffers=None, persistent: bool = False):
r"""
Args:
buffers (`dict`):
A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type
(string, `torch.Tensor`).
"""
super().__init__()
if buffers is not None:
self.update(buffers)
self.persistent = persistent
def __getitem__(self, key):
return self._buffers[key]
def __setitem__(self, key, buffer):
self.register_buffer(key, buffer, persistent=self.persistent)
def __delitem__(self, key):
del self._buffers[key]
def __len__(self):
return len(self._buffers)
def __iter__(self):
return iter(self._buffers.keys())
def __contains__(self, key):
return key in self._buffers
def clear(self):
"""Remove all items from the BufferDict."""
self._buffers.clear()
def pop(self, key):
r"""Remove key from the BufferDict and return its buffer.
Args:
key (`str`):
Key to pop from the BufferDict
"""
v = self[key]
del self[key]
return v
def keys(self):
r"""Return an iterable of the BufferDict keys."""
return self._buffers.keys()
def items(self):
r"""Return an iterable of the BufferDict key/value pairs."""
return self._buffers.items()
def values(self):
r"""Return an iterable of the BufferDict values."""
return self._buffers.values()
def update(self, buffers):
r"""
Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing
keys.
Note:
If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of
new elements in it is preserved.
Args:
buffers (iterable):
a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type
(string, `torch.Tensor`).
"""
if not isinstance(buffers, collections.abc.Iterable):
raise TypeError(
"BuffersDict.update should be called with an "
"iterable of key/value pairs, but got " + type(buffers).__name__
)
if isinstance(buffers, collections.abc.Mapping):
if isinstance(buffers, (OrderedDict, BufferDict)):
for key, buffer in buffers.items():
self[key] = buffer
else:
for key, buffer in sorted(buffers.items()):
self[key] = buffer
else:
for j, p in enumerate(buffers):
if not isinstance(p, collections.abc.Iterable):
raise TypeError(
"BufferDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
if not len(p) == 2:
raise ValueError(
"BufferDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
self[p[0]] = p[1]
def extra_repr(self):
child_lines = []
for k, p in self._buffers.items():
size_str = "x".join(str(size) for size in p.size())
device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})"
parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]"
child_lines.append(" (" + k + "): " + parastr)
tmpstr = "\n".join(child_lines)
return tmpstr
def __call__(self, input):
raise RuntimeError("BufferDict should not be called.")