StyleGene / models /stylegene /gene_pool.py
wmpscc
add
7d1312d
raw
history blame
1.59 kB
import os
import random
import pandas as pd
import torch.nn.functional as F
from .util import load_img
from configs import path_csv_ffhq_attritube
class GenePoolFactory(object):
def __init__(self, root_ffhq, device, mean_latent, max_sample=100):
self.device = device
self.mean_latent = mean_latent
self.root_ffhq = root_ffhq
self.max_sample = max_sample
self.pools = {}
path_ffhq_attributes = path_csv_ffhq_attritube
self.df = pd.read_csv(path_ffhq_attributes)
self.df.replace('Male', 'male', inplace=True)
self.df.replace('Female', 'female', inplace=True)
def __call__(self, encoder, w2sub34, age, gender, race):
keyname = f'{age}-{gender}-{race}'
if keyname in self.pools.keys():
return self.pools[keyname]
elif self.root_ffhq is not None:
result = self.df.query(f'gender == "{gender}" and age == "{age}" and race == "{race}"')
result = result[['file_id']].values
tmp = []
random.shuffle(result)
for fid in result[:self.max_sample]:
filename = format(int(fid[0]), '05d') + ".png"
img = load_img(os.path.join(self.root_ffhq, filename))
img = img.to(self.device)
w18_1 = encoder(F.interpolate(img, size=(256, 256))) + self.mean_latent
mu, var, sub34_1 = w2sub34(w18_1)
tmp.append((mu.cpu(), var.cpu()))
self.pools[keyname] = tmp
return self.pools[keyname]
else:
return []