import os import random import pandas as pd import torch.nn.functional as F from .util import load_img from configs import path_csv_ffhq_attritube class GenePoolFactory(object): def __init__(self, root_ffhq, device, mean_latent, max_sample=100): self.device = device self.mean_latent = mean_latent self.root_ffhq = root_ffhq self.max_sample = max_sample self.pools = {} path_ffhq_attributes = path_csv_ffhq_attritube self.df = pd.read_csv(path_ffhq_attributes) self.df.replace('Male', 'male', inplace=True) self.df.replace('Female', 'female', inplace=True) def __call__(self, encoder, w2sub34, age, gender, race): keyname = f'{age}-{gender}-{race}' if keyname in self.pools.keys(): return self.pools[keyname] elif self.root_ffhq is not None: result = self.df.query(f'gender == "{gender}" and age == "{age}" and race == "{race}"') result = result[['file_id']].values tmp = [] random.shuffle(result) for fid in result[:self.max_sample]: filename = format(int(fid[0]), '05d') + ".png" img = load_img(os.path.join(self.root_ffhq, filename)) img = img.to(self.device) w18_1 = encoder(F.interpolate(img, size=(256, 256))) + self.mean_latent mu, var, sub34_1 = w2sub34(w18_1) tmp.append((mu.cpu(), var.cpu())) self.pools[keyname] = tmp return self.pools[keyname] else: return []