Julien Blanchon commited on
Commit
28e3661
·
1 Parent(s): 548985d

Remove wandb

Browse files
Files changed (1) hide show
  1. tim/utils/misc_utils.py +19 -17
tim/utils/misc_utils.py CHANGED
@@ -1,7 +1,6 @@
1
  import functools
2
  import importlib
3
  import os
4
- import wandb
5
  import fsspec
6
  import numpy as np
7
  import torch
@@ -13,12 +12,13 @@ from PIL import Image, ImageDraw, ImageFont
13
  from safetensors.torch import load_file
14
  from tqdm import tqdm
15
 
 
16
  def create_npz_from_sample_folder(sample_dir, num=50_000):
17
  """
18
  Builds a single .npz file from a folder of .png samples.
19
  """
20
  samples = []
21
- imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split('.')[0]))
22
  print(len(imgs))
23
  assert len(imgs) >= num
24
  for i in tqdm(range(num), desc="Building .npz file from samples"):
@@ -32,14 +32,13 @@ def create_npz_from_sample_folder(sample_dir, num=50_000):
32
  print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
33
  return npz_path
34
 
35
- def init_from_ckpt(
36
- model, checkpoint_dir, ignore_keys=None, verbose=False
37
- ) -> None:
38
  if checkpoint_dir.endswith(".safetensors"):
39
- model_state_dict=load_file(checkpoint_dir, device='cpu')
40
  else:
41
- model_state_dict=torch.load(checkpoint_dir, map_location="cpu")
42
- model_new_ckpt=dict()
43
  for i in model_state_dict.keys():
44
  model_new_ckpt[i] = model_state_dict[i]
45
  keys = list(model_new_ckpt.keys())
@@ -63,14 +62,14 @@ def init_from_ckpt(
63
 
64
 
65
  def get_dtype(str_dtype):
66
- if str_dtype == 'fp16':
67
  return torch.float16
68
- elif str_dtype == 'bf16':
69
  return torch.bfloat16
70
  else:
71
  return torch.float32
72
-
73
-
74
  def disabled_train(self, mode=True):
75
  """Overwrite model.train with this function to make sure train/eval mode
76
  does not change anymore."""
@@ -221,12 +220,12 @@ def mean_flat(tensor):
221
  def count_params(model, verbose=False):
222
  total_params = sum(p.numel() for p in model.parameters())
223
  if verbose:
224
- print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
225
  return total_params
226
 
227
 
228
  def instantiate_from_config(config):
229
- if not "target" in config:
230
  if config == "__is_first_stage__":
231
  return None
232
  elif config == "__is_unconditional__":
@@ -295,7 +294,8 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
295
  def format_number(num):
296
  num = float(num)
297
  num /= 1000.0
298
- return '{:.0f}{}'.format(num, 'k')
 
299
 
300
  def get_num_params(model: torch.nn.ModuleList) -> int:
301
  num_params = sum(p.numel() for p in model.parameters())
@@ -319,13 +319,14 @@ def get_num_flop_per_token(num_params, model_config, seq_len) -> int:
319
 
320
  return flop_per_token
321
 
 
322
  def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:
323
  l, h, q = (
324
  model_config.n_layers,
325
  model_config.n_heads,
326
  model_config.dim // model_config.n_heads,
327
  )
328
-
329
  # 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)
330
  # 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)
331
  # 3. 双向注意力需要考虑所有token对,所以是t^2而不是t
@@ -351,6 +352,7 @@ def get_peak_flops(device_name: str) -> int:
351
  else: # for other GPU types, assume A100
352
  return 312e12
353
 
 
354
  @dataclass(frozen=True)
355
  class Color:
356
  black = "\033[30m"
@@ -374,4 +376,4 @@ class NoColor:
374
  magenta = ""
375
  cyan = ""
376
  white = ""
377
- reset = ""
 
1
  import functools
2
  import importlib
3
  import os
 
4
  import fsspec
5
  import numpy as np
6
  import torch
 
12
  from safetensors.torch import load_file
13
  from tqdm import tqdm
14
 
15
+
16
  def create_npz_from_sample_folder(sample_dir, num=50_000):
17
  """
18
  Builds a single .npz file from a folder of .png samples.
19
  """
20
  samples = []
21
+ imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split(".")[0]))
22
  print(len(imgs))
23
  assert len(imgs) >= num
24
  for i in tqdm(range(num), desc="Building .npz file from samples"):
 
32
  print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
33
  return npz_path
34
 
35
+
36
+ def init_from_ckpt(model, checkpoint_dir, ignore_keys=None, verbose=False) -> None:
 
37
  if checkpoint_dir.endswith(".safetensors"):
38
+ model_state_dict = load_file(checkpoint_dir, device="cpu")
39
  else:
40
+ model_state_dict = torch.load(checkpoint_dir, map_location="cpu")
41
+ model_new_ckpt = dict()
42
  for i in model_state_dict.keys():
43
  model_new_ckpt[i] = model_state_dict[i]
44
  keys = list(model_new_ckpt.keys())
 
62
 
63
 
64
  def get_dtype(str_dtype):
65
+ if str_dtype == "fp16":
66
  return torch.float16
67
+ elif str_dtype == "bf16":
68
  return torch.bfloat16
69
  else:
70
  return torch.float32
71
+
72
+
73
  def disabled_train(self, mode=True):
74
  """Overwrite model.train with this function to make sure train/eval mode
75
  does not change anymore."""
 
220
  def count_params(model, verbose=False):
221
  total_params = sum(p.numel() for p in model.parameters())
222
  if verbose:
223
+ print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.")
224
  return total_params
225
 
226
 
227
  def instantiate_from_config(config):
228
+ if "target" not in config:
229
  if config == "__is_first_stage__":
230
  return None
231
  elif config == "__is_unconditional__":
 
294
  def format_number(num):
295
  num = float(num)
296
  num /= 1000.0
297
+ return "{:.0f}{}".format(num, "k")
298
+
299
 
300
  def get_num_params(model: torch.nn.ModuleList) -> int:
301
  num_params = sum(p.numel() for p in model.parameters())
 
319
 
320
  return flop_per_token
321
 
322
+
323
  def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:
324
  l, h, q = (
325
  model_config.n_layers,
326
  model_config.n_heads,
327
  model_config.dim // model_config.n_heads,
328
  )
329
+
330
  # 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)
331
  # 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)
332
  # 3. 双向注意力需要考虑所有token对,所以是t^2而不是t
 
352
  else: # for other GPU types, assume A100
353
  return 312e12
354
 
355
+
356
  @dataclass(frozen=True)
357
  class Color:
358
  black = "\033[30m"
 
376
  magenta = ""
377
  cyan = ""
378
  white = ""
379
+ reset = ""