|
""" |
|
Prints out the ratio of activation memory for the a transformer Block when using ReLU vs GELU. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import act_mem |
|
import layers |
|
|
|
if __name__ == "__main__": |
|
batch_size, seq_len, d_model, n_heads = 2, 4096, 1024, 2 |
|
dtype = torch.bfloat16 |
|
inputs = torch.randn( |
|
batch_size, |
|
seq_len, |
|
d_model, |
|
device="cuda", |
|
requires_grad=True, |
|
dtype=dtype, |
|
) |
|
|
|
act_fn_dict = {"ReLU": nn.ReLU(), "GELU": nn.GELU()} |
|
|
|
outputs = [] |
|
mem_bytes = [] |
|
|
|
for name, act_fn in act_fn_dict.items(): |
|
block = layers.Block( |
|
d_model=d_model, |
|
act_fn=act_fn, |
|
n_heads=n_heads, |
|
device="cuda", |
|
dtype=dtype, |
|
) |
|
with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext( |
|
ignored_tensors=block.parameters() |
|
) as saved: |
|
out = block(inputs) |
|
outputs.append(out) |
|
print(f"{name} block bytes: {saved.saved_tensor_mem}") |
|
mem_bytes.append(saved.saved_tensor_mem) |
|
|
|
print(f"ReLU/GeLU block act mem ratio: {mem_bytes[0]/mem_bytes[1]}") |
|
|