Spaces:
Runtime error
Runtime error
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
import os | |
import sys | |
import glob | |
import argparse | |
import threading | |
import six.moves.queue as Queue | |
import traceback | |
import numpy as np | |
import tensorflow as tf | |
import PIL.Image | |
import tfutil | |
import dataset | |
#---------------------------------------------------------------------------- | |
def error(msg): | |
print('Error: ' + msg) | |
exit(1) | |
#---------------------------------------------------------------------------- | |
class TFRecordExporter: | |
def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10): | |
self.tfrecord_dir = tfrecord_dir | |
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) | |
self.expected_images = expected_images | |
self.cur_images = 0 | |
self.shape = None | |
self.resolution_log2 = None | |
self.tfr_writers = [] | |
self.print_progress = print_progress | |
self.progress_interval = progress_interval | |
if self.print_progress: | |
print('Creating dataset "%s"' % tfrecord_dir) | |
if not os.path.isdir(self.tfrecord_dir): | |
os.makedirs(self.tfrecord_dir) | |
assert(os.path.isdir(self.tfrecord_dir)) | |
def close(self): | |
if self.print_progress: | |
print('%-40s\r' % 'Flushing data...', end='', flush=True) | |
for tfr_writer in self.tfr_writers: | |
tfr_writer.close() | |
self.tfr_writers = [] | |
if self.print_progress: | |
print('%-40s\r' % '', end='', flush=True) | |
print('Added %d images.' % self.cur_images) | |
def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order. | |
order = np.arange(self.expected_images) | |
np.random.RandomState(123).shuffle(order) | |
return order | |
def add_image(self, img): | |
if self.print_progress and self.cur_images % self.progress_interval == 0: | |
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True) | |
if self.shape is None: | |
self.shape = img.shape | |
self.resolution_log2 = int(np.log2(self.shape[1])) | |
assert self.shape[0] in [1, 3] | |
assert self.shape[1] == self.shape[2] | |
assert self.shape[1] == 2**self.resolution_log2 | |
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) | |
for lod in range(self.resolution_log2 - 1): | |
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) | |
self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) | |
assert img.shape == self.shape | |
for lod, tfr_writer in enumerate(self.tfr_writers): | |
if lod: | |
img = img.astype(np.float32) | |
img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25 | |
quant = np.rint(img).clip(0, 255).astype(np.uint8) | |
ex = tf.train.Example(features=tf.train.Features(feature={ | |
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)), | |
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))})) | |
tfr_writer.write(ex.SerializeToString()) | |
self.cur_images += 1 | |
def add_labels(self, labels): | |
if self.print_progress: | |
print('%-40s\r' % 'Saving labels...', end='', flush=True) | |
assert labels.shape[0] == self.cur_images | |
with open(self.tfr_prefix + '-rxx.labels', 'wb') as f: | |
np.save(f, labels.astype(np.float32)) | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
self.close() | |
#---------------------------------------------------------------------------- | |
class ExceptionInfo(object): | |
def __init__(self): | |
self.value = sys.exc_info()[1] | |
self.traceback = traceback.format_exc() | |
#---------------------------------------------------------------------------- | |
class WorkerThread(threading.Thread): | |
def __init__(self, task_queue): | |
threading.Thread.__init__(self) | |
self.task_queue = task_queue | |
def run(self): | |
while True: | |
func, args, result_queue = self.task_queue.get() | |
if func is None: | |
break | |
try: | |
result = func(*args) | |
except: | |
result = ExceptionInfo() | |
result_queue.put((result, args)) | |
#---------------------------------------------------------------------------- | |
class ThreadPool(object): | |
def __init__(self, num_threads): | |
assert num_threads >= 1 | |
self.task_queue = Queue.Queue() | |
self.result_queues = dict() | |
self.num_threads = num_threads | |
for idx in range(self.num_threads): | |
thread = WorkerThread(self.task_queue) | |
thread.daemon = True | |
thread.start() | |
def add_task(self, func, args=()): | |
assert hasattr(func, '__call__') # must be a function | |
if func not in self.result_queues: | |
self.result_queues[func] = Queue.Queue() | |
self.task_queue.put((func, args, self.result_queues[func])) | |
def get_result(self, func): # returns (result, args) | |
result, args = self.result_queues[func].get() | |
if isinstance(result, ExceptionInfo): | |
print('\n\nWorker thread caught an exception:\n' + result.traceback) | |
raise result.value | |
return result, args | |
def finish(self): | |
for idx in range(self.num_threads): | |
self.task_queue.put((None, (), None)) | |
def __enter__(self): # for 'with' statement | |
return self | |
def __exit__(self, *excinfo): | |
self.finish() | |
def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None): | |
if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4 | |
assert max_items_in_flight >= 1 | |
results = [] | |
retire_idx = [0] | |
def task_func(prepared, idx): | |
return process_func(prepared) | |
def retire_result(): | |
processed, (prepared, idx) = self.get_result(task_func) | |
results[idx] = processed | |
while retire_idx[0] < len(results) and results[retire_idx[0]] is not None: | |
yield post_func(results[retire_idx[0]]) | |
results[retire_idx[0]] = None | |
retire_idx[0] += 1 | |
for idx, item in enumerate(item_iterator): | |
prepared = pre_func(item) | |
results.append(None) | |
self.add_task(func=task_func, args=(prepared, idx)) | |
while retire_idx[0] < idx - max_items_in_flight + 2: | |
for res in retire_result(): yield res | |
while retire_idx[0] < len(results): | |
for res in retire_result(): yield res | |
#---------------------------------------------------------------------------- | |
def display(tfrecord_dir): | |
print('Loading dataset "%s"' % tfrecord_dir) | |
tfutil.init_tf({'gpu_options.allow_growth': True}) | |
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0) | |
tfutil.init_uninited_vars() | |
idx = 0 | |
while True: | |
try: | |
images, labels = dset.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
break | |
if idx == 0: | |
print('Displaying images') | |
import cv2 # pip install opencv-python | |
cv2.namedWindow('dataset_tool') | |
print('Press SPACE or ENTER to advance, ESC to exit') | |
print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist())) | |
cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR | |
idx += 1 | |
if cv2.waitKey() == 27: | |
break | |
print('\nDisplayed %d images.' % idx) | |
#---------------------------------------------------------------------------- | |
def extract(tfrecord_dir, output_dir): | |
print('Loading dataset "%s"' % tfrecord_dir) | |
tfutil.init_tf({'gpu_options.allow_growth': True}) | |
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0) | |
tfutil.init_uninited_vars() | |
print('Extracting images to "%s"' % output_dir) | |
if not os.path.isdir(output_dir): | |
os.makedirs(output_dir) | |
idx = 0 | |
while True: | |
if idx % 10 == 0: | |
print('%d\r' % idx, end='', flush=True) | |
try: | |
images, labels = dset.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
break | |
if images.shape[1] == 1: | |
img = PIL.Image.fromarray(images[0][0], 'L') | |
else: | |
img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB') | |
img.save(os.path.join(output_dir, 'img%08d.png' % idx)) | |
idx += 1 | |
print('Extracted %d images.' % idx) | |
#---------------------------------------------------------------------------- | |
def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels): | |
max_label_size = 0 if ignore_labels else 'full' | |
print('Loading dataset "%s"' % tfrecord_dir_a) | |
tfutil.init_tf({'gpu_options.allow_growth': True}) | |
dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0) | |
print('Loading dataset "%s"' % tfrecord_dir_b) | |
dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0) | |
tfutil.init_uninited_vars() | |
print('Comparing datasets') | |
idx = 0 | |
identical_images = 0 | |
identical_labels = 0 | |
while True: | |
if idx % 100 == 0: | |
print('%d\r' % idx, end='', flush=True) | |
try: | |
images_a, labels_a = dset_a.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
images_a, labels_a = None, None | |
try: | |
images_b, labels_b = dset_b.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
images_b, labels_b = None, None | |
if images_a is None or images_b is None: | |
if images_a is not None or images_b is not None: | |
print('Datasets contain different number of images') | |
break | |
if images_a.shape == images_b.shape and np.all(images_a == images_b): | |
identical_images += 1 | |
else: | |
print('Image %d is different' % idx) | |
if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b): | |
identical_labels += 1 | |
else: | |
print('Label %d is different' % idx) | |
idx += 1 | |
print('Identical images: %d / %d' % (identical_images, idx)) | |
if not ignore_labels: | |
print('Identical labels: %d / %d' % (identical_labels, idx)) | |
#---------------------------------------------------------------------------- | |
def create_mnist(tfrecord_dir, mnist_dir): | |
print('Loading MNIST from "%s"' % mnist_dir) | |
import gzip | |
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: | |
images = np.frombuffer(file.read(), np.uint8, offset=16) | |
with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file: | |
labels = np.frombuffer(file.read(), np.uint8, offset=8) | |
images = images.reshape(-1, 1, 28, 28) | |
images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0) | |
assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (60000,) and labels.dtype == np.uint8 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 9 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123): | |
print('Loading MNIST from "%s"' % mnist_dir) | |
import gzip | |
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: | |
images = np.frombuffer(file.read(), np.uint8, offset=16) | |
images = images.reshape(-1, 28, 28) | |
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) | |
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
with TFRecordExporter(tfrecord_dir, num_images) as tfr: | |
rnd = np.random.RandomState(random_seed) | |
for idx in range(num_images): | |
tfr.add_image(images[rnd.randint(images.shape[0], size=3)]) | |
#---------------------------------------------------------------------------- | |
def create_cifar10(tfrecord_dir, cifar10_dir): | |
print('Loading CIFAR-10 from "%s"' % cifar10_dir) | |
import pickle | |
images = [] | |
labels = [] | |
for batch in range(1, 6): | |
with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file: | |
data = pickle.load(file, encoding='latin1') | |
images.append(data['data'].reshape(-1, 3, 32, 32)) | |
labels.append(data['labels']) | |
images = np.concatenate(images) | |
labels = np.concatenate(labels) | |
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (50000,) and labels.dtype == np.int32 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 9 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_cifar100(tfrecord_dir, cifar100_dir): | |
print('Loading CIFAR-100 from "%s"' % cifar100_dir) | |
import pickle | |
with open(os.path.join(cifar100_dir, 'train'), 'rb') as file: | |
data = pickle.load(file, encoding='latin1') | |
images = data['data'].reshape(-1, 3, 32, 32) | |
labels = np.array(data['fine_labels']) | |
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (50000,) and labels.dtype == np.int32 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 99 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_svhn(tfrecord_dir, svhn_dir): | |
print('Loading SVHN from "%s"' % svhn_dir) | |
import pickle | |
images = [] | |
labels = [] | |
for batch in range(1, 4): | |
with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file: | |
data = pickle.load(file, encoding='latin1') | |
images.append(data[0]) | |
labels.append(data[1]) | |
images = np.concatenate(images) | |
labels = np.concatenate(labels) | |
assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (73257,) and labels.dtype == np.uint8 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 9 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None): | |
print('Loading LSUN dataset from "%s"' % lmdb_dir) | |
import lmdb # pip install lmdb | |
import cv2 # pip install opencv-python | |
import io | |
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: | |
total_images = txn.stat()['entries'] | |
if max_images is None: | |
max_images = total_images | |
with TFRecordExporter(tfrecord_dir, max_images) as tfr: | |
for idx, (key, value) in enumerate(txn.cursor()): | |
try: | |
try: | |
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) | |
if img is None: | |
raise IOError('cv2.imdecode failed') | |
img = img[:, :, ::-1] # BGR => RGB | |
except IOError: | |
img = np.asarray(PIL.Image.open(io.BytesIO(value))) | |
crop = np.min(img.shape[:2]) | |
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] | |
img = PIL.Image.fromarray(img, 'RGB') | |
img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS) | |
img = np.asarray(img) | |
img = img.transpose(2, 0, 1) # HWC => CHW | |
tfr.add_image(img) | |
except: | |
print(sys.exc_info()[1]) | |
if tfr.cur_images == max_images: | |
break | |
#---------------------------------------------------------------------------- | |
def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121): | |
print('Loading CelebA from "%s"' % celeba_dir) | |
glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png') | |
image_filenames = sorted(glob.glob(glob_pattern)) | |
expected_images = 202599 | |
if len(image_filenames) != expected_images: | |
error('Expected to find %d images' % expected_images) | |
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) | |
assert img.shape == (218, 178, 3) | |
img = img[cy - 64 : cy + 64, cx - 64 : cx + 64] | |
img = img.transpose(2, 0, 1) # HWC => CHW | |
tfr.add_image(img) | |
#---------------------------------------------------------------------------- | |
def create_celebahq(tfrecord_dir, celeba_dir, delta_dir, num_threads=4, num_tasks=100): | |
print('Loading CelebA from "%s"' % celeba_dir) | |
expected_images = 202599 | |
if len(glob.glob(os.path.join(celeba_dir, 'img_celeba', '*.jpg'))) != expected_images: | |
error('Expected to find %d images' % expected_images) | |
with open(os.path.join(celeba_dir, 'Anno', 'list_landmarks_celeba.txt'), 'rt') as file: | |
landmarks = [[float(value) for value in line.split()[1:]] for line in file.readlines()[2:]] | |
landmarks = np.float32(landmarks).reshape(-1, 5, 2) | |
print('Loading CelebA-HQ deltas from "%s"' % delta_dir) | |
import scipy.ndimage | |
import hashlib | |
import bz2 | |
import zipfile | |
import base64 | |
import cryptography.hazmat.primitives.hashes | |
import cryptography.hazmat.backends | |
import cryptography.hazmat.primitives.kdf.pbkdf2 | |
import cryptography.fernet | |
expected_zips = 30 | |
if len(glob.glob(os.path.join(delta_dir, 'delta*.zip'))) != expected_zips: | |
error('Expected to find %d zips' % expected_zips) | |
with open(os.path.join(delta_dir, 'image_list.txt'), 'rt') as file: | |
lines = [line.split() for line in file] | |
fields = dict() | |
for idx, field in enumerate(lines[0]): | |
type = int if field.endswith('idx') else str | |
fields[field] = [type(line[idx]) for line in lines[1:]] | |
indices = np.array(fields['idx']) | |
# Must use pillow version 3.1.1 for everything to work correctly. | |
if getattr(PIL, 'PILLOW_VERSION', '') != '3.1.1': | |
error('create_celebahq requires pillow version 3.1.1') # conda install pillow=3.1.1 | |
# Must use libjpeg version 8d for everything to work correctly. | |
img = np.array(PIL.Image.open(os.path.join(celeba_dir, 'img_celeba', '000001.jpg'))) | |
md5 = hashlib.md5() | |
md5.update(img.tobytes()) | |
if md5.hexdigest() != '9cad8178d6cb0196b36f7b34bc5eb6d3': | |
error('create_celebahq requires libjpeg version 8d') # conda install jpeg=8d | |
def rot90(v): | |
return np.array([-v[1], v[0]]) | |
def process_func(idx): | |
# Load original image. | |
orig_idx = fields['orig_idx'][idx] | |
orig_file = fields['orig_file'][idx] | |
orig_path = os.path.join(celeba_dir, 'img_celeba', orig_file) | |
img = PIL.Image.open(orig_path) | |
# Choose oriented crop rectangle. | |
lm = landmarks[orig_idx] | |
eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5 | |
mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5 | |
eye_to_eye = lm[1] - lm[0] | |
eye_to_mouth = mouth_avg - eye_avg | |
x = eye_to_eye - rot90(eye_to_mouth) | |
x /= np.hypot(*x) | |
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) | |
y = rot90(x) | |
c = eye_avg + eye_to_mouth * 0.1 | |
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) | |
zoom = 1024 / (np.hypot(*x) * 2) | |
# Shrink. | |
shrink = int(np.floor(0.5 / zoom)) | |
if shrink > 1: | |
size = (int(np.round(float(img.size[0]) / shrink)), int(np.round(float(img.size[1]) / shrink))) | |
img = img.resize(size, PIL.Image.ANTIALIAS) | |
quad /= shrink | |
zoom *= shrink | |
# Crop. | |
border = max(int(np.round(1024 * 0.1 / zoom)), 3) | |
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) | |
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) | |
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: | |
img = img.crop(crop) | |
quad -= crop[0:2] | |
# Simulate super-resolution. | |
superres = int(np.exp2(np.ceil(np.log2(zoom)))) | |
if superres > 1: | |
img = img.resize((img.size[0] * superres, img.size[1] * superres), PIL.Image.ANTIALIAS) | |
quad *= superres | |
zoom /= superres | |
# Pad. | |
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) | |
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) | |
if max(pad) > border - 4: | |
pad = np.maximum(pad, int(np.round(1024 * 0.3 / zoom))) | |
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') | |
h, w, _ = img.shape | |
y, x, _ = np.mgrid[:h, :w, :1] | |
mask = 1.0 - np.minimum(np.minimum(np.float32(x) / pad[0], np.float32(y) / pad[1]), np.minimum(np.float32(w-1-x) / pad[2], np.float32(h-1-y) / pad[3])) | |
blur = 1024 * 0.02 / zoom | |
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) | |
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) | |
img = PIL.Image.fromarray(np.uint8(np.clip(np.round(img), 0, 255)), 'RGB') | |
quad += pad[0:2] | |
# Transform. | |
img = img.transform((4096, 4096), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) | |
img = img.resize((1024, 1024), PIL.Image.ANTIALIAS) | |
img = np.asarray(img).transpose(2, 0, 1) | |
# Verify MD5. | |
md5 = hashlib.md5() | |
md5.update(img.tobytes()) | |
assert md5.hexdigest() == fields['proc_md5'][idx] | |
# Load delta image and original JPG. | |
with zipfile.ZipFile(os.path.join(delta_dir, 'deltas%05d.zip' % (idx - idx % 1000)), 'r') as zip: | |
delta_bytes = zip.read('delta%05d.dat' % idx) | |
with open(orig_path, 'rb') as file: | |
orig_bytes = file.read() | |
# Decrypt delta image, using original JPG data as decryption key. | |
algorithm = cryptography.hazmat.primitives.hashes.SHA256() | |
backend = cryptography.hazmat.backends.default_backend() | |
salt = bytes(orig_file, 'ascii') | |
kdf = cryptography.hazmat.primitives.kdf.pbkdf2.PBKDF2HMAC(algorithm=algorithm, length=32, salt=salt, iterations=100000, backend=backend) | |
key = base64.urlsafe_b64encode(kdf.derive(orig_bytes)) | |
delta = np.frombuffer(bz2.decompress(cryptography.fernet.Fernet(key).decrypt(delta_bytes)), dtype=np.uint8).reshape(3, 1024, 1024) | |
# Apply delta image. | |
img = img + delta | |
# Verify MD5. | |
md5 = hashlib.md5() | |
md5.update(img.tobytes()) | |
assert md5.hexdigest() == fields['final_md5'][idx] | |
return img | |
with TFRecordExporter(tfrecord_dir, indices.size) as tfr: | |
order = tfr.choose_shuffled_order() | |
with ThreadPool(num_threads) as pool: | |
for img in pool.process_items_concurrently(indices[order].tolist(), process_func=process_func, max_items_in_flight=num_tasks): | |
tfr.add_image(img) | |
#---------------------------------------------------------------------------- | |
def create_from_images(tfrecord_dir, image_dir, shuffle): | |
print('Loading images from "%s"' % image_dir) | |
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*'))) | |
if len(image_filenames) == 0: | |
error('No input images found') | |
img = np.asarray(PIL.Image.open(image_filenames[0])) | |
resolution = img.shape[0] | |
channels = img.shape[2] if img.ndim == 3 else 1 | |
if img.shape[1] != resolution: | |
error('Input images must have the same width and height') | |
if resolution != 2 ** int(np.floor(np.log2(resolution))): | |
error('Input image resolution must be a power-of-two') | |
if channels not in [1, 3]: | |
error('Input images must be stored as RGB or grayscale') | |
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: | |
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames)) | |
for idx in range(order.size): | |
img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) | |
if channels == 1: | |
img = img[np.newaxis, :, :] # HW => CHW | |
else: | |
img = img.transpose(2, 0, 1) # HWC => CHW | |
tfr.add_image(img) | |
#---------------------------------------------------------------------------- | |
def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle): | |
print('Loading HDF5 archive from "%s"' % hdf5_filename) | |
import h5py # conda install h5py | |
with h5py.File(hdf5_filename, 'r') as hdf5_file: | |
hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3]) | |
with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0]) | |
for idx in range(order.size): | |
tfr.add_image(hdf5_data[order[idx]]) | |
npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy' | |
if os.path.isfile(npy_filename): | |
tfr.add_labels(np.load(npy_filename)[order]) | |
#---------------------------------------------------------------------------- | |
def execute_cmdline(argv): | |
prog = argv[0] | |
parser = argparse.ArgumentParser( | |
prog = prog, | |
description = 'Tool for creating, extracting, and visualizing Progressive GAN datasets.', | |
epilog = 'Type "%s <command> -h" for more information.' % prog) | |
subparsers = parser.add_subparsers(dest='command') | |
subparsers.required = True | |
def add_command(cmd, desc, example=None): | |
epilog = 'Example: %s %s' % (prog, example) if example is not None else None | |
return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog) | |
p = add_command( 'display', 'Display images in dataset.', | |
'display datasets/mnist') | |
p.add_argument( 'tfrecord_dir', help='Directory containing dataset') | |
p = add_command( 'extract', 'Extract images from dataset.', | |
'extract datasets/mnist mnist-images') | |
p.add_argument( 'tfrecord_dir', help='Directory containing dataset') | |
p.add_argument( 'output_dir', help='Directory to extract the images into') | |
p = add_command( 'compare', 'Compare two datasets.', | |
'compare datasets/mydataset datasets/mnist') | |
p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset') | |
p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset') | |
p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0) | |
p = add_command( 'create_mnist', 'Create dataset for MNIST.', | |
'create_mnist datasets/mnist ~/downloads/mnist') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'mnist_dir', help='Directory containing MNIST') | |
p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.', | |
'create_mnistrgb datasets/mnistrgb ~/downloads/mnist') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'mnist_dir', help='Directory containing MNIST') | |
p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000) | |
p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123) | |
p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.', | |
'create_cifar10 datasets/cifar10 ~/downloads/cifar10') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10') | |
p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.', | |
'create_cifar100 datasets/cifar100 ~/downloads/cifar100') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100') | |
p = add_command( 'create_svhn', 'Create dataset for SVHN.', | |
'create_svhn datasets/svhn ~/downloads/svhn') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'svhn_dir', help='Directory containing SVHN') | |
p = add_command( 'create_lsun', 'Create dataset for single LSUN category.', | |
'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') | |
p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256) | |
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) | |
p = add_command( 'create_celeba', 'Create dataset for CelebA.', | |
'create_celeba datasets/celeba ~/downloads/celeba') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'celeba_dir', help='Directory containing CelebA') | |
p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89) | |
p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121) | |
p = add_command( 'create_celebahq', 'Create dataset for CelebA-HQ.', | |
'create_celebahq datasets/celebahq ~/downloads/celeba ~/downloads/celeba-hq-deltas') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'celeba_dir', help='Directory containing CelebA') | |
p.add_argument( 'delta_dir', help='Directory containing CelebA-HQ deltas') | |
p.add_argument( '--num_threads', help='Number of concurrent threads (default: 4)', type=int, default=4) | |
p.add_argument( '--num_tasks', help='Number of concurrent processing tasks (default: 100)', type=int, default=100) | |
p = add_command( 'create_from_images', 'Create dataset from a directory full of images.', | |
'create_from_images datasets/mydataset myimagedir') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'image_dir', help='Directory containing the images') | |
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) | |
p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.', | |
'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images') | |
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) | |
args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h']) | |
func = globals()[args.command] | |
del args.command | |
func(**vars(args)) | |
#---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
execute_cmdline(sys.argv) | |
#---------------------------------------------------------------------------- | |