|
import torch |
|
|
|
import act_mem |
|
import layers |
|
|
|
if __name__ == "__main__": |
|
batch_size, seq_len, d_model, n_heads = 1, 128, 1024, 32 |
|
print(f"Batch size: {batch_size}, sequence length: {seq_len}, d_model: {d_model}, n_heads: {n_heads}") |
|
dtype = torch.bfloat16 |
|
inputs = torch.randn( |
|
batch_size, |
|
seq_len, |
|
d_model, |
|
device="cuda", |
|
requires_grad=True, |
|
dtype=dtype, |
|
) |
|
|
|
attn = layers.Attention( |
|
d_model=d_model, |
|
n_heads=n_heads, |
|
device="cuda", |
|
dtype=dtype, |
|
) |
|
with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext( |
|
ignored_tensors=attn.parameters() |
|
) as saved: |
|
out = attn(inputs) |
|
stm = saved.saved_tensor_mem |
|
print(f'{mem.delta["current"]=}') |
|
print(f"{stm=}") |
|
print(f"{stm/out.numel()=}") |
|
|