Spaces:
Runtime error
Runtime error
def memory_for_attention_layer(precession: int, | |
seq_len: int, | |
batch_size: int, | |
hidden_size: int, | |
num_heads: int): | |
""" | |
head_dim = hidden_size // num_heads | |
Model Parameters: | |
q_proj: (hidden_size, num_heads * head_dim) | |
k_proj: (hidden_size, num_key_value_heads * head_dim) | |
v_proj: (hidden_size, num_key_value_heads * head_dim) | |
o_proj: (hidden_size, hidden_size) | |
Total parameters = 3 * hidden_size * num_heads * head_dim + hidden_size^2 | |
Memory required for model parameters = (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
Gradients: | |
Gradients have the same size as the model parameters. | |
Memory required for gradients = (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
Optimizer States: | |
Assuming Adam optimizer with two states per parameter (momentum and variance). | |
Memory required for optimizer states = 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
Activations: | |
query_states: (batch_size, num_heads, q_len, head_dim) | |
key_states: (batch_size, num_key_value_heads, q_len, head_dim) | |
value_states: (batch_size, num_key_value_heads, q_len, head_dim) | |
attn_weights: (batch_size, num_heads, q_len, q_len) | |
attn_output: (batch_size, q_len, hidden_size) | |
Total activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size) | |
Memory required for activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size) | |
Temporary Memory: | |
Additional temporary memory for intermediate computations and buffer storage. | |
Assuming 20% of the total memory as temporary memory. | |
total_memory = (model_parameters + gradients + optimizer_states + activations) * (1 + temporary_memory_factor) | |
((3 * hidden_size * num_heads * head_dim + hidden_size^2) + | |
(3 * hidden_size * num_heads * head_dim + hidden_size^2) + | |
2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) + | |
batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)) * (1 + 0.2) | |
""" | |
head_dim = hidden_size // num_heads | |
# Model Memory (3 * hidden_size * num_heads * head_dim + hidden_size^2) | |
model_memory = 3 * hidden_size * num_heads * head_dim + hidden_size ** 2 | |
# Gradients = model_memory | |
gradients = model_memory | |
# Optimizer | |
optimizer = 2 * model_memory | |
# Activation | |
activation = batch_size * (3 * num_heads * seq_len * head_dim + | |
num_heads * seq_len ** 2 + | |
seq_len * hidden_size | |
) | |
total_memory = (model_memory + gradients + optimizer + activation) * precession | |
return total_memory | |
def memory_mlp_layer(precession: int, | |
seq_len: int, | |
batch_size: int, | |
hidden_size: int, | |
intermediate_size: int): | |
""" | |
MLP model | |
gate_proj (hidden_size, intermediate_size) | |
up_proj (hidden_size, intermediate_size) | |
down_proj (intermediate_size, hidden_size) | |
Memory required for gate_proj weights = intermediate_size * hidden_size | |
Memory required for up_proj weights = intermediate_size * hidden_size | |
Memory required for down_proj weights = intermediate_size * hidden_size | |
model memory = 3 * (hidden_size * intermediate_size) | |
gradient = model_memory | |
optimizer = 2 * model_memory | |
activations = batch_size * seq_len * hidden_size + 2 * batch_size * seq_len * intermediate_size | |
total_memory = 3 * (hidden_size * intermediate_size) + 3 * (hidden_size * intermediate_size) + 6 * (hidden_size * intermediate_size) + batch_size * (2 * intermediate_size + hidden_size) | |
total_memory = (hidden_size * intermediate_size) * 12 + Batch_size * seq_len * (2 * intermediate_size + hidden_size) | |
Args: | |
hidden_size: | |
intermediate_size: | |
batch_size: | |
seq_len: | |
Returns: | |
""" | |
model_memory = 3 * (hidden_size * intermediate_size) | |
gradient = model_memory | |
optimizer = 2 * model_memory | |
activation = batch_size * seq_len * (2 * intermediate_size + hidden_size) | |
total_memory = (model_memory + gradient + hidden_size + activation) * precession | |
return total_memory | |
def memory_moe_mlp(precession: int, | |
seq_len: int, | |
batch_size: int, | |
hidden_size: int, | |
intermediate_size: int, | |
num_expert: int, | |
top_k: int): | |
# model memory | |
gat_memory = hidden_size * num_expert | |
# The result in byte | |
moe_mlp = memory_mlp_layer(precession, seq_len, batch_size, hidden_size, intermediate_size) * num_expert | |
# total model memory The result in byte | |
model_memory = gat_memory * precession + moe_mlp | |
# optimizer and gradient as before. | |
# activation | |
max_memory_activation = ( | |
(batch_size * seq_len * num_expert * precession) + # Router logits | |
(batch_size * seq_len * top_k * precession) + # Routing weights | |
(batch_size * seq_len * top_k * precession) + # Selected experts | |
(batch_size * seq_len * hidden_size * precession) + # Final hidden states | |
(batch_size * seq_len * hidden_size * precession) + # Current state (worst-case) | |
(batch_size * seq_len * hidden_size * precession) # Current hidden states (worst-case) | |
) | |
total_memory = model_memory + model_memory + 2 * model_memory + max_memory_activation | |
return total_memory | |