diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae36f63d8859ec0c60dcbfe67c4ac324e751ddf7 --- /dev/null +++ b/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/util/__init__.pyc b/util/__init__.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c392f407f3f9e827bcc55415a3e1f9b81c90184 Binary files /dev/null and b/util/__init__.pyc differ diff --git a/util/__pycache__/__init__.cpython-36.pyc b/util/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79f57d780fcd8d6435e5814f221dcfa6cd3ed170 Binary files /dev/null and b/util/__pycache__/__init__.cpython-36.pyc differ diff --git a/util/__pycache__/__init__.cpython-37.pyc b/util/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14fdfd971d5ee9a09eac0859907d78e9dfa2e221 Binary files /dev/null and b/util/__pycache__/__init__.cpython-37.pyc differ diff --git a/util/__pycache__/__init__.cpython-38.pyc b/util/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bb33973b046dec62e812676ac03071b023ab9e7 Binary files /dev/null and b/util/__pycache__/__init__.cpython-38.pyc differ diff --git a/util/__pycache__/__init__.cpython-39.pyc b/util/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..588db021bf9941c832b3c68a8cd3a103f5f4304c Binary files /dev/null and b/util/__pycache__/__init__.cpython-39.pyc differ diff --git a/util/__pycache__/html.cpython-36.pyc b/util/__pycache__/html.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e267b2a2f9424019de2cb3017503a7a85ed5cc8 Binary files /dev/null and b/util/__pycache__/html.cpython-36.pyc differ diff --git a/util/__pycache__/html.cpython-37.pyc b/util/__pycache__/html.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a6fc285a0977fc8b6ac91828be7c4a841fb3a8c Binary files /dev/null and b/util/__pycache__/html.cpython-37.pyc differ diff --git a/util/__pycache__/misc.cpython-36.pyc b/util/__pycache__/misc.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e817e73ed71735cc03ee39bdd6821e935497cbe Binary files /dev/null and b/util/__pycache__/misc.cpython-36.pyc differ diff --git a/util/__pycache__/misc.cpython-37.pyc b/util/__pycache__/misc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5a46050ef9e421ccd6fa9c573b5d168fcfd0480 Binary files /dev/null and b/util/__pycache__/misc.cpython-37.pyc differ diff --git a/util/__pycache__/params.cpython-37.pyc b/util/__pycache__/params.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebbe137af547756022137886f846e2c47055144f Binary files /dev/null and b/util/__pycache__/params.cpython-37.pyc differ diff --git a/util/__pycache__/util.cpython-36.pyc b/util/__pycache__/util.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b16edd442fc5fa6ae3ac105f413ae7390e369fe Binary files /dev/null and b/util/__pycache__/util.cpython-36.pyc differ diff --git a/util/__pycache__/util.cpython-37.pyc b/util/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d125f3dabee27823ba330cc59a710fc26d30e23f Binary files /dev/null and b/util/__pycache__/util.cpython-37.pyc differ diff --git a/util/__pycache__/util.cpython-38.pyc b/util/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17fdc53c608e350f10e6cdffb0e8fa81636150ad Binary files /dev/null and b/util/__pycache__/util.cpython-38.pyc differ diff --git a/util/__pycache__/util.cpython-39.pyc b/util/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4137011c41cbf8d240c5e1032108f951609ee3b Binary files /dev/null and b/util/__pycache__/util.cpython-39.pyc differ diff --git a/util/__pycache__/visualizer.cpython-36.pyc b/util/__pycache__/visualizer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..647da318a5b08df7da5e879f42a5989bde2b18ef Binary files /dev/null and b/util/__pycache__/visualizer.cpython-36.pyc differ diff --git a/util/__pycache__/visualizer.cpython-37.pyc b/util/__pycache__/visualizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6334979dd8581ef65f4d7963d7494341ec2392a3 Binary files /dev/null and b/util/__pycache__/visualizer.cpython-37.pyc differ diff --git a/util/html.py b/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68 --- /dev/null +++ b/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3124e56b0d26aa865754d23bfc7f9456b14c9f31 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,465 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/util/models/BigGAN_layers.py b/util/models/BigGAN_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..af85720a6b39103fd418dbde789e397fff722fc0 --- /dev/null +++ b/util/models/BigGAN_layers.py @@ -0,0 +1,469 @@ +''' Layers + This file contains various layers for the BigGAN models. +''' +import numpy as np +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P + +from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d + +# Projection of x onto y +def proj(x, y): + return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) + + +# Orthogonalize x wrt list of vectors ys +def gram_schmidt(x, ys): + for y in ys: + x = x - proj(x, y) + return x + + +# Apply num_itrs steps of the power method to estimate top N singular values. +def power_iteration(W, u_, update=True, eps=1e-12): + # Lists holding singular vectors and values + us, vs, svs = [], [], [] + for i, u in enumerate(u_): + # Run one step of the power iteration + with torch.no_grad(): + v = torch.matmul(u, W) + # Run Gram-Schmidt to subtract components of all other singular vectors + v = F.normalize(gram_schmidt(v, vs), eps=eps) + # Add to the list + vs += [v] + # Update the other singular vector + u = torch.matmul(v, W.t()) + # Run Gram-Schmidt to subtract components of all other singular vectors + u = F.normalize(gram_schmidt(u, us), eps=eps) + # Add to the list + us += [u] + if update: + u_[i][:] = u + # Compute this singular value and add it to the list + svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] + # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] + return svs, us, vs + + +# Convenience passthrough function +class identity(nn.Module): + def forward(self, input): + return input + + +# Spectral normalization base class +class SN(object): + def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): + # Number of power iterations per step + self.num_itrs = num_itrs + # Number of singular values + self.num_svs = num_svs + # Transposed? + self.transpose = transpose + # Epsilon value for avoiding divide-by-0 + self.eps = eps + # Register a singular vector for each sv + for i in range(self.num_svs): + self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) + self.register_buffer('sv%d' % i, torch.ones(1)) + + # Singular vectors (u side) + @property + def u(self): + return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] + + # Singular values; + # note that these buffers are just for logging and are not used in training. + @property + def sv(self): + return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] + + # Compute the spectrally-normalized weight + def W_(self): + W_mat = self.weight.view(self.weight.size(0), -1) + if self.transpose: + W_mat = W_mat.t() + # Apply num_itrs power iterations + for _ in range(self.num_itrs): + svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) + # Update the svs + if self.training: + with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! + for i, sv in enumerate(svs): + self.sv[i][:] = sv + return self.weight / svs[0] + + +# 2D Conv layer with spectral norm +class SNConv2d(nn.Conv2d, SN): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) + + def forward(self, x): + return F.conv2d(x, self.W_(), self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# Linear layer with spectral norm +class SNLinear(nn.Linear, SN): + def __init__(self, in_features, out_features, bias=True, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Linear.__init__(self, in_features, out_features, bias) + SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) + + def forward(self, x): + return F.linear(x, self.W_(), self.bias) + + +# Embedding layer with spectral norm +# We use num_embeddings as the dim instead of embedding_dim here +# for convenience sake +class SNEmbedding(nn.Embedding, SN): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, + max_norm=None, norm_type=2, scale_grad_by_freq=False, + sparse=False, _weight=None, + num_svs=1, num_itrs=1, eps=1e-12): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight) + SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) + + def forward(self, x): + return F.embedding(x, self.W_()) + + +# A non-local block as used in SA-GAN +# Note that the implementation as described in the paper is largely incorrect; +# refer to the released code for the actual implementation. +class Attention(nn.Module): + def __init__(self, ch, which_conv=SNConv2d, name='attention'): + super(Attention, self).__init__() + # Channel multiplier + self.ch = ch + self.which_conv = which_conv + self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) + self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) + self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) + # Learnable gain parameter + self.gamma = P(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + # Apply convs + theta = self.theta(x) + phi = F.max_pool2d(self.phi(x), [2, 2]) + g = F.max_pool2d(self.g(x), [2, 2]) + # Perform reshapes + theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) + try: + phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) + except: + print(phi.shape) + g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) + return self.gamma * o + x + + +# Fused batchnorm op +def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): + # Apply scale and shift--if gain and bias are provided, fuse them here + # Prepare scale + scale = torch.rsqrt(var + eps) + # If a gain is provided, use it + if gain is not None: + scale = scale * gain + # Prepare shift + shift = mean * scale + # If bias is provided, use it + if bias is not None: + shift = shift - bias + return x * scale - shift + # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. + + +# Manual BN +# Calculate means and variances using mean-of-squares minus mean-squared +def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): + # Cast x to float32 if necessary + float_x = x.float() + # Calculate expected value of x (m) and expected value of x**2 (m2) + # Mean of x + m = torch.mean(float_x, [0, 2, 3], keepdim=True) + # Mean of x squared + m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) + # Calculate variance as mean of squared minus mean squared. + var = (m2 - m ** 2) + # Cast back to float 16 if necessary + var = var.type(x.type()) + m = m.type(x.type()) + # Return mean and variance for updating stored mean/var if requested + if return_mean_var: + return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() + else: + return fused_bn(x, m, var, gain, bias, eps) + + +# My batchnorm, supports standing stats +class myBN(nn.Module): + def __init__(self, num_channels, eps=1e-5, momentum=0.1): + super(myBN, self).__init__() + # momentum for updating running stats + self.momentum = momentum + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Register buffers + self.register_buffer('stored_mean', torch.zeros(num_channels)) + self.register_buffer('stored_var', torch.ones(num_channels)) + self.register_buffer('accumulation_counter', torch.zeros(1)) + # Accumulate running means and vars + self.accumulate_standing = False + + # reset standing stats + def reset_stats(self): + self.stored_mean[:] = 0 + self.stored_var[:] = 0 + self.accumulation_counter[:] = 0 + + def forward(self, x, gain, bias): + if self.training: + out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) + # If accumulating standing stats, increment them + if self.accumulate_standing: + self.stored_mean[:] = self.stored_mean + mean.data + self.stored_var[:] = self.stored_var + var.data + self.accumulation_counter += 1.0 + # If not accumulating standing stats, take running averages + else: + self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum + self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum + return out + # If not in training mode, use the stored statistics + else: + mean = self.stored_mean.view(1, -1, 1, 1) + var = self.stored_var.view(1, -1, 1, 1) + # If using standing stats, divide them by the accumulation counter + if self.accumulate_standing: + mean = mean / self.accumulation_counter + var = var / self.accumulation_counter + return fused_bn(x, mean, var, gain, bias, self.eps) + + +# Simple function to handle groupnorm norm stylization +def groupnorm(x, norm_style): + # If number of channels specified in norm_style: + if 'ch' in norm_style: + ch = int(norm_style.split('_')[-1]) + groups = max(int(x.shape[1]) // ch, 1) + # If number of groups specified in norm style + elif 'grp' in norm_style: + groups = int(norm_style.split('_')[-1]) + # If neither, default to groups = 16 + else: + groups = 16 + return F.group_norm(x, groups) + + +# Class-conditional bn +# output size is the number of channels, input size is for the linear layers +# Andy's Note: this class feels messy but I'm not really sure how to clean it up +# Suggestions welcome! (By which I mean, refactor this and make a pull request +# if you want to make this more readable/usable). +class ccbn(nn.Module): + def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False, norm_style='bn', ): + super(ccbn, self).__init__() + self.output_size, self.input_size = output_size, input_size + # Prepare gain and bias layers + self.gain = which_linear(input_size, output_size) + self.bias = which_linear(input_size, output_size) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # Norm style? + self.norm_style = norm_style + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif self.mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + elif self.norm_style in ['bn', 'in']: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y): + # Calculate class-conditional gains and biases + gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) + bias = self.bias(y).view(y.size(0), -1, 1, 1) + # If using my batchnorm + if self.mybn or self.cross_replica: + return self.bn(x, gain=gain, bias=bias) + # else: + else: + if self.norm_style == 'bn': + out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'in': + out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, + self.training, 0.1, self.eps) + elif self.norm_style == 'gn': + out = groupnorm(x, self.normstyle) + elif self.norm_style == 'nonorm': + out = x + return out * gain + bias + + def extra_repr(self): + s = 'out: {output_size}, in: {input_size},' + s += ' cross_replica={cross_replica}' + return s.format(**self.__dict__) + + +# Normal, non-class-conditional BN +class bn(nn.Module): + def __init__(self, output_size, eps=1e-5, momentum=0.1, + cross_replica=False, mybn=False): + super(bn, self).__init__() + self.output_size = output_size + # Prepare gain and bias layers + self.gain = P(torch.ones(output_size), requires_grad=True) + self.bias = P(torch.zeros(output_size), requires_grad=True) + # epsilon to avoid dividing by 0 + self.eps = eps + # Momentum + self.momentum = momentum + # Use cross-replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + + if self.cross_replica: + self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) + elif mybn: + self.bn = myBN(output_size, self.eps, self.momentum) + # Register buffers if neither of the above + else: + self.register_buffer('stored_mean', torch.zeros(output_size)) + self.register_buffer('stored_var', torch.ones(output_size)) + + def forward(self, x, y=None): + if self.cross_replica or self.mybn: + gain = self.gain.view(1, -1, 1, 1) + bias = self.bias.view(1, -1, 1, 1) + return self.bn(x, gain=gain, bias=bias) + else: + return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, + self.bias, self.training, self.momentum, self.eps) + + +# Generator blocks +# Note that this class assumes the kernel size and padding (and any other +# settings) have been selected in the main generator module and passed in +# through the which_conv arg. Similar rules apply with which_bn (the input +# size [which is actually the number of channels of the conditional info] must +# be preselected) +class GBlock(nn.Module): + def __init__(self, in_channels, out_channels, + which_conv1=nn.Conv2d, which_conv2=nn.Conv2d, which_bn=bn, activation=None, + upsample=None): + super(GBlock, self).__init__() + + self.in_channels, self.out_channels = in_channels, out_channels + self.which_conv1, self.which_conv2, self.which_bn = which_conv1, which_conv2, which_bn + self.activation = activation + self.upsample = upsample + # Conv layers + self.conv1 = self.which_conv1(self.in_channels, self.out_channels) + self.conv2 = self.which_conv2(self.out_channels, self.out_channels) + self.learnable_sc = in_channels != out_channels or upsample + if self.learnable_sc: + self.conv_sc = self.which_conv1(in_channels, out_channels, + kernel_size=1, padding=0) + # Batchnorm layers + self.bn1 = self.which_bn(in_channels) + self.bn2 = self.which_bn(out_channels) + # upsample layers + self.upsample = upsample + + def forward(self, x, y): + h = self.activation(self.bn1(x, y)) + # h = self.activation(x) + # h=x + if self.upsample: + h = self.upsample(h) + x = self.upsample(x) + h = self.conv1(h) + h = self.activation(self.bn2(h, y)) + # h = self.activation(h) + h = self.conv2(h) + if self.learnable_sc: + x = self.conv_sc(x) + return h + x + + +# Residual block for the discriminator +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, + preactivation=False, activation=None, downsample=None, ): + super(DBlock, self).__init__() + self.in_channels, self.out_channels = in_channels, out_channels + # If using wide D (as in SA-GAN and BigGAN), change the channel pattern + self.hidden_channels = self.out_channels if wide else self.in_channels + self.which_conv = which_conv + self.preactivation = preactivation + self.activation = activation + self.downsample = downsample + + # Conv layers + self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) + self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) + self.learnable_sc = True if (in_channels != out_channels) or downsample else False + if self.learnable_sc: + self.conv_sc = self.which_conv(in_channels, out_channels, + kernel_size=1, padding=0) + + def shortcut(self, x): + if self.preactivation: + if self.learnable_sc: + x = self.conv_sc(x) + if self.downsample: + x = self.downsample(x) + else: + if self.downsample: + x = self.downsample(x) + if self.learnable_sc: + x = self.conv_sc(x) + return x + + def forward(self, x): + if self.preactivation: + # h = self.activation(x) # NOT TODAY SATAN + # Andy's note: This line *must* be an out-of-place ReLU or it + # will negatively affect the shortcut connection. + h = F.relu(x) + else: + h = x + h = self.conv1(h) + h = self.conv2(self.activation(h)) + if self.downsample: + h = self.downsample(h) + + return h + self.shortcut(x) + +# dogball \ No newline at end of file diff --git a/util/models/BigGAN_networks.py b/util/models/BigGAN_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..385b73b5cfd5dbd9d64c9ae10defce995ca09d81 --- /dev/null +++ b/util/models/BigGAN_networks.py @@ -0,0 +1,841 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT + +import numpy as np +import math +import functools + +import torch +import torch.nn as nn +from torch.nn import init +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import Parameter as P +from .transformer import Transformer +from . import BigGAN_layers as layers +from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d +from util.util import to_device, load_network +from .networks import init_weights +from params import * +# Attention is passed in in the format '32_64' to mean applying an attention +# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. + +from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock + +class Decoder(nn.Module): + def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'): + super(Decoder, self).__init__() + + self.model = [] + self.model += [ResBlocks(n_res, dim, res_norm, + activ, pad_type=pad_type)] + for i in range(ups): + self.model += [nn.Upsample(scale_factor=2), + Conv2dBlock(dim, dim // 2, 5, 1, 2, + norm='in', + activation=activ, + pad_type=pad_type)] + dim //= 2 + self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3, + norm='none', + activation='tanh', + pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + y = self.model(x) + + return y + + + +def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): + arch = {} + arch[512] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], + 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], + 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)], + 'resolution': [8, 16, 32, 64, 128, 256, 512], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 10)}} + arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]], + 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)], + 'resolution': [8, 16, 32, 64, 128, 256], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 9)}} + arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2)], + 'resolution': [8, 16, 32, 64, 128], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 8)}} + arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]], + 'out_channels': [ch * item for item in [16, 8, 4, 2]], + 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2)], + 'resolution': [8, 16, 32, 64], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 7)}} + + arch[63] = {'in_channels': [ch * item for item in [16, 16, 8, 4]], + 'out_channels': [ch * item for item in [16, 8, 4, 2]], + 'upsample': [(2, 2), (2, 2), (2, 2), (2,1)], + 'resolution': [8, 16, 32, 64], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 7)}, + 'kernel1': [3, 3, 3, 3], + 'kernel2': [3, 3, 1, 1] + } + + arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]], + 'out_channels': [ch * item for item in [4, 4, 4]], + 'upsample': [(2, 2), (2, 2), (2, 2)], + 'resolution': [8, 16, 32], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}} + + arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]], + 'out_channels': [ch * item for item in [4, 4, 4]], + 'upsample': [(2, 2), (2, 2), (2, 2)], + 'resolution': [8, 16, 32], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}, + 'kernel1': [3, 3, 3], + 'kernel2': [3, 3, 1] + } + + arch[129] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], + 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], + 'upsample': [(2,2), (2,2), (2,2), (2,2), (2,2), (1,2), (1,2)], + 'resolution': [8, 16, 32, 64, 128, 256, 512], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 10)}} + + arch[33] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]], + 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]], + 'upsample': [(2,2), (2,2), (2,2), (1,2), (1,2)], + 'resolution': [8, 16, 32, 64, 128], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 8)}} + + arch[31] = {'in_channels': [ch * item for item in [16, 16, 4, 2]], + 'out_channels': [ch * item for item in [16, 4, 2, 1]], + 'upsample': [(2,2), (2,2), (2,2), (1,2)], + 'resolution': [8, 16, 32, 64], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 7)}, + 'kernel1':[3, 3, 3, 3], + 'kernel2': [3, 1, 1, 1]} + + arch[16] = {'in_channels': [ch * item for item in [8, 4, 2]], + 'out_channels': [ch * item for item in [4, 2, 1]], + 'upsample': [(2,2), (2,2), (2,1)], + 'resolution': [8, 16, 16], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}, + 'kernel1':[3, 3, 3], + 'kernel2': [3, 3, 1]} + + arch[17] = {'in_channels': [ch * item for item in [8, 4, 2]], + 'out_channels': [ch * item for item in [4, 2, 1]], + 'upsample': [(2,2), (2,2), (2,1)], + 'resolution': [8, 16, 16], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}, + 'kernel1':[3, 3, 3], + 'kernel2': [3, 3, 1]} + + arch[20] = {'in_channels': [ch * item for item in [8, 4, 2]], + 'out_channels': [ch * item for item in [4, 2, 1]], + 'upsample': [(2,2), (2,2), (2,1)], + 'resolution': [8, 16, 16], + 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')]) + for i in range(3, 6)}, + 'kernel1':[3, 3, 3], + 'kernel2': [3, 1, 1]} + + return arch + + +class Generator(nn.Module): + def __init__(self, G_ch=64, dim_z=128, bottom_width=4, bottom_height=4,resolution=128, + G_kernel_size=3, G_attn='64', n_classes=1000, + num_G_SVs=1, num_G_SV_itrs=1, + G_shared=True, shared_dim=0, no_hier=False, + cross_replica=False, mybn=False, + G_activation=nn.ReLU(inplace=False), + BN_eps=1e-5, SN_eps=1e-12, G_fp16=False, + G_init='ortho', skip_init=False, + G_param='SN', norm_style='bn',gpu_ids=[], bn_linear='embed', input_nc=3, + one_hot=False, first_layer=False, one_hot_k=1, + **kwargs): + super(Generator, self).__init__() + self.name = 'G' + # Use class only in first layer + self.first_layer = first_layer + # gpu-ids + self.gpu_ids = gpu_ids + # Use one hot vector representation for input class + self.one_hot = one_hot + # Use one hot k vector representation for input class if k is larger than 0. If it's 0, simly use the class number and not a k-hot encoding. + self.one_hot_k = one_hot_k + # Channel width mulitplier + self.ch = G_ch + # Dimensionality of the latent space + self.dim_z = dim_z + # The initial width dimensions + self.bottom_width = bottom_width + # The initial height dimension + self.bottom_height = bottom_height + # Resolution of the output + self.resolution = resolution + # Kernel size? + self.kernel_size = G_kernel_size + # Attention? + self.attention = G_attn + # number of classes, for use in categorical conditional generation + self.n_classes = n_classes + # Use shared embeddings? + self.G_shared = G_shared + # Dimensionality of the shared embedding? Unused if not using G_shared + self.shared_dim = shared_dim if shared_dim > 0 else dim_z + # Hierarchical latent space? + self.hier = not no_hier + # Cross replica batchnorm? + self.cross_replica = cross_replica + # Use my batchnorm? + self.mybn = mybn + # nonlinearity for residual blocks + self.activation = G_activation + # Initialization style + self.init = G_init + # Parameterization style + self.G_param = G_param + # Normalization style + self.norm_style = norm_style + # Epsilon for BatchNorm? + self.BN_eps = BN_eps + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # fp16? + self.fp16 = G_fp16 + # Architecture dict + self.arch = G_arch(self.ch, self.attention)[resolution] + self.bn_linear = bn_linear + + #self.transformer = Transformer(d_model = 2560) + #self.input_proj = nn.Conv2d(512, 2560, kernel_size=1) + self.linear_q = nn.Linear(512,2048*2) + + self.DETR = build() + self.DEC = Decoder(res_norm = 'in') + # If using hierarchical latents, adjust z + if self.hier: + # Number of places z slots into + self.num_slots = len(self.arch['in_channels']) + 1 + self.z_chunk_size = (self.dim_z // self.num_slots) + # Recalculate latent dimensionality for even splitting into chunks + self.dim_z = self.z_chunk_size * self.num_slots + else: + self.num_slots = 1 + self.z_chunk_size = 0 + + # Which convs, batchnorms, and linear layers to use + if self.G_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + if one_hot: + self.which_embedding = functools.partial(layers.SNLinear, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + else: + self.which_embedding = nn.Embedding + + bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared + else self.which_embedding) + if self.bn_linear=='SN': + bn_linear = functools.partial(self.which_linear, bias=False) + if self.G_shared: + input_size = self.shared_dim + self.z_chunk_size + elif self.hier: + if self.first_layer: + input_size = self.z_chunk_size + else: + input_size = self.n_classes + self.z_chunk_size + self.which_bn = functools.partial(layers.ccbn, + which_linear=bn_linear, + cross_replica=self.cross_replica, + mybn=self.mybn, + input_size=input_size, + norm_style=self.norm_style, + eps=self.BN_eps) + else: + input_size = self.n_classes + self.which_bn = functools.partial(layers.bn, + cross_replica=self.cross_replica, + mybn=self.mybn, + eps=self.BN_eps) + + + + + # Prepare model + # If not using shared embeddings, self.shared is just a passthrough + self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared + else layers.identity()) + # First linear layer + # The parameters for the first linear layer depend on the different input variations. + if self.first_layer: + if self.one_hot: + self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes, + self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) + else: + self.linear = self.which_linear(self.dim_z // self.num_slots + 1, + self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) + if self.one_hot_k==1: + self.linear = self.which_linear((self.dim_z // self.num_slots) * self.n_classes, + self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) + if self.one_hot_k>1: + self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes*self.one_hot_k, + self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) + + + else: + self.linear = self.which_linear(self.dim_z // self.num_slots, + self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + # while the inner loop is over a given block + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + if 'kernel1' in self.arch.keys(): + padd1 = 1 if self.arch['kernel1'][index]>1 else 0 + padd2 = 1 if self.arch['kernel2'][index]>1 else 0 + conv1 = functools.partial(layers.SNConv2d, + kernel_size=self.arch['kernel1'][index], padding=padd1, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + conv2 = functools.partial(layers.SNConv2d, + kernel_size=self.arch['kernel2'][index], padding=padd2, + num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, + eps=self.SN_eps) + self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv1=conv1, + which_conv2=conv2, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, + scale_factor=self.arch['upsample'][index]) + if index < len(self.arch['upsample']) else None))]] + else: + self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv1=self.which_conv, + which_conv2=self.which_conv, + which_bn=self.which_bn, + activation=self.activation, + upsample=(functools.partial(F.interpolate, scale_factor=self.arch['upsample'][index]) + if index < len(self.arch['upsample']) else None))]] + + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] + + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + + # output layer: batchnorm-relu-conv. + # Consider using a non-spectral conv here + self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], + cross_replica=self.cross_replica, + mybn=self.mybn), + self.activation, + self.which_conv(self.arch['out_channels'][-1], input_nc)) + + # Initialize weights. Optionally skip init for testing. + if not skip_init: + self = init_weights(self, G_init) + + # Note on this forward function: we pass in a y vector which has + # already been passed through G.shared to enable easy class-wise + # interpolation later. If we passed in the one-hot and then ran it through + # G.shared in this forward function, it would be harder to handle. + def forward(self, x, y_ind, y): + # If hierarchical, concatenate zs and ys + + + h_all = self.DETR(x, y_ind) + #h_all = torch.stack([h_all, h_all, h_all]) + + #h_all_bs = torch.unbind(h_all[-1], 0) + #y_bs = torch.unbind(y_ind, 0) + + #h = torch.stack([h_i[y_j] for h_i,y_j in zip(h_all_bs, y_bs)], 0) + + + + + h = self.linear_q(h_all) + + + h = h.contiguous() + # Reshape - when y is not a single class value but rather an array of classes, the reshape is needed to create + # a separate vertical patch for each input. + if self.first_layer: + # correct reshape + h = h.view(h.size(0), h.shape[1]*2, 4, -1) + h = h.permute(0, 3, 2, 1) + + else: + h = h.view(h.size(0), -1, self.bottom_width, self.bottom_height) + + + #for index, blocklist in enumerate(self.blocks): + # Second inner loop in case block has multiple layers + # for block in blocklist: + # h = block(h, ys[index]) + + #Apply batchnorm-relu-conv-tanh at output + # h = torch.tanh(self.output_layer(h)) + + h = self.DEC(h) + return h + + + + + + + +# Discriminator architecture, same paradigm as G's above +def D_arch(ch=64, attention='64', input_nc=3, ksize='333333', dilation='111111'): + arch = {} + arch[256] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], + 'downsample': [True] * 6 + [False], + 'resolution': [128, 64, 32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 8)}} + arch[128] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample': [True] * 5 + [False], + 'resolution': [64, 32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 8)}} + arch[64] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]], + 'downsample': [True] * 4 + [False], + 'resolution': [32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 7)}} + arch[63] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]], + 'downsample': [True] * 4 + [False], + 'resolution': [32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 7)}} + arch[32] = {'in_channels': [input_nc] + [item * ch for item in [4, 4, 4]], + 'out_channels': [item * ch for item in [4, 4, 4, 4]], + 'downsample': [True, True, False, False], + 'resolution': [16, 16, 16, 16], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 6)}} + arch[129] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], + 'downsample': [True] * 6 + [False], + 'resolution': [128, 64, 32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 8)}} + arch[33] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample': [True] * 5 + [False], + 'resolution': [64, 32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 10)}} + arch[31] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]], + 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]], + 'downsample': [True] * 5 + [False], + 'resolution': [64, 32, 16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 10)}} + arch[16] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]], + 'out_channels': [item * ch for item in [1, 8, 16, 16]], + 'downsample': [True] * 3 + [False], + 'resolution': [16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 5)}} + + arch[17] = {'in_channels': [input_nc] + [ch * item for item in [1, 4]], + 'out_channels': [item * ch for item in [1, 4, 8]], + 'downsample': [True] * 3, + 'resolution': [16, 8, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 5)}} + + + arch[20] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]], + 'out_channels': [item * ch for item in [1, 8, 16, 16]], + 'downsample': [True] * 3 + [False], + 'resolution': [16, 8, 4, 4], + 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')] + for i in range(2, 5)}} + return arch + + +class Discriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, resolution=resolution, + D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + SN_eps=1e-8, output_dim=1, D_mixed_precision=False, D_fp16=False, + D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs): + + super(Discriminator, self).__init__() + self.name = 'D' + # gpu_ids + self.gpu_ids = gpu_ids + # one_hot representation + self.one_hot = one_hot + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention, input_nc)[resolution] + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + if bn_linear=='SN': + self.which_embedding = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + if one_hot: + self.which_embedding = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + # Prepare model + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=(index > 0), + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + + # Initialize weights + if not skip_init: + self = init_weights(self, D_init) + + def forward(self, x, y=None, **kwargs): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + if y is not None: + out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + return out + + def return_features(self, x, y=None): + # Stick x into h for cleaner for loops without flow control + h = x + block_output = [] + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + block_output.append(h) + # Apply global sum pooling as in SN-GAN + # h = torch.sum(self.activation(h), [2, 3]) + return block_output + + + + +class WDiscriminator(nn.Module): + + def __init__(self, D_ch=64, D_wide=True, resolution=resolution, + D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE, + num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), + SN_eps=1e-8, output_dim=NUM_WRITERS, D_mixed_precision=False, D_fp16=False, + D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs): + super(WDiscriminator, self).__init__() + self.name = 'D' + # gpu_ids + self.gpu_ids = gpu_ids + # one_hot representation + self.one_hot = one_hot + # Width multiplier + self.ch = D_ch + # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? + self.D_wide = D_wide + # Resolution + self.resolution = resolution + # Kernel size + self.kernel_size = D_kernel_size + # Attention? + self.attention = D_attn + # Number of classes + self.n_classes = n_classes + # Activation + self.activation = D_activation + # Initialization style + self.init = D_init + # Parameterization style + self.D_param = D_param + # Epsilon for Spectral Norm? + self.SN_eps = SN_eps + # Fp16? + self.fp16 = D_fp16 + # Architecture + self.arch = D_arch(self.ch, self.attention, input_nc)[resolution] + + # Which convs, batchnorms, and linear layers to use + # No option to turn off SN in D right now + if self.D_param == 'SN': + self.which_conv = functools.partial(layers.SNConv2d, + kernel_size=3, padding=1, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_linear = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + self.which_embedding = functools.partial(layers.SNEmbedding, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + if bn_linear=='SN': + self.which_embedding = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + else: + self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) + self.which_linear = nn.Linear + # We use a non-spectral-normed embedding here regardless; + # For some reason applying SN to G's embedding seems to randomly cripple G + self.which_embedding = nn.Embedding + if one_hot: + self.which_embedding = functools.partial(layers.SNLinear, + num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, + eps=self.SN_eps) + # Prepare model + # self.blocks is a doubly-nested list of modules, the outer loop intended + # to be over blocks at a given resolution (resblocks and/or self-attention) + self.blocks = [] + for index in range(len(self.arch['out_channels'])): + self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], + out_channels=self.arch['out_channels'][index], + which_conv=self.which_conv, + wide=self.D_wide, + activation=self.activation, + preactivation=(index > 0), + downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] + # If attention on this block, attach it to the end + if self.arch['attention'][self.arch['resolution'][index]]: + print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) + self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], + self.which_conv)] + # Turn self.blocks into a ModuleList so that it's all properly registered. + self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) + # Linear output layer. The output dimension is typically 1, but may be + # larger if we're e.g. turning this into a VAE with an inference output + self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) + # Embedding for projection discrimination + self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) + self.cross_entropy = nn.CrossEntropyLoss() + # Initialize weights + if not skip_init: + self = init_weights(self, D_init) + + def forward(self, x, y=None, **kwargs): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + # Apply global sum pooling as in SN-GAN + h = torch.sum(self.activation(h), [2, 3]) + # Get initial class-unconditional output + out = self.linear(h) + # Get projection of final featureset onto class vectors and add to evidence + #if y is not None: + #out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) + + loss = self.cross_entropy(out, y.long()) + + return loss + + def return_features(self, x, y=None): + # Stick x into h for cleaner for loops without flow control + h = x + block_output = [] + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + block_output.append(h) + # Apply global sum pooling as in SN-GAN + # h = torch.sum(self.activation(h), [2, 3]) + return block_output + +class Encoder(Discriminator): + def __init__(self, opt, output_dim, **kwargs): + super(Encoder, self).__init__(**vars(opt)) + self.output_layer = nn.Sequential(self.activation, + nn.Conv2d(self.arch['out_channels'][-1], output_dim, kernel_size=(4,2), padding=0, stride=2)) + + def forward(self, x): + # Stick x into h for cleaner for loops without flow control + h = x + # Loop over blocks + for index, blocklist in enumerate(self.blocks): + for block in blocklist: + h = block(h) + out = self.output_layer(h) + return out + +class BiDiscriminator(nn.Module): + def __init__(self, opt): + super(BiDiscriminator, self).__init__() + self.infer_img = Encoder(opt, output_dim=opt.nimg_features) + # self.infer_z = nn.Sequential( + # nn.Conv2d(opt.dim_z, 512, 1, stride=1, bias=False), + # nn.LeakyReLU(inplace=True), + # nn.Dropout2d(p=self.dropout), + # nn.Conv2d(512, opt.nz_features, 1, stride=1, bias=False), + # nn.LeakyReLU(inplace=True), + # nn.Dropout2d(p=self.dropout) + # ) + self.infer_joint = nn.Sequential( + nn.Conv2d(opt.dim_z+opt.nimg_features, 1024, 1, stride=1, bias=True), + nn.ReLU(inplace=True), + + nn.Conv2d(1024, 1024, 1, stride=1, bias=True), + nn.ReLU(inplace=True) + ) + self.final = nn.Conv2d(1024, 1, 1, stride=1, bias=True) + + def forward(self, x, z, **kwargs): + output_x = self.infer_img(x) + # output_z = self.infer_z(z) + if len(z.shape)==2: + z = z.unsqueeze(2).unsqueeze(2).repeat((1,1,1,output_x.shape[3])) + output_features = self.infer_joint(torch.cat([output_x, z], dim=1)) + output = self.final(output_features) + return output + +# Parallelized G_D to minimize cross-gpu communication +# Without this, Generator outputs would get all-gathered and then rebroadcast. +class G_D(nn.Module): + def __init__(self, G, D): + super(G_D, self).__init__() + self.G = G + self.D = D + + def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, + split_D=False): + # If training G, enable grad tape + with torch.set_grad_enabled(train_G): + # Get Generator output given noise + G_z = self.G(z, self.G.shared(gy)) + # Cast as necessary + if self.G.fp16 and not self.D.fp16: + G_z = G_z.float() + if self.D.fp16 and not self.G.fp16: + G_z = G_z.half() + # Split_D means to run D once with real data and once with fake, + # rather than concatenating along the batch dimension. + if split_D: + D_fake = self.D(G_z, gy) + if x is not None: + D_real = self.D(x, dy) + return D_fake, D_real + else: + if return_G_z: + return D_fake, G_z + else: + return D_fake + # If real data is provided, concatenate it with the Generator's output + # along the batch dimension for improved efficiency. + else: + D_input = torch.cat([G_z, x], 0) if x is not None else G_z + D_class = torch.cat([gy, dy], 0) if dy is not None else gy + # Get Discriminator output + D_out = self.D(D_input, D_class) + if x is not None: + return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real + else: + if return_G_z: + return D_out, G_z + else: + return D_out + diff --git a/util/models/OCR_network.py b/util/models/OCR_network.py new file mode 100644 index 0000000000000000000000000000000000000000..39b1c2a4c6a247d55055a82ac4867895d9a8bcce --- /dev/null +++ b/util/models/OCR_network.py @@ -0,0 +1,304 @@ +import torch.nn as nn +from util.util import to_device +from torch.nn import init +import os +import torch +from .networks import * +from params import * + +class BidirectionalLSTM(nn.Module): + + def __init__(self, nIn, nHidden, nOut): + super(BidirectionalLSTM, self).__init__() + + self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) + self.embedding = nn.Linear(nHidden * 2, nOut) + + + def forward(self, input): + recurrent, _ = self.rnn(input) + T, b, h = recurrent.size() + t_rec = recurrent.view(T * b, h) + + output = self.embedding(t_rec) # [T * b, nOut] + output = output.view(T, b, -1) + + return output + + +class CRNN(nn.Module): + + def __init__(self, leakyRelu=False): + super(CRNN, self).__init__() + self.name = 'OCR' + #assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16' + + ks = [3, 3, 3, 3, 3, 3, 2] + ps = [1, 1, 1, 1, 1, 1, 0] + ss = [1, 1, 1, 1, 1, 1, 1] + nm = [64, 128, 256, 256, 512, 512, 512] + + cnn = nn.Sequential() + nh = 256 + dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero + + def convRelu(i, batchNormalization=False): + nIn = 1 if i == 0 else nm[i - 1] + nOut = nm[i] + cnn.add_module('conv{0}'.format(i), + nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) + if batchNormalization: + cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) + if leakyRelu: + cnn.add_module('relu{0}'.format(i), + nn.LeakyReLU(0.2, inplace=True)) + else: + cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) + + convRelu(0) + cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 + convRelu(1) + cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 + convRelu(2, True) + convRelu(3) + cnn.add_module('pooling{0}'.format(2), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + convRelu(4, True) + if resolution==63: + cnn.add_module('pooling{0}'.format(3), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + convRelu(5) + cnn.add_module('pooling{0}'.format(4), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 + convRelu(6, True) # 512x1x16 + + self.cnn = cnn + self.use_rnn = False + if self.use_rnn: + self.rnn = nn.Sequential( + BidirectionalLSTM(512, nh, nh), + BidirectionalLSTM(nh, nh, )) + else: + self.linear = nn.Linear(512, VOCAB_SIZE) + + # replace all nan/inf in gradients to zero + if dealwith_lossnone: + self.register_backward_hook(self.backward_hook) + + self.device = torch.device('cuda:{}'.format(0)) + self.init = 'N02' + # Initialize weights + + self = init_weights(self, self.init) + + def forward(self, input): + # conv features + conv = self.cnn(input) + b, c, h, w = conv.size() + if h!=1: + print('a') + assert h == 1, "the height of conv must be 1" + conv = conv.squeeze(2) + conv = conv.permute(2, 0, 1) # [w, b, c] + + if self.use_rnn: + # rnn features + output = self.rnn(conv) + else: + output = self.linear(conv) + return output + + def backward_hook(self, module, grad_input, grad_output): + for g in grad_input: + g[g != g] = 0 # replace all nan/inf in gradients to zero + + +class OCRLabelConverter(object): + """Convert between str and label. + + NOTE: + Insert `blank` to the alphabet for CTC. + + Args: + alphabet (str): set of the possible characters. + ignore_case (bool, default=True): whether or not to ignore all of the case. + """ + + def __init__(self, alphabet, ignore_case=False): + self._ignore_case = ignore_case + if self._ignore_case: + alphabet = alphabet.lower() + self.alphabet = alphabet + '-' # for `-1` index + + self.dict = {} + for i, char in enumerate(alphabet): + # NOTE: 0 is reserved for 'blank' required by wrap_ctc + self.dict[char] = i + 1 + + def encode(self, text): + """Support batch or single str. + + Args: + text (str or list of str): texts to convert. + + Returns: + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. + torch.IntTensor [n]: length of each text. + """ + ''' + if isinstance(text, str): + text = [ + self.dict[char.lower() if self._ignore_case else char] + for char in text + ] + length = [len(text)] + elif isinstance(text, collections.Iterable): + length = [len(s) for s in text] + text = ''.join(text) + text, _ = self.encode(text) + return (torch.IntTensor(text), torch.IntTensor(length)) + ''' + length = [] + result = [] + for item in text: + item = item.decode('utf-8', 'strict') + length.append(len(item)) + for char in item: + index = self.dict[char] + result.append(index) + + text = result + return (torch.IntTensor(text), torch.IntTensor(length)) + + def decode(self, t, length, raw=False): + """Decode encoded texts back into strs. + + Args: + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. + torch.IntTensor [n]: length of each text. + + Raises: + AssertionError: when the texts and its length does not match. + + Returns: + text (str or list of str): texts to convert. + """ + if length.numel() == 1: + length = length[0] + assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), + length) + if raw: + return ''.join([self.alphabet[i - 1] for i in t]) + else: + char_list = [] + for i in range(length): + if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): + char_list.append(self.alphabet[t[i] - 1]) + return ''.join(char_list) + else: + # batch mode + assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( + t.numel(), length.sum()) + texts = [] + index = 0 + for i in range(length.numel()): + l = length[i] + texts.append( + self.decode( + t[index:index + l], torch.IntTensor([l]), raw=raw)) + index += l + return texts + + +class strLabelConverter(object): + """Convert between str and label. + NOTE: + Insert `blank` to the alphabet for CTC. + Args: + alphabet (str): set of the possible characters. + ignore_case (bool, default=True): whether or not to ignore all of the case. + """ + + def __init__(self, alphabet, ignore_case=False): + self._ignore_case = ignore_case + if self._ignore_case: + alphabet = alphabet.lower() + self.alphabet = alphabet + '-' # for `-1` index + + self.dict = {} + for i, char in enumerate(alphabet): + # NOTE: 0 is reserved for 'blank' required by wrap_ctc + self.dict[char] = i + 1 + + def encode(self, text): + """Support batch or single str. + Args: + text (str or list of str): texts to convert. + Returns: + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. + torch.IntTensor [n]: length of each text. + """ + ''' + if isinstance(text, str): + text = [ + self.dict[char.lower() if self._ignore_case else char] + for char in text + ] + length = [len(text)] + elif isinstance(text, collections.Iterable): + length = [len(s) for s in text] + text = ''.join(text) + text, _ = self.encode(text) + return (torch.IntTensor(text), torch.IntTensor(length)) + ''' + length = [] + result = [] + results = [] + for item in text: + item = item.decode('utf-8', 'strict') + length.append(len(item)) + for char in item: + index = self.dict[char] + result.append(index) + results.append(result) + result = [] + + return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length)) + + def decode(self, t, length, raw=False): + """Decode encoded texts back into strs. + Args: + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. + torch.IntTensor [n]: length of each text. + Raises: + AssertionError: when the texts and its length does not match. + Returns: + text (str or list of str): texts to convert. + """ + if length.numel() == 1: + length = length[0] + assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), + length) + if raw: + return ''.join([self.alphabet[i - 1] for i in t]) + else: + char_list = [] + for i in range(length): + if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): + char_list.append(self.alphabet[t[i] - 1]) + return ''.join(char_list) + else: + # batch mode + assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( + t.numel(), length.sum()) + texts = [] + index = 0 + for i in range(length.numel()): + l = length[i] + texts.append( + self.decode( + t[index:index + l], torch.IntTensor([l]), raw=raw)) + index += l + return texts + + diff --git a/util/models/__init__.py b/util/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d25ffe5abdc7b633a164ac96d023ebbfd9d00c2 --- /dev/null +++ b/util/models/__init__.py @@ -0,0 +1,65 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +""" + +import importlib + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/util/models/__pycache__/BigGAN_layers.cpython-36.pyc b/util/models/__pycache__/BigGAN_layers.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..031dfe4d4c1667720011d927eabf15139e4ff012 Binary files /dev/null and b/util/models/__pycache__/BigGAN_layers.cpython-36.pyc differ diff --git a/util/models/__pycache__/BigGAN_networks.cpython-36.pyc b/util/models/__pycache__/BigGAN_networks.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9330de304e8a74bc37ca5554aa0d224af4ae8aae Binary files /dev/null and b/util/models/__pycache__/BigGAN_networks.cpython-36.pyc differ diff --git a/util/models/__pycache__/OCR_network.cpython-36.pyc b/util/models/__pycache__/OCR_network.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e3736a538d8312f3ad3a1865fd305e1146daaf5 Binary files /dev/null and b/util/models/__pycache__/OCR_network.cpython-36.pyc differ diff --git a/util/models/__pycache__/__init__.cpython-36.pyc b/util/models/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f7827b146d0d3fa231614d180f24a3459ef060e Binary files /dev/null and b/util/models/__pycache__/__init__.cpython-36.pyc differ diff --git a/util/models/__pycache__/blocks.cpython-36.pyc b/util/models/__pycache__/blocks.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4df216f3a341e16c0c043892f3c4d56c05f2f26d Binary files /dev/null and b/util/models/__pycache__/blocks.cpython-36.pyc differ diff --git a/util/models/__pycache__/inception.cpython-36.pyc b/util/models/__pycache__/inception.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..556213423deaf05c1a1009176ff0a41704d30e64 Binary files /dev/null and b/util/models/__pycache__/inception.cpython-36.pyc differ diff --git a/util/models/__pycache__/model.cpython-36.pyc b/util/models/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c6de38b6e111e37bee64ba9975070ead6ad4f5e Binary files /dev/null and b/util/models/__pycache__/model.cpython-36.pyc differ diff --git a/util/models/__pycache__/model_.cpython-36.pyc b/util/models/__pycache__/model_.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7d9dec0c29fbd7d85117a39c25df7bf2a145409 Binary files /dev/null and b/util/models/__pycache__/model_.cpython-36.pyc differ diff --git a/util/models/__pycache__/networks.cpython-36.pyc b/util/models/__pycache__/networks.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cde1887c10ad7a384a36538378fb5e21e4cad474 Binary files /dev/null and b/util/models/__pycache__/networks.cpython-36.pyc differ diff --git a/util/models/__pycache__/transformer.cpython-36.pyc b/util/models/__pycache__/transformer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6a8c733e447a06abaa3e00a9a10864b050b150e Binary files /dev/null and b/util/models/__pycache__/transformer.cpython-36.pyc differ diff --git a/util/models/blocks.py b/util/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..470084c509fa9193fbc3a02d83c2ceefd12368d7 --- /dev/null +++ b/util/models/blocks.py @@ -0,0 +1,190 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class ResBlocks(nn.Module): + def __init__(self, num_blocks, dim, norm, activation, pad_type): + super(ResBlocks, self).__init__() + self.model = [] + for i in range(num_blocks): + self.model += [ResBlock(dim, + norm=norm, + activation=activation, + pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x) + + +class ResBlock(nn.Module): + def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): + super(ResBlock, self).__init__() + model = [] + model += [Conv2dBlock(dim, dim, 3, 1, 1, + norm=norm, + activation=activation, + pad_type=pad_type)] + model += [Conv2dBlock(dim, dim, 3, 1, 1, + norm=norm, + activation='none', + pad_type=pad_type)] + self.model = nn.Sequential(*model) + + def forward(self, x): + residual = x + out = self.model(x) + out += residual + return out + + +class ActFirstResBlock(nn.Module): + def __init__(self, fin, fout, fhid=None, + activation='lrelu', norm='none'): + super().__init__() + self.learned_shortcut = (fin != fout) + self.fin = fin + self.fout = fout + self.fhid = min(fin, fout) if fhid is None else fhid + self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1, + padding=1, pad_type='reflect', norm=norm, + activation=activation, activation_first=True) + self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1, + padding=1, pad_type='reflect', norm=norm, + activation=activation, activation_first=True) + if self.learned_shortcut: + self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1, + activation='none', use_bias=False) + + def forward(self, x): + x_s = self.conv_s(x) if self.learned_shortcut else x + dx = self.conv_0(x) + dx = self.conv_1(dx) + out = x_s + dx + return out + + +class LinearBlock(nn.Module): + def __init__(self, in_dim, out_dim, norm='none', activation='relu'): + super(LinearBlock, self).__init__() + use_bias = True + self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) + + # initialize normalization + norm_dim = out_dim + if norm == 'bn': + self.norm = nn.BatchNorm1d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm1d(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=False) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=False) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + def forward(self, x): + out = self.fc(x) + if self.norm: + out = self.norm(out) + if self.activation: + out = self.activation(out) + return out + + +class Conv2dBlock(nn.Module): + def __init__(self, in_dim, out_dim, ks, st, padding=0, + norm='none', activation='relu', pad_type='zero', + use_bias=True, activation_first=False): + super(Conv2dBlock, self).__init__() + self.use_bias = use_bias + self.activation_first = activation_first + # initialize padding + if pad_type == 'reflect': + self.pad = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + self.pad = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + self.pad = nn.ZeroPad2d(padding) + else: + assert 0, "Unsupported padding type: {}".format(pad_type) + + # initialize normalization + norm_dim = out_dim + if norm == 'bn': + self.norm = nn.BatchNorm2d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm2d(norm_dim) + elif norm == 'adain': + self.norm = AdaptiveInstanceNorm2d(norm_dim) + elif norm == 'none': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=False) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=False) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias) + + def forward(self, x): + if self.activation_first: + if self.activation: + x = self.activation(x) + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + else: + x = self.conv(self.pad(x)) + if self.norm: + x = self.norm(x) + if self.activation: + x = self.activation(x) + return x + + +class AdaptiveInstanceNorm2d(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(AdaptiveInstanceNorm2d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = None + self.bias = None + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, x): + assert self.weight is not None and \ + self.bias is not None, "Please assign AdaIN weight first" + b, c = x.size(0), x.size(1) + running_mean = self.running_mean.repeat(b) + running_var = self.running_var.repeat(b) + x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) + out = F.batch_norm( + x_reshaped, running_mean, running_var, self.weight, self.bias, + True, self.momentum, self.eps) + return out.view(b, c, *x.size()[2:]) + + def __repr__(self): + return self.__class__.__name__ + '(' + str(self.num_features) + ')' diff --git a/util/models/inception.py b/util/models/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..01ce338a85ffd41d1f393fdbf93c4da1c0d6de41 --- /dev/null +++ b/util/models/inception.py @@ -0,0 +1,363 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models +import numpy as np + +from itertools import cycle +from scipy import linalg + + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = models.inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) \ No newline at end of file diff --git a/util/models/model.py b/util/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..342a3f164361acedb090fe2df648124f12e83847 --- /dev/null +++ b/util/models/model.py @@ -0,0 +1,1389 @@ +import torch +import pandas as pd +from .OCR_network import * +from torch.nn import CTCLoss, MSELoss, L1Loss +from torch.nn.utils import clip_grad_norm_ +import random +import unicodedata +import sys +import torchvision.models as models +from models.transformer import * +from .BigGAN_networks import * +from params import * +from .OCR_network import * +from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock +from util.util import toggle_grad, loss_hinge_dis, loss_hinge_gen, ortho, default_ortho, toggle_grad, prepare_z_y, \ + make_one_hot, to_device, multiple_replace, random_word +from models.inception import InceptionV3, calculate_frechet_distance +import cv2 + +class FCNDecoder(nn.Module): + def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'): + super(FCNDecoder, self).__init__() + + self.model = [] + self.model += [ResBlocks(n_res, dim, res_norm, + activ, pad_type=pad_type)] + for i in range(ups): + self.model += [nn.Upsample(scale_factor=2), + Conv2dBlock(dim, dim // 2, 5, 1, 2, + norm='in', + activation=activ, + pad_type=pad_type)] + dim //= 2 + self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3, + norm='none', + activation='tanh', + pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + y = self.model(x) + + return y + + + +class Generator(nn.Module): + + def __init__(self): + super(Generator, self).__init__() + + INP_CHANNEL = NUM_EXAMPLES + if IS_SEQ: INP_CHANNEL = 1 + + + encoder_layer = TransformerEncoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD, + TN_DROPOUT, "relu", True) + encoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) if True else None + self.encoder = TransformerEncoder(encoder_layer, TN_ENC_LAYERS, encoder_norm) + + decoder_layer = TransformerDecoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD, + TN_DROPOUT, "relu", True) + decoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) + self.decoder = TransformerDecoder(decoder_layer, TN_DEC_LAYERS, decoder_norm, + return_intermediate=True) + + self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2])) + + self.query_embed = nn.Embedding(VOCAB_SIZE, TN_HIDDEN_DIM) + + + self.linear_q = nn.Linear(TN_DIM_FEEDFORWARD*2, TN_DIM_FEEDFORWARD*8) + + self.DEC = FCNDecoder(res_norm = 'in') + + + self._muE = nn.Linear(512,512) + self._logvarE = nn.Linear(512,512) + + self._muD = nn.Linear(512,512) + self._logvarD = nn.Linear(512,512) + + + self.l1loss = nn.L1Loss() + + self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0])) + + + + + + + def reparameterize(self, mu, logvar): + + mu = torch.unbind(mu , 1) + logvar = torch.unbind(logvar , 1) + + outs = [] + + for m,l in zip(mu, logvar): + + sigma = torch.exp(l) + eps = torch.cuda.FloatTensor(l.size()[0],1).normal_(0,1) + eps = eps.expand(sigma.size()) + + out = m + sigma*eps + + outs.append(out) + + + return torch.stack(outs, 1) + + + def Eval(self, ST, QRS): + + if IS_SEQ: + B, N, R, C = ST.shape + FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C)) + FEAT_ST = FEAT_ST.view(B, 512, 1, -1) + else: + FEAT_ST = self.Feat_Encoder(ST) + + + FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1) + + memory = self.encoder(FEAT_ST_ENC) + + if IS_KLD: + + Ex = memory.permute(1,0,2) + + memory_mu = self._muE(Ex) + memory_logvar = self._logvarE(Ex) + + memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2) + + + OUT_IMGS = [] + + for i in range(QRS.shape[1]): + + QR = QRS[:, i, :] + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + + h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1) + if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE) + + h = self.linear_q(h) + h = h.contiguous() + + h = h.view(h.size(0), h.shape[1]*2, 4, -1) + h = h.permute(0, 3, 2, 1) + + h = self.DEC(h) + + + OUT_IMGS.append(h.detach()) + + + + return OUT_IMGS + + + + + + + def forward(self, ST, QR, QRs = None, mode = 'train'): + + #Attention Visualization Init + + + enc_attn_weights, dec_attn_weights = [], [] + + self.hooks = [ + + self.encoder.layers[-1].self_attn.register_forward_hook( + lambda self, input, output: enc_attn_weights.append(output[1]) + ), + self.decoder.layers[-1].multihead_attn.register_forward_hook( + lambda self, input, output: dec_attn_weights.append(output[1]) + ), + ] + + + #Attention Visualization Init + + if IS_SEQ: + B, N, R, C = ST.shape + FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C)) + FEAT_ST = FEAT_ST.view(B, 512, 1, -1) + else: + FEAT_ST = self.Feat_Encoder(ST) + + + FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1) + + memory = self.encoder(FEAT_ST_ENC) + + if IS_KLD: + + Ex = memory.permute(1,0,2) + + memory_mu = self._muE(Ex) + memory_logvar = self._logvarE(Ex) + + memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2) + + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats1_mu = [hs_mu] + OUT_Feats1_logvar = [hs_logvar] + + + OUT_Feats1 = [hs] + + + h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1) + + if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE) + + h = self.linear_q(h) + h = h.contiguous() + + h = h.view(h.size(0), h.shape[1]*2, 4, -1) + h = h.permute(0, 3, 2, 1) + + h = self.DEC(h) + + self.dec_attn_weights = dec_attn_weights[-1].detach() + self.enc_attn_weights = enc_attn_weights[-1].detach() + + + + for hook in self.hooks: + hook.remove() + + if mode == 'test' or (not IS_CYCLE and not IS_KLD): + + return h + + + OUT_IMGS = [h] + + for QR in QRs: + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory, query_pos=QR_EMB) + + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats1_mu.append(hs_mu) + OUT_Feats1_logvar.append(hs_logvar) + + + OUT_Feats1.append(hs) + + + h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1) + if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE) + + h = self.linear_q(h) + h = h.contiguous() + + h = h.view(h.size(0), h.shape[1]*2, 4, -1) + h = h.permute(0, 3, 2, 1) + + h = self.DEC(h) + + OUT_IMGS.append(h) + + + if (not IS_CYCLE) and IS_KLD: + + OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0] + + OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1); + + + KLD = (0.5 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \ + + (0.5 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp())) + + + + def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar): + return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \ + torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum() + + + lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])] + + + lda1 = torch.stack(lda1).mean() + + + + return OUT_IMGS[0], lda1, KLD + + + with torch.no_grad(): + + if IS_SEQ: + + FEAT_ST_T = torch.cat([self.Feat_Encoder(IM) for IM in OUT_IMGS], -1) + + else: + + max_width_ = max([i_.shape[-1] for i_ in OUT_IMGS]) + + FEAT_ST_T = self.Feat_Encoder(torch.cat([torch.cat([i_, torch.ones((i_.shape[0], i_.shape[1],i_.shape[2], max_width_-i_.shape[3])).to(DEVICE)], -1) for i_ in OUT_IMGS], 1)) + + FEAT_ST_ENC_T = FEAT_ST_T.flatten(2).permute(2,0,1) + + memory_T = self.encoder(FEAT_ST_ENC_T) + + if IS_KLD: + + Ex = memory_T.permute(1,0,2) + + memory_T_mu = self._muE(Ex) + memory_T_logvar = self._logvarE(Ex) + + memory_T = self.reparameterize(memory_T_mu, memory_T_logvar).permute(1,0,2) + + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory_T, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats2_mu = [hs_mu] + OUT_Feats2_logvar = [hs_logvar] + + + OUT_Feats2 = [hs] + + + + for QR in QRs: + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory_T, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats2_mu.append(hs_mu) + OUT_Feats2_logvar.append(hs_logvar) + + + OUT_Feats2.append(hs) + + + + + Lcycle1 = np.sum([self.l1loss(memory[m_i], memory_T[m_j]) for m_i in range(memory.shape[0]) for m_j in range(memory_T.shape[0])])/(memory.shape[0]*memory_T.shape[0]) + OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]; OUT_Feats2 = torch.cat(OUT_Feats2, 1)[0] + + Lcycle2 = np.sum([self.l1loss(OUT_Feats1[f_i], OUT_Feats2[f_j]) for f_i in range(OUT_Feats1.shape[0]) for f_j in range(OUT_Feats2.shape[0])])/(OUT_Feats1.shape[0]*OUT_Feats2.shape[0]) + + if IS_KLD: + + OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1); + OUT_Feats2_mu = torch.cat(OUT_Feats2_mu, 1); OUT_Feats2_logvar = torch.cat(OUT_Feats2_logvar, 1); + + KLD = (0.25 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \ + + (0.25 * torch.mean(1 + memory_T_logvar - memory_T_mu.pow(2) - memory_T_logvar.exp()))\ + + (0.25 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))\ + + (0.25 * torch.mean(1 + OUT_Feats2_logvar - OUT_Feats2_mu.pow(2) - OUT_Feats2_logvar.exp())) + + + def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar): + return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \ + torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum() + + + lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])] + lda2 = [_get_lda(memory_T_mu[:,idi,:], OUT_Feats2_mu[:,idj,:], memory_T_logvar[:,idi,:], OUT_Feats2_logvar[:,idj,:]) for idi in range(memory_T.shape[0]) for idj in range(OUT_Feats2.shape[0])] + + lda1 = torch.stack(lda1).mean() + lda2 = torch.stack(lda2).mean() + + + return OUT_IMGS[0], Lcycle1, Lcycle2, lda1, lda2, KLD + + + return OUT_IMGS[0], Lcycle1, Lcycle2 + + + +class TRGAN(nn.Module): + + def __init__(self): + super(TRGAN, self).__init__() + + + self.epsilon = 1e-7 + self.netG = Generator().to(DEVICE) + self.netD = nn.DataParallel(Discriminator()).to(DEVICE) + self.netW = nn.DataParallel(WDiscriminator()).to(DEVICE) + self.netconverter = strLabelConverter(ALPHABET) + self.netOCR = CRNN().to(DEVICE) + self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none') + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] + self.inception = InceptionV3([block_idx]).to(DEVICE) + + + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + lr=G_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) + self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(), + lr=OCR_LR, betas=(0.0, 0.999), weight_decay=0, + eps=1e-8) + + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + lr=D_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) + + + self.optimizer_wl = torch.optim.Adam(self.netW.parameters(), + lr=W_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) + + + self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl] + + + self.optimizer_G.zero_grad() + self.optimizer_OCR.zero_grad() + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + + self.loss_G = 0 + self.loss_D = 0 + self.loss_Dfake = 0 + self.loss_Dreal = 0 + self.loss_OCR_fake = 0 + self.loss_OCR_real = 0 + self.loss_w_fake = 0 + self.loss_w_real = 0 + self.Lcycle1 = 0 + self.Lcycle2 = 0 + self.lda1 = 0 + self.lda2 = 0 + self.KLD = 0 + + + with open('../Lexicon/english_words.txt', 'rb') as f: + self.lex = f.read().splitlines() + lex=[] + for word in self.lex: + try: + word=word.decode("utf-8") + except: + continue + if len(word)<20: + lex.append(word) + self.lex = lex + + + f = open('mytext.txt', 'r') + + self.text = [j.encode() for j in sum([i.split(' ') for i in f.readlines()], [])][:NUM_EXAMPLES] + self.eval_text_encode, self.eval_len_text = self.netconverter.encode(self.text) + self.eval_text_encode = self.eval_text_encode.to(DEVICE).repeat(batch_size, 1, 1) + + + def _generate_page(self): + + self.fakes = self.netG.Eval(self.sdata, self.eval_text_encode) + + word_t = [] + word_l = [] + + gap = np.ones([32,16]) + + line_wids = [] + + + for idx, fake_ in enumerate(self.fakes): + + word_t.append((fake_[0,0,:,:self.eval_len_text[idx]*resolution].cpu().numpy()+1)/2) + + word_t.append(gap) + + if len(word_t) == 16 or idx == len(self.fakes) - 1: + + line_ = np.concatenate(word_t, -1) + + word_l.append(line_) + line_wids.append(line_.shape[1]) + + word_t = [] + + + gap_h = np.ones([16,max(line_wids)]) + + page_= [] + + for l in word_l: + + pad_ = np.ones([32,max(line_wids) - l.shape[1]]) + + page_.append(np.concatenate([l, pad_], 1)) + page_.append(gap_h) + + + + page1 = np.concatenate(page_, 0) + + + word_t = [] + word_l = [] + + gap = np.ones([32,16]) + + line_wids = [] + + sdata_ = [i.unsqueeze(1) for i in torch.unbind(self.sdata, 1)] + + for idx, st in enumerate((sdata_)): + + word_t.append((st[0,0,:,:int(self.input['swids'].cpu().numpy()[0][idx]) +].cpu().numpy()+1)/2) + + word_t.append(gap) + + if len(word_t) == 16 or idx == len(self.fakes) - 1: + + line_ = np.concatenate(word_t, -1) + + word_l.append(line_) + line_wids.append(line_.shape[1]) + + word_t = [] + + + gap_h = np.ones([16,max(line_wids)]) + + page_= [] + + for l in word_l: + + pad_ = np.ones([32,max(line_wids) - l.shape[1]]) + + page_.append(np.concatenate([l, pad_], 1)) + page_.append(gap_h) + + + + page2 = np.concatenate(page_, 0) + + merge_w_size = max(page1.shape[0], page2.shape[0]) + + if page1.shape[0] != merge_w_size: + + page1 = np.concatenate([page1, np.ones([merge_w_size-page1.shape[0], page1.shape[1]])], 0) + + if page2.shape[0] != merge_w_size: + + page2 = np.concatenate([page2, np.ones([merge_w_size-page2.shape[0], page2.shape[1]])], 0) + + + page = np.concatenate([page2, page1], 1) + + + return page + + + + + + + + + + + + + + + + + + + + #FEAT1 = self.inception(torch.cat(self.fakes, 0).repeat(1,3,1,1))[0].detach().view(batch_size, len(self.fakes), -1).cpu().numpy() + #FEAT2 = self.inception(self.sdata.view(batch_size*NUM_EXAMPLES, 1, 32, -1).repeat(1,3,1,1))[0].detach().view(batch_size, NUM_EXAMPLES, -1 ).cpu().numpy() + #muvars1 = [{'mu':np.mean(FEAT1[i], axis=0), 'sigma' : np.cov(FEAT1[i], rowvar=False)} for i in range(FEAT1.shape[0])] + #muvars2 = [{'mu':np.mean(FEAT2[i], axis=0), 'sigma' : np.cov(FEAT2[i], rowvar=False)} for i in range(FEAT2.shape[0])] + + + + + + + def get_current_losses(self): + + losses = {} + + losses['G'] = self.loss_G + losses['D'] = self.loss_D + losses['Dfake'] = self.loss_Dfake + losses['Dreal'] = self.loss_Dreal + losses['OCR_fake'] = self.loss_OCR_fake + losses['OCR_real'] = self.loss_OCR_real + losses['w_fake'] = self.loss_w_fake + losses['w_real'] = self.loss_w_real + losses['cycle1'] = self.Lcycle1 + losses['cycle2'] = self.Lcycle2 + losses['lda1'] = self.lda1 + losses['lda2'] = self.lda2 + losses['KLD'] = self.KLD + + return losses + + def visualize_images(self): + + imgs = {} + + + imgs['fake-1']=self.netG(self.sdata[0:1], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach() + imgs['fake-2']=self.netG(self.sdata[0:1], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-3']=self.netG(self.sdata[0:1], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach() + + + imgs['res-1'] = torch.cat([self.sdata[0, 0],self.sdata[0, 1],self.sdata[0, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1) + + + imgs['fake-1']=self.netG(self.sdata[1:2], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach() + imgs['fake-2']=self.netG(self.sdata[1:2], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-3']=self.netG(self.sdata[1:2], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach() + + + imgs['res-2'] = torch.cat([self.sdata[1, 0],self.sdata[1, 1],self.sdata[1, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1) + + + imgs['fake-1']=self.netG(self.sdata[2:3], self.text_encode_fake[0].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-2']=self.netG(self.sdata[2:3], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-3']=self.netG(self.sdata[2:3], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach() + + + imgs['res-3'] = torch.cat([self.sdata[2, 0],self.sdata[2, 1],self.sdata[2, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1) + + + + + return imgs + + + def load_networks(self, epoch): + BaseModel.load_networks(self, epoch) + if self.opt.single_writer: + load_filename = '%s_z.pkl' % (epoch) + load_path = os.path.join(self.save_dir, load_filename) + self.z = torch.load(load_path) + + def _set_input(self, input): + self.input = input + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def forward(self): + + + self.real = self.input['img'].to(DEVICE) + self.label = self.input['label'] + self.sdata = self.input['simg'].to(DEVICE) + self.ST_LEN = self.input['swids'] + self.text_encode, self.len_text = self.netconverter.encode(self.label) + self.one_hot_real = make_one_hot(self.text_encode, self.len_text, VOCAB_SIZE).to(DEVICE).detach() + self.text_encode = self.text_encode.to(DEVICE).detach() + self.len_text = self.len_text.detach() + + self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)] + self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words) + self.text_encode_fake = self.text_encode_fake.to(DEVICE) + self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, VOCAB_SIZE).to(DEVICE) + + + self.text_encode_fake_js = [] + + for _ in range(NUM_WORDS - 1): + + self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)] + self.text_encode_fake_j, self.len_text_fake_j = self.netconverter.encode(self.words_j) + self.text_encode_fake_j = self.text_encode_fake_j.to(DEVICE) + self.text_encode_fake_js.append(self.text_encode_fake_j) + + + if IS_CYCLE and IS_KLD: + + self.fake, self.Lcycle1, self.Lcycle2, self.lda1, self.lda2, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + elif IS_CYCLE and (not IS_KLD): + + self.fake, self.Lcycle1, self.Lcycle2 = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + elif (not IS_CYCLE) and IS_KLD: + + self.fake, self.lda1, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + else: + + self.fake = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + + + def visualize_attention(self): + + def _norm_scores(arr): + return (arr - min(arr))/(max(arr) - min(arr)) + + simgs = self.sdata[0].detach().cpu().numpy() + fake = self.fake[0,0].detach().cpu().numpy() + slen = self.ST_LEN[0].detach().cpu().numpy() + selfatt = self.netG.enc_attn_weights[0].detach().cpu().numpy() + selfatt = np.stack([_norm_scores(i) for i in selfatt], 1) + fake_lab = self.words[0].decode() + + decatt = self.netG.dec_attn_weights[0].detach().cpu().numpy() + decatt = np.stack([_norm_scores(i) for i in decatt], 0) + + STdict = {} + FAKEdict = {} + count = 0 + + for sim_, sle_ in zip(simgs,slen): + + for pi in range(sim_.shape[1]//sim_.shape[0]): + + STdict[count] = {'patch':sim_[:, pi*32:(pi+1)*32], 'ischar': sle_>=pi*32, 'encoder_attention_score': selfatt[count], 'decoder_attention_score': decatt[:,count]} + count = count + 1 + + + for pi in range(fake.shape[1]//resolution): + + FAKEdict[pi] = {'patch': fake[:, pi*resolution:(pi+1)*resolution]} + + show_ims = [] + + for idx in range(len(fake_lab)): + + viz_pats = [] + viz_lin = [] + + for i in STdict.keys(): + + if STdict[i]['ischar']: + + viz_pats.append(cv2.addWeighted(STdict[i]['patch'], 0.5, np.ones_like(STdict[i]['patch'])*STdict[i]['decoder_attention_score'][idx], 0.5, 0)) + + if len(viz_pats) >= 20: + + viz_lin.append(np.concatenate(viz_pats, -1)) + + viz_pats = [] + + + + + src = np.concatenate(viz_lin[:-2], 0)*255 + + viz_gts = [] + + for i in range(len(fake_lab)): + + + + #if i == idx: + + #bordersize = 5 + + #FAKEdict[i]['patch'] = cv2.addWeighted(FAKEdict[i]['patch'] , 0.5, np.ones_like(FAKEdict[i]['patch'] ), 0.5, 0) + + + + + + + img = np.zeros((54,16)) + font = cv2.FONT_HERSHEY_SIMPLEX + text = fake_lab[i] + + # get boundary of this text + textsize = cv2.getTextSize(text, font, 1, 2)[0] + + # get coords based on boundary + textX = (img.shape[1] - textsize[0]) // 2 + textY = (img.shape[0] + textsize[1]) // 2 + + # add text centered on image + cv2.putText(img, text, (textX, textY ), font, 1, (255, 255, 255), 2) + + img = (255 - img)/255 + + if i == idx: + + img = (1 - img) + + viz_gts.append(img) + + + + tgt = np.concatenate([fake[:,:len(fake_lab)*16],np.concatenate(viz_gts, -1)], 0) + pad_ = np.ones((tgt.shape[0], (src.shape[1]-tgt.shape[1])//2)) + tgt = np.concatenate([pad_, tgt, pad_], -1)*255 + final = np.concatenate([src, tgt], 0) + + + show_ims.append(final) + + return show_ims + + + def backward_D_OCR(self): + + pred_real = self.netD(self.real.detach()) + + pred_fake = self.netD(**{'x': self.fake.detach()}) + + + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True) + + self.loss_D = self.loss_Dreal + self.loss_Dfake + + self.pred_real_OCR = self.netOCR(self.real.detach()) + preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * batch_size).detach() + loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)]) + + loss_total = self.loss_D + self.loss_OCR_real + + # backward + loss_total.backward() + for param in self.netOCR.parameters(): + param.grad[param.grad!=param.grad]=0 + param.grad[torch.isnan(param.grad)]=0 + param.grad[torch.isinf(param.grad)]=0 + + + + return loss_total + + def backward_D_WL(self): + # Real + pred_real = self.netD(self.real.detach()) + + pred_fake = self.netD(**{'x': self.fake.detach()}) + + + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True) + + self.loss_D = self.loss_Dreal + self.loss_Dfake + + + self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(DEVICE)).mean() + # total loss + loss_total = self.loss_D + self.loss_w_real + + # backward + loss_total.backward() + + + return loss_total + + def optimize_D_WL(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], True) + + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + + self.backward_D_WL() + + + + + def backward_D_OCR_WL(self): + # Real + if self.real_z_mean is None: + pred_real = self.netD(self.real.detach()) + else: + pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()}) + # Fake + try: + pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()}) + except: + print('a') + # Combined loss + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) + + self.loss_D = self.loss_Dreal + self.loss_Dfake + # OCR loss on real data + self.pred_real_OCR = self.netOCR(self.real.detach()) + preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach() + loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)]) + # total loss + self.loss_w_real = self.netW(self.real.detach(), self.wcl) + loss_total = self.loss_D + self.loss_OCR_real + self.loss_w_real + + # backward + loss_total.backward() + for param in self.netOCR.parameters(): + param.grad[param.grad!=param.grad]=0 + param.grad[torch.isnan(param.grad)]=0 + param.grad[torch.isinf(param.grad)]=0 + + + + return loss_total + + def optimize_D_WL_step(self): + self.optimizer_D.step() + self.optimizer_wl.step() + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + + def backward_OCR(self): + # OCR loss on real data + self.pred_real_OCR = self.netOCR(self.real.detach()) + preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach() + loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)]) + + # backward + self.loss_OCR_real.backward() + for param in self.netOCR.parameters(): + param.grad[param.grad!=param.grad]=0 + param.grad[torch.isnan(param.grad)]=0 + param.grad[torch.isinf(param.grad)]=0 + + return self.loss_OCR_real + + + def backward_D(self): + # Real + if self.real_z_mean is None: + pred_real = self.netD(self.real.detach()) + else: + pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()}) + pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()}) + # Combined loss + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) + self.loss_D = self.loss_Dreal + self.loss_Dfake + # backward + self.loss_D.backward() + + + return self.loss_D + + + def backward_G_only(self): + + self.gb_alpha = 0.7 + #self.Lcycle1 = self.Lcycle1.mean() + #self.Lcycle2 = self.Lcycle2.mean() + self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean() + + + pred_fake_OCR = self.netOCR(self.fake) + preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * batch_size).detach() + loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach()) + self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) + + self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD + + self.loss_T = self.loss_G + self.loss_OCR_fake + + + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] + + + self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0] + self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) + + + self.loss_T.backward(retain_graph=True) + + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] + + + a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR)) + + + if a is None: + print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) + if a>1000 or a<0.0001: + print(a) + + + self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + + self.loss_T = self.loss_G + self.loss_OCR_fake + + + self.loss_T.backward(retain_graph=True) + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] + self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) + self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) + + with torch.no_grad(): + self.loss_T.backward() + + if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G): + print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) + sys.exit() + + def backward_G_WL(self): + + self.gb_alpha = 0.7 + #self.Lcycle1 = self.Lcycle1.mean() + #self.Lcycle2 = self.Lcycle2.mean() + + self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean() + + self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(DEVICE)).mean() + + self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD + + self.loss_T = self.loss_G + self.loss_w_fake + + + + + #grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, retain_graph=True)[0] + + + #self.loss_grad_fake_WL = 10**6*torch.mean(grad_fake_WL**2) + #grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0] + #self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) + + + + self.loss_T.backward(retain_graph=True) + + + grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] + + + a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_WL)) + + + + if a is None: + print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_WL)) + if a>1000 or a<0.0001: + print(a) + + self.loss_w_fake = a.detach() * self.loss_w_fake + + self.loss_T = self.loss_G + self.loss_w_fake + + self.loss_T.backward(retain_graph=True) + grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] + self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2) + self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) + + with torch.no_grad(): + self.loss_T.backward() + + def backward_G(self): + self.opt.gb_alpha = 0.7 + self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss) + # OCR loss on real data + + pred_fake_OCR = self.netOCR(self.fake) + preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach() + loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach()) + self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) + + + self.loss_w_fake = self.netW(self.fake, self.wcl) + #self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake + # total loss + + # l1 = self.params[0]*self.loss_G + # l2 = self.params[0]*self.loss_OCR_fake + #l3 = self.params[0]*self.loss_w_fake + self.loss_G_ = 10*self.loss_G + self.loss_w_fake + self.loss_T = self.loss_G_ + self.loss_OCR_fake + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] + + + self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) + grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0] + self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) + + if not False: + + self.loss_T.backward(retain_graph=True) + + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0] + #grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0] + + + a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR)) + + + #a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl)) + + if a is None: + print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) + if a>1000 or a<0.0001: + print(a) + b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) - + torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))* + torch.mean(grad_fake_OCR)) + # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake) + self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + #self.loss_w_fake = a0.detach() * self.loss_w_fake + + self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G_ + self.loss_OCR_fake# + self.loss_w_fake + self.loss_T.backward(retain_graph=True) + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0] + self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) + self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) + with torch.no_grad(): + self.loss_T.backward() + else: + self.loss_T.backward() + + if self.opt.clip_grad > 0: + clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad) + if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_): + print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) + sys.exit() + + + + def optimize_D_OCR(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.set_requires_grad([self.netOCR], True) + self.optimizer_D.zero_grad() + #if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.zero_grad() + self.backward_D_OCR() + + def optimize_OCR(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], True) + if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.zero_grad() + self.backward_OCR() + + def optimize_D(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.backward_D() + + def optimize_D_OCR_step(self): + self.optimizer_D.step() + + self.optimizer_OCR.step() + self.optimizer_D.zero_grad() + self.optimizer_OCR.zero_grad() + + + def optimize_D_OCR_WL(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.set_requires_grad([self.netOCR], True) + self.set_requires_grad([self.netW], True) + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.zero_grad() + self.backward_D_OCR_WL() + + def optimize_D_OCR_WL_step(self): + self.optimizer_D.step() + if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.step() + self.optimizer_wl.step() + self.optimizer_D.zero_grad() + self.optimizer_OCR.zero_grad() + self.optimizer_wl.zero_grad() + + def optimize_D_step(self): + self.optimizer_D.step() + if any(torch.isnan(self.netD.infer_img.blocks[0][0].conv1.bias)): + print('D is nan') + sys.exit() + self.optimizer_D.zero_grad() + + def optimize_G(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], False) + self.backward_G() + + def optimize_G_WL(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], False) + self.backward_G_WL() + + + def optimize_G_only(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], False) + self.backward_G_only() + + + def optimize_G_step(self): + + self.optimizer_G.step() + self.optimizer_G.zero_grad() + + def optimize_ocr(self): + self.set_requires_grad([self.netOCR], True) + # OCR loss on real data + pred_real_OCR = self.netOCR(self.real) + preds_size =torch.IntTensor([pred_real_OCR.size(0)] * self.opt.batch_size).detach() + self.loss_OCR_real = self.OCR_criterion(pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real.backward() + self.optimizer_OCR.step() + + def optimize_z(self): + self.set_requires_grad([self.z], True) + + + def optimize_parameters(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + self.set_requires_grad([self.netD], True) + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + def test(self): + self.visual_names = ['fake'] + self.netG.eval() + with torch.no_grad(): + self.forward() + + def train_GD(self): + self.netG.train() + self.netD.train() + self.optimizer_G.zero_grad() + self.optimizer_D.zero_grad() + # How many chunks to split x and y into? + x = torch.split(self.real, self.opt.batch_size) + y = torch.split(self.label, self.opt.batch_size) + counter = 0 + + # Optionally toggle D and G's "require_grad" + if self.opt.toggle_grads: + toggle_grad(self.netD, True) + toggle_grad(self.netG, False) + + for step_index in range(self.opt.num_critic_train): + self.optimizer_D.zero_grad() + with torch.set_grad_enabled(False): + self.forward() + D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake + D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter] + # Get Discriminator output + D_out = self.netD(D_input, D_class) + if x is not None: + pred_fake, pred_real = torch.split(D_out, [self.fake.shape[0], x[counter].shape[0]]) # D_fake, D_real + else: + pred_fake = D_out + # Combined loss + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) + self.loss_D = self.loss_Dreal + self.loss_Dfake + self.loss_D.backward() + counter += 1 + self.optimizer_D.step() + + # Optionally toggle D and G's "require_grad" + if self.opt.toggle_grads: + toggle_grad(self.netD, False) + toggle_grad(self.netG, True) + # Zero G's gradients by default before training G, for safety + self.optimizer_G.zero_grad() + self.forward() + self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss) + self.loss_G.backward() + self.optimizer_G.step() + + + + + + + + + + + + + + + + + diff --git a/util/models/model_.py b/util/models/model_.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1c24d7d0fdfca7d68651dfeba9ddd92ba92b90 --- /dev/null +++ b/util/models/model_.py @@ -0,0 +1,1264 @@ +import torch +import pandas as pd +from .OCR_network import * +from torch.nn import CTCLoss, MSELoss, L1Loss +from torch.nn.utils import clip_grad_norm_ +import random +import unicodedata +import sys +import torchvision.models as models +from models.transformer import * +from .BigGAN_networks import * +from params import * +from .OCR_network import * +from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock +from util.util import toggle_grad, loss_hinge_dis, loss_hinge_gen, ortho, default_ortho, toggle_grad, prepare_z_y, \ + make_one_hot, to_device, multiple_replace, random_word +from models.inception import InceptionV3, calculate_frechet_distance + +class FCNDecoder(nn.Module): + def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'): + super(FCNDecoder, self).__init__() + + self.model = [] + self.model += [ResBlocks(n_res, dim, res_norm, + activ, pad_type=pad_type)] + for i in range(ups): + self.model += [nn.Upsample(scale_factor=2), + Conv2dBlock(dim, dim // 2, 5, 1, 2, + norm='in', + activation=activ, + pad_type=pad_type)] + dim //= 2 + self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3, + norm='none', + activation='tanh', + pad_type=pad_type)] + self.model = nn.Sequential(*self.model) + + def forward(self, x): + y = self.model(x) + + return y + + + +class Generator(nn.Module): + + def __init__(self): + super(Generator, self).__init__() + + INP_CHANNEL = NUM_EXAMPLES + if IS_SEQ: INP_CHANNEL = 1 + + + encoder_layer = TransformerEncoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD, + TN_DROPOUT, "relu", True) + encoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) if True else None + self.encoder = TransformerEncoder(encoder_layer, TN_ENC_LAYERS, encoder_norm) + + decoder_layer = TransformerDecoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD, + TN_DROPOUT, "relu", True) + decoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) + self.decoder = TransformerDecoder(decoder_layer, TN_DEC_LAYERS, decoder_norm, + return_intermediate=True) + + self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2])) + + self.query_embed = nn.Embedding(VOCAB_SIZE, TN_HIDDEN_DIM) + + + self.linear_q = nn.Linear(TN_DIM_FEEDFORWARD*2, TN_DIM_FEEDFORWARD*8) + + self.DEC = FCNDecoder(res_norm = 'in') + + + self._muE = nn.Linear(512,512) + self._logvarE = nn.Linear(512,512) + + self._muD = nn.Linear(512,512) + self._logvarD = nn.Linear(512,512) + + + self.l1loss = nn.L1Loss() + + self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0])) + + + + def reparameterize(self, mu, logvar): + + mu = torch.unbind(mu , 1) + logvar = torch.unbind(logvar , 1) + + outs = [] + + for m,l in zip(mu, logvar): + + sigma = torch.exp(l) + eps = torch.cuda.FloatTensor(l.size()[0],1).normal_(0,1) + eps = eps.expand(sigma.size()) + + out = m + sigma*eps + + outs.append(out) + + + return torch.stack(outs, 1) + + + def Eval(self, ST, QRS): + + if IS_SEQ: + B, N, R, C = ST.shape + FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C)) + FEAT_ST = FEAT_ST.view(B, 512, 1, -1) + else: + FEAT_ST = self.Feat_Encoder(ST) + + + FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1) + + memory = self.encoder(FEAT_ST_ENC) + + if IS_KLD: + + Ex = memory.permute(1,0,2) + + memory_mu = self._muE(Ex) + memory_logvar = self._logvarE(Ex) + + memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2) + + + OUT_IMGS = [] + + for i in range(QRS.shape[1]): + + QR = QRS[:, i, :] + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + + h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1) + if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE) + + h = self.linear_q(h) + h = h.contiguous() + + h = h.view(h.size(0), h.shape[1]*2, 4, -1) + h = h.permute(0, 3, 2, 1) + + h = self.DEC(h) + + + OUT_IMGS.append(h.detach()) + + + + return OUT_IMGS + + + + + + + def forward(self, ST, QR, QRs = None, mode = 'train'): + + if IS_SEQ: + B, N, R, C = ST.shape + FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C)) + FEAT_ST = FEAT_ST.view(B, 512, 1, -1) + else: + FEAT_ST = self.Feat_Encoder(ST) + + + FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1) + + memory = self.encoder(FEAT_ST_ENC) + + if IS_KLD: + + Ex = memory.permute(1,0,2) + + memory_mu = self._muE(Ex) + memory_logvar = self._logvarE(Ex) + + memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2) + + + QR_EMB = self.query_embed.weight.repeat(batch_size,1,1).permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory, query_pos=QR_EMB) + + + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats1_mu = [hs_mu] + OUT_Feats1_logvar = [hs_logvar] + + + OUT_Feats1 = [hs] + + + h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1) + + if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE) + + h = self.linear_q(h) + + h = h.contiguous() + + h = [torch.stack([h[i][QR[i]] for i in range(batch_size)], 0) for QR in QRs] + + h_list = [] + + for h_ in h: + + h_ = h_.view(h_.size(0), h_.shape[1]*2, 4, -1) + h_ = h_.permute(0, 3, 2, 1) + + #h_ = self.DEC(h_) + + h_list.append(h_) + + if mode == 'test' or (not IS_CYCLE and not IS_KLD): + + return h + + + OUT_IMGS = [h] + + for QR in QRs: + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory, query_pos=QR_EMB) + + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats1_mu.append(hs_mu) + OUT_Feats1_logvar.append(hs_logvar) + + + OUT_Feats1.append(hs) + + + h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1) + if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE) + + h = self.linear_q(h) + h = h.contiguous() + + h = h.view(h.size(0), h.shape[1]*2, 4, -1) + h = h.permute(0, 3, 2, 1) + + h = self.DEC(h) + + OUT_IMGS.append(h) + + + if (not IS_CYCLE) and IS_KLD: + + OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0] + + OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1); + + + KLD = (0.5 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \ + + (0.5 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp())) + + + + def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar): + return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \ + torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum() + + + lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])] + + + lda1 = torch.stack(lda1).mean() + + + + return OUT_IMGS[0], lda1, KLD + + + with torch.no_grad(): + + if IS_SEQ: + + FEAT_ST_T = torch.cat([self.Feat_Encoder(IM) for IM in OUT_IMGS], -1) + + else: + + max_width_ = max([i_.shape[-1] for i_ in OUT_IMGS]) + + FEAT_ST_T = self.Feat_Encoder(torch.cat([torch.cat([i_, torch.ones((i_.shape[0], i_.shape[1],i_.shape[2], max_width_-i_.shape[3])).to(DEVICE)], -1) for i_ in OUT_IMGS], 1)) + + FEAT_ST_ENC_T = FEAT_ST_T.flatten(2).permute(2,0,1) + + memory_T = self.encoder(FEAT_ST_ENC_T) + + if IS_KLD: + + Ex = memory_T.permute(1,0,2) + + memory_T_mu = self._muE(Ex) + memory_T_logvar = self._logvarE(Ex) + + memory_T = self.reparameterize(memory_T_mu, memory_T_logvar).permute(1,0,2) + + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory_T, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats2_mu = [hs_mu] + OUT_Feats2_logvar = [hs_logvar] + + + OUT_Feats2 = [hs] + + + + for QR in QRs: + + QR_EMB = self.query_embed.weight[QR].permute(1,0,2) + + tgt = torch.zeros_like(QR_EMB) + + hs = self.decoder(tgt, memory_T, query_pos=QR_EMB) + + if IS_KLD: + + Dx = hs[0].permute(1,0,2) + + hs_mu = self._muD(Dx) + hs_logvar = self._logvarD(Dx) + + hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0) + + OUT_Feats2_mu.append(hs_mu) + OUT_Feats2_logvar.append(hs_logvar) + + + OUT_Feats2.append(hs) + + + + + Lcycle1 = np.sum([self.l1loss(memory[m_i], memory_T[m_j]) for m_i in range(memory.shape[0]) for m_j in range(memory_T.shape[0])])/(memory.shape[0]*memory_T.shape[0]) + OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]; OUT_Feats2 = torch.cat(OUT_Feats2, 1)[0] + + Lcycle2 = np.sum([self.l1loss(OUT_Feats1[f_i], OUT_Feats2[f_j]) for f_i in range(OUT_Feats1.shape[0]) for f_j in range(OUT_Feats2.shape[0])])/(OUT_Feats1.shape[0]*OUT_Feats2.shape[0]) + + if IS_KLD: + + OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1); + OUT_Feats2_mu = torch.cat(OUT_Feats2_mu, 1); OUT_Feats2_logvar = torch.cat(OUT_Feats2_logvar, 1); + + KLD = (0.25 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \ + + (0.25 * torch.mean(1 + memory_T_logvar - memory_T_mu.pow(2) - memory_T_logvar.exp()))\ + + (0.25 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))\ + + (0.25 * torch.mean(1 + OUT_Feats2_logvar - OUT_Feats2_mu.pow(2) - OUT_Feats2_logvar.exp())) + + + def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar): + return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \ + torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum() + + + lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])] + lda2 = [_get_lda(memory_T_mu[:,idi,:], OUT_Feats2_mu[:,idj,:], memory_T_logvar[:,idi,:], OUT_Feats2_logvar[:,idj,:]) for idi in range(memory_T.shape[0]) for idj in range(OUT_Feats2.shape[0])] + + lda1 = torch.stack(lda1).mean() + lda2 = torch.stack(lda2).mean() + + + return OUT_IMGS[0], Lcycle1, Lcycle2, lda1, lda2, KLD + + + return OUT_IMGS[0], Lcycle1, Lcycle2 + + + +class TRGAN(nn.Module): + + def __init__(self): + super(TRGAN, self).__init__() + + + self.epsilon = 1e-7 + self.netG = Generator().to(DEVICE) + self.netD = nn.DataParallel(Discriminator()).to(DEVICE) + self.netW = nn.DataParallel(WDiscriminator()).to(DEVICE) + self.netconverter = strLabelConverter(ALPHABET) + self.netOCR = CRNN().to(DEVICE) + self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none') + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] + self.inception = InceptionV3([block_idx]).to(DEVICE) + + + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + lr=G_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) + self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(), + lr=OCR_LR, betas=(0.0, 0.999), weight_decay=0, + eps=1e-8) + + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + lr=D_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) + + + self.optimizer_wl = torch.optim.Adam(self.netW.parameters(), + lr=W_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8) + + + self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl] + + + self.optimizer_G.zero_grad() + self.optimizer_OCR.zero_grad() + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + + self.loss_G = 0 + self.loss_D = 0 + self.loss_Dfake = 0 + self.loss_Dreal = 0 + self.loss_OCR_fake = 0 + self.loss_OCR_real = 0 + self.loss_w_fake = 0 + self.loss_w_real = 0 + self.Lcycle1 = 0 + self.Lcycle2 = 0 + self.lda1 = 0 + self.lda2 = 0 + self.KLD = 0 + + + with open('../Lexicon/english_words.txt', 'rb') as f: + self.lex = f.read().splitlines() + lex=[] + for word in self.lex: + try: + word=word.decode("utf-8") + except: + continue + if len(word)<20: + lex.append(word) + self.lex = lex + + + f = open('mytext.txt', 'r') + + self.text = [j.encode() for j in sum([i.split(' ') for i in f.readlines()], [])][:NUM_EXAMPLES] + self.eval_text_encode, self.eval_len_text = self.netconverter.encode(self.text) + self.eval_text_encode = self.eval_text_encode.to(DEVICE).repeat(batch_size, 1, 1) + + + def _generate_page(self): + + self.fakes = self.netG.Eval(self.sdata, self.eval_text_encode) + + word_t = [] + word_l = [] + + gap = np.ones([32,16]) + + line_wids = [] + + + for idx, fake_ in enumerate(self.fakes): + + word_t.append((fake_[0,0,:,:self.eval_len_text[idx]*resolution].cpu().numpy()+1)/2) + + word_t.append(gap) + + if len(word_t) == 16 or idx == len(self.fakes) - 1: + + line_ = np.concatenate(word_t, -1) + + word_l.append(line_) + line_wids.append(line_.shape[1]) + + word_t = [] + + + gap_h = np.ones([16,max(line_wids)]) + + page_= [] + + for l in word_l: + + pad_ = np.ones([32,max(line_wids) - l.shape[1]]) + + page_.append(np.concatenate([l, pad_], 1)) + page_.append(gap_h) + + + + page1 = np.concatenate(page_, 0) + + + word_t = [] + word_l = [] + + gap = np.ones([32,16]) + + line_wids = [] + + sdata_ = [i.unsqueeze(1) for i in torch.unbind(self.sdata, 1)] + + for idx, st in enumerate((sdata_)): + + word_t.append((st[0,0,:,:int(self.input['swids'].cpu().numpy()[0][idx]) +].cpu().numpy()+1)/2) + + word_t.append(gap) + + if len(word_t) == 16 or idx == len(self.fakes) - 1: + + line_ = np.concatenate(word_t, -1) + + word_l.append(line_) + line_wids.append(line_.shape[1]) + + word_t = [] + + + gap_h = np.ones([16,max(line_wids)]) + + page_= [] + + for l in word_l: + + pad_ = np.ones([32,max(line_wids) - l.shape[1]]) + + page_.append(np.concatenate([l, pad_], 1)) + page_.append(gap_h) + + + + page2 = np.concatenate(page_, 0) + + merge_w_size = max(page1.shape[0], page2.shape[0]) + + if page1.shape[0] != merge_w_size: + + page1 = np.concatenate([page1, np.ones([merge_w_size-page1.shape[0], page1.shape[1]])], 0) + + if page2.shape[0] != merge_w_size: + + page2 = np.concatenate([page2, np.ones([merge_w_size-page2.shape[0], page2.shape[1]])], 0) + + + page = np.concatenate([page2, page1], 1) + + + return page + + + + + + + + + + + + + + + + + + + + #FEAT1 = self.inception(torch.cat(self.fakes, 0).repeat(1,3,1,1))[0].detach().view(batch_size, len(self.fakes), -1).cpu().numpy() + #FEAT2 = self.inception(self.sdata.view(batch_size*NUM_EXAMPLES, 1, 32, -1).repeat(1,3,1,1))[0].detach().view(batch_size, NUM_EXAMPLES, -1 ).cpu().numpy() + #muvars1 = [{'mu':np.mean(FEAT1[i], axis=0), 'sigma' : np.cov(FEAT1[i], rowvar=False)} for i in range(FEAT1.shape[0])] + #muvars2 = [{'mu':np.mean(FEAT2[i], axis=0), 'sigma' : np.cov(FEAT2[i], rowvar=False)} for i in range(FEAT2.shape[0])] + + + + + + + def get_current_losses(self): + + losses = {} + + losses['G'] = self.loss_G + losses['D'] = self.loss_D + losses['Dfake'] = self.loss_Dfake + losses['Dreal'] = self.loss_Dreal + losses['OCR_fake'] = self.loss_OCR_fake + losses['OCR_real'] = self.loss_OCR_real + losses['w_fake'] = self.loss_w_fake + losses['w_real'] = self.loss_w_real + losses['cycle1'] = self.Lcycle1 + losses['cycle2'] = self.Lcycle2 + losses['lda1'] = self.lda1 + losses['lda2'] = self.lda2 + losses['KLD'] = self.KLD + + return losses + + def visualize_images(self): + + imgs = {} + + + imgs['fake-1']=self.netG(self.sdata[0:1], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach() + imgs['fake-2']=self.netG(self.sdata[0:1], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-3']=self.netG(self.sdata[0:1], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach() + + + imgs['res-1'] = torch.cat([self.sdata[0, 0],self.sdata[0, 1],self.sdata[0, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1) + + + imgs['fake-1']=self.netG(self.sdata[1:2], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach() + imgs['fake-2']=self.netG(self.sdata[1:2], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-3']=self.netG(self.sdata[1:2], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach() + + + imgs['res-2'] = torch.cat([self.sdata[1, 0],self.sdata[1, 1],self.sdata[1, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1) + + + imgs['fake-1']=self.netG(self.sdata[2:3], self.text_encode_fake[0].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-2']=self.netG(self.sdata[2:3], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach() + imgs['fake-3']=self.netG(self.sdata[2:3], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach() + + + imgs['res-3'] = torch.cat([self.sdata[2, 0],self.sdata[2, 1],self.sdata[2, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1) + + + + + return imgs + + + def load_networks(self, epoch): + BaseModel.load_networks(self, epoch) + if self.opt.single_writer: + load_filename = '%s_z.pkl' % (epoch) + load_path = os.path.join(self.save_dir, load_filename) + self.z = torch.load(load_path) + + def _set_input(self, input): + self.input = input + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def forward(self): + + + self.real = self.input['img'].to(DEVICE) + self.label = self.input['label'] + self.sdata = self.input['simg'].to(DEVICE) + self.ST_LEN = self.input['swids'] + self.text_encode, self.len_text = self.netconverter.encode(self.label) + self.one_hot_real = make_one_hot(self.text_encode, self.len_text, VOCAB_SIZE).to(DEVICE).detach() + self.text_encode = self.text_encode.to(DEVICE).detach() + self.len_text = self.len_text.detach() + + self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)] + self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words) + self.text_encode_fake = self.text_encode_fake.to(DEVICE) + self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, VOCAB_SIZE).to(DEVICE) + + + self.text_encode_fake_js = [] + + for _ in range(NUM_WORDS - 1): + + self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)] + self.text_encode_fake_j, self.len_text_fake_j = self.netconverter.encode(self.words_j) + self.text_encode_fake_j = self.text_encode_fake_j.to(DEVICE) + self.text_encode_fake_js.append(self.text_encode_fake_j) + + + if IS_CYCLE and IS_KLD: + + self.fake, self.Lcycle1, self.Lcycle2, self.lda1, self.lda2, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + elif IS_CYCLE and (not IS_KLD): + + self.fake, self.Lcycle1, self.Lcycle2 = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + elif (not IS_CYCLE) and IS_KLD: + + self.fake, self.lda1, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + else: + + self.fake = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js) + + + + + def backward_D_OCR(self): + + pred_real = self.netD(self.real.detach()) + + pred_fake = self.netD(**{'x': self.fake.detach()}) + + + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True) + + self.loss_D = self.loss_Dreal + self.loss_Dfake + + self.pred_real_OCR = self.netOCR(self.real.detach()) + preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * batch_size).detach() + loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)]) + + loss_total = self.loss_D + self.loss_OCR_real + + # backward + loss_total.backward() + for param in self.netOCR.parameters(): + param.grad[param.grad!=param.grad]=0 + param.grad[torch.isnan(param.grad)]=0 + param.grad[torch.isinf(param.grad)]=0 + + + + return loss_total + + def backward_D_WL(self): + # Real + pred_real = self.netD(self.real.detach()) + + pred_fake = self.netD(**{'x': self.fake.detach()}) + + + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True) + + self.loss_D = self.loss_Dreal + self.loss_Dfake + + + self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(DEVICE)).mean() + # total loss + loss_total = self.loss_D + self.loss_w_real + + # backward + loss_total.backward() + + + return loss_total + + def optimize_D_WL(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], True) + + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + + self.backward_D_WL() + + + + + def backward_D_OCR_WL(self): + # Real + if self.real_z_mean is None: + pred_real = self.netD(self.real.detach()) + else: + pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()}) + # Fake + try: + pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()}) + except: + print('a') + # Combined loss + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) + + self.loss_D = self.loss_Dreal + self.loss_Dfake + # OCR loss on real data + self.pred_real_OCR = self.netOCR(self.real.detach()) + preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach() + loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)]) + # total loss + self.loss_w_real = self.netW(self.real.detach(), self.wcl) + loss_total = self.loss_D + self.loss_OCR_real + self.loss_w_real + + # backward + loss_total.backward() + for param in self.netOCR.parameters(): + param.grad[param.grad!=param.grad]=0 + param.grad[torch.isnan(param.grad)]=0 + param.grad[torch.isinf(param.grad)]=0 + + + + return loss_total + + def optimize_D_WL_step(self): + self.optimizer_D.step() + self.optimizer_wl.step() + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + + def backward_OCR(self): + # OCR loss on real data + self.pred_real_OCR = self.netOCR(self.real.detach()) + preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach() + loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)]) + + # backward + self.loss_OCR_real.backward() + for param in self.netOCR.parameters(): + param.grad[param.grad!=param.grad]=0 + param.grad[torch.isnan(param.grad)]=0 + param.grad[torch.isinf(param.grad)]=0 + + return self.loss_OCR_real + + + def backward_D(self): + # Real + if self.real_z_mean is None: + pred_real = self.netD(self.real.detach()) + else: + pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()}) + pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()}) + # Combined loss + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) + self.loss_D = self.loss_Dreal + self.loss_Dfake + # backward + self.loss_D.backward() + + + return self.loss_D + + + def backward_G_only(self): + + self.gb_alpha = 0.7 + #self.Lcycle1 = self.Lcycle1.mean() + #self.Lcycle2 = self.Lcycle2.mean() + self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean() + + + pred_fake_OCR = self.netOCR(self.fake) + preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * batch_size).detach() + loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach()) + self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) + + self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD + + self.loss_T = self.loss_G + self.loss_OCR_fake + + + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] + + + self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0] + self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) + + + self.loss_T.backward(retain_graph=True) + + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] + + + a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR)) + + + if a is None: + print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) + if a>1000 or a<0.0001: + print(a) + + + self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + + self.loss_T = self.loss_G + self.loss_OCR_fake + + + self.loss_T.backward(retain_graph=True) + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] + self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) + self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) + + with torch.no_grad(): + self.loss_T.backward() + + if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G): + print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) + sys.exit() + + def backward_G_WL(self): + + self.gb_alpha = 0.7 + #self.Lcycle1 = self.Lcycle1.mean() + #self.Lcycle2 = self.Lcycle2.mean() + + self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean() + + self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(DEVICE)).mean() + + self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD + + self.loss_T = self.loss_G + self.loss_w_fake + + + + + #grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, retain_graph=True)[0] + + + #self.loss_grad_fake_WL = 10**6*torch.mean(grad_fake_WL**2) + #grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0] + #self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) + + + + self.loss_T.backward(retain_graph=True) + + + grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0] + + + a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_WL)) + + + + if a is None: + print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_WL)) + if a>1000 or a<0.0001: + print(a) + + self.loss_w_fake = a.detach() * self.loss_w_fake + + self.loss_T = self.loss_G + self.loss_w_fake + + self.loss_T.backward(retain_graph=True) + grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0] + self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2) + self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) + + with torch.no_grad(): + self.loss_T.backward() + + def backward_G(self): + self.opt.gb_alpha = 0.7 + self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss) + # OCR loss on real data + + pred_fake_OCR = self.netOCR(self.fake) + preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach() + loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach()) + self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)]) + + + self.loss_w_fake = self.netW(self.fake, self.wcl) + #self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake + # total loss + + # l1 = self.params[0]*self.loss_G + # l2 = self.params[0]*self.loss_OCR_fake + #l3 = self.params[0]*self.loss_w_fake + self.loss_G_ = 10*self.loss_G + self.loss_w_fake + self.loss_T = self.loss_G_ + self.loss_OCR_fake + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0] + + + self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2) + grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0] + self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2) + + if not False: + + self.loss_T.backward(retain_graph=True) + + + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0] + #grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0] + + + a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR)) + + + #a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl)) + + if a is None: + print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR)) + if a>1000 or a<0.0001: + print(a) + b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) - + torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))* + torch.mean(grad_fake_OCR)) + # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake) + self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + #self.loss_w_fake = a0.detach() * self.loss_w_fake + + self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G_ + self.loss_OCR_fake# + self.loss_w_fake + self.loss_T.backward(retain_graph=True) + grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0] + grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0] + self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2) + self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2) + with torch.no_grad(): + self.loss_T.backward() + else: + self.loss_T.backward() + + if self.opt.clip_grad > 0: + clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad) + if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_): + print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words) + sys.exit() + + + + def optimize_D_OCR(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.set_requires_grad([self.netOCR], True) + self.optimizer_D.zero_grad() + #if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.zero_grad() + self.backward_D_OCR() + + def optimize_OCR(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], True) + if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.zero_grad() + self.backward_OCR() + + def optimize_D(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.backward_D() + + def optimize_D_OCR_step(self): + self.optimizer_D.step() + + self.optimizer_OCR.step() + self.optimizer_D.zero_grad() + self.optimizer_OCR.zero_grad() + + + def optimize_D_OCR_WL(self): + self.forward() + self.set_requires_grad([self.netD], True) + self.set_requires_grad([self.netOCR], True) + self.set_requires_grad([self.netW], True) + self.optimizer_D.zero_grad() + self.optimizer_wl.zero_grad() + if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.zero_grad() + self.backward_D_OCR_WL() + + def optimize_D_OCR_WL_step(self): + self.optimizer_D.step() + if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']: + self.optimizer_OCR.step() + self.optimizer_wl.step() + self.optimizer_D.zero_grad() + self.optimizer_OCR.zero_grad() + self.optimizer_wl.zero_grad() + + def optimize_D_step(self): + self.optimizer_D.step() + if any(torch.isnan(self.netD.infer_img.blocks[0][0].conv1.bias)): + print('D is nan') + sys.exit() + self.optimizer_D.zero_grad() + + def optimize_G(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], False) + self.backward_G() + + def optimize_G_WL(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], False) + self.backward_G_WL() + + + def optimize_G_only(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.set_requires_grad([self.netOCR], False) + self.set_requires_grad([self.netW], False) + self.backward_G_only() + + + def optimize_G_step(self): + + self.optimizer_G.step() + self.optimizer_G.zero_grad() + + def optimize_ocr(self): + self.set_requires_grad([self.netOCR], True) + # OCR loss on real data + pred_real_OCR = self.netOCR(self.real) + preds_size =torch.IntTensor([pred_real_OCR.size(0)] * self.opt.batch_size).detach() + self.loss_OCR_real = self.OCR_criterion(pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach()) + self.loss_OCR_real.backward() + self.optimizer_OCR.step() + + def optimize_z(self): + self.set_requires_grad([self.z], True) + + + def optimize_parameters(self): + self.forward() + self.set_requires_grad([self.netD], False) + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + self.set_requires_grad([self.netD], True) + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + def test(self): + self.visual_names = ['fake'] + self.netG.eval() + with torch.no_grad(): + self.forward() + + def train_GD(self): + self.netG.train() + self.netD.train() + self.optimizer_G.zero_grad() + self.optimizer_D.zero_grad() + # How many chunks to split x and y into? + x = torch.split(self.real, self.opt.batch_size) + y = torch.split(self.label, self.opt.batch_size) + counter = 0 + + # Optionally toggle D and G's "require_grad" + if self.opt.toggle_grads: + toggle_grad(self.netD, True) + toggle_grad(self.netG, False) + + for step_index in range(self.opt.num_critic_train): + self.optimizer_D.zero_grad() + with torch.set_grad_enabled(False): + self.forward() + D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake + D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter] + # Get Discriminator output + D_out = self.netD(D_input, D_class) + if x is not None: + pred_fake, pred_real = torch.split(D_out, [self.fake.shape[0], x[counter].shape[0]]) # D_fake, D_real + else: + pred_fake = D_out + # Combined loss + self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss) + self.loss_D = self.loss_Dreal + self.loss_Dfake + self.loss_D.backward() + counter += 1 + self.optimizer_D.step() + + # Optionally toggle D and G's "require_grad" + if self.opt.toggle_grads: + toggle_grad(self.netD, False) + toggle_grad(self.netG, True) + # Zero G's gradients by default before training G, for safety + self.optimizer_G.zero_grad() + self.forward() + self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss) + self.loss_G.backward() + self.optimizer_G.step() + + + + + + + + + + + + + + + + + diff --git a/util/models/networks.py b/util/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..6f885c7a18958a4c61f278bc104fc0d34643dbbc --- /dev/null +++ b/util/models/networks.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler +from util.util import to_device, load_network + +############################################################################### +# Helper Functions +############################################################################### + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if (isinstance(m, nn.Conv2d) + or isinstance(m, nn.Linear) + or isinstance(m, nn.Embedding)): + # if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'N02': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type in ['glorot', 'xavier']: + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'ortho': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + # if hasattr(m, 'bias') and m.bias is not None: + # init.constant_(m.bias.data, 0.0) + # elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + # init.normal_(m.weight.data, 1.0, init_gain) + # init.constant_(m.bias.data, 0.0) + if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']: + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + else: + print('loading the model from %s' % init_type) + net = load_network(net, init_type, 'latest') + return net + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + diff --git a/util/models/sync_batchnorm/__init__.py b/util/models/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/util/models/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc b/util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..380d492cba466ac62c90691dad2b92f36bacc80f Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc b/util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8461cd52affd6473e92e41a2e0d2b1a4fd90e2d Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc b/util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3630ab462bf0ca66a35ab94b9df7f64a80a6c4d6 Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc b/util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..142044ed1c3f15b6839e526be075ba872e9f5c26 Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc b/util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d68a3ba4417197e4528b428f544adba35e109a1b Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc b/util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..569bd74b395e4b5cbad0caa1c723c4f0bc9cdd83 Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc b/util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..300aec39e8a6ea3f395f7db2c447e71008a3362a Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc differ diff --git a/util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc b/util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95c56f2c02fbba8427cbc74ac372c6db69be94a9 Binary files /dev/null and b/util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc differ diff --git a/util/models/sync_batchnorm/batchnorm.py b/util/models/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..5453729bf6f4beb7c6412117656130360ed06164 --- /dev/null +++ b/util/models/sync_batchnorm/batchnorm.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) +# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input, gain=None, bias=None): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + out = F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + if gain is not None: + out = out + gain + if bias is not None: + out = out + bias + return out + + # Resize the input to (B, C, -1). + input_shape = input.size() + # print(input_shape) + input = input.view(input.size(0), input.size(1), -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + # Reduce-and-broadcast the statistics. + # print('it begins') + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + # if self._parallel_id == 0: + # # print('here') + # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + # else: + # # print('there') + # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # print('how2') + # num = sum_size + # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) + # Fix the graph + # sum = (sum.detach() - input_sum.detach()) + input_sum + # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum + + # mean = sum / num + # var = ssum / num - mean ** 2 + # # var = (ssum - mean * sum) / num + # inv_std = torch.rsqrt(var + self.eps) + + # Compute the output. + if gain is not None: + # print('gaining') + # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) + # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) + # output = input * scale - shift + output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) + elif self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + # print('a') + # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) + # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) + # print('b') + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + return mean, torch.rsqrt(bias_var + self.eps) + # return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) \ No newline at end of file diff --git a/util/models/sync_batchnorm/batchnorm_reimpl.py b/util/models/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..7afcdaff9c56d7ac9c487f2dbe61fe6cb9c353a0 --- /dev/null +++ b/util/models/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNormReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/util/models/sync_batchnorm/comm.py b/util/models/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/util/models/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/util/models/sync_batchnorm/replicate.py b/util/models/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/util/models/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/util/models/sync_batchnorm/unittest.py b/util/models/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..bed56f1caa929ac3e9a57c583f8d3e42624f58be --- /dev/null +++ b/util/models/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/util/models/transformer.py b/util/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d887d54b2960f6ef63c4abba75c104f2c937daa --- /dev/null +++ b/util/models/transformer.py @@ -0,0 +1,296 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, query_embed, y_ind): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + + y_emb = query_embed[y_ind].permute(1,0,2) + + tgt = torch.zeros_like(y_emb) + memory = self.encoder(src) + hs = self.decoder(tgt, memory, query_pos=y_emb) + + return torch.cat([hs.transpose(1, 2)[-1], y_emb.permute(1,0,2)], -1) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/util/params.py b/util/params.py new file mode 100644 index 0000000000000000000000000000000000000000..4a66930cfd32fbaa842bcd1e7b0a3999f958766a --- /dev/null +++ b/util/params.py @@ -0,0 +1,44 @@ +import torch + +############################################### + +EXP_NAME = "IAM-1000"; RESUME = True + +############################################### + +IMG_HEIGHT = 32 +resolution = 16 +batch_size = 8 +NUM_EXAMPLES = 50 +TN_HIDDEN_DIM = 512 +TN_DROPOUT = 0.1 +TN_NHEADS = 8 +TN_DIM_FEEDFORWARD = 512 +TN_ENC_LAYERS = 1 +TN_DEC_LAYERS = 1 +ALPHABET = 'Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%' +VOCAB_SIZE = len(ALPHABET) +G_LR = 0.0002 +D_LR = 0.0002 +W_LR = 0.0002 +OCR_LR = 0.0002 +EPOCHS = 100000 +NUM_CRITIC_GOCR_TRAIN = 2 +NUM_CRITIC_DOCR_TRAIN = 1 +NUM_CRITIC_GWL_TRAIN = 2 +NUM_CRITIC_DWL_TRAIN = 1 +NUM_FID_FREQ = 100 +DATASET = ['IAM'] +DATASET_PATHS = {'IAM':'../IAM_32.pickle'} +NUM_WRITERS = 1000 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +IS_SEQ = True +NUM_WORDS = 3 +if not IS_SEQ: NUM_WORDS = NUM_EXAMPLES +IS_CC = False +IS_KLD = False +ADD_NOISE = False +ALL_CHARS = False +SAVE_MODEL = 5 +SAVE_MODEL_HISTORY = 100 + diff --git a/util/process.py b/util/process.py new file mode 100644 index 0000000000000000000000000000000000000000..4a87f13fb68bd37432805e0df9acf6fdcbbee7d0 --- /dev/null +++ b/util/process.py @@ -0,0 +1,743 @@ +import argparse +from collections import namedtuple +import numpy as np +import torch +import cv2,os +import torch +import torch.nn.functional as F +from collections import defaultdict +from sklearn.cluster import DBSCAN + +""" +taken from https://github.com/githubharald/WordDetectorNN +Download the models from https://www.dropbox.com/s/mqhco2q67ovpfjq/model.zip?dl=1 and pass the path to word_segment(.) as argument. +""" + +from typing import Type, Any, Callable, Union, List, Optional + +import torch.nn as nn +from torch import Tensor + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + out1 = self.relu(x) + x = self.maxpool(out1) + + out2 = self.layer1(x) + out3 = self.layer2(out2) + out4 = self.layer3(out3) + out5 = self.layer4(out4) + + return out5, out4, out3, out2, out1 + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + +def compute_iou(ra, rb): + """intersection over union of two axis aligned rectangles ra and rb""" + if ra.xmax < rb.xmin or rb.xmax < ra.xmin or ra.ymax < rb.ymin or rb.ymax < ra.ymin: + return 0 + + l = max(ra.xmin, rb.xmin) + r = min(ra.xmax, rb.xmax) + t = max(ra.ymin, rb.ymin) + b = min(ra.ymax, rb.ymax) + + intersection = (r - l) * (b - t) + union = ra.area() + rb.area() - intersection + + iou = intersection / union + return iou + +def compute_dist_mat(aabbs): + """Jaccard distance matrix of all pairs of aabbs""" + num_aabbs = len(aabbs) + + dists = np.zeros((num_aabbs, num_aabbs)) + for i in range(num_aabbs): + for j in range(num_aabbs): + if j > i: + break + + dists[i, j] = dists[j, i] = 1 - compute_iou(aabbs[i], aabbs[j]) + + return dists + + +def cluster_aabbs(aabbs): + """cluster aabbs using DBSCAN and the Jaccard distance between bounding boxes""" + if len(aabbs) < 2: + return aabbs + + dists = compute_dist_mat(aabbs) + clustering = DBSCAN(eps=0.7, min_samples=3, metric='precomputed').fit(dists) + + clusters = defaultdict(list) + for i, c in enumerate(clustering.labels_): + if c == -1: + continue + clusters[c].append(aabbs[i]) + + res_aabbs = [] + for curr_cluster in clusters.values(): + xmin = np.median([aabb.xmin for aabb in curr_cluster]) + xmax = np.median([aabb.xmax for aabb in curr_cluster]) + ymin = np.median([aabb.ymin for aabb in curr_cluster]) + ymax = np.median([aabb.ymax for aabb in curr_cluster]) + res_aabbs.append(AABB(xmin, xmax, ymin, ymax)) + + return res_aabbs + + +class AABB: + """axis aligned bounding box""" + + def __init__(self, xmin, xmax, ymin, ymax): + self.xmin = xmin + self.xmax = xmax + self.ymin = ymin + self.ymax = ymax + + def scale(self, fx, fy): + new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) + new.xmin = fx * new.xmin + new.xmax = fx * new.xmax + new.ymin = fy * new.ymin + new.ymax = fy * new.ymax + return new + + def scale_around_center(self, fx, fy): + cx = (self.xmin + self.xmax) / 2 + cy = (self.ymin + self.ymax) / 2 + + new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) + new.xmin = cx - fx * (cx - self.xmin) + new.xmax = cx + fx * (self.xmax - cx) + new.ymin = cy - fy * (cy - self.ymin) + new.ymax = cy + fy * (self.ymax - cy) + return new + + def translate(self, tx, ty): + new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) + new.xmin = new.xmin + tx + new.xmax = new.xmax + tx + new.ymin = new.ymin + ty + new.ymax = new.ymax + ty + return new + + def as_type(self, t): + new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) + new.xmin = t(new.xmin) + new.xmax = t(new.xmax) + new.ymin = t(new.ymin) + new.ymax = t(new.ymax) + return new + + def enlarge_to_int_grid(self): + new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) + new.xmin = np.floor(new.xmin) + new.xmax = np.ceil(new.xmax) + new.ymin = np.floor(new.ymin) + new.ymax = np.ceil(new.ymax) + return new + + def clip(self, clip_aabb): + new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) + new.xmin = min(max(new.xmin, clip_aabb.xmin), clip_aabb.xmax) + new.xmax = max(min(new.xmax, clip_aabb.xmax), clip_aabb.xmin) + new.ymin = min(max(new.ymin, clip_aabb.ymin), clip_aabb.ymax) + new.ymax = max(min(new.ymax, clip_aabb.ymax), clip_aabb.ymin) + return new + + def area(self): + return (self.xmax - self.xmin) * (self.ymax - self.ymin) + + def __str__(self): + return f'AABB(xmin={self.xmin},xmax={self.xmax},ymin={self.ymin},ymax={self.ymax})' + + def __repr__(self): + return str(self) + +class MapOrdering: + """order of the maps encoding the aabbs around the words""" + SEG_WORD = 0 + SEG_SURROUNDING = 1 + SEG_BACKGROUND = 2 + GEO_TOP = 3 + GEO_BOTTOM = 4 + GEO_LEFT = 5 + GEO_RIGHT = 6 + NUM_MAPS = 7 + + +def encode(shape, gt, f=1.0): + gt_map = np.zeros((MapOrdering.NUM_MAPS,) + shape) + for aabb in gt: + aabb = aabb.scale(f, f) + + # segmentation map + aabb_clip = AABB(0, shape[0] - 1, 0, shape[1] - 1) + + aabb_word = aabb.scale_around_center(0.5, 0.5).as_type(int).clip(aabb_clip) + aabb_sur = aabb.as_type(int).clip(aabb_clip) + gt_map[MapOrdering.SEG_SURROUNDING, aabb_sur.ymin:aabb_sur.ymax + 1, aabb_sur.xmin:aabb_sur.xmax + 1] = 1 + gt_map[MapOrdering.SEG_SURROUNDING, aabb_word.ymin:aabb_word.ymax + 1, aabb_word.xmin:aabb_word.xmax + 1] = 0 + gt_map[MapOrdering.SEG_WORD, aabb_word.ymin:aabb_word.ymax + 1, aabb_word.xmin:aabb_word.xmax + 1] = 1 + + # geometry map TODO vectorize + for x in range(aabb_word.xmin, aabb_word.xmax + 1): + for y in range(aabb_word.ymin, aabb_word.ymax + 1): + gt_map[MapOrdering.GEO_TOP, y, x] = y - aabb.ymin + gt_map[MapOrdering.GEO_BOTTOM, y, x] = aabb.ymax - y + gt_map[MapOrdering.GEO_LEFT, y, x] = x - aabb.xmin + gt_map[MapOrdering.GEO_RIGHT, y, x] = aabb.xmax - x + + gt_map[MapOrdering.SEG_BACKGROUND] = np.clip(1 - gt_map[MapOrdering.SEG_WORD] - gt_map[MapOrdering.SEG_SURROUNDING], + 0, 1) + + return gt_map + + +def subsample(idx, max_num): + """restrict fg indices to a maximum number""" + f = len(idx[0]) / max_num + if f > 1: + a = np.asarray([idx[0][int(j * f)] for j in range(max_num)], np.int64) + b = np.asarray([idx[1][int(j * f)] for j in range(max_num)], np.int64) + idx = (a, b) + return idx + + +def fg_by_threshold(thres, max_num=None): + """all pixels above threshold are fg pixels, optionally limited to a maximum number""" + + def func(seg_map): + idx = np.where(seg_map > thres) + if max_num is not None: + idx = subsample(idx, max_num) + return idx + + return func + + +def fg_by_cc(thres, max_num): + """take a maximum number of pixels per connected component, but at least 3 (->DBSCAN minPts)""" + + def func(seg_map): + seg_mask = (seg_map > thres).astype(np.uint8) + num_labels, label_img = cv2.connectedComponents(seg_mask, connectivity=4) + max_num_per_cc = max(max_num // (num_labels + 1), 3) # at least 3 because of DBSCAN clustering + + all_idx = [np.empty(0, np.int64), np.empty(0, np.int64)] + for curr_label in range(1, num_labels): + curr_idx = np.where(label_img == curr_label) + curr_idx = subsample(curr_idx, max_num_per_cc) + all_idx[0] = np.append(all_idx[0], curr_idx[0]) + all_idx[1] = np.append(all_idx[1], curr_idx[1]) + return tuple(all_idx) + + return func + + +def decode(pred_map, comp_fg=fg_by_threshold(0.5), f=1): + idx = comp_fg(pred_map[MapOrdering.SEG_WORD]) + pred_map_masked = pred_map[..., idx[0], idx[1]] + aabbs = [] + for yc, xc, pred in zip(idx[0], idx[1], pred_map_masked.T): + t = pred[MapOrdering.GEO_TOP] + b = pred[MapOrdering.GEO_BOTTOM] + l = pred[MapOrdering.GEO_LEFT] + r = pred[MapOrdering.GEO_RIGHT] + aabb = AABB(xc - l, xc + r, yc - t, yc + b) + aabbs.append(aabb.scale(f, f)) + return aabbs + + +def main(): + import matplotlib.pyplot as plt + aabbs_in = [AABB(10, 30, 30, 60)] + encoded = encode((50, 50), aabbs_in, f=0.5) + aabbs_out = decode(encoded, f=2) + print(aabbs_out[0]) + plt.subplot(151) + plt.imshow(encoded[MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1].transpose(1, 2, 0)) + + plt.subplot(152) + plt.imshow(encoded[MapOrdering.GEO_TOP]) + plt.subplot(153) + plt.imshow(encoded[MapOrdering.GEO_BOTTOM]) + plt.subplot(154) + plt.imshow(encoded[MapOrdering.GEO_LEFT]) + plt.subplot(155) + plt.imshow(encoded[MapOrdering.GEO_RIGHT]) + + plt.show() + + +def compute_scale_down(input_size, output_size): + """compute scale down factor of neural network, given input and output size""" + return output_size[0] / input_size[0] + + +def prob_true(p): + """return True with probability p""" + return np.random.random() < p + + +class UpscaleAndConcatLayer(torch.nn.Module): + """ + take small map with cx channels + upscale to size of large map (s*s) + concat large map with cy channels and upscaled small map + apply conv and output map with cz channels + """ + + def __init__(self, cx, cy, cz): + super(UpscaleAndConcatLayer, self).__init__() + self.conv = torch.nn.Conv2d(cx + cy, cz, 3, padding=1) + + def forward(self, x, y, s): + x = F.interpolate(x, s) + z = torch.cat((x, y), 1) + z = F.relu(self.conv(z)) + return z + + +class WordDetectorNet(torch.nn.Module): + # fixed sizes for training + input_size = (448, 448) + output_size = (224, 224) + scale_down = compute_scale_down(input_size, output_size) + + def __init__(self): + super(WordDetectorNet, self).__init__() + + self.backbone = resnet18() + + self.up1 = UpscaleAndConcatLayer(512, 256, 256) # input//16 + self.up2 = UpscaleAndConcatLayer(256, 128, 128) # input//8 + self.up3 = UpscaleAndConcatLayer(128, 64, 64) # input//4 + self.up4 = UpscaleAndConcatLayer(64, 64, 32) # input//2 + + self.conv1 = torch.nn.Conv2d(32, MapOrdering.NUM_MAPS, 3, 1, padding=1) + + @staticmethod + def scale_shape(s, f): + assert s[0] % f == 0 and s[1] % f == 0 + return s[0] // f, s[1] // f + + def output_activation(self, x, apply_softmax): + if apply_softmax: + seg = torch.softmax(x[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1], dim=1) + else: + seg = x[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1] + geo = torch.sigmoid(x[:, MapOrdering.GEO_TOP:]) * self.input_size[0] + y = torch.cat([seg, geo], dim=1) + return y + + def forward(self, x, apply_softmax=False): + # x: BxCxHxW + # eval backbone with 448px: bb1: 224px, bb2: 112px, bb3: 56px, bb4: 28px, bb5: 14px + s = x.shape[2:] + bb5, bb4, bb3, bb2, bb1 = self.backbone(x) + + x = self.up1(bb5, bb4, self.scale_shape(s, 16)) + x = self.up2(x, bb3, self.scale_shape(s, 8)) + x = self.up3(x, bb2, self.scale_shape(s, 4)) + x = self.up4(x, bb1, self.scale_shape(s, 2)) + x = self.conv1(x) + + return self.output_activation(x, apply_softmax) + + +def ceil32(val): + if val % 32 == 0: + return val + val = (val // 32 + 1) * 32 + return val + +def word_segment(path, output_folder, model_path): + + os.makedirs(output_folder, exist_ok = True) + + max_side_len = 5000 + thres = 0.5 + max_aabbs = 1000 + + orig = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + net = WordDetectorNet() + net.load_state_dict(torch.load(model_path, map_location='cuda')) + net.eval() + net.cuda() + + f = min(max_side_len / orig.shape[0], max_side_len / orig.shape[1]) + if f < 1: + orig = cv2.resize(orig, dsize=None, fx=f, fy=f) + img = np.ones((ceil32(orig.shape[0]), ceil32(orig.shape[1])), np.uint8) * 255 + img[:orig.shape[0], :orig.shape[1]] = orig + + img = (img / 255 - 0.5).astype(np.float32) + imgs = img[None, None, ...] + imgs = torch.from_numpy(imgs).cuda() + with torch.no_grad(): + y = net(imgs, apply_softmax=True) + y_np = y.to('cpu').numpy() + scale_up = 1 / compute_scale_down(WordDetectorNet.input_size, WordDetectorNet.output_size) + + img_np = imgs[0, 0].to('cpu').numpy() + pred_map = y_np[0] + + aabbs = decode(pred_map, comp_fg=fg_by_cc(thres, max_aabbs), f=scale_up) + h, w = img_np.shape + aabbs = [aabb.clip(AABB(0, w - 1, 0, h - 1)) for aabb in aabbs] # bounding box must be inside img + clustered_aabbs = cluster_aabbs(aabbs) + + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + + for idx,bb in enumerate(clustered_aabbs): + bb1 = bb + im_i = (img_np[int(bb1.ymin):int(bb1.ymax),int(bb1.xmin):int(bb1.xmax)]+0.5)*255 + cv2.imwrite(f'{output_folder}/im_{idx}.png',im_i) diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe48c4ee180990fefb38761ae87f04c1cf51614 --- /dev/null +++ b/util/util.py @@ -0,0 +1,311 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os +import torch.nn.functional as F +from torch.autograd import Variable + +def random_word(len_word, alphabet): + # generate a word constructed from len_word characters where each character is randomly chosen from the alphabet. + char = np.random.randint(low=0, high=len(alphabet), size=len_word) + word = [alphabet[c] for c in char] + return ''.join(word) + +def load_network(net, save_dir, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + load_filename = '%s_net_%s.pth' % (epoch, net.name) + load_path = os.path.join(save_dir, load_filename) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + net.load_state_dict(state_dict) + return net + +def writeCache(env, cache): + with env.begin(write=True) as txn: + for k, v in cache.items(): + if type(k) == str: + k = k.encode() + if type(v) == str: + v = v.encode() + txn.put(k, v) + +def loadData(v, data): + with torch.no_grad(): + v.resize_(data.size()).copy_(data) + +def multiple_replace(string, rep_dict): + for key in rep_dict.keys(): + string = string.replace(key, rep_dict[key]) + return string + +def get_curr_data(data, batch_size, counter): + curr_data = {} + for key in data: + curr_data[key] = data[key][batch_size*counter:batch_size*(counter+1)] + return curr_data + +# Utility file to seed rngs +def seed_rng(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + +# turn tensor of classes to tensor of one hot tensors: +def make_one_hot(labels, len_labels, n_classes): + one_hot = torch.zeros((labels.shape[0], labels.shape[1], n_classes),dtype=torch.float32) + for i in range(len(labels)): + one_hot[i,np.array(range(len_labels[i])), labels[i,:len_labels[i]]-1]=1 + return one_hot + +# Hinge Loss +def loss_hinge_dis(dis_fake, dis_real, len_text_fake, len_text, mask_loss): + mask_real = torch.ones(dis_real.shape).to(dis_real.device) + mask_fake = torch.ones(dis_fake.shape).to(dis_fake.device) + if mask_loss and len(dis_fake.shape)>2: + for i in range(len(len_text)): + mask_real[i, :, :, len_text[i]:] = 0 + mask_fake[i, :, :, len_text_fake[i]:] = 0 + loss_real = torch.sum(F.relu(1. - dis_real * mask_real))/torch.sum(mask_real) + loss_fake = torch.sum(F.relu(1. + dis_fake * mask_fake))/torch.sum(mask_fake) + return loss_real, loss_fake + + +def loss_hinge_gen(dis_fake, len_text_fake, mask_loss): + mask_fake = torch.ones(dis_fake.shape).to(dis_fake.device) + if mask_loss and len(dis_fake.shape)>2: + for i in range(len(len_text_fake)): + mask_fake[i, :, :, len_text_fake[i]:] = 0 + loss = -torch.sum(dis_fake*mask_fake)/torch.sum(mask_fake) + return loss + +def loss_std(z, lengths, mask_loss): + loss_std = torch.zeros(1).to(z.device) + z_mean = torch.ones((z.shape[0], z.shape[1])).to(z.device) + for i in range(len(lengths)): + if mask_loss: + if lengths[i]>1: + loss_std += torch.mean(torch.std(z[i, :, :, :lengths[i]], 2)) + z_mean[i,:] = torch.mean(z[i, :, :, :lengths[i]], 2).squeeze(1) + else: + z_mean[i, :] = z[i, :, :, 0].squeeze(1) + else: + loss_std += torch.mean(torch.std(z[i, :, :, :], 2)) + z_mean[i,:] = torch.mean(z[i, :, :, :], 2).squeeze(1) + loss_std = loss_std/z.shape[0] + return loss_std, z_mean + +# Convenience utility to switch off requires_grad +def toggle_grad(model, on_or_off): + for param in model.parameters(): + param.requires_grad = on_or_off + + +# Apply modified ortho reg to a model +# This function is an optimized version that directly computes the gradient, +# instead of computing and then differentiating the loss. +def ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes, and not in the blacklist + if len(param.shape) < 2 or any([param is item for item in blacklist]): + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + * (1. - torch.eye(w.shape[0], device=w.device)), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Default ortho reg +# This function is an optimized version that directly computes the gradient, +# instead of computing and then differentiating the loss. +def default_ortho(model, strength=1e-4, blacklist=[]): + with torch.no_grad(): + for param in model.parameters(): + # Only apply this to parameters with at least 2 axes & not in blacklist + if len(param.shape) < 2 or param in blacklist: + continue + w = param.view(param.shape[0], -1) + grad = (2 * torch.mm(torch.mm(w, w.t()) + - torch.eye(w.shape[0], device=w.device), w)) + param.grad.data += strength * grad.view(param.shape) + + +# Convenience utility to switch off requires_grad +def toggle_grad(model, on_or_off): + for param in model.parameters(): + param.requires_grad = on_or_off + + +# A highly simplified convenience class for sampling from distributions +# One could also use PyTorch's inbuilt distributions package. +# Note that this class requires initialization to proceed as +# x = Distribution(torch.randn(size)) +# x.init_distribution(dist_type, **dist_kwargs) +# x = x.to(device,dtype) +# This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 +class Distribution(torch.Tensor): + # Init the params of the distribution + def init_distribution(self, dist_type, **kwargs): + seed_rng(kwargs['seed']) + self.dist_type = dist_type + self.dist_kwargs = kwargs + if self.dist_type == 'normal': + self.mean, self.var = kwargs['mean'], kwargs['var'] + elif self.dist_type == 'categorical': + self.num_categories = kwargs['num_categories'] + elif self.dist_type == 'poisson': + self.lam = kwargs['var'] + elif self.dist_type == 'gamma': + self.scale = kwargs['var'] + + + def sample_(self): + if self.dist_type == 'normal': + self.normal_(self.mean, self.var) + elif self.dist_type == 'categorical': + self.random_(0, self.num_categories) + elif self.dist_type == 'poisson': + type = self.type() + device = self.device + data = np.random.poisson(self.lam, self.size()) + self.data = torch.from_numpy(data).type(type).to(device) + elif self.dist_type == 'gamma': + type = self.type() + device = self.device + data = np.random.gamma(shape=1, scale=self.scale, size=self.size()) + self.data = torch.from_numpy(data).type(type).to(device) + # return self.variable + + # Silly hack: overwrite the to() method to wrap the new object + # in a distribution as well + def to(self, *args, **kwargs): + new_obj = Distribution(self) + new_obj.init_distribution(self.dist_type, **self.dist_kwargs) + new_obj.data = super().to(*args, **kwargs) + return new_obj + + +def to_device(net, gpu_ids): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + if len(gpu_ids)>1: + net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda() + # net = torch.nn.DistributedDataParallel(net) + return net + + +# Convenience function to prepare a z and y vector +def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', + fp16=False, z_var=1.0, z_dist='normal', seed=0): + z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) + z_.init_distribution(z_dist, mean=0, var=z_var, seed=seed) + z_ = z_.to(device, torch.float16 if fp16 else torch.float32) + + if fp16: + z_ = z_.half() + + y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) + y_.init_distribution('categorical', num_categories=nclasses, seed=seed) + y_ = y_.to(device, torch.int64) + return z_, y_ + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/util/util.pyc b/util/util.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ca1bc0413332317605a213cceff67c6461fd2b2 Binary files /dev/null and b/util/util.pyc differ diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b45169c6ae2cd671e8d7c7395f7a02c21366f018 --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,227 @@ +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from scipy.misc import imresize + +if sys.version_info[0] == 2: + VisdomExceptionBase = Exception +else: + VisdomExceptionBase = ConnectionError + + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + h, w, _ = im.shape + if aspect_ratio > 1.0: + im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') + if aspect_ratio < 1.0: + im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') + util.save_image(im, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: connect to a visdom server + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.port = opt.display_port + self.saved = False + if self.display_id > 0: # connect to a visdom server given and + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) + if not self.vis.check_connection(): + self.create_visdom_connections() + + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + def create_visdom_connections(self): + """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ + cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port + print('\n\nCould not connect to Visdom server. \n Trying to start a server....') + print('Command: %s' % cmd) + Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) + + def display_current_results(self, visuals, epoch, save_result): + """Display current results on visdom; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + if self.display_id > 0: # show images in the browser using visdom + ncols = self.ncols + if ncols > 0: # show all the images in one visdom panel + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) # create a table css + # create a table of images. + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except VisdomExceptionBase: + self.create_visdom_connections() + + else: # show each image in a separate visdom panel; + idx = 1 + try: + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + except VisdomExceptionBase: + self.create_visdom_connections() + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, epoch, counter_ratio, losses): + """display the current losses on visdom display: dictionary of error labels and values + + Parameters: + epoch (int) -- current epoch + counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + """ + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except VisdomExceptionBase: + self.create_visdom_connections() + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message