Spaces:
Build error
Build error
[fix] update misc
Browse files- utils/misc.py +70 -0
utils/misc.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import timm.models.hub as timm_hub
|
8 |
+
|
9 |
+
def setup_for_distributed(is_master):
|
10 |
+
"""
|
11 |
+
This function disables printing when not in master process
|
12 |
+
"""
|
13 |
+
import builtins as __builtin__
|
14 |
+
|
15 |
+
builtin_print = __builtin__.print
|
16 |
+
|
17 |
+
def print(*args, **kwargs):
|
18 |
+
force = kwargs.pop("force", False)
|
19 |
+
if is_master or force:
|
20 |
+
builtin_print(*args, **kwargs)
|
21 |
+
|
22 |
+
__builtin__.print = print
|
23 |
+
|
24 |
+
|
25 |
+
def is_dist_avail_and_initialized():
|
26 |
+
if not dist.is_available():
|
27 |
+
return False
|
28 |
+
if not dist.is_initialized():
|
29 |
+
return False
|
30 |
+
return True
|
31 |
+
|
32 |
+
|
33 |
+
def get_world_size():
|
34 |
+
if not is_dist_avail_and_initialized():
|
35 |
+
return 1
|
36 |
+
return dist.get_world_size()
|
37 |
+
|
38 |
+
|
39 |
+
def get_rank():
|
40 |
+
if not is_dist_avail_and_initialized():
|
41 |
+
return 0
|
42 |
+
return dist.get_rank()
|
43 |
+
|
44 |
+
|
45 |
+
def is_main_process():
|
46 |
+
return get_rank() == 0
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
51 |
+
"""
|
52 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
53 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
54 |
+
"""
|
55 |
+
if not url.startswith('http'):
|
56 |
+
return url
|
57 |
+
|
58 |
+
def get_cached_file_path():
|
59 |
+
# a hack to sync the file path across processes
|
60 |
+
parts = torch.hub.urlparse(url)
|
61 |
+
filename = os.path.basename(parts.path)
|
62 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
63 |
+
|
64 |
+
return cached_file
|
65 |
+
|
66 |
+
if is_main_process():
|
67 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
68 |
+
|
69 |
+
|
70 |
+
return get_cached_file_path()
|