Julien Blanchon
commited on
Commit
·
28e3661
1
Parent(s):
548985d
Remove wandb
Browse files- tim/utils/misc_utils.py +19 -17
tim/utils/misc_utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import functools
|
2 |
import importlib
|
3 |
import os
|
4 |
-
import wandb
|
5 |
import fsspec
|
6 |
import numpy as np
|
7 |
import torch
|
@@ -13,12 +12,13 @@ from PIL import Image, ImageDraw, ImageFont
|
|
13 |
from safetensors.torch import load_file
|
14 |
from tqdm import tqdm
|
15 |
|
|
|
16 |
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
17 |
"""
|
18 |
Builds a single .npz file from a folder of .png samples.
|
19 |
"""
|
20 |
samples = []
|
21 |
-
imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split(
|
22 |
print(len(imgs))
|
23 |
assert len(imgs) >= num
|
24 |
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
@@ -32,14 +32,13 @@ def create_npz_from_sample_folder(sample_dir, num=50_000):
|
|
32 |
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
33 |
return npz_path
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
) -> None:
|
38 |
if checkpoint_dir.endswith(".safetensors"):
|
39 |
-
model_state_dict=load_file(checkpoint_dir, device=
|
40 |
else:
|
41 |
-
model_state_dict=torch.load(checkpoint_dir,
|
42 |
-
model_new_ckpt=dict()
|
43 |
for i in model_state_dict.keys():
|
44 |
model_new_ckpt[i] = model_state_dict[i]
|
45 |
keys = list(model_new_ckpt.keys())
|
@@ -63,14 +62,14 @@ def init_from_ckpt(
|
|
63 |
|
64 |
|
65 |
def get_dtype(str_dtype):
|
66 |
-
if str_dtype ==
|
67 |
return torch.float16
|
68 |
-
elif str_dtype ==
|
69 |
return torch.bfloat16
|
70 |
else:
|
71 |
return torch.float32
|
72 |
-
|
73 |
-
|
74 |
def disabled_train(self, mode=True):
|
75 |
"""Overwrite model.train with this function to make sure train/eval mode
|
76 |
does not change anymore."""
|
@@ -221,12 +220,12 @@ def mean_flat(tensor):
|
|
221 |
def count_params(model, verbose=False):
|
222 |
total_params = sum(p.numel() for p in model.parameters())
|
223 |
if verbose:
|
224 |
-
print(f"{model.__class__.__name__} has {total_params * 1.
|
225 |
return total_params
|
226 |
|
227 |
|
228 |
def instantiate_from_config(config):
|
229 |
-
if
|
230 |
if config == "__is_first_stage__":
|
231 |
return None
|
232 |
elif config == "__is_unconditional__":
|
@@ -295,7 +294,8 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
|
295 |
def format_number(num):
|
296 |
num = float(num)
|
297 |
num /= 1000.0
|
298 |
-
return
|
|
|
299 |
|
300 |
def get_num_params(model: torch.nn.ModuleList) -> int:
|
301 |
num_params = sum(p.numel() for p in model.parameters())
|
@@ -319,13 +319,14 @@ def get_num_flop_per_token(num_params, model_config, seq_len) -> int:
|
|
319 |
|
320 |
return flop_per_token
|
321 |
|
|
|
322 |
def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:
|
323 |
l, h, q = (
|
324 |
model_config.n_layers,
|
325 |
model_config.n_heads,
|
326 |
model_config.dim // model_config.n_heads,
|
327 |
)
|
328 |
-
|
329 |
# 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)
|
330 |
# 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)
|
331 |
# 3. 双向注意力需要考虑所有token对,所以是t^2而不是t
|
@@ -351,6 +352,7 @@ def get_peak_flops(device_name: str) -> int:
|
|
351 |
else: # for other GPU types, assume A100
|
352 |
return 312e12
|
353 |
|
|
|
354 |
@dataclass(frozen=True)
|
355 |
class Color:
|
356 |
black = "\033[30m"
|
@@ -374,4 +376,4 @@ class NoColor:
|
|
374 |
magenta = ""
|
375 |
cyan = ""
|
376 |
white = ""
|
377 |
-
reset = ""
|
|
|
1 |
import functools
|
2 |
import importlib
|
3 |
import os
|
|
|
4 |
import fsspec
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
12 |
from safetensors.torch import load_file
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
+
|
16 |
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
17 |
"""
|
18 |
Builds a single .npz file from a folder of .png samples.
|
19 |
"""
|
20 |
samples = []
|
21 |
+
imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split(".")[0]))
|
22 |
print(len(imgs))
|
23 |
assert len(imgs) >= num
|
24 |
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
|
|
32 |
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
33 |
return npz_path
|
34 |
|
35 |
+
|
36 |
+
def init_from_ckpt(model, checkpoint_dir, ignore_keys=None, verbose=False) -> None:
|
|
|
37 |
if checkpoint_dir.endswith(".safetensors"):
|
38 |
+
model_state_dict = load_file(checkpoint_dir, device="cpu")
|
39 |
else:
|
40 |
+
model_state_dict = torch.load(checkpoint_dir, map_location="cpu")
|
41 |
+
model_new_ckpt = dict()
|
42 |
for i in model_state_dict.keys():
|
43 |
model_new_ckpt[i] = model_state_dict[i]
|
44 |
keys = list(model_new_ckpt.keys())
|
|
|
62 |
|
63 |
|
64 |
def get_dtype(str_dtype):
|
65 |
+
if str_dtype == "fp16":
|
66 |
return torch.float16
|
67 |
+
elif str_dtype == "bf16":
|
68 |
return torch.bfloat16
|
69 |
else:
|
70 |
return torch.float32
|
71 |
+
|
72 |
+
|
73 |
def disabled_train(self, mode=True):
|
74 |
"""Overwrite model.train with this function to make sure train/eval mode
|
75 |
does not change anymore."""
|
|
|
220 |
def count_params(model, verbose=False):
|
221 |
total_params = sum(p.numel() for p in model.parameters())
|
222 |
if verbose:
|
223 |
+
print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.")
|
224 |
return total_params
|
225 |
|
226 |
|
227 |
def instantiate_from_config(config):
|
228 |
+
if "target" not in config:
|
229 |
if config == "__is_first_stage__":
|
230 |
return None
|
231 |
elif config == "__is_unconditional__":
|
|
|
294 |
def format_number(num):
|
295 |
num = float(num)
|
296 |
num /= 1000.0
|
297 |
+
return "{:.0f}{}".format(num, "k")
|
298 |
+
|
299 |
|
300 |
def get_num_params(model: torch.nn.ModuleList) -> int:
|
301 |
num_params = sum(p.numel() for p in model.parameters())
|
|
|
319 |
|
320 |
return flop_per_token
|
321 |
|
322 |
+
|
323 |
def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:
|
324 |
l, h, q = (
|
325 |
model_config.n_layers,
|
326 |
model_config.n_heads,
|
327 |
model_config.dim // model_config.n_heads,
|
328 |
)
|
329 |
+
|
330 |
# 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)
|
331 |
# 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)
|
332 |
# 3. 双向注意力需要考虑所有token对,所以是t^2而不是t
|
|
|
352 |
else: # for other GPU types, assume A100
|
353 |
return 312e12
|
354 |
|
355 |
+
|
356 |
@dataclass(frozen=True)
|
357 |
class Color:
|
358 |
black = "\033[30m"
|
|
|
376 |
magenta = ""
|
377 |
cyan = ""
|
378 |
white = ""
|
379 |
+
reset = ""
|