|
import sys |
|
import os |
|
import os.path as osp |
|
import argparse |
|
import hashlib |
|
import tarfile |
|
import time |
|
import urllib.request |
|
from lib import GENFORCE, GENFORCE_MODELS, SFD, ARCFACE, FAIRFACE, HOPENET, AUDET, CELEBA_ATTRIBUTES, ContraCLIP_models |
|
|
|
|
|
def reporthook(count, block_size, total_size): |
|
global start_time |
|
if count == 0: |
|
start_time = time.time() |
|
return |
|
duration = time.time() - start_time |
|
progress_size = int(count * block_size) |
|
speed = int(progress_size / (1024 * duration)) |
|
percent = min(int(count * block_size * 100 / total_size), 100) |
|
sys.stdout.write("\r \\__%d%%, %d MB, %d KB/s, %d seconds passed" % |
|
(percent, progress_size / (1024 * 1024), speed, duration)) |
|
|
|
sys.stdout.flush() |
|
|
|
|
|
def download(src, sha256sum, dest): |
|
tmp_tar = osp.join(dest, ".tmp.tar") |
|
try: |
|
urllib.request.urlretrieve(src, tmp_tar, reporthook) |
|
except: |
|
raise ConnectionError("Error: {}".format(src)) |
|
|
|
sha256_hash = hashlib.sha256() |
|
with open(tmp_tar, "rb") as f: |
|
|
|
for byte_block in iter(lambda: f.read(4096), b""): |
|
sha256_hash.update(byte_block) |
|
|
|
sha256_check = sha256_hash.hexdigest() == sha256sum |
|
print() |
|
print(" \\__Check sha256: {}".format("OK!" if sha256_check else "Error")) |
|
if not sha256_check: |
|
raise Exception("Error: Invalid sha256 sum: {}".format(sha256_hash.hexdigest())) |
|
|
|
tar_file = tarfile.open(tmp_tar, mode='r') |
|
tar_file.extractall(dest) |
|
os.remove(tmp_tar) |
|
|
|
|
|
def main(): |
|
"""Download pre-trained GAN generators and various pre-trained detectors (used only during testing), as well as |
|
pre-trained ContraCLIP models: |
|
-- GenForce GAN generators [1] |
|
-- SFD face detector [2] |
|
-- ArcFace [3] |
|
-- FairFace [4] |
|
-- Hopenet [5] |
|
-- AU detector [6] for 12 DISFA [7] Action Units |
|
-- Facial attributes detector [8] for 5 CelebA [9] attributes |
|
-- ContraCLIP [10] pre-trained models: |
|
StyleGAN2@FFHQ |
|
ProgGAN@CelebA-HQ: |
|
StyleGAN2@AFHQ-Cats |
|
StyleGAN2@AFHQ-Dogs |
|
StyleGAN2@AFHQ-Cars |
|
|
|
References: |
|
[1] https://genforce.github.io/ |
|
[2] Zhang, Shifeng, et al. "S3FD: Single shot scale-invariant face detector." Proceedings of the IEEE |
|
international conference on computer vision. 2017. |
|
[3] Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." |
|
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. |
|
[4] Karkkainen, Kimmo, and Jungseock Joo. "FairFace: Face attribute dataset for balanced race, gender, and age." |
|
arXiv preprint arXiv:1908.04913 (2019). |
|
[5] Doosti, Bardia, et al. "Hope-net: A graph-based model for hand-object pose estimation." Proceedings of the |
|
IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. |
|
[6] Ntinou, Ioanna, et al. "A transfer learning approach to heatmap regression for action unit intensity |
|
estimation." IEEE Transactions on Affective Computing (2021). |
|
[7] Mavadati, S. Mohammad, et al. "DISFA: A spontaneous facial action intensity database." IEEE Transactions on |
|
Affective Computing 4.2 (2013): 151-160. |
|
[8] Jiang, Yuming, et al. "Talk-to-Edit: Fine-Grained Facial Editing via Dialog." Proceedings of the IEEE/CVF |
|
International Conference on Computer Vision. 2021. |
|
[9] Liu, Ziwei, et al. "Deep learning face attributes in the wild." Proceedings of the IEEE international |
|
conference on computer vision. 2015. |
|
[10] Tzelepis, C., Oldfield, J., Tzimiropoulos, G., & Patras, I. (2022). ContraCLIP: Interpretable GAN |
|
generation driven by pairs of contrasting sentences. arXiv preprint arXiv:2206.02104. |
|
""" |
|
parser = argparse.ArgumentParser(description="Download pre-trained models") |
|
parser.add_argument('-m', '--contraclip-models', action='store_true', help="download pre-trained ContraCLIP models") |
|
args = parser.parse_args() |
|
|
|
|
|
pretrained_models_root = osp.join('models', 'pretrained') |
|
os.makedirs(pretrained_models_root, exist_ok=True) |
|
|
|
|
|
print("#. Download pre-trained GAN generators...") |
|
print(" \\__.GenForce") |
|
download_genforce_models = False |
|
for k, v in GENFORCE_MODELS.items(): |
|
if not osp.exists(osp.join(pretrained_models_root, 'genforce', v[0])): |
|
download_genforce_models = True |
|
break |
|
if download_genforce_models: |
|
download(src=GENFORCE[0], sha256sum=GENFORCE[1], dest=pretrained_models_root) |
|
else: |
|
print(" \\__Already exists.") |
|
|
|
print("#. Download pre-trained ArcFace model...") |
|
print(" \\__.ArcFace") |
|
if osp.exists(osp.join(pretrained_models_root, 'arcface', 'model_ir_se50.pth')): |
|
print(" \\__Already exists.") |
|
else: |
|
download(src=ARCFACE[0], sha256sum=ARCFACE[1], dest=pretrained_models_root) |
|
|
|
print("#. Download pre-trained SFD face detector model...") |
|
print(" \\__.Face detector (SFD)") |
|
if osp.exists(osp.join(pretrained_models_root, 'sfd', 's3fd-619a316812.pth')): |
|
print(" \\__Already exists.") |
|
else: |
|
download(src=SFD[0], sha256sum=SFD[1], dest=pretrained_models_root) |
|
|
|
print("#. Download pre-trained FairFace model...") |
|
print(" \\__.FairFace") |
|
if osp.exists(osp.join(pretrained_models_root, 'fairface', 'fairface_alldata_4race_20191111.pt')) and \ |
|
osp.exists(osp.join(pretrained_models_root, 'fairface', 'res34_fair_align_multi_7_20190809.pt')): |
|
print(" \\__Already exists.") |
|
else: |
|
download(src=FAIRFACE[0], sha256sum=FAIRFACE[1], dest=pretrained_models_root) |
|
|
|
print("#. Download pre-trained Hopenet model...") |
|
print(" \\__.Hopenet") |
|
if osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha1.pkl')) and \ |
|
osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha2.pkl')) and \ |
|
osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_robust_alpha1.pkl')): |
|
print(" \\__Already exists.") |
|
else: |
|
download(src=HOPENET[0], sha256sum=HOPENET[1], dest=pretrained_models_root) |
|
|
|
print("#. Download pre-trained AU detector model...") |
|
print(" \\__.FANet") |
|
if osp.exists(osp.join(pretrained_models_root, 'au_detector', 'disfa_adaptation_f0.pth')): |
|
print(" \\__Already exists.") |
|
else: |
|
download(src=AUDET[0], sha256sum=AUDET[1], dest=pretrained_models_root) |
|
|
|
print("#. Download pre-trained CelebA attributes predictors models...") |
|
print(" \\__.CelebA") |
|
if osp.exists(osp.join(pretrained_models_root, 'celeba_attributes', 'eval_predictor.pth.tar')): |
|
print(" \\__Already exists.") |
|
else: |
|
download(src=CELEBA_ATTRIBUTES[0], sha256sum=CELEBA_ATTRIBUTES[1], dest=pretrained_models_root) |
|
|
|
|
|
if args.contraclip_models: |
|
pretrained_contraclip_root = osp.join('experiments', 'complete') |
|
os.makedirs(pretrained_contraclip_root, exist_ok=True) |
|
|
|
print("#. Download pre-trained ContraCLIP models...") |
|
print(" \\__.ContraCLIP pre-trained models...") |
|
download(src=ContraCLIP_models[0], |
|
sha256sum=ContraCLIP_models[1], |
|
dest=pretrained_contraclip_root) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|