File size: 3,793 Bytes
d3caa11
 
8e6a186
d3caa11
 
 
 
 
91c2486
d3caa11
 
 
846299c
d3caa11
 
 
 
846299c
 
d3caa11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6a186
 
 
 
 
 
 
 
 
d3caa11
 
 
 
 
 
 
 
 
846299c
 
 
 
d3caa11
846299c
d3caa11
8e6a186
d3caa11
 
846299c
d3caa11
 
 
 
 
 
 
 
 
 
 
8e6a186
 
 
 
d3caa11
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import csv
import math

def generate_ratios(max_layers, max_width):
    ratios = []
    for i in range(1, 9):  # 1/8 steps from 1:1 to 1:3 and vice versa
        ratio = i / 8
        if ratio <= 6:  # Ensure the ratio is not greater than 3
            ratios.append(ratio)
    return ratios

def generate_experiments(max_layers, max_width, min_layers=1, min_width=1):
    experiments = []
    ratios = generate_ratios(max_layers, max_width)

    for ratio in ratios:
        for layers in range(min_layers, max_layers + 1):
            width = max(int(max_width * ratio), min_width)
            experiments.append((layers, width))

    return experiments

def estimate_vram(layer_count, width, input_size, output_size):
    # Calculate the number of parameters
    # Each layer has (input_size * width) + width parameters (weights + biases)
    # The last layer has (width * output_size) + output_size parameters
    param_count = 0
    for i in range(layer_count):
        if i == 0:
            param_count += (input_size * width) + width
        else:
            param_count += (width * width) + width
    param_count += (width * output_size) + output_size

    # Estimate the VRAM usage
    # Parameters: 4 bytes per parameter (FP32)
    # Activations: Assume the size of the activations is the same as the input size
    vram_usage = param_count * 4 + input_size * 4

    return vram_usage

def calculate_batch_size(memory_gb=20):
    memory_bytes = memory_gb * (1024 ** 3)  # Convert GiB to bytes
    batch_memory_bytes = memory_bytes / 4  # Divide by 4
    
    # Find the nearest power of 2
    batch_size = 2 ** int(math.log2(batch_memory_bytes))
    
    return batch_size

def write_csv(experiments, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['layer_count', 'width', 'vram_usage'])
        for experiment in experiments:
            writer.writerow(experiment)

def main():
    parser = argparse.ArgumentParser(description='Generate a CSV file with a variety of layer counts and widths.')
    parser.add_argument('--max_layers', type=int, default=72, help='Maximum number of layers (default: 72)')
    parser.add_argument('--max_width', type=int, default=4096, help='Maximum width (default: 4096)')
    parser.add_argument('--min_layers', type=int, default=1, help='Minimum number of layers (default: 1)')
    parser.add_argument('--min_width', type=int, default=1, help='Minimum width (default: 1)')
    parser.add_argument('--output_file', type=str, default='experiments.csv', help='Output CSV file (default: experiments.csv)')
    parser.add_argument('--input_size', type=int, default=64*64*3, help='Input size (default: 64*64*3)')
    parser.add_argument('--output_size', type=int, default=10, help='Output size (default: 10)')
    parser.add_argument('--memory_gb', type=int, default=20, help='Total memory in GiB (default: 20)')
    args = parser.parse_args()

    experiments = generate_experiments(args.max_layers, args.max_width, args.min_layers, args.min_width)
    experiments_with_vram = []

    for experiment in experiments:
        layer_count, width = experiment
        vram_usage = estimate_vram(layer_count, width, args.input_size, args.output_size)
        experiments_with_vram.append((layer_count, width, vram_usage))
        print(f'Layer Count: {layer_count}, Width: {width}, Estimated VRAM Usage: {vram_usage} bytes')

    write_csv(experiments_with_vram, args.output_file)
    print(f'Generated {len(experiments_with_vram)} experiments and saved to {args.output_file}')

    # Calculate and print the batch size
    batch_size = calculate_batch_size(args.memory_gb)
    print(f'Recommended batch size: {batch_size}')

if __name__ == '__main__':
    main()