TeacherPuffy commited on
Commit
8e6a186
·
verified ·
1 Parent(s): 30e65d2

Update create_experiments.py

Browse files
Files changed (1) hide show
  1. create_experiments.py +15 -0
create_experiments.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  import csv
 
3
 
4
  def generate_ratios(max_layers, max_width):
5
  ratios = []
@@ -39,6 +40,15 @@ def estimate_vram(layer_count, width, input_size, output_size):
39
 
40
  return vram_usage
41
 
 
 
 
 
 
 
 
 
 
42
  def write_csv(experiments, filename):
43
  with open(filename, 'w', newline='') as csvfile:
44
  writer = csv.writer(csvfile)
@@ -55,6 +65,7 @@ def main():
55
  parser.add_argument('--output_file', type=str, default='experiments.csv', help='Output CSV file (default: experiments.csv)')
56
  parser.add_argument('--input_size', type=int, default=64*64*3, help='Input size (default: 64*64*3)')
57
  parser.add_argument('--output_size', type=int, default=10, help='Output size (default: 10)')
 
58
  args = parser.parse_args()
59
 
60
  experiments = generate_experiments(args.max_layers, args.max_width, args.min_layers, args.min_width)
@@ -69,5 +80,9 @@ def main():
69
  write_csv(experiments_with_vram, args.output_file)
70
  print(f'Generated {len(experiments_with_vram)} experiments and saved to {args.output_file}')
71
 
 
 
 
 
72
  if __name__ == '__main__':
73
  main()
 
1
  import argparse
2
  import csv
3
+ import math
4
 
5
  def generate_ratios(max_layers, max_width):
6
  ratios = []
 
40
 
41
  return vram_usage
42
 
43
+ def calculate_batch_size(memory_gb=20):
44
+ memory_bytes = memory_gb * (1024 ** 3) # Convert GiB to bytes
45
+ batch_memory_bytes = memory_bytes / 4 # Divide by 4
46
+
47
+ # Find the nearest power of 2
48
+ batch_size = 2 ** int(math.log2(batch_memory_bytes))
49
+
50
+ return batch_size
51
+
52
  def write_csv(experiments, filename):
53
  with open(filename, 'w', newline='') as csvfile:
54
  writer = csv.writer(csvfile)
 
65
  parser.add_argument('--output_file', type=str, default='experiments.csv', help='Output CSV file (default: experiments.csv)')
66
  parser.add_argument('--input_size', type=int, default=64*64*3, help='Input size (default: 64*64*3)')
67
  parser.add_argument('--output_size', type=int, default=10, help='Output size (default: 10)')
68
+ parser.add_argument('--memory_gb', type=int, default=20, help='Total memory in GiB (default: 20)')
69
  args = parser.parse_args()
70
 
71
  experiments = generate_experiments(args.max_layers, args.max_width, args.min_layers, args.min_width)
 
80
  write_csv(experiments_with_vram, args.output_file)
81
  print(f'Generated {len(experiments_with_vram)} experiments and saved to {args.output_file}')
82
 
83
+ # Calculate and print the batch size
84
+ batch_size = calculate_batch_size(args.memory_gb)
85
+ print(f'Recommended batch size: {batch_size}')
86
+
87
  if __name__ == '__main__':
88
  main()