minchul commited on
Commit
87d7f4b
1 Parent(s): 79ca423

Upload directory

Browse files
Files changed (1) hide show
  1. models/base/utils.py +91 -0
models/base/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List, Optional, Tuple, Union
3
+ import safetensors
4
+ import torch
5
+ from torch import Tensor
6
+ import os
7
+ from pathlib import Path
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+
11
+ def get_parameter_device(parameter: torch.nn.Module):
12
+ try:
13
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
14
+ return next(parameters_and_buffers).device
15
+ except StopIteration:
16
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
17
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
18
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
19
+ return tuples
20
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
21
+ first_tuple = next(gen)
22
+ return first_tuple[1].device
23
+
24
+
25
+ def get_parameter_dtype(parameter: torch.nn.Module):
26
+ try:
27
+ params = tuple(parameter.parameters())
28
+ if len(params) > 0:
29
+ return params[0].dtype
30
+
31
+ buffers = tuple(parameter.buffers())
32
+ if len(buffers) > 0:
33
+ return buffers[0].dtype
34
+
35
+ except StopIteration:
36
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
37
+
38
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
39
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
40
+ return tuples
41
+
42
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
43
+ first_tuple = next(gen)
44
+ return first_tuple[1].dtype
45
+
46
+
47
+ def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
48
+ path_obj = Path(save_path)
49
+ return path_obj.parent
50
+
51
+ def get_base_name(save_path: Union[str, os.PathLike]) -> str:
52
+ path_obj = Path(save_path)
53
+ return path_obj.name
54
+
55
+ def load_state_dict_from_path(path: Union[str, os.PathLike]):
56
+ # Load a state dict from a path.
57
+ if 'safetensors' in path:
58
+ state_dict = safetensors.torch.load_file(path)
59
+ else:
60
+ state_dict = torch.load(path, map_location="cpu")
61
+ return state_dict
62
+
63
+ def replace_extension(path, new_extension):
64
+ if not new_extension.startswith('.'):
65
+ new_extension = '.' + new_extension
66
+ return os.path.splitext(path)[0] + new_extension
67
+
68
+ def make_config_path(save_path):
69
+ config_path = replace_extension(save_path, '.yaml')
70
+ return config_path
71
+
72
+ def save_config(config, config_path):
73
+ assert isinstance(config, dict) or isinstance(config, DictConfig)
74
+ os.makedirs(get_parent_directory(config_path), exist_ok=True)
75
+ if isinstance(config, dict):
76
+ config = OmegaConf.create(config)
77
+ OmegaConf.save(config, config_path)
78
+
79
+
80
+ def save_state_dict_and_config(state_dict, config, save_path):
81
+ os.makedirs(get_parent_directory(save_path), exist_ok=True)
82
+
83
+ # save config dict
84
+ config_path = make_config_path(save_path)
85
+ save_config(config, config_path)
86
+
87
+ # Save the model
88
+ if 'safetensors' in save_path:
89
+ safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
90
+ else:
91
+ torch.save(state_dict, save_path)