wzhouxiff commited on
Commit
df4a3e9
1 Parent(s): 1fa3b8b

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +38 -0
utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import requests
4
+ from torch.hub import download_url_to_file, get_dir
5
+
6
+ from urllib.parse import urlparse
7
+
8
+
9
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
10
+ """Load file form http url, will download models if necessary.
11
+
12
+ Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
13
+
14
+ Args:
15
+ url (str): URL to be downloaded.
16
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
17
+ Default: None.
18
+ progress (bool): Whether to show the download progress. Default: True.
19
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
20
+
21
+ Returns:
22
+ str: The path to the downloaded file.
23
+ """
24
+ if model_dir is None: # use the pytorch hub_dir
25
+ hub_dir = get_dir()
26
+ model_dir = os.path.join(hub_dir, 'checkpoints')
27
+
28
+ os.makedirs(model_dir, exist_ok=True)
29
+
30
+ parts = urlparse(url)
31
+ filename = os.path.basename(parts.path)
32
+ if file_name is not None:
33
+ filename = file_name
34
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
35
+ if not os.path.exists(cached_file):
36
+ print(f'Downloading: "{url}" to {cached_file}\n')
37
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
38
+ return cached_file