ynhe commited on
Commit
72446c5
β€’
1 Parent(s): d6602e7

[fix] update misc

Browse files
Files changed (1) hide show
  1. utils/misc.py +70 -0
utils/misc.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ import timm.models.hub as timm_hub
8
+
9
+ def setup_for_distributed(is_master):
10
+ """
11
+ This function disables printing when not in master process
12
+ """
13
+ import builtins as __builtin__
14
+
15
+ builtin_print = __builtin__.print
16
+
17
+ def print(*args, **kwargs):
18
+ force = kwargs.pop("force", False)
19
+ if is_master or force:
20
+ builtin_print(*args, **kwargs)
21
+
22
+ __builtin__.print = print
23
+
24
+
25
+ def is_dist_avail_and_initialized():
26
+ if not dist.is_available():
27
+ return False
28
+ if not dist.is_initialized():
29
+ return False
30
+ return True
31
+
32
+
33
+ def get_world_size():
34
+ if not is_dist_avail_and_initialized():
35
+ return 1
36
+ return dist.get_world_size()
37
+
38
+
39
+ def get_rank():
40
+ if not is_dist_avail_and_initialized():
41
+ return 0
42
+ return dist.get_rank()
43
+
44
+
45
+ def is_main_process():
46
+ return get_rank() == 0
47
+
48
+
49
+
50
+ def download_cached_file(url, check_hash=True, progress=False):
51
+ """
52
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
53
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
54
+ """
55
+ if not url.startswith('http'):
56
+ return url
57
+
58
+ def get_cached_file_path():
59
+ # a hack to sync the file path across processes
60
+ parts = torch.hub.urlparse(url)
61
+ filename = os.path.basename(parts.path)
62
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
63
+
64
+ return cached_file
65
+
66
+ if is_main_process():
67
+ timm_hub.download_cached_file(url, check_hash, progress)
68
+
69
+
70
+ return get_cached_file_path()