dattarij's picture
adding ContraCLIP folder
8c212a5
import sys
import os
import os.path as osp
import argparse
import hashlib
import tarfile
import time
import urllib.request
from lib import GENFORCE, GENFORCE_MODELS, SFD, ARCFACE, FAIRFACE, HOPENET, AUDET, CELEBA_ATTRIBUTES, ContraCLIP_models
def reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r \\__%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download(src, sha256sum, dest):
tmp_tar = osp.join(dest, ".tmp.tar")
try:
urllib.request.urlretrieve(src, tmp_tar, reporthook)
except:
raise ConnectionError("Error: {}".format(src))
sha256_hash = hashlib.sha256()
with open(tmp_tar, "rb") as f:
# Read and update hash string value in blocks of 4K
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
sha256_check = sha256_hash.hexdigest() == sha256sum
print()
print(" \\__Check sha256: {}".format("OK!" if sha256_check else "Error"))
if not sha256_check:
raise Exception("Error: Invalid sha256 sum: {}".format(sha256_hash.hexdigest()))
tar_file = tarfile.open(tmp_tar, mode='r')
tar_file.extractall(dest)
os.remove(tmp_tar)
def main():
"""Download pre-trained GAN generators and various pre-trained detectors (used only during testing), as well as
pre-trained ContraCLIP models:
-- GenForce GAN generators [1]
-- SFD face detector [2]
-- ArcFace [3]
-- FairFace [4]
-- Hopenet [5]
-- AU detector [6] for 12 DISFA [7] Action Units
-- Facial attributes detector [8] for 5 CelebA [9] attributes
-- ContraCLIP [10] pre-trained models:
StyleGAN2@FFHQ
ProgGAN@CelebA-HQ:
StyleGAN2@AFHQ-Cats
StyleGAN2@AFHQ-Dogs
StyleGAN2@AFHQ-Cars
References:
[1] https://genforce.github.io/
[2] Zhang, Shifeng, et al. "S3FD: Single shot scale-invariant face detector." Proceedings of the IEEE
international conference on computer vision. 2017.
[3] Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
[4] Karkkainen, Kimmo, and Jungseock Joo. "FairFace: Face attribute dataset for balanced race, gender, and age."
arXiv preprint arXiv:1908.04913 (2019).
[5] Doosti, Bardia, et al. "Hope-net: A graph-based model for hand-object pose estimation." Proceedings of the
IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.
[6] Ntinou, Ioanna, et al. "A transfer learning approach to heatmap regression for action unit intensity
estimation." IEEE Transactions on Affective Computing (2021).
[7] Mavadati, S. Mohammad, et al. "DISFA: A spontaneous facial action intensity database." IEEE Transactions on
Affective Computing 4.2 (2013): 151-160.
[8] Jiang, Yuming, et al. "Talk-to-Edit: Fine-Grained Facial Editing via Dialog." Proceedings of the IEEE/CVF
International Conference on Computer Vision. 2021.
[9] Liu, Ziwei, et al. "Deep learning face attributes in the wild." Proceedings of the IEEE international
conference on computer vision. 2015.
[10] Tzelepis, C., Oldfield, J., Tzimiropoulos, G., & Patras, I. (2022). ContraCLIP: Interpretable GAN
generation driven by pairs of contrasting sentences. arXiv preprint arXiv:2206.02104.
"""
parser = argparse.ArgumentParser(description="Download pre-trained models")
parser.add_argument('-m', '--contraclip-models', action='store_true', help="download pre-trained ContraCLIP models")
args = parser.parse_args()
# Create pre-trained models root directory
pretrained_models_root = osp.join('models', 'pretrained')
os.makedirs(pretrained_models_root, exist_ok=True)
# Download the following pre-trained GAN generators (under models/pretrained/)
print("#. Download pre-trained GAN generators...")
print(" \\__.GenForce")
download_genforce_models = False
for k, v in GENFORCE_MODELS.items():
if not osp.exists(osp.join(pretrained_models_root, 'genforce', v[0])):
download_genforce_models = True
break
if download_genforce_models:
download(src=GENFORCE[0], sha256sum=GENFORCE[1], dest=pretrained_models_root)
else:
print(" \\__Already exists.")
print("#. Download pre-trained ArcFace model...")
print(" \\__.ArcFace")
if osp.exists(osp.join(pretrained_models_root, 'arcface', 'model_ir_se50.pth')):
print(" \\__Already exists.")
else:
download(src=ARCFACE[0], sha256sum=ARCFACE[1], dest=pretrained_models_root)
print("#. Download pre-trained SFD face detector model...")
print(" \\__.Face detector (SFD)")
if osp.exists(osp.join(pretrained_models_root, 'sfd', 's3fd-619a316812.pth')):
print(" \\__Already exists.")
else:
download(src=SFD[0], sha256sum=SFD[1], dest=pretrained_models_root)
print("#. Download pre-trained FairFace model...")
print(" \\__.FairFace")
if osp.exists(osp.join(pretrained_models_root, 'fairface', 'fairface_alldata_4race_20191111.pt')) and \
osp.exists(osp.join(pretrained_models_root, 'fairface', 'res34_fair_align_multi_7_20190809.pt')):
print(" \\__Already exists.")
else:
download(src=FAIRFACE[0], sha256sum=FAIRFACE[1], dest=pretrained_models_root)
print("#. Download pre-trained Hopenet model...")
print(" \\__.Hopenet")
if osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha1.pkl')) and \
osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_alpha2.pkl')) and \
osp.exists(osp.join(pretrained_models_root, 'hopenet', 'hopenet_robust_alpha1.pkl')):
print(" \\__Already exists.")
else:
download(src=HOPENET[0], sha256sum=HOPENET[1], dest=pretrained_models_root)
print("#. Download pre-trained AU detector model...")
print(" \\__.FANet")
if osp.exists(osp.join(pretrained_models_root, 'au_detector', 'disfa_adaptation_f0.pth')):
print(" \\__Already exists.")
else:
download(src=AUDET[0], sha256sum=AUDET[1], dest=pretrained_models_root)
print("#. Download pre-trained CelebA attributes predictors models...")
print(" \\__.CelebA")
if osp.exists(osp.join(pretrained_models_root, 'celeba_attributes', 'eval_predictor.pth.tar')):
print(" \\__Already exists.")
else:
download(src=CELEBA_ATTRIBUTES[0], sha256sum=CELEBA_ATTRIBUTES[1], dest=pretrained_models_root)
# Download pre-trained ContraCLIP models
if args.contraclip_models:
pretrained_contraclip_root = osp.join('experiments', 'complete')
os.makedirs(pretrained_contraclip_root, exist_ok=True)
print("#. Download pre-trained ContraCLIP models...")
print(" \\__.ContraCLIP pre-trained models...")
download(src=ContraCLIP_models[0],
sha256sum=ContraCLIP_models[1],
dest=pretrained_contraclip_root)
if __name__ == '__main__':
main()