Test1 / lib /models /common.py
Andranik Sargsyan
add demo code
bfd34e9
raw
history blame
1.52 kB
import importlib
import requests
from pathlib import Path
from os.path import dirname
from omegaconf import OmegaConf
from tqdm import tqdm
PROJECT_DIR = dirname(dirname(dirname(__file__)))
CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config'
MODEL_FOLDER = f'{PROJECT_DIR}/assets/models'
def download_file(url, save_path, chunk_size=1024):
try:
save_path = Path(save_path)
if save_path.exists():
print(f'{save_path.name} exists')
return
save_path.parent.mkdir(exist_ok=True, parents=True)
resp = requests.get(url, stream=True)
total = int(resp.headers.get('content-length', 0))
with open(save_path, 'wb') as file, tqdm(
desc=save_path.name,
total=total,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
print(f'{save_path.name} download finished')
except Exception as e:
raise Exception(f"Download failed: {e}")
def get_obj_from_str(string):
module, cls = string.rsplit(".", 1)
try:
return getattr(importlib.import_module(module, package=None), cls)
except:
return getattr(importlib.import_module('lib.' + module, package=None), cls)
def load_obj(path):
objyaml = OmegaConf.load(path)
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))