ybelkada's picture
commit files
4d6b877
raw
history blame
35.1 kB
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
import os
import sys
import glob
import argparse
import threading
import six.moves.queue as Queue
import traceback
import numpy as np
import tensorflow as tf
import PIL.Image
import tfutil
import dataset
#----------------------------------------------------------------------------
def error(msg):
print('Error: ' + msg)
exit(1)
#----------------------------------------------------------------------------
class TFRecordExporter:
def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10):
self.tfrecord_dir = tfrecord_dir
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
self.expected_images = expected_images
self.cur_images = 0
self.shape = None
self.resolution_log2 = None
self.tfr_writers = []
self.print_progress = print_progress
self.progress_interval = progress_interval
if self.print_progress:
print('Creating dataset "%s"' % tfrecord_dir)
if not os.path.isdir(self.tfrecord_dir):
os.makedirs(self.tfrecord_dir)
assert(os.path.isdir(self.tfrecord_dir))
def close(self):
if self.print_progress:
print('%-40s\r' % 'Flushing data...', end='', flush=True)
for tfr_writer in self.tfr_writers:
tfr_writer.close()
self.tfr_writers = []
if self.print_progress:
print('%-40s\r' % '', end='', flush=True)
print('Added %d images.' % self.cur_images)
def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.
order = np.arange(self.expected_images)
np.random.RandomState(123).shuffle(order)
return order
def add_image(self, img):
if self.print_progress and self.cur_images % self.progress_interval == 0:
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
if self.shape is None:
self.shape = img.shape
self.resolution_log2 = int(np.log2(self.shape[1]))
assert self.shape[0] in [1, 3]
assert self.shape[1] == self.shape[2]
assert self.shape[1] == 2**self.resolution_log2
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
for lod in range(self.resolution_log2 - 1):
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod)
self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
assert img.shape == self.shape
for lod, tfr_writer in enumerate(self.tfr_writers):
if lod:
img = img.astype(np.float32)
img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
quant = np.rint(img).clip(0, 255).astype(np.uint8)
ex = tf.train.Example(features=tf.train.Features(feature={
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
tfr_writer.write(ex.SerializeToString())
self.cur_images += 1
def add_labels(self, labels):
if self.print_progress:
print('%-40s\r' % 'Saving labels...', end='', flush=True)
assert labels.shape[0] == self.cur_images
with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:
np.save(f, labels.astype(np.float32))
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
#----------------------------------------------------------------------------
class ExceptionInfo(object):
def __init__(self):
self.value = sys.exc_info()[1]
self.traceback = traceback.format_exc()
#----------------------------------------------------------------------------
class WorkerThread(threading.Thread):
def __init__(self, task_queue):
threading.Thread.__init__(self)
self.task_queue = task_queue
def run(self):
while True:
func, args, result_queue = self.task_queue.get()
if func is None:
break
try:
result = func(*args)
except:
result = ExceptionInfo()
result_queue.put((result, args))
#----------------------------------------------------------------------------
class ThreadPool(object):
def __init__(self, num_threads):
assert num_threads >= 1
self.task_queue = Queue.Queue()
self.result_queues = dict()
self.num_threads = num_threads
for idx in range(self.num_threads):
thread = WorkerThread(self.task_queue)
thread.daemon = True
thread.start()
def add_task(self, func, args=()):
assert hasattr(func, '__call__') # must be a function
if func not in self.result_queues:
self.result_queues[func] = Queue.Queue()
self.task_queue.put((func, args, self.result_queues[func]))
def get_result(self, func): # returns (result, args)
result, args = self.result_queues[func].get()
if isinstance(result, ExceptionInfo):
print('\n\nWorker thread caught an exception:\n' + result.traceback)
raise result.value
return result, args
def finish(self):
for idx in range(self.num_threads):
self.task_queue.put((None, (), None))
def __enter__(self): # for 'with' statement
return self
def __exit__(self, *excinfo):
self.finish()
def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None):
if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4
assert max_items_in_flight >= 1
results = []
retire_idx = [0]
def task_func(prepared, idx):
return process_func(prepared)
def retire_result():
processed, (prepared, idx) = self.get_result(task_func)
results[idx] = processed
while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:
yield post_func(results[retire_idx[0]])
results[retire_idx[0]] = None
retire_idx[0] += 1
for idx, item in enumerate(item_iterator):
prepared = pre_func(item)
results.append(None)
self.add_task(func=task_func, args=(prepared, idx))
while retire_idx[0] < idx - max_items_in_flight + 2:
for res in retire_result(): yield res
while retire_idx[0] < len(results):
for res in retire_result(): yield res
#----------------------------------------------------------------------------
def display(tfrecord_dir):
print('Loading dataset "%s"' % tfrecord_dir)
tfutil.init_tf({'gpu_options.allow_growth': True})
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0)
tfutil.init_uninited_vars()
idx = 0
while True:
try:
images, labels = dset.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
break
if idx == 0:
print('Displaying images')
import cv2 # pip install opencv-python
cv2.namedWindow('dataset_tool')
print('Press SPACE or ENTER to advance, ESC to exit')
print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist()))
cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR
idx += 1
if cv2.waitKey() == 27:
break
print('\nDisplayed %d images.' % idx)
#----------------------------------------------------------------------------
def extract(tfrecord_dir, output_dir):
print('Loading dataset "%s"' % tfrecord_dir)
tfutil.init_tf({'gpu_options.allow_growth': True})
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0)
tfutil.init_uninited_vars()
print('Extracting images to "%s"' % output_dir)
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
idx = 0
while True:
if idx % 10 == 0:
print('%d\r' % idx, end='', flush=True)
try:
images, labels = dset.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
break
if images.shape[1] == 1:
img = PIL.Image.fromarray(images[0][0], 'L')
else:
img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB')
img.save(os.path.join(output_dir, 'img%08d.png' % idx))
idx += 1
print('Extracted %d images.' % idx)
#----------------------------------------------------------------------------
def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
max_label_size = 0 if ignore_labels else 'full'
print('Loading dataset "%s"' % tfrecord_dir_a)
tfutil.init_tf({'gpu_options.allow_growth': True})
dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
print('Loading dataset "%s"' % tfrecord_dir_b)
dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
tfutil.init_uninited_vars()
print('Comparing datasets')
idx = 0
identical_images = 0
identical_labels = 0
while True:
if idx % 100 == 0:
print('%d\r' % idx, end='', flush=True)
try:
images_a, labels_a = dset_a.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
images_a, labels_a = None, None
try:
images_b, labels_b = dset_b.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
images_b, labels_b = None, None
if images_a is None or images_b is None:
if images_a is not None or images_b is not None:
print('Datasets contain different number of images')
break
if images_a.shape == images_b.shape and np.all(images_a == images_b):
identical_images += 1
else:
print('Image %d is different' % idx)
if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):
identical_labels += 1
else:
print('Label %d is different' % idx)
idx += 1
print('Identical images: %d / %d' % (identical_images, idx))
if not ignore_labels:
print('Identical labels: %d / %d' % (identical_labels, idx))
#----------------------------------------------------------------------------
def create_mnist(tfrecord_dir, mnist_dir):
print('Loading MNIST from "%s"' % mnist_dir)
import gzip
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
images = np.frombuffer(file.read(), np.uint8, offset=16)
with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:
labels = np.frombuffer(file.read(), np.uint8, offset=8)
images = images.reshape(-1, 1, 28, 28)
images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0)
assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
assert labels.shape == (60000,) and labels.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):
print('Loading MNIST from "%s"' % mnist_dir)
import gzip
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
images = np.frombuffer(file.read(), np.uint8, offset=16)
images = images.reshape(-1, 28, 28)
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
with TFRecordExporter(tfrecord_dir, num_images) as tfr:
rnd = np.random.RandomState(random_seed)
for idx in range(num_images):
tfr.add_image(images[rnd.randint(images.shape[0], size=3)])
#----------------------------------------------------------------------------
def create_cifar10(tfrecord_dir, cifar10_dir):
print('Loading CIFAR-10 from "%s"' % cifar10_dir)
import pickle
images = []
labels = []
for batch in range(1, 6):
with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file:
data = pickle.load(file, encoding='latin1')
images.append(data['data'].reshape(-1, 3, 32, 32))
labels.append(data['labels'])
images = np.concatenate(images)
labels = np.concatenate(labels)
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
assert labels.shape == (50000,) and labels.dtype == np.int32
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_cifar100(tfrecord_dir, cifar100_dir):
print('Loading CIFAR-100 from "%s"' % cifar100_dir)
import pickle
with open(os.path.join(cifar100_dir, 'train'), 'rb') as file:
data = pickle.load(file, encoding='latin1')
images = data['data'].reshape(-1, 3, 32, 32)
labels = np.array(data['fine_labels'])
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
assert labels.shape == (50000,) and labels.dtype == np.int32
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 99
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_svhn(tfrecord_dir, svhn_dir):
print('Loading SVHN from "%s"' % svhn_dir)
import pickle
images = []
labels = []
for batch in range(1, 4):
with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file:
data = pickle.load(file, encoding='latin1')
images.append(data[0])
labels.append(data[1])
images = np.concatenate(images)
labels = np.concatenate(labels)
assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8
assert labels.shape == (73257,) and labels.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):
print('Loading LSUN dataset from "%s"' % lmdb_dir)
import lmdb # pip install lmdb
import cv2 # pip install opencv-python
import io
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
total_images = txn.stat()['entries']
if max_images is None:
max_images = total_images
with TFRecordExporter(tfrecord_dir, max_images) as tfr:
for idx, (key, value) in enumerate(txn.cursor()):
try:
try:
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
if img is None:
raise IOError('cv2.imdecode failed')
img = img[:, :, ::-1] # BGR => RGB
except IOError:
img = np.asarray(PIL.Image.open(io.BytesIO(value)))
crop = np.min(img.shape[:2])
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)
img = np.asarray(img)
img = img.transpose(2, 0, 1) # HWC => CHW
tfr.add_image(img)
except:
print(sys.exc_info()[1])
if tfr.cur_images == max_images:
break
#----------------------------------------------------------------------------
def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):
print('Loading CelebA from "%s"' % celeba_dir)
glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png')
image_filenames = sorted(glob.glob(glob_pattern))
expected_images = 202599
if len(image_filenames) != expected_images:
error('Expected to find %d images' % expected_images)
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
assert img.shape == (218, 178, 3)
img = img[cy - 64 : cy + 64, cx - 64 : cx + 64]
img = img.transpose(2, 0, 1) # HWC => CHW
tfr.add_image(img)
#----------------------------------------------------------------------------
def create_celebahq(tfrecord_dir, celeba_dir, delta_dir, num_threads=4, num_tasks=100):
print('Loading CelebA from "%s"' % celeba_dir)
expected_images = 202599
if len(glob.glob(os.path.join(celeba_dir, 'img_celeba', '*.jpg'))) != expected_images:
error('Expected to find %d images' % expected_images)
with open(os.path.join(celeba_dir, 'Anno', 'list_landmarks_celeba.txt'), 'rt') as file:
landmarks = [[float(value) for value in line.split()[1:]] for line in file.readlines()[2:]]
landmarks = np.float32(landmarks).reshape(-1, 5, 2)
print('Loading CelebA-HQ deltas from "%s"' % delta_dir)
import scipy.ndimage
import hashlib
import bz2
import zipfile
import base64
import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.kdf.pbkdf2
import cryptography.fernet
expected_zips = 30
if len(glob.glob(os.path.join(delta_dir, 'delta*.zip'))) != expected_zips:
error('Expected to find %d zips' % expected_zips)
with open(os.path.join(delta_dir, 'image_list.txt'), 'rt') as file:
lines = [line.split() for line in file]
fields = dict()
for idx, field in enumerate(lines[0]):
type = int if field.endswith('idx') else str
fields[field] = [type(line[idx]) for line in lines[1:]]
indices = np.array(fields['idx'])
# Must use pillow version 3.1.1 for everything to work correctly.
if getattr(PIL, 'PILLOW_VERSION', '') != '3.1.1':
error('create_celebahq requires pillow version 3.1.1') # conda install pillow=3.1.1
# Must use libjpeg version 8d for everything to work correctly.
img = np.array(PIL.Image.open(os.path.join(celeba_dir, 'img_celeba', '000001.jpg')))
md5 = hashlib.md5()
md5.update(img.tobytes())
if md5.hexdigest() != '9cad8178d6cb0196b36f7b34bc5eb6d3':
error('create_celebahq requires libjpeg version 8d') # conda install jpeg=8d
def rot90(v):
return np.array([-v[1], v[0]])
def process_func(idx):
# Load original image.
orig_idx = fields['orig_idx'][idx]
orig_file = fields['orig_file'][idx]
orig_path = os.path.join(celeba_dir, 'img_celeba', orig_file)
img = PIL.Image.open(orig_path)
# Choose oriented crop rectangle.
lm = landmarks[orig_idx]
eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
eye_to_eye = lm[1] - lm[0]
eye_to_mouth = mouth_avg - eye_avg
x = eye_to_eye - rot90(eye_to_mouth)
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = rot90(x)
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
zoom = 1024 / (np.hypot(*x) * 2)
# Shrink.
shrink = int(np.floor(0.5 / zoom))
if shrink > 1:
size = (int(np.round(float(img.size[0]) / shrink)), int(np.round(float(img.size[1]) / shrink)))
img = img.resize(size, PIL.Image.ANTIALIAS)
quad /= shrink
zoom *= shrink
# Crop.
border = max(int(np.round(1024 * 0.1 / zoom)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
# Simulate super-resolution.
superres = int(np.exp2(np.ceil(np.log2(zoom))))
if superres > 1:
img = img.resize((img.size[0] * superres, img.size[1] * superres), PIL.Image.ANTIALIAS)
quad *= superres
zoom /= superres
# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if max(pad) > border - 4:
pad = np.maximum(pad, int(np.round(1024 * 0.3 / zoom)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.mgrid[:h, :w, :1]
mask = 1.0 - np.minimum(np.minimum(np.float32(x) / pad[0], np.float32(y) / pad[1]), np.minimum(np.float32(w-1-x) / pad[2], np.float32(h-1-y) / pad[3]))
blur = 1024 * 0.02 / zoom
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(np.uint8(np.clip(np.round(img), 0, 255)), 'RGB')
quad += pad[0:2]
# Transform.
img = img.transform((4096, 4096), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
img = img.resize((1024, 1024), PIL.Image.ANTIALIAS)
img = np.asarray(img).transpose(2, 0, 1)
# Verify MD5.
md5 = hashlib.md5()
md5.update(img.tobytes())
assert md5.hexdigest() == fields['proc_md5'][idx]
# Load delta image and original JPG.
with zipfile.ZipFile(os.path.join(delta_dir, 'deltas%05d.zip' % (idx - idx % 1000)), 'r') as zip:
delta_bytes = zip.read('delta%05d.dat' % idx)
with open(orig_path, 'rb') as file:
orig_bytes = file.read()
# Decrypt delta image, using original JPG data as decryption key.
algorithm = cryptography.hazmat.primitives.hashes.SHA256()
backend = cryptography.hazmat.backends.default_backend()
salt = bytes(orig_file, 'ascii')
kdf = cryptography.hazmat.primitives.kdf.pbkdf2.PBKDF2HMAC(algorithm=algorithm, length=32, salt=salt, iterations=100000, backend=backend)
key = base64.urlsafe_b64encode(kdf.derive(orig_bytes))
delta = np.frombuffer(bz2.decompress(cryptography.fernet.Fernet(key).decrypt(delta_bytes)), dtype=np.uint8).reshape(3, 1024, 1024)
# Apply delta image.
img = img + delta
# Verify MD5.
md5 = hashlib.md5()
md5.update(img.tobytes())
assert md5.hexdigest() == fields['final_md5'][idx]
return img
with TFRecordExporter(tfrecord_dir, indices.size) as tfr:
order = tfr.choose_shuffled_order()
with ThreadPool(num_threads) as pool:
for img in pool.process_items_concurrently(indices[order].tolist(), process_func=process_func, max_items_in_flight=num_tasks):
tfr.add_image(img)
#----------------------------------------------------------------------------
def create_from_images(tfrecord_dir, image_dir, shuffle):
print('Loading images from "%s"' % image_dir)
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
if len(image_filenames) == 0:
error('No input images found')
img = np.asarray(PIL.Image.open(image_filenames[0]))
resolution = img.shape[0]
channels = img.shape[2] if img.ndim == 3 else 1
if img.shape[1] != resolution:
error('Input images must have the same width and height')
if resolution != 2 ** int(np.floor(np.log2(resolution))):
error('Input image resolution must be a power-of-two')
if channels not in [1, 3]:
error('Input images must be stored as RGB or grayscale')
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
if channels == 1:
img = img[np.newaxis, :, :] # HW => CHW
else:
img = img.transpose(2, 0, 1) # HWC => CHW
tfr.add_image(img)
#----------------------------------------------------------------------------
def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):
print('Loading HDF5 archive from "%s"' % hdf5_filename)
import h5py # conda install h5py
with h5py.File(hdf5_filename, 'r') as hdf5_file:
hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3])
with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr:
order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0])
for idx in range(order.size):
tfr.add_image(hdf5_data[order[idx]])
npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy'
if os.path.isfile(npy_filename):
tfr.add_labels(np.load(npy_filename)[order])
#----------------------------------------------------------------------------
def execute_cmdline(argv):
prog = argv[0]
parser = argparse.ArgumentParser(
prog = prog,
description = 'Tool for creating, extracting, and visualizing Progressive GAN datasets.',
epilog = 'Type "%s <command> -h" for more information.' % prog)
subparsers = parser.add_subparsers(dest='command')
subparsers.required = True
def add_command(cmd, desc, example=None):
epilog = 'Example: %s %s' % (prog, example) if example is not None else None
return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)
p = add_command( 'display', 'Display images in dataset.',
'display datasets/mnist')
p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
p = add_command( 'extract', 'Extract images from dataset.',
'extract datasets/mnist mnist-images')
p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
p.add_argument( 'output_dir', help='Directory to extract the images into')
p = add_command( 'compare', 'Compare two datasets.',
'compare datasets/mydataset datasets/mnist')
p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset')
p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset')
p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0)
p = add_command( 'create_mnist', 'Create dataset for MNIST.',
'create_mnist datasets/mnist ~/downloads/mnist')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'mnist_dir', help='Directory containing MNIST')
p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.',
'create_mnistrgb datasets/mnistrgb ~/downloads/mnist')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'mnist_dir', help='Directory containing MNIST')
p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000)
p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123)
p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.',
'create_cifar10 datasets/cifar10 ~/downloads/cifar10')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10')
p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.',
'create_cifar100 datasets/cifar100 ~/downloads/cifar100')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100')
p = add_command( 'create_svhn', 'Create dataset for SVHN.',
'create_svhn datasets/svhn ~/downloads/svhn')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'svhn_dir', help='Directory containing SVHN')
p = add_command( 'create_lsun', 'Create dataset for single LSUN category.',
'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256)
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
p = add_command( 'create_celeba', 'Create dataset for CelebA.',
'create_celeba datasets/celeba ~/downloads/celeba')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'celeba_dir', help='Directory containing CelebA')
p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89)
p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121)
p = add_command( 'create_celebahq', 'Create dataset for CelebA-HQ.',
'create_celebahq datasets/celebahq ~/downloads/celeba ~/downloads/celeba-hq-deltas')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'celeba_dir', help='Directory containing CelebA')
p.add_argument( 'delta_dir', help='Directory containing CelebA-HQ deltas')
p.add_argument( '--num_threads', help='Number of concurrent threads (default: 4)', type=int, default=4)
p.add_argument( '--num_tasks', help='Number of concurrent processing tasks (default: 100)', type=int, default=100)
p = add_command( 'create_from_images', 'Create dataset from a directory full of images.',
'create_from_images datasets/mydataset myimagedir')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'image_dir', help='Directory containing the images')
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.',
'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images')
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h'])
func = globals()[args.command]
del args.command
func(**vars(args))
#----------------------------------------------------------------------------
if __name__ == "__main__":
execute_cmdline(sys.argv)
#----------------------------------------------------------------------------