File size: 2,933 Bytes
86b55e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475bc5f
86b55e7
 
4d46abb
a002825
 
 
9e0a736
4d46abb
c4f69f6
475bc5f
86b55e7
 
 
 
 
 
 
475bc5f
86b55e7
 
c4f69f6
86b55e7
 
 
 
 
 
475bc5f
4d46abb
 
86b55e7
c4f69f6
 
 
86b55e7
c4f69f6
 
 
86b55e7
c4f69f6
86b55e7
 
 
9e0a736
 
86b55e7
 
 
 
 
 
 
 
475bc5f
 
4d46abb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st

# Define bit sizes for different quantization options
quantization_bit_sizes = {
    'float32': 32,
    'float16': 16,
    'Q2_K': 2,
    'Q3_K_L': 3,
    'Q3_K_M': 3,
    'Q3_K_S': 3,
    'Q4_0': 4,
    'Q4_1': 4,
    'Q4_K_M': 4,
    'Q4_K_S': 4,
    'Q5_0': 5,
    'Q5_1': 5,
    'Q5_K_M': 5,
    'Q5_K_S': 5,
    'Q6_K': 6,
    'Q8_0': 8
}

# Define precision options
precision_options = {
    'full': 4,
    'mixed': 6,
    'half': 2
}

# Streamlit app
st.title("Memory Usage Calculator for Large Language Models")



# Taken from "Reducing Activation Recomputation in Large Transformer Models" https://arxiv.org/abs/2205.05198
def calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision):
    # Convert bit size to byte size
    byte_size = quantization_bit_sizes[data_type] / 8

    # Memory usage for model parameters
    memory_params = parameter_count * byte_size

    # Memory usage for context (activations)
    activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision)

    # Total memory usage
    total_memory_usage = memory_params + activations

    # Convert bytes to gigabytes
    total_memory_usage_gb = total_memory_usage / (1024 ** 3)

    return total_memory_usage_gb

def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision):
    # Assuming square root relationship for hidden size
    hidden_dimensions = int(parameter_count ** 0.5)

    # Calculate activations based on the formula from the paper
    activations_per_layer = context_length * batch_size * hidden_dimensions * (34 + ((5 * attention_heads * context_length) / hidden_dimensions))
    activations = layers * activations_per_layer / 2  # divided by 2 as per the paper's calculation at 16bit precision

    # Convert activations to bytes based on the precision
    bytes_per_param = precision_options[precision] / 8
    total_activations = bytes_per_param * activations

    return total_activations

# User inputs
parameter_count = st.number_input("Parameter Count (in billions)", value=1, step=1) * 1e9
layers = st.number_input("Number of Layers", value=32, step=1)
attention_heads = st.number_input("Number of Attention Heads", value=32, step=1)
context_length = st.number_input("Context Length (number of tokens)", value=512, step=1)
data_type = st.selectbox("Data Type", options=list(quantization_bit_sizes.keys()))
batch_size = st.number_input("Batch Size", value=1, step=1)
vocab_size = st.number_input("Vocabulary Size", value=30000, step=1000)
precision = st.selectbox("Precision", options=list(precision_options.keys()))

# Calculate memory usage
if st.button("Calculate Memory Usage"):
    memory_usage = calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision)
    st.write(f"Estimated Memory Usage for Inference: {memory_usage:.2f} GB")