|
import json |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
from threading import Thread |
|
from typing import Any, Dict, List |
|
|
|
import requests |
|
from tqdm import tqdm |
|
|
|
|
|
class BaseModelDownloader: |
|
""" |
|
A utility for fast download of base model from S3 or any CDN served storage. |
|
Works by downloading multiple files in parallel and dividing large files |
|
into smaller chunks and combining them at the end. |
|
|
|
Currently it uses multithreading (not multiprocessing) assuming GIL won't |
|
interfere with network/disk IO. |
|
|
|
Created by: KP |
|
""" |
|
|
|
def __init__(self, urls: List[str], url_paths: List[str], out_dir: Path): |
|
self.urls = urls |
|
self.url_paths = url_paths |
|
shutil.rmtree(out_dir, ignore_errors=True) |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
self.out_dir = out_dir |
|
|
|
def download(self): |
|
threads = [] |
|
batch_urls = {} |
|
|
|
for url, url_path in zip(self.urls, self.url_paths): |
|
out_dir = self.out_dir / url_path |
|
self.out_dir.parent.mkdir(parents=True, exist_ok=True) |
|
if url.endswith(".bin"): |
|
if "unet/" in url_path: |
|
thread = Thread( |
|
target=self.__download_parallel, args=(url, out_dir, 6) |
|
) |
|
thread.start() |
|
threads.append(thread) |
|
else: |
|
thread = Thread( |
|
target=self.__download_files, args=([url], [out_dir]) |
|
) |
|
thread.start() |
|
threads.append(thread) |
|
pass |
|
else: |
|
batch_urls[url] = out_dir |
|
|
|
if batch_urls: |
|
thread = Thread( |
|
target=self.__download_files, |
|
args=(list(batch_urls.keys()), list(batch_urls.values())), |
|
) |
|
thread.start() |
|
threads.append(thread) |
|
pass |
|
|
|
for thread in threads: |
|
thread.join() |
|
|
|
def __download_parallel(self, url, output_filename, num_parts=4): |
|
response = requests.head(url) |
|
total_size = int(response.headers.get("content-length", 0)) |
|
print("total_size", total_size) |
|
|
|
chunk_size = total_size // num_parts |
|
ranges = [ |
|
(i * chunk_size, (i + 1) * chunk_size - 1) for i in range(num_parts - 1) |
|
] |
|
ranges.append((ranges[-1][1] + 1, total_size)) |
|
|
|
print(ranges) |
|
|
|
save_dir = Path.home() / ".cache" / "download_parts" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
threads = [] |
|
for i, (start, end) in enumerate(ranges): |
|
thread = Thread( |
|
target=self.__download_part, args=(url, start, end, i, save_dir) |
|
) |
|
thread.start() |
|
threads.append(thread) |
|
|
|
for thread in threads: |
|
thread.join() |
|
|
|
self.__combine_parts(save_dir, output_filename, num_parts) |
|
os.rmdir(save_dir) |
|
|
|
def __combine_parts(self, save_dir, output_filename, num_parts): |
|
part_files = [os.path.join(save_dir, f"part_{i}.tmp") for i in range(num_parts)] |
|
|
|
output_filename.parent.mkdir(parents=True, exist_ok=True) |
|
with open(output_filename, "wb") as output_file: |
|
for part_file in part_files: |
|
print("combining: ", part_file) |
|
with open(part_file, "rb") as part: |
|
output_file.write(part.read()) |
|
|
|
out_file_size = output_file.tell() |
|
print("out_file_size", out_file_size) |
|
|
|
for part_file in part_files: |
|
os.remove(part_file) |
|
|
|
def __download_part(self, url, start_byte, end_byte, part_num, save_dir): |
|
headers = {"Range": f"bytes={start_byte}-{end_byte}"} |
|
response = requests.get(url, headers=headers, stream=True) |
|
|
|
part_filename = os.path.join(save_dir, f"part_{part_num}.tmp") |
|
print("Downloading part: ", url, part_filename, end_byte - start_byte) |
|
|
|
with open(part_filename, "wb") as part_file, tqdm( |
|
desc=str(part_filename), |
|
total=end_byte - start_byte, |
|
unit="B", |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
size = part_file.write(chunk) |
|
bar.update(size) |
|
|
|
return part_filename |
|
|
|
def __download_files(self, urls, out_paths: List[Path]): |
|
for url, out_path in zip(urls, out_paths): |
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
with requests.get(url, stream=True) as r: |
|
print("Downloading: ", url) |
|
total_size = int(r.headers.get("content-length", 0)) |
|
chunk_size = 8192 |
|
r.raise_for_status() |
|
with open(out_path, "wb") as f, tqdm( |
|
desc=str(out_path), |
|
total=total_size, |
|
unit="B", |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for data in r.iter_content(chunk_size=chunk_size): |
|
size = f.write(data) |
|
bar.update(size) |
|
|