Hennara's picture
add utils file
6bf4672
raw
history blame
5.91 kB
def memory_for_attention_layer(precession: int,
seq_len: int,
batch_size: int,
hidden_size: int,
num_heads: int):
"""
head_dim = hidden_size // num_heads
Model Parameters:
q_proj: (hidden_size, num_heads * head_dim)
k_proj: (hidden_size, num_key_value_heads * head_dim)
v_proj: (hidden_size, num_key_value_heads * head_dim)
o_proj: (hidden_size, hidden_size)
Total parameters = 3 * hidden_size * num_heads * head_dim + hidden_size^2
Memory required for model parameters = (3 * hidden_size * num_heads * head_dim + hidden_size^2)
Gradients:
Gradients have the same size as the model parameters.
Memory required for gradients = (3 * hidden_size * num_heads * head_dim + hidden_size^2)
Optimizer States:
Assuming Adam optimizer with two states per parameter (momentum and variance).
Memory required for optimizer states = 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2)
Activations:
query_states: (batch_size, num_heads, q_len, head_dim)
key_states: (batch_size, num_key_value_heads, q_len, head_dim)
value_states: (batch_size, num_key_value_heads, q_len, head_dim)
attn_weights: (batch_size, num_heads, q_len, q_len)
attn_output: (batch_size, q_len, hidden_size)
Total activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)
Memory required for activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)
Temporary Memory:
Additional temporary memory for intermediate computations and buffer storage.
Assuming 20% of the total memory as temporary memory.
total_memory = (model_parameters + gradients + optimizer_states + activations) * (1 + temporary_memory_factor)
((3 * hidden_size * num_heads * head_dim + hidden_size^2) +
(3 * hidden_size * num_heads * head_dim + hidden_size^2) +
2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) +
batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)) * (1 + 0.2)
"""
head_dim = hidden_size // num_heads
# Model Memory (3 * hidden_size * num_heads * head_dim + hidden_size^2)
model_memory = 3 * hidden_size * num_heads * head_dim + hidden_size ** 2
# Gradients = model_memory
gradients = model_memory
# Optimizer
optimizer = 2 * model_memory
# Activation
activation = batch_size * (3 * num_heads * seq_len * head_dim +
num_heads * seq_len ** 2 +
seq_len * hidden_size
)
total_memory = (model_memory + gradients + optimizer + activation) * precession
return total_memory
def memory_mlp_layer(precession: int,
seq_len: int,
batch_size: int,
hidden_size: int,
intermediate_size: int):
"""
MLP model
gate_proj (hidden_size, intermediate_size)
up_proj (hidden_size, intermediate_size)
down_proj (intermediate_size, hidden_size)
Memory required for gate_proj weights = intermediate_size * hidden_size
Memory required for up_proj weights = intermediate_size * hidden_size
Memory required for down_proj weights = intermediate_size * hidden_size
model memory = 3 * (hidden_size * intermediate_size)
gradient = model_memory
optimizer = 2 * model_memory
activations = batch_size * seq_len * hidden_size + 2 * batch_size * seq_len * intermediate_size
total_memory = 3 * (hidden_size * intermediate_size) + 3 * (hidden_size * intermediate_size) + 6 * (hidden_size * intermediate_size) + batch_size * (2 * intermediate_size + hidden_size)
total_memory = (hidden_size * intermediate_size) * 12 + Batch_size * seq_len * (2 * intermediate_size + hidden_size)
Args:
hidden_size:
intermediate_size:
batch_size:
seq_len:
Returns:
"""
model_memory = 3 * (hidden_size * intermediate_size)
gradient = model_memory
optimizer = 2 * model_memory
activation = batch_size * seq_len * (2 * intermediate_size + hidden_size)
total_memory = (model_memory + gradient + hidden_size + activation) * precession
return total_memory
def memory_moe_mlp(precession: int,
seq_len: int,
batch_size: int,
hidden_size: int,
intermediate_size: int,
num_expert: int,
top_k: int):
# model memory
gat_memory = hidden_size * num_expert
# The result in byte
moe_mlp = memory_mlp_layer(precession, seq_len, batch_size, hidden_size, intermediate_size) * num_expert
# total model memory The result in byte
model_memory = gat_memory * precession + moe_mlp
# optimizer and gradient as before.
# activation
max_memory_activation = (
(batch_size * seq_len * num_expert * precession) + # Router logits
(batch_size * seq_len * top_k * precession) + # Routing weights
(batch_size * seq_len * top_k * precession) + # Selected experts
(batch_size * seq_len * hidden_size * precession) + # Final hidden states
(batch_size * seq_len * hidden_size * precession) + # Current state (worst-case)
(batch_size * seq_len * hidden_size * precession) # Current hidden states (worst-case)
)
total_memory = model_memory + model_memory + 2 * model_memory + max_memory_activation
return total_memory