|
from typing import Optional |
|
|
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
|
|
import act_mem |
|
import layers |
|
|
|
BATCH_SIZES = (1, 2) |
|
D_MODELS = (128, 256) |
|
SEQ_LENS = (64, 128) |
|
N_HEADS = (2, 4) |
|
|
|
|
|
DEVICES = ["cpu"] |
|
if torch.cuda.is_available(): |
|
DEVICES.append("cuda") |
|
|
|
|
|
ZERO_MEM_ACT_FNS = [ |
|
nn.ReLU(), |
|
nn.Sigmoid(), |
|
nn.Tanh(), |
|
nn.LeakyReLU(inplace=True), |
|
nn.Sigmoid(), |
|
] |
|
ALL_ACT_FNS = ZERO_MEM_ACT_FNS + [ |
|
nn.ELU(), |
|
nn.GELU(), |
|
nn.Hardshrink(), |
|
nn.Hardsigmoid(), |
|
nn.Hardswish(), |
|
nn.Hardtanh(), |
|
nn.LeakyReLU(), |
|
nn.SELU(), |
|
nn.SiLU(), |
|
] |
|
|
|
|
|
class TestSavedTensorContext: |
|
@pytest.mark.parametrize("device", DEVICES) |
|
@pytest.mark.parametrize("d_model", D_MODELS) |
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
|
def test_linear(self, device: str, d_model: int, batch_size: int) -> None: |
|
""" |
|
Test a simple linear layer. The inputs should be saved for backwards |
|
""" |
|
inputs = torch.randn(batch_size, d_model, requires_grad=True, device=device) |
|
lin = nn.Linear(d_model, d_model, device=device) |
|
with act_mem.SavedTensorContext(ignored_tensors=lin.parameters()) as saved: |
|
_ = lin(inputs) |
|
assert saved.saved_tensor_mem == inputs.numel() * inputs.element_size() |
|
|
|
@pytest.mark.parametrize("device", DEVICES) |
|
@pytest.mark.parametrize("d_model", D_MODELS) |
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
|
def test_linear_amp(self, device: str, d_model: int, batch_size: int) -> None: |
|
""" |
|
Test a linear layer with AMP. The saved tensors should now be a low-precision version of the |
|
inputs and the low-precision version of the weights version of the weights |
|
""" |
|
inputs = torch.randn(batch_size, d_model, requires_grad=True, device=device) |
|
lin = nn.Linear(d_model, d_model, device=device) |
|
dtype = torch.bfloat16 |
|
with torch.autocast(device_type=device, dtype=dtype): |
|
with act_mem.SavedTensorContext(ignored_tensors=lin.parameters()) as saved: |
|
out = lin(inputs) |
|
assert ( |
|
saved.saved_tensor_mem |
|
== out.numel() * out.element_size() + lin.weight.numel() * dtype.itemsize |
|
) |
|
|
|
@pytest.mark.parametrize("act_fn", ALL_ACT_FNS) |
|
@pytest.mark.parametrize("dropout_prob", (None, 0.5)) |
|
@pytest.mark.parametrize("device", DEVICES) |
|
@pytest.mark.parametrize("d_model", D_MODELS) |
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
|
@pytest.mark.parametrize("seq_len", SEQ_LENS) |
|
def test_mlp( |
|
self, |
|
act_fn: nn.Module, |
|
dropout_prob: Optional[float], |
|
device: str, |
|
d_model: int, |
|
batch_size: int, |
|
seq_len: int, |
|
) -> None: |
|
""" |
|
For the transformer MLP layer with a ReLU non-linearity, the initial inputs and the inputs |
|
to the final linear layer (which are four times as large) must always be saved. If the |
|
derivative of the activation function cannot be expressed in terms of the activation |
|
function's *outputs*, then the activation inputs must also be saved (which are again four |
|
times as large as the MLP's inputs). The MLP activation memory can be nearly halved by a |
|
choice of activation function. |
|
""" |
|
inputs = torch.randn( |
|
batch_size, seq_len, d_model, requires_grad=True, device=device |
|
) |
|
expansion_factor = 4 |
|
mlp = layers.MLP( |
|
d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, device=device |
|
) |
|
with act_mem.SavedTensorContext(ignored_tensors=mlp.parameters()) as saved: |
|
_ = mlp(inputs) |
|
|
|
|
|
first_lin_input_mem = act_mem.get_tensor_bytes(inputs) |
|
second_lin_input_mem = expansion_factor * first_lin_input_mem |
|
|
|
activation_input_mem = 0 if act_fn in ZERO_MEM_ACT_FNS else second_lin_input_mem |
|
dropout_act_mem = ( |
|
0 if not dropout_prob else inputs.numel() * (4 if device == "cpu" else 1) |
|
) |
|
|
|
expected_mem = ( |
|
first_lin_input_mem |
|
+ second_lin_input_mem |
|
+ activation_input_mem |
|
+ dropout_act_mem |
|
) |
|
assert saved.saved_tensor_mem == expected_mem |
|
|
|
@pytest.mark.parametrize("act_fn", ALL_ACT_FNS) |
|
@pytest.mark.parametrize("dropout_prob", (None, 0.5)) |
|
@pytest.mark.parametrize("device", DEVICES) |
|
@pytest.mark.parametrize("d_model", D_MODELS) |
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
|
@pytest.mark.parametrize("seq_len", SEQ_LENS) |
|
def test_mlp_amp( |
|
self, |
|
act_fn: nn.Module, |
|
dropout_prob: Optional[float], |
|
device: str, |
|
d_model: int, |
|
batch_size: int, |
|
seq_len: int, |
|
) -> None: |
|
""" |
|
Similar story with AMP. The only changes come from the modified dtypes and needing to also |
|
save references to the low-precision weights in the Linear layers. |
|
""" |
|
inputs = torch.randn( |
|
batch_size, seq_len, d_model, requires_grad=True, device=device |
|
) |
|
expansion_factor = 4 |
|
mlp = layers.MLP( |
|
d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, device=device |
|
) |
|
dtype = torch.bfloat16 |
|
with torch.autocast(device_type=device, dtype=dtype): |
|
with act_mem.SavedTensorContext(ignored_tensors=mlp.parameters()) as saved: |
|
_ = mlp(inputs) |
|
|
|
|
|
amp_weight_mem = 2 * expansion_factor * d_model**2 * dtype.itemsize |
|
first_lin_input_mem = inputs.numel() * dtype.itemsize |
|
second_lin_input_mem = expansion_factor * inputs.numel() * dtype.itemsize |
|
|
|
activation_input_mem = 0 if act_fn in ZERO_MEM_ACT_FNS else second_lin_input_mem |
|
dropout_act_mem = ( |
|
0 |
|
if not dropout_prob |
|
else inputs.numel() * (dtype.itemsize if device == "cpu" else 1) |
|
) |
|
|
|
expected_mem = ( |
|
amp_weight_mem |
|
+ first_lin_input_mem |
|
+ second_lin_input_mem |
|
+ activation_input_mem |
|
+ dropout_act_mem |
|
) |
|
assert ( |
|
saved.saved_tensor_mem == expected_mem |
|
), f"Failed on {act_fn=}, {dropout_prob=}" |
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not available") |
|
class TestCUDAMemReadings: |
|
@pytest.mark.parametrize("d_model", D_MODELS) |
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
|
@pytest.mark.parametrize("seq_len", SEQ_LENS) |
|
@pytest.mark.parametrize("act_fn", ALL_ACT_FNS) |
|
def test_mlp( |
|
self, d_model: int, batch_size: int, seq_len: int, act_fn: nn.Module |
|
) -> None: |
|
""" |
|
Track saved tensors and allocated memory and verify they agree. |
|
""" |
|
|
|
inputs = torch.randn(batch_size, seq_len, d_model, device="cuda") |
|
mlp = layers.MLP(d_model=d_model, act_fn=act_fn, device="cuda") |
|
|
|
with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext( |
|
ignored_tensors=mlp.parameters() |
|
) as saved: |
|
outputs = mlp(inputs) |
|
|
|
|
|
|
|
|
|
assert mem.delta["current"] == saved.saved_tensor_mem |
|
|