import sys |
import os |
import os.path as osp |
import argparse |
import hashlib |
import tarfile |
import time |
import urllib.request |
def reporthook(count, block_size, total_size): |
global start_time |
if count == 0: |
start_time = time.time() |
return |
duration = time.time() - start_time |
progress_size = int(count * block_size) |
speed = int(progress_size / (1024 * duration)) |
percent = min(int(count * block_size * 100 / total_size), 100) |
sys.stdout.write("\r \\__%d%%, %d MB, %d KB/s, %d seconds passed" % |
(percent, progress_size / (1024 * 1024), speed, duration)) |
sys.stdout.flush() |
def download(src, sha256sum, dest): |
tmp_tar = osp.join(dest, ".tmp.tar") |
try: |
urllib.request.urlretrieve(src, tmp_tar, reporthook) |
except: |
raise ConnectionError("Error: {}".format(src)) |
sha256_hash = hashlib.sha256() |
with open(tmp_tar, "rb") as f: |
for byte_block in iter(lambda: f.read(4096), b""): |
sha256_hash.update(byte_block) |
sha256_check = sha256_hash.hexdigest() == sha256sum |
print() |
print(" \\__Check sha256: {}".format("OK!" if sha256_check else "Error")) |
if not sha256_check: |
raise Exception("Error: Invalid sha256 sum: {}".format(sha256_hash.hexdigest())) |
tar_file = tarfile.open(tmp_tar, mode='r') |
tar_file.extractall(dest) |
os.remove(tmp_tar) |
def main(): |
"""Download pre-trained GAN generators and various pre-trained detectors (used only during testing), as well as |
pre-trained ContraCLIP models: |
-- GenForce GAN generators [1] |
-- SFD face detector [2] |
-- ArcFace [3] |
-- FairFace [4] |
-- Hopenet [5] |
-- AU detector [6] for 12 DISFA [7] Action Units |
-- Facial attributes detector [8] for 5 CelebA [9] attributes |
-- ContraCLIP [10] pre-trained models: |
StyleGAN2@FFHQ |
ProgGAN@CelebA-HQ: |
StyleGAN2@AFHQ-Cats |
StyleGAN2@AFHQ-Dogs |
StyleGAN2@AFHQ-Cars |
References: |
[1] https://genforce.github.io/ |
[2] Zhang, Shifeng, et al. "S3FD: Single shot scale-invariant face detector." Proceedings of the IEEE |
international conference on computer vision. 2017. |
[3] Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." |
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. |
[4] Karkkainen, Kimmo, and Jungseock Joo. "FairFace: Face attribute dataset for balanced race, gender, and age." |
arXiv preprint arXiv:1908.04913 (2019). |
[5] Doosti, Bardia, et al. "Hope-net: A graph-based model for hand-object pose estimation." Proceedings of the |
IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. |
[6] Ntinou, Ioanna, et al. "A transfer learning approach to heatmap regression for action unit intensity |
estimation." IEEE Transactions on Affective Computing (2021). |
[7] Mavadati, S. Mohammad, et al. "DISFA: A spontaneous facial action intensity database." IEEE Transactions on |
Affective Computing 4.2 (2013): 151-160. |
[8] Jiang, Yuming, et al. "Talk-to-Edit: Fine-Grained Facial Editing via Dialog." Proceedings of the IEEE/CVF |
International Conference on Computer Vision. 2021. |
[9] Liu, Ziwei, et al. "Deep learning face attributes in the wild." Proceedings of the IEEE international |
conference on computer vision. 2015. |
[10] Tzelepis, C., Oldfield, J., Tzimiropoulos, G., & Patras, I. (2022). ContraCLIP: Interpretable GAN |
generation driven by pairs of contrasting sentences. arXiv preprint arXiv:2206.02104. |
""" |
parser = argparse.ArgumentParser(description="Download pre-trained models") |
parser.add_argument('-m', '--contraclip-models', action='store_true', help="download pre-trained ContraCLIP models") |
args = parser.parse_args() |
pretrained_models_root = osp.join('models', 'pretrained') |
os.makedirs(pretrained_models_root, exist_ok=True) |
print("#. Download pre-trained GAN generators...") |
print(" \\__.GenForce") |
download_genforce_models = False |
for k, v in GENFORCE_MODELS.items(): |
if not osp.exists(osp.join(pretrained_models_root, 'genforce', v[0])): |
download_genforce_models = True |
break |
if download_genforce_models: |
download(src=GENFORCE[0], sha256sum=GENFORCE[1], dest=pretrained_models_root) |
else: |
print(" \\__Already exists.") |
print("#. Download pre-trained ArcFace model...") |
print(" \\__.ArcFace") |
if osp.exists(osp.join(pretrained_models_root, 'arcface', 'model_ir_se50.pth')): |
print(" \\__Already exists.") |
else: |
download(src=ARCFACE[0], sha256sum=ARCFACE[1], dest=pretrained_models_root) |
print("#. Download pre-trained SFD face detector model...") |
print(" \\__.Face detector (SFD)") |
if osp.exists(osp.join(pretrained_models_root, 'sfd', 's3fd-619a316812.pth')): |
print(" \\__Already exists.") |
else: |
download(src=SFD[0], sha256sum=SFD[1], dest=pretrained_models_root) |
print("#. Download pre-trained FairFace model...") |
print(" \\__.FairFace") |
if osp.exists(osp.join(pretrained_models_root, 'fairface', 'fairface_alldata_4race_20191111.pt')) and \ |
osp.exists(osp.join(pretrained_models_root, 'fairface', 'res34_fair_align_multi_7_20190809.pt')): |
print(" \\__Already exists.") |
else: |
download(src=FAIRFACE[0], sha256sum=FAIRFACE[1], dest=pretrained_models_root) |
print("#. Download pre-trained Hopenet model...") |
print(" \\__.Hopenet") |
if osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha1.pkl')) and \ |
osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha2.pkl')) and \ |
osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_robust_alpha1.pkl')): |
print(" \\__Already exists.") |
else: |
download(src=HOPENET[0], sha256sum=HOPENET[1], dest=pretrained_models_root) |
print("#. Download pre-trained AU detector model...") |
print(" \\__.FANet") |
if osp.exists(osp.join(pretrained_models_root, 'au_detector', 'disfa_adaptation_f0.pth')): |
print(" \\__Already exists.") |
else: |
download(src=AUDET[0], sha256sum=AUDET[1], dest=pretrained_models_root) |
print("#. Download pre-trained CelebA attributes predictors models...") |
print(" \\__.CelebA") |
if osp.exists(osp.join(pretrained_models_root, 'celeba_attributes', 'eval_predictor.pth.tar')): |
print(" \\__Already exists.") |
else: |
download(src=CELEBA_ATTRIBUTES[0], sha256sum=CELEBA_ATTRIBUTES[1], dest=pretrained_models_root) |
if args.contraclip_models: |
pretrained_contraclip_root = osp.join('experiments', 'complete') |
os.makedirs(pretrained_contraclip_root, exist_ok=True) |
print("#. Download pre-trained ContraCLIP models...") |
print(" \\__.ContraCLIP pre-trained models...") |
download(src=ContraCLIP_models[0], |
sha256sum=ContraCLIP_models[1], |
dest=pretrained_contraclip_root) |
if __name__ == '__main__': |
main() |