ankankbhunia commited on
Commit
9acea67
1 Parent(s): 83cab2b

Upload 58 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. util/__init__.py +1 -0
  2. util/__init__.pyc +0 -0
  3. util/__pycache__/__init__.cpython-36.pyc +0 -0
  4. util/__pycache__/__init__.cpython-37.pyc +0 -0
  5. util/__pycache__/__init__.cpython-38.pyc +0 -0
  6. util/__pycache__/__init__.cpython-39.pyc +0 -0
  7. util/__pycache__/html.cpython-36.pyc +0 -0
  8. util/__pycache__/html.cpython-37.pyc +0 -0
  9. util/__pycache__/misc.cpython-36.pyc +0 -0
  10. util/__pycache__/misc.cpython-37.pyc +0 -0
  11. util/__pycache__/params.cpython-37.pyc +0 -0
  12. util/__pycache__/util.cpython-36.pyc +0 -0
  13. util/__pycache__/util.cpython-37.pyc +0 -0
  14. util/__pycache__/util.cpython-38.pyc +0 -0
  15. util/__pycache__/util.cpython-39.pyc +0 -0
  16. util/__pycache__/visualizer.cpython-36.pyc +0 -0
  17. util/__pycache__/visualizer.cpython-37.pyc +0 -0
  18. util/html.py +86 -0
  19. util/misc.py +465 -0
  20. util/models/BigGAN_layers.py +469 -0
  21. util/models/BigGAN_networks.py +841 -0
  22. util/models/OCR_network.py +304 -0
  23. util/models/__init__.py +65 -0
  24. util/models/__pycache__/BigGAN_layers.cpython-36.pyc +0 -0
  25. util/models/__pycache__/BigGAN_networks.cpython-36.pyc +0 -0
  26. util/models/__pycache__/OCR_network.cpython-36.pyc +0 -0
  27. util/models/__pycache__/__init__.cpython-36.pyc +0 -0
  28. util/models/__pycache__/blocks.cpython-36.pyc +0 -0
  29. util/models/__pycache__/inception.cpython-36.pyc +0 -0
  30. util/models/__pycache__/model.cpython-36.pyc +0 -0
  31. util/models/__pycache__/model_.cpython-36.pyc +0 -0
  32. util/models/__pycache__/networks.cpython-36.pyc +0 -0
  33. util/models/__pycache__/transformer.cpython-36.pyc +0 -0
  34. util/models/blocks.py +190 -0
  35. util/models/inception.py +363 -0
  36. util/models/model.py +1389 -0
  37. util/models/model_.py +1264 -0
  38. util/models/networks.py +98 -0
  39. util/models/sync_batchnorm/__init__.py +12 -0
  40. util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc +0 -0
  41. util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc +0 -0
  42. util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc +0 -0
  43. util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc +0 -0
  44. util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc +0 -0
  45. util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc +0 -0
  46. util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc +0 -0
  47. util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc +0 -0
  48. util/models/sync_batchnorm/batchnorm.py +349 -0
  49. util/models/sync_batchnorm/batchnorm_reimpl.py +74 -0
  50. util/models/sync_batchnorm/comm.py +137 -0
util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
util/__init__.pyc ADDED
Binary file (265 Bytes). View file
 
util/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (215 Bytes). View file
 
util/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (261 Bytes). View file
 
util/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (259 Bytes). View file
 
util/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (252 Bytes). View file
 
util/__pycache__/html.cpython-36.pyc ADDED
Binary file (3.56 kB). View file
 
util/__pycache__/html.cpython-37.pyc ADDED
Binary file (3.57 kB). View file
 
util/__pycache__/misc.cpython-36.pyc ADDED
Binary file (14.4 kB). View file
 
util/__pycache__/misc.cpython-37.pyc ADDED
Binary file (14.3 kB). View file
 
util/__pycache__/params.cpython-37.pyc ADDED
Binary file (1.15 kB). View file
 
util/__pycache__/util.cpython-36.pyc ADDED
Binary file (9.98 kB). View file
 
util/__pycache__/util.cpython-37.pyc ADDED
Binary file (10 kB). View file
 
util/__pycache__/util.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
util/__pycache__/util.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
util/__pycache__/visualizer.cpython-36.pyc ADDED
Binary file (8.55 kB). View file
 
util/__pycache__/visualizer.cpython-37.pyc ADDED
Binary file (8.55 kB). View file
 
util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
util/misc.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import time
10
+ from collections import defaultdict, deque
11
+ import datetime
12
+ import pickle
13
+ from typing import Optional, List
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from torch import Tensor
18
+
19
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
20
+ import torchvision
21
+
22
+
23
+
24
+ class SmoothedValue(object):
25
+ """Track a series of values and provide access to smoothed values over a
26
+ window or the global series average.
27
+ """
28
+
29
+ def __init__(self, window_size=20, fmt=None):
30
+ if fmt is None:
31
+ fmt = "{median:.4f} ({global_avg:.4f})"
32
+ self.deque = deque(maxlen=window_size)
33
+ self.total = 0.0
34
+ self.count = 0
35
+ self.fmt = fmt
36
+
37
+ def update(self, value, n=1):
38
+ self.deque.append(value)
39
+ self.count += n
40
+ self.total += value * n
41
+
42
+ def synchronize_between_processes(self):
43
+ """
44
+ Warning: does not synchronize the deque!
45
+ """
46
+ if not is_dist_avail_and_initialized():
47
+ return
48
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
49
+ dist.barrier()
50
+ dist.all_reduce(t)
51
+ t = t.tolist()
52
+ self.count = int(t[0])
53
+ self.total = t[1]
54
+
55
+ @property
56
+ def median(self):
57
+ d = torch.tensor(list(self.deque))
58
+ return d.median().item()
59
+
60
+ @property
61
+ def avg(self):
62
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
63
+ return d.mean().item()
64
+
65
+ @property
66
+ def global_avg(self):
67
+ return self.total / self.count
68
+
69
+ @property
70
+ def max(self):
71
+ return max(self.deque)
72
+
73
+ @property
74
+ def value(self):
75
+ return self.deque[-1]
76
+
77
+ def __str__(self):
78
+ return self.fmt.format(
79
+ median=self.median,
80
+ avg=self.avg,
81
+ global_avg=self.global_avg,
82
+ max=self.max,
83
+ value=self.value)
84
+
85
+
86
+ def all_gather(data):
87
+ """
88
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
89
+ Args:
90
+ data: any picklable object
91
+ Returns:
92
+ list[data]: list of data gathered from each rank
93
+ """
94
+ world_size = get_world_size()
95
+ if world_size == 1:
96
+ return [data]
97
+
98
+ # serialized to a Tensor
99
+ buffer = pickle.dumps(data)
100
+ storage = torch.ByteStorage.from_buffer(buffer)
101
+ tensor = torch.ByteTensor(storage).to("cuda")
102
+
103
+ # obtain Tensor size of each rank
104
+ local_size = torch.tensor([tensor.numel()], device="cuda")
105
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
106
+ dist.all_gather(size_list, local_size)
107
+ size_list = [int(size.item()) for size in size_list]
108
+ max_size = max(size_list)
109
+
110
+ # receiving Tensor from all ranks
111
+ # we pad the tensor because torch all_gather does not support
112
+ # gathering tensors of different shapes
113
+ tensor_list = []
114
+ for _ in size_list:
115
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
116
+ if local_size != max_size:
117
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
118
+ tensor = torch.cat((tensor, padding), dim=0)
119
+ dist.all_gather(tensor_list, tensor)
120
+
121
+ data_list = []
122
+ for size, tensor in zip(size_list, tensor_list):
123
+ buffer = tensor.cpu().numpy().tobytes()[:size]
124
+ data_list.append(pickle.loads(buffer))
125
+
126
+ return data_list
127
+
128
+
129
+ def reduce_dict(input_dict, average=True):
130
+ """
131
+ Args:
132
+ input_dict (dict): all the values will be reduced
133
+ average (bool): whether to do average or sum
134
+ Reduce the values in the dictionary from all processes so that all processes
135
+ have the averaged results. Returns a dict with the same fields as
136
+ input_dict, after reduction.
137
+ """
138
+ world_size = get_world_size()
139
+ if world_size < 2:
140
+ return input_dict
141
+ with torch.no_grad():
142
+ names = []
143
+ values = []
144
+ # sort the keys so that they are consistent across processes
145
+ for k in sorted(input_dict.keys()):
146
+ names.append(k)
147
+ values.append(input_dict[k])
148
+ values = torch.stack(values, dim=0)
149
+ dist.all_reduce(values)
150
+ if average:
151
+ values /= world_size
152
+ reduced_dict = {k: v for k, v in zip(names, values)}
153
+ return reduced_dict
154
+
155
+
156
+ class MetricLogger(object):
157
+ def __init__(self, delimiter="\t"):
158
+ self.meters = defaultdict(SmoothedValue)
159
+ self.delimiter = delimiter
160
+
161
+ def update(self, **kwargs):
162
+ for k, v in kwargs.items():
163
+ if isinstance(v, torch.Tensor):
164
+ v = v.item()
165
+ assert isinstance(v, (float, int))
166
+ self.meters[k].update(v)
167
+
168
+ def __getattr__(self, attr):
169
+ if attr in self.meters:
170
+ return self.meters[attr]
171
+ if attr in self.__dict__:
172
+ return self.__dict__[attr]
173
+ raise AttributeError("'{}' object has no attribute '{}'".format(
174
+ type(self).__name__, attr))
175
+
176
+ def __str__(self):
177
+ loss_str = []
178
+ for name, meter in self.meters.items():
179
+ loss_str.append(
180
+ "{}: {}".format(name, str(meter))
181
+ )
182
+ return self.delimiter.join(loss_str)
183
+
184
+ def synchronize_between_processes(self):
185
+ for meter in self.meters.values():
186
+ meter.synchronize_between_processes()
187
+
188
+ def add_meter(self, name, meter):
189
+ self.meters[name] = meter
190
+
191
+ def log_every(self, iterable, print_freq, header=None):
192
+ i = 0
193
+ if not header:
194
+ header = ''
195
+ start_time = time.time()
196
+ end = time.time()
197
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
198
+ data_time = SmoothedValue(fmt='{avg:.4f}')
199
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
200
+ if torch.cuda.is_available():
201
+ log_msg = self.delimiter.join([
202
+ header,
203
+ '[{0' + space_fmt + '}/{1}]',
204
+ 'eta: {eta}',
205
+ '{meters}',
206
+ 'time: {time}',
207
+ 'data: {data}',
208
+ 'max mem: {memory:.0f}'
209
+ ])
210
+ else:
211
+ log_msg = self.delimiter.join([
212
+ header,
213
+ '[{0' + space_fmt + '}/{1}]',
214
+ 'eta: {eta}',
215
+ '{meters}',
216
+ 'time: {time}',
217
+ 'data: {data}'
218
+ ])
219
+ MB = 1024.0 * 1024.0
220
+ for obj in iterable:
221
+ data_time.update(time.time() - end)
222
+ yield obj
223
+ iter_time.update(time.time() - end)
224
+ if i % print_freq == 0 or i == len(iterable) - 1:
225
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
226
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
227
+ if torch.cuda.is_available():
228
+ print(log_msg.format(
229
+ i, len(iterable), eta=eta_string,
230
+ meters=str(self),
231
+ time=str(iter_time), data=str(data_time),
232
+ memory=torch.cuda.max_memory_allocated() / MB))
233
+ else:
234
+ print(log_msg.format(
235
+ i, len(iterable), eta=eta_string,
236
+ meters=str(self),
237
+ time=str(iter_time), data=str(data_time)))
238
+ i += 1
239
+ end = time.time()
240
+ total_time = time.time() - start_time
241
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
242
+ print('{} Total time: {} ({:.4f} s / it)'.format(
243
+ header, total_time_str, total_time / len(iterable)))
244
+
245
+
246
+ def get_sha():
247
+ cwd = os.path.dirname(os.path.abspath(__file__))
248
+
249
+ def _run(command):
250
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
251
+ sha = 'N/A'
252
+ diff = "clean"
253
+ branch = 'N/A'
254
+ try:
255
+ sha = _run(['git', 'rev-parse', 'HEAD'])
256
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
257
+ diff = _run(['git', 'diff-index', 'HEAD'])
258
+ diff = "has uncommited changes" if diff else "clean"
259
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
260
+ except Exception:
261
+ pass
262
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
263
+ return message
264
+
265
+
266
+ def collate_fn(batch):
267
+ batch = list(zip(*batch))
268
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
269
+ return tuple(batch)
270
+
271
+
272
+ def _max_by_axis(the_list):
273
+ # type: (List[List[int]]) -> List[int]
274
+ maxes = the_list[0]
275
+ for sublist in the_list[1:]:
276
+ for index, item in enumerate(sublist):
277
+ maxes[index] = max(maxes[index], item)
278
+ return maxes
279
+
280
+
281
+ class NestedTensor(object):
282
+ def __init__(self, tensors, mask: Optional[Tensor]):
283
+ self.tensors = tensors
284
+ self.mask = mask
285
+
286
+ def to(self, device):
287
+ # type: (Device) -> NestedTensor # noqa
288
+ cast_tensor = self.tensors.to(device)
289
+ mask = self.mask
290
+ if mask is not None:
291
+ assert mask is not None
292
+ cast_mask = mask.to(device)
293
+ else:
294
+ cast_mask = None
295
+ return NestedTensor(cast_tensor, cast_mask)
296
+
297
+ def decompose(self):
298
+ return self.tensors, self.mask
299
+
300
+ def __repr__(self):
301
+ return str(self.tensors)
302
+
303
+
304
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
305
+ # TODO make this more general
306
+ if tensor_list[0].ndim == 3:
307
+ if torchvision._is_tracing():
308
+ # nested_tensor_from_tensor_list() does not export well to ONNX
309
+ # call _onnx_nested_tensor_from_tensor_list() instead
310
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
311
+
312
+ # TODO make it support different-sized images
313
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
314
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
315
+ batch_shape = [len(tensor_list)] + max_size
316
+ b, c, h, w = batch_shape
317
+ dtype = tensor_list[0].dtype
318
+ device = tensor_list[0].device
319
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
320
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
321
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
322
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
323
+ m[: img.shape[1], :img.shape[2]] = False
324
+ else:
325
+ raise ValueError('not supported')
326
+ return NestedTensor(tensor, mask)
327
+
328
+
329
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
330
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
331
+ @torch.jit.unused
332
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
333
+ max_size = []
334
+ for i in range(tensor_list[0].dim()):
335
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
336
+ max_size.append(max_size_i)
337
+ max_size = tuple(max_size)
338
+
339
+ # work around for
340
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
341
+ # m[: img.shape[1], :img.shape[2]] = False
342
+ # which is not yet supported in onnx
343
+ padded_imgs = []
344
+ padded_masks = []
345
+ for img in tensor_list:
346
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
347
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
348
+ padded_imgs.append(padded_img)
349
+
350
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
351
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
352
+ padded_masks.append(padded_mask.to(torch.bool))
353
+
354
+ tensor = torch.stack(padded_imgs)
355
+ mask = torch.stack(padded_masks)
356
+
357
+ return NestedTensor(tensor, mask=mask)
358
+
359
+
360
+ def setup_for_distributed(is_master):
361
+ """
362
+ This function disables printing when not in master process
363
+ """
364
+ import builtins as __builtin__
365
+ builtin_print = __builtin__.print
366
+
367
+ def print(*args, **kwargs):
368
+ force = kwargs.pop('force', False)
369
+ if is_master or force:
370
+ builtin_print(*args, **kwargs)
371
+
372
+ __builtin__.print = print
373
+
374
+
375
+ def is_dist_avail_and_initialized():
376
+ if not dist.is_available():
377
+ return False
378
+ if not dist.is_initialized():
379
+ return False
380
+ return True
381
+
382
+
383
+ def get_world_size():
384
+ if not is_dist_avail_and_initialized():
385
+ return 1
386
+ return dist.get_world_size()
387
+
388
+
389
+ def get_rank():
390
+ if not is_dist_avail_and_initialized():
391
+ return 0
392
+ return dist.get_rank()
393
+
394
+
395
+ def is_main_process():
396
+ return get_rank() == 0
397
+
398
+
399
+ def save_on_master(*args, **kwargs):
400
+ if is_main_process():
401
+ torch.save(*args, **kwargs)
402
+
403
+
404
+ def init_distributed_mode(args):
405
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
406
+ args.rank = int(os.environ["RANK"])
407
+ args.world_size = int(os.environ['WORLD_SIZE'])
408
+ args.gpu = int(os.environ['LOCAL_RANK'])
409
+ elif 'SLURM_PROCID' in os.environ:
410
+ args.rank = int(os.environ['SLURM_PROCID'])
411
+ args.gpu = args.rank % torch.cuda.device_count()
412
+ else:
413
+ print('Not using distributed mode')
414
+ args.distributed = False
415
+ return
416
+
417
+ args.distributed = True
418
+
419
+ torch.cuda.set_device(args.gpu)
420
+ args.dist_backend = 'nccl'
421
+ print('| distributed init (rank {}): {}'.format(
422
+ args.rank, args.dist_url), flush=True)
423
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
424
+ world_size=args.world_size, rank=args.rank)
425
+ torch.distributed.barrier()
426
+ setup_for_distributed(args.rank == 0)
427
+
428
+
429
+ @torch.no_grad()
430
+ def accuracy(output, target, topk=(1,)):
431
+ """Computes the precision@k for the specified values of k"""
432
+ if target.numel() == 0:
433
+ return [torch.zeros([], device=output.device)]
434
+ maxk = max(topk)
435
+ batch_size = target.size(0)
436
+
437
+ _, pred = output.topk(maxk, 1, True, True)
438
+ pred = pred.t()
439
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
440
+
441
+ res = []
442
+ for k in topk:
443
+ correct_k = correct[:k].view(-1).float().sum(0)
444
+ res.append(correct_k.mul_(100.0 / batch_size))
445
+ return res
446
+
447
+
448
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
449
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
450
+ """
451
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
452
+ This will eventually be supported natively by PyTorch, and this
453
+ class can go away.
454
+ """
455
+ if float(torchvision.__version__[:3]) < 0.7:
456
+ if input.numel() > 0:
457
+ return torch.nn.functional.interpolate(
458
+ input, size, scale_factor, mode, align_corners
459
+ )
460
+
461
+ output_shape = _output_size(2, input, size, scale_factor)
462
+ output_shape = list(input.shape[:-2]) + list(output_shape)
463
+ return _new_empty_tensor(input, output_shape)
464
+ else:
465
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
util/models/BigGAN_layers.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Layers
2
+ This file contains various layers for the BigGAN models.
3
+ '''
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import init
8
+ import torch.optim as optim
9
+ import torch.nn.functional as F
10
+ from torch.nn import Parameter as P
11
+
12
+ from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
13
+
14
+ # Projection of x onto y
15
+ def proj(x, y):
16
+ return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
17
+
18
+
19
+ # Orthogonalize x wrt list of vectors ys
20
+ def gram_schmidt(x, ys):
21
+ for y in ys:
22
+ x = x - proj(x, y)
23
+ return x
24
+
25
+
26
+ # Apply num_itrs steps of the power method to estimate top N singular values.
27
+ def power_iteration(W, u_, update=True, eps=1e-12):
28
+ # Lists holding singular vectors and values
29
+ us, vs, svs = [], [], []
30
+ for i, u in enumerate(u_):
31
+ # Run one step of the power iteration
32
+ with torch.no_grad():
33
+ v = torch.matmul(u, W)
34
+ # Run Gram-Schmidt to subtract components of all other singular vectors
35
+ v = F.normalize(gram_schmidt(v, vs), eps=eps)
36
+ # Add to the list
37
+ vs += [v]
38
+ # Update the other singular vector
39
+ u = torch.matmul(v, W.t())
40
+ # Run Gram-Schmidt to subtract components of all other singular vectors
41
+ u = F.normalize(gram_schmidt(u, us), eps=eps)
42
+ # Add to the list
43
+ us += [u]
44
+ if update:
45
+ u_[i][:] = u
46
+ # Compute this singular value and add it to the list
47
+ svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
48
+ # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
49
+ return svs, us, vs
50
+
51
+
52
+ # Convenience passthrough function
53
+ class identity(nn.Module):
54
+ def forward(self, input):
55
+ return input
56
+
57
+
58
+ # Spectral normalization base class
59
+ class SN(object):
60
+ def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
61
+ # Number of power iterations per step
62
+ self.num_itrs = num_itrs
63
+ # Number of singular values
64
+ self.num_svs = num_svs
65
+ # Transposed?
66
+ self.transpose = transpose
67
+ # Epsilon value for avoiding divide-by-0
68
+ self.eps = eps
69
+ # Register a singular vector for each sv
70
+ for i in range(self.num_svs):
71
+ self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
72
+ self.register_buffer('sv%d' % i, torch.ones(1))
73
+
74
+ # Singular vectors (u side)
75
+ @property
76
+ def u(self):
77
+ return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
78
+
79
+ # Singular values;
80
+ # note that these buffers are just for logging and are not used in training.
81
+ @property
82
+ def sv(self):
83
+ return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
84
+
85
+ # Compute the spectrally-normalized weight
86
+ def W_(self):
87
+ W_mat = self.weight.view(self.weight.size(0), -1)
88
+ if self.transpose:
89
+ W_mat = W_mat.t()
90
+ # Apply num_itrs power iterations
91
+ for _ in range(self.num_itrs):
92
+ svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
93
+ # Update the svs
94
+ if self.training:
95
+ with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
96
+ for i, sv in enumerate(svs):
97
+ self.sv[i][:] = sv
98
+ return self.weight / svs[0]
99
+
100
+
101
+ # 2D Conv layer with spectral norm
102
+ class SNConv2d(nn.Conv2d, SN):
103
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
104
+ padding=0, dilation=1, groups=1, bias=True,
105
+ num_svs=1, num_itrs=1, eps=1e-12):
106
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
107
+ padding, dilation, groups, bias)
108
+ SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
109
+
110
+ def forward(self, x):
111
+ return F.conv2d(x, self.W_(), self.bias, self.stride,
112
+ self.padding, self.dilation, self.groups)
113
+
114
+
115
+ # Linear layer with spectral norm
116
+ class SNLinear(nn.Linear, SN):
117
+ def __init__(self, in_features, out_features, bias=True,
118
+ num_svs=1, num_itrs=1, eps=1e-12):
119
+ nn.Linear.__init__(self, in_features, out_features, bias)
120
+ SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
121
+
122
+ def forward(self, x):
123
+ return F.linear(x, self.W_(), self.bias)
124
+
125
+
126
+ # Embedding layer with spectral norm
127
+ # We use num_embeddings as the dim instead of embedding_dim here
128
+ # for convenience sake
129
+ class SNEmbedding(nn.Embedding, SN):
130
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
131
+ max_norm=None, norm_type=2, scale_grad_by_freq=False,
132
+ sparse=False, _weight=None,
133
+ num_svs=1, num_itrs=1, eps=1e-12):
134
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
135
+ max_norm, norm_type, scale_grad_by_freq,
136
+ sparse, _weight)
137
+ SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
138
+
139
+ def forward(self, x):
140
+ return F.embedding(x, self.W_())
141
+
142
+
143
+ # A non-local block as used in SA-GAN
144
+ # Note that the implementation as described in the paper is largely incorrect;
145
+ # refer to the released code for the actual implementation.
146
+ class Attention(nn.Module):
147
+ def __init__(self, ch, which_conv=SNConv2d, name='attention'):
148
+ super(Attention, self).__init__()
149
+ # Channel multiplier
150
+ self.ch = ch
151
+ self.which_conv = which_conv
152
+ self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
153
+ self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
154
+ self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
155
+ self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
156
+ # Learnable gain parameter
157
+ self.gamma = P(torch.tensor(0.), requires_grad=True)
158
+
159
+ def forward(self, x, y=None):
160
+ # Apply convs
161
+ theta = self.theta(x)
162
+ phi = F.max_pool2d(self.phi(x), [2, 2])
163
+ g = F.max_pool2d(self.g(x), [2, 2])
164
+ # Perform reshapes
165
+ theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
166
+ try:
167
+ phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
168
+ except:
169
+ print(phi.shape)
170
+ g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
171
+ # Matmul and softmax to get attention maps
172
+ beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
173
+ # Attention map times g path
174
+ o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
175
+ return self.gamma * o + x
176
+
177
+
178
+ # Fused batchnorm op
179
+ def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
180
+ # Apply scale and shift--if gain and bias are provided, fuse them here
181
+ # Prepare scale
182
+ scale = torch.rsqrt(var + eps)
183
+ # If a gain is provided, use it
184
+ if gain is not None:
185
+ scale = scale * gain
186
+ # Prepare shift
187
+ shift = mean * scale
188
+ # If bias is provided, use it
189
+ if bias is not None:
190
+ shift = shift - bias
191
+ return x * scale - shift
192
+ # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
193
+
194
+
195
+ # Manual BN
196
+ # Calculate means and variances using mean-of-squares minus mean-squared
197
+ def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
198
+ # Cast x to float32 if necessary
199
+ float_x = x.float()
200
+ # Calculate expected value of x (m) and expected value of x**2 (m2)
201
+ # Mean of x
202
+ m = torch.mean(float_x, [0, 2, 3], keepdim=True)
203
+ # Mean of x squared
204
+ m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
205
+ # Calculate variance as mean of squared minus mean squared.
206
+ var = (m2 - m ** 2)
207
+ # Cast back to float 16 if necessary
208
+ var = var.type(x.type())
209
+ m = m.type(x.type())
210
+ # Return mean and variance for updating stored mean/var if requested
211
+ if return_mean_var:
212
+ return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
213
+ else:
214
+ return fused_bn(x, m, var, gain, bias, eps)
215
+
216
+
217
+ # My batchnorm, supports standing stats
218
+ class myBN(nn.Module):
219
+ def __init__(self, num_channels, eps=1e-5, momentum=0.1):
220
+ super(myBN, self).__init__()
221
+ # momentum for updating running stats
222
+ self.momentum = momentum
223
+ # epsilon to avoid dividing by 0
224
+ self.eps = eps
225
+ # Momentum
226
+ self.momentum = momentum
227
+ # Register buffers
228
+ self.register_buffer('stored_mean', torch.zeros(num_channels))
229
+ self.register_buffer('stored_var', torch.ones(num_channels))
230
+ self.register_buffer('accumulation_counter', torch.zeros(1))
231
+ # Accumulate running means and vars
232
+ self.accumulate_standing = False
233
+
234
+ # reset standing stats
235
+ def reset_stats(self):
236
+ self.stored_mean[:] = 0
237
+ self.stored_var[:] = 0
238
+ self.accumulation_counter[:] = 0
239
+
240
+ def forward(self, x, gain, bias):
241
+ if self.training:
242
+ out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
243
+ # If accumulating standing stats, increment them
244
+ if self.accumulate_standing:
245
+ self.stored_mean[:] = self.stored_mean + mean.data
246
+ self.stored_var[:] = self.stored_var + var.data
247
+ self.accumulation_counter += 1.0
248
+ # If not accumulating standing stats, take running averages
249
+ else:
250
+ self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
251
+ self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
252
+ return out
253
+ # If not in training mode, use the stored statistics
254
+ else:
255
+ mean = self.stored_mean.view(1, -1, 1, 1)
256
+ var = self.stored_var.view(1, -1, 1, 1)
257
+ # If using standing stats, divide them by the accumulation counter
258
+ if self.accumulate_standing:
259
+ mean = mean / self.accumulation_counter
260
+ var = var / self.accumulation_counter
261
+ return fused_bn(x, mean, var, gain, bias, self.eps)
262
+
263
+
264
+ # Simple function to handle groupnorm norm stylization
265
+ def groupnorm(x, norm_style):
266
+ # If number of channels specified in norm_style:
267
+ if 'ch' in norm_style:
268
+ ch = int(norm_style.split('_')[-1])
269
+ groups = max(int(x.shape[1]) // ch, 1)
270
+ # If number of groups specified in norm style
271
+ elif 'grp' in norm_style:
272
+ groups = int(norm_style.split('_')[-1])
273
+ # If neither, default to groups = 16
274
+ else:
275
+ groups = 16
276
+ return F.group_norm(x, groups)
277
+
278
+
279
+ # Class-conditional bn
280
+ # output size is the number of channels, input size is for the linear layers
281
+ # Andy's Note: this class feels messy but I'm not really sure how to clean it up
282
+ # Suggestions welcome! (By which I mean, refactor this and make a pull request
283
+ # if you want to make this more readable/usable).
284
+ class ccbn(nn.Module):
285
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
286
+ cross_replica=False, mybn=False, norm_style='bn', ):
287
+ super(ccbn, self).__init__()
288
+ self.output_size, self.input_size = output_size, input_size
289
+ # Prepare gain and bias layers
290
+ self.gain = which_linear(input_size, output_size)
291
+ self.bias = which_linear(input_size, output_size)
292
+ # epsilon to avoid dividing by 0
293
+ self.eps = eps
294
+ # Momentum
295
+ self.momentum = momentum
296
+ # Use cross-replica batchnorm?
297
+ self.cross_replica = cross_replica
298
+ # Use my batchnorm?
299
+ self.mybn = mybn
300
+ # Norm style?
301
+ self.norm_style = norm_style
302
+
303
+ if self.cross_replica:
304
+ self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
305
+ elif self.mybn:
306
+ self.bn = myBN(output_size, self.eps, self.momentum)
307
+ elif self.norm_style in ['bn', 'in']:
308
+ self.register_buffer('stored_mean', torch.zeros(output_size))
309
+ self.register_buffer('stored_var', torch.ones(output_size))
310
+
311
+ def forward(self, x, y):
312
+ # Calculate class-conditional gains and biases
313
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
314
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
315
+ # If using my batchnorm
316
+ if self.mybn or self.cross_replica:
317
+ return self.bn(x, gain=gain, bias=bias)
318
+ # else:
319
+ else:
320
+ if self.norm_style == 'bn':
321
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
322
+ self.training, 0.1, self.eps)
323
+ elif self.norm_style == 'in':
324
+ out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
325
+ self.training, 0.1, self.eps)
326
+ elif self.norm_style == 'gn':
327
+ out = groupnorm(x, self.normstyle)
328
+ elif self.norm_style == 'nonorm':
329
+ out = x
330
+ return out * gain + bias
331
+
332
+ def extra_repr(self):
333
+ s = 'out: {output_size}, in: {input_size},'
334
+ s += ' cross_replica={cross_replica}'
335
+ return s.format(**self.__dict__)
336
+
337
+
338
+ # Normal, non-class-conditional BN
339
+ class bn(nn.Module):
340
+ def __init__(self, output_size, eps=1e-5, momentum=0.1,
341
+ cross_replica=False, mybn=False):
342
+ super(bn, self).__init__()
343
+ self.output_size = output_size
344
+ # Prepare gain and bias layers
345
+ self.gain = P(torch.ones(output_size), requires_grad=True)
346
+ self.bias = P(torch.zeros(output_size), requires_grad=True)
347
+ # epsilon to avoid dividing by 0
348
+ self.eps = eps
349
+ # Momentum
350
+ self.momentum = momentum
351
+ # Use cross-replica batchnorm?
352
+ self.cross_replica = cross_replica
353
+ # Use my batchnorm?
354
+ self.mybn = mybn
355
+
356
+ if self.cross_replica:
357
+ self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
358
+ elif mybn:
359
+ self.bn = myBN(output_size, self.eps, self.momentum)
360
+ # Register buffers if neither of the above
361
+ else:
362
+ self.register_buffer('stored_mean', torch.zeros(output_size))
363
+ self.register_buffer('stored_var', torch.ones(output_size))
364
+
365
+ def forward(self, x, y=None):
366
+ if self.cross_replica or self.mybn:
367
+ gain = self.gain.view(1, -1, 1, 1)
368
+ bias = self.bias.view(1, -1, 1, 1)
369
+ return self.bn(x, gain=gain, bias=bias)
370
+ else:
371
+ return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
372
+ self.bias, self.training, self.momentum, self.eps)
373
+
374
+
375
+ # Generator blocks
376
+ # Note that this class assumes the kernel size and padding (and any other
377
+ # settings) have been selected in the main generator module and passed in
378
+ # through the which_conv arg. Similar rules apply with which_bn (the input
379
+ # size [which is actually the number of channels of the conditional info] must
380
+ # be preselected)
381
+ class GBlock(nn.Module):
382
+ def __init__(self, in_channels, out_channels,
383
+ which_conv1=nn.Conv2d, which_conv2=nn.Conv2d, which_bn=bn, activation=None,
384
+ upsample=None):
385
+ super(GBlock, self).__init__()
386
+
387
+ self.in_channels, self.out_channels = in_channels, out_channels
388
+ self.which_conv1, self.which_conv2, self.which_bn = which_conv1, which_conv2, which_bn
389
+ self.activation = activation
390
+ self.upsample = upsample
391
+ # Conv layers
392
+ self.conv1 = self.which_conv1(self.in_channels, self.out_channels)
393
+ self.conv2 = self.which_conv2(self.out_channels, self.out_channels)
394
+ self.learnable_sc = in_channels != out_channels or upsample
395
+ if self.learnable_sc:
396
+ self.conv_sc = self.which_conv1(in_channels, out_channels,
397
+ kernel_size=1, padding=0)
398
+ # Batchnorm layers
399
+ self.bn1 = self.which_bn(in_channels)
400
+ self.bn2 = self.which_bn(out_channels)
401
+ # upsample layers
402
+ self.upsample = upsample
403
+
404
+ def forward(self, x, y):
405
+ h = self.activation(self.bn1(x, y))
406
+ # h = self.activation(x)
407
+ # h=x
408
+ if self.upsample:
409
+ h = self.upsample(h)
410
+ x = self.upsample(x)
411
+ h = self.conv1(h)
412
+ h = self.activation(self.bn2(h, y))
413
+ # h = self.activation(h)
414
+ h = self.conv2(h)
415
+ if self.learnable_sc:
416
+ x = self.conv_sc(x)
417
+ return h + x
418
+
419
+
420
+ # Residual block for the discriminator
421
+ class DBlock(nn.Module):
422
+ def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
423
+ preactivation=False, activation=None, downsample=None, ):
424
+ super(DBlock, self).__init__()
425
+ self.in_channels, self.out_channels = in_channels, out_channels
426
+ # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
427
+ self.hidden_channels = self.out_channels if wide else self.in_channels
428
+ self.which_conv = which_conv
429
+ self.preactivation = preactivation
430
+ self.activation = activation
431
+ self.downsample = downsample
432
+
433
+ # Conv layers
434
+ self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
435
+ self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
436
+ self.learnable_sc = True if (in_channels != out_channels) or downsample else False
437
+ if self.learnable_sc:
438
+ self.conv_sc = self.which_conv(in_channels, out_channels,
439
+ kernel_size=1, padding=0)
440
+
441
+ def shortcut(self, x):
442
+ if self.preactivation:
443
+ if self.learnable_sc:
444
+ x = self.conv_sc(x)
445
+ if self.downsample:
446
+ x = self.downsample(x)
447
+ else:
448
+ if self.downsample:
449
+ x = self.downsample(x)
450
+ if self.learnable_sc:
451
+ x = self.conv_sc(x)
452
+ return x
453
+
454
+ def forward(self, x):
455
+ if self.preactivation:
456
+ # h = self.activation(x) # NOT TODAY SATAN
457
+ # Andy's note: This line *must* be an out-of-place ReLU or it
458
+ # will negatively affect the shortcut connection.
459
+ h = F.relu(x)
460
+ else:
461
+ h = x
462
+ h = self.conv1(h)
463
+ h = self.conv2(self.activation(h))
464
+ if self.downsample:
465
+ h = self.downsample(h)
466
+
467
+ return h + self.shortcut(x)
468
+
469
+ # dogball
util/models/BigGAN_networks.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import numpy as np
5
+ import math
6
+ import functools
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import init
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from torch.nn import Parameter as P
14
+ from .transformer import Transformer
15
+ from . import BigGAN_layers as layers
16
+ from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
17
+ from util.util import to_device, load_network
18
+ from .networks import init_weights
19
+ from params import *
20
+ # Attention is passed in in the format '32_64' to mean applying an attention
21
+ # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
22
+
23
+ from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
24
+
25
+ class Decoder(nn.Module):
26
+ def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
27
+ super(Decoder, self).__init__()
28
+
29
+ self.model = []
30
+ self.model += [ResBlocks(n_res, dim, res_norm,
31
+ activ, pad_type=pad_type)]
32
+ for i in range(ups):
33
+ self.model += [nn.Upsample(scale_factor=2),
34
+ Conv2dBlock(dim, dim // 2, 5, 1, 2,
35
+ norm='in',
36
+ activation=activ,
37
+ pad_type=pad_type)]
38
+ dim //= 2
39
+ self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
40
+ norm='none',
41
+ activation='tanh',
42
+ pad_type=pad_type)]
43
+ self.model = nn.Sequential(*self.model)
44
+
45
+ def forward(self, x):
46
+ y = self.model(x)
47
+
48
+ return y
49
+
50
+
51
+
52
+ def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
53
+ arch = {}
54
+ arch[512] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
55
+ 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
56
+ 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
57
+ 'resolution': [8, 16, 32, 64, 128, 256, 512],
58
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
59
+ for i in range(3, 10)}}
60
+ arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
61
+ 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
62
+ 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
63
+ 'resolution': [8, 16, 32, 64, 128, 256],
64
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
65
+ for i in range(3, 9)}}
66
+ arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
67
+ 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
68
+ 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
69
+ 'resolution': [8, 16, 32, 64, 128],
70
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
71
+ for i in range(3, 8)}}
72
+ arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
73
+ 'out_channels': [ch * item for item in [16, 8, 4, 2]],
74
+ 'upsample': [(2, 2), (2, 2), (2, 2), (2, 2)],
75
+ 'resolution': [8, 16, 32, 64],
76
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
77
+ for i in range(3, 7)}}
78
+
79
+ arch[63] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
80
+ 'out_channels': [ch * item for item in [16, 8, 4, 2]],
81
+ 'upsample': [(2, 2), (2, 2), (2, 2), (2,1)],
82
+ 'resolution': [8, 16, 32, 64],
83
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
84
+ for i in range(3, 7)},
85
+ 'kernel1': [3, 3, 3, 3],
86
+ 'kernel2': [3, 3, 1, 1]
87
+ }
88
+
89
+ arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
90
+ 'out_channels': [ch * item for item in [4, 4, 4]],
91
+ 'upsample': [(2, 2), (2, 2), (2, 2)],
92
+ 'resolution': [8, 16, 32],
93
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
94
+ for i in range(3, 6)}}
95
+
96
+ arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
97
+ 'out_channels': [ch * item for item in [4, 4, 4]],
98
+ 'upsample': [(2, 2), (2, 2), (2, 2)],
99
+ 'resolution': [8, 16, 32],
100
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
101
+ for i in range(3, 6)},
102
+ 'kernel1': [3, 3, 3],
103
+ 'kernel2': [3, 3, 1]
104
+ }
105
+
106
+ arch[129] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
107
+ 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
108
+ 'upsample': [(2,2), (2,2), (2,2), (2,2), (2,2), (1,2), (1,2)],
109
+ 'resolution': [8, 16, 32, 64, 128, 256, 512],
110
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
111
+ for i in range(3, 10)}}
112
+
113
+ arch[33] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
114
+ 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
115
+ 'upsample': [(2,2), (2,2), (2,2), (1,2), (1,2)],
116
+ 'resolution': [8, 16, 32, 64, 128],
117
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
118
+ for i in range(3, 8)}}
119
+
120
+ arch[31] = {'in_channels': [ch * item for item in [16, 16, 4, 2]],
121
+ 'out_channels': [ch * item for item in [16, 4, 2, 1]],
122
+ 'upsample': [(2,2), (2,2), (2,2), (1,2)],
123
+ 'resolution': [8, 16, 32, 64],
124
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
125
+ for i in range(3, 7)},
126
+ 'kernel1':[3, 3, 3, 3],
127
+ 'kernel2': [3, 1, 1, 1]}
128
+
129
+ arch[16] = {'in_channels': [ch * item for item in [8, 4, 2]],
130
+ 'out_channels': [ch * item for item in [4, 2, 1]],
131
+ 'upsample': [(2,2), (2,2), (2,1)],
132
+ 'resolution': [8, 16, 16],
133
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
134
+ for i in range(3, 6)},
135
+ 'kernel1':[3, 3, 3],
136
+ 'kernel2': [3, 3, 1]}
137
+
138
+ arch[17] = {'in_channels': [ch * item for item in [8, 4, 2]],
139
+ 'out_channels': [ch * item for item in [4, 2, 1]],
140
+ 'upsample': [(2,2), (2,2), (2,1)],
141
+ 'resolution': [8, 16, 16],
142
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
143
+ for i in range(3, 6)},
144
+ 'kernel1':[3, 3, 3],
145
+ 'kernel2': [3, 3, 1]}
146
+
147
+ arch[20] = {'in_channels': [ch * item for item in [8, 4, 2]],
148
+ 'out_channels': [ch * item for item in [4, 2, 1]],
149
+ 'upsample': [(2,2), (2,2), (2,1)],
150
+ 'resolution': [8, 16, 16],
151
+ 'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
152
+ for i in range(3, 6)},
153
+ 'kernel1':[3, 3, 3],
154
+ 'kernel2': [3, 1, 1]}
155
+
156
+ return arch
157
+
158
+
159
+ class Generator(nn.Module):
160
+ def __init__(self, G_ch=64, dim_z=128, bottom_width=4, bottom_height=4,resolution=128,
161
+ G_kernel_size=3, G_attn='64', n_classes=1000,
162
+ num_G_SVs=1, num_G_SV_itrs=1,
163
+ G_shared=True, shared_dim=0, no_hier=False,
164
+ cross_replica=False, mybn=False,
165
+ G_activation=nn.ReLU(inplace=False),
166
+ BN_eps=1e-5, SN_eps=1e-12, G_fp16=False,
167
+ G_init='ortho', skip_init=False,
168
+ G_param='SN', norm_style='bn',gpu_ids=[], bn_linear='embed', input_nc=3,
169
+ one_hot=False, first_layer=False, one_hot_k=1,
170
+ **kwargs):
171
+ super(Generator, self).__init__()
172
+ self.name = 'G'
173
+ # Use class only in first layer
174
+ self.first_layer = first_layer
175
+ # gpu-ids
176
+ self.gpu_ids = gpu_ids
177
+ # Use one hot vector representation for input class
178
+ self.one_hot = one_hot
179
+ # Use one hot k vector representation for input class if k is larger than 0. If it's 0, simly use the class number and not a k-hot encoding.
180
+ self.one_hot_k = one_hot_k
181
+ # Channel width mulitplier
182
+ self.ch = G_ch
183
+ # Dimensionality of the latent space
184
+ self.dim_z = dim_z
185
+ # The initial width dimensions
186
+ self.bottom_width = bottom_width
187
+ # The initial height dimension
188
+ self.bottom_height = bottom_height
189
+ # Resolution of the output
190
+ self.resolution = resolution
191
+ # Kernel size?
192
+ self.kernel_size = G_kernel_size
193
+ # Attention?
194
+ self.attention = G_attn
195
+ # number of classes, for use in categorical conditional generation
196
+ self.n_classes = n_classes
197
+ # Use shared embeddings?
198
+ self.G_shared = G_shared
199
+ # Dimensionality of the shared embedding? Unused if not using G_shared
200
+ self.shared_dim = shared_dim if shared_dim > 0 else dim_z
201
+ # Hierarchical latent space?
202
+ self.hier = not no_hier
203
+ # Cross replica batchnorm?
204
+ self.cross_replica = cross_replica
205
+ # Use my batchnorm?
206
+ self.mybn = mybn
207
+ # nonlinearity for residual blocks
208
+ self.activation = G_activation
209
+ # Initialization style
210
+ self.init = G_init
211
+ # Parameterization style
212
+ self.G_param = G_param
213
+ # Normalization style
214
+ self.norm_style = norm_style
215
+ # Epsilon for BatchNorm?
216
+ self.BN_eps = BN_eps
217
+ # Epsilon for Spectral Norm?
218
+ self.SN_eps = SN_eps
219
+ # fp16?
220
+ self.fp16 = G_fp16
221
+ # Architecture dict
222
+ self.arch = G_arch(self.ch, self.attention)[resolution]
223
+ self.bn_linear = bn_linear
224
+
225
+ #self.transformer = Transformer(d_model = 2560)
226
+ #self.input_proj = nn.Conv2d(512, 2560, kernel_size=1)
227
+ self.linear_q = nn.Linear(512,2048*2)
228
+
229
+ self.DETR = build()
230
+ self.DEC = Decoder(res_norm = 'in')
231
+ # If using hierarchical latents, adjust z
232
+ if self.hier:
233
+ # Number of places z slots into
234
+ self.num_slots = len(self.arch['in_channels']) + 1
235
+ self.z_chunk_size = (self.dim_z // self.num_slots)
236
+ # Recalculate latent dimensionality for even splitting into chunks
237
+ self.dim_z = self.z_chunk_size * self.num_slots
238
+ else:
239
+ self.num_slots = 1
240
+ self.z_chunk_size = 0
241
+
242
+ # Which convs, batchnorms, and linear layers to use
243
+ if self.G_param == 'SN':
244
+ self.which_conv = functools.partial(layers.SNConv2d,
245
+ kernel_size=3, padding=1,
246
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
247
+ eps=self.SN_eps)
248
+ self.which_linear = functools.partial(layers.SNLinear,
249
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
250
+ eps=self.SN_eps)
251
+ else:
252
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
253
+ self.which_linear = nn.Linear
254
+
255
+ # We use a non-spectral-normed embedding here regardless;
256
+ # For some reason applying SN to G's embedding seems to randomly cripple G
257
+ if one_hot:
258
+ self.which_embedding = functools.partial(layers.SNLinear,
259
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
260
+ eps=self.SN_eps)
261
+ else:
262
+ self.which_embedding = nn.Embedding
263
+
264
+ bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
265
+ else self.which_embedding)
266
+ if self.bn_linear=='SN':
267
+ bn_linear = functools.partial(self.which_linear, bias=False)
268
+ if self.G_shared:
269
+ input_size = self.shared_dim + self.z_chunk_size
270
+ elif self.hier:
271
+ if self.first_layer:
272
+ input_size = self.z_chunk_size
273
+ else:
274
+ input_size = self.n_classes + self.z_chunk_size
275
+ self.which_bn = functools.partial(layers.ccbn,
276
+ which_linear=bn_linear,
277
+ cross_replica=self.cross_replica,
278
+ mybn=self.mybn,
279
+ input_size=input_size,
280
+ norm_style=self.norm_style,
281
+ eps=self.BN_eps)
282
+ else:
283
+ input_size = self.n_classes
284
+ self.which_bn = functools.partial(layers.bn,
285
+ cross_replica=self.cross_replica,
286
+ mybn=self.mybn,
287
+ eps=self.BN_eps)
288
+
289
+
290
+
291
+
292
+ # Prepare model
293
+ # If not using shared embeddings, self.shared is just a passthrough
294
+ self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
295
+ else layers.identity())
296
+ # First linear layer
297
+ # The parameters for the first linear layer depend on the different input variations.
298
+ if self.first_layer:
299
+ if self.one_hot:
300
+ self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes,
301
+ self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
302
+ else:
303
+ self.linear = self.which_linear(self.dim_z // self.num_slots + 1,
304
+ self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
305
+ if self.one_hot_k==1:
306
+ self.linear = self.which_linear((self.dim_z // self.num_slots) * self.n_classes,
307
+ self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
308
+ if self.one_hot_k>1:
309
+ self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes*self.one_hot_k,
310
+ self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
311
+
312
+
313
+ else:
314
+ self.linear = self.which_linear(self.dim_z // self.num_slots,
315
+ self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
316
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
317
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
318
+ # while the inner loop is over a given block
319
+ self.blocks = []
320
+ for index in range(len(self.arch['out_channels'])):
321
+ if 'kernel1' in self.arch.keys():
322
+ padd1 = 1 if self.arch['kernel1'][index]>1 else 0
323
+ padd2 = 1 if self.arch['kernel2'][index]>1 else 0
324
+ conv1 = functools.partial(layers.SNConv2d,
325
+ kernel_size=self.arch['kernel1'][index], padding=padd1,
326
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
327
+ eps=self.SN_eps)
328
+ conv2 = functools.partial(layers.SNConv2d,
329
+ kernel_size=self.arch['kernel2'][index], padding=padd2,
330
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
331
+ eps=self.SN_eps)
332
+ self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
333
+ out_channels=self.arch['out_channels'][index],
334
+ which_conv1=conv1,
335
+ which_conv2=conv2,
336
+ which_bn=self.which_bn,
337
+ activation=self.activation,
338
+ upsample=(functools.partial(F.interpolate,
339
+ scale_factor=self.arch['upsample'][index])
340
+ if index < len(self.arch['upsample']) else None))]]
341
+ else:
342
+ self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
343
+ out_channels=self.arch['out_channels'][index],
344
+ which_conv1=self.which_conv,
345
+ which_conv2=self.which_conv,
346
+ which_bn=self.which_bn,
347
+ activation=self.activation,
348
+ upsample=(functools.partial(F.interpolate, scale_factor=self.arch['upsample'][index])
349
+ if index < len(self.arch['upsample']) else None))]]
350
+
351
+ # If attention on this block, attach it to the end
352
+ if self.arch['attention'][self.arch['resolution'][index]]:
353
+ print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
354
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
355
+
356
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
357
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
358
+
359
+ # output layer: batchnorm-relu-conv.
360
+ # Consider using a non-spectral conv here
361
+ self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
362
+ cross_replica=self.cross_replica,
363
+ mybn=self.mybn),
364
+ self.activation,
365
+ self.which_conv(self.arch['out_channels'][-1], input_nc))
366
+
367
+ # Initialize weights. Optionally skip init for testing.
368
+ if not skip_init:
369
+ self = init_weights(self, G_init)
370
+
371
+ # Note on this forward function: we pass in a y vector which has
372
+ # already been passed through G.shared to enable easy class-wise
373
+ # interpolation later. If we passed in the one-hot and then ran it through
374
+ # G.shared in this forward function, it would be harder to handle.
375
+ def forward(self, x, y_ind, y):
376
+ # If hierarchical, concatenate zs and ys
377
+
378
+
379
+ h_all = self.DETR(x, y_ind)
380
+ #h_all = torch.stack([h_all, h_all, h_all])
381
+
382
+ #h_all_bs = torch.unbind(h_all[-1], 0)
383
+ #y_bs = torch.unbind(y_ind, 0)
384
+
385
+ #h = torch.stack([h_i[y_j] for h_i,y_j in zip(h_all_bs, y_bs)], 0)
386
+
387
+
388
+
389
+
390
+ h = self.linear_q(h_all)
391
+
392
+
393
+ h = h.contiguous()
394
+ # Reshape - when y is not a single class value but rather an array of classes, the reshape is needed to create
395
+ # a separate vertical patch for each input.
396
+ if self.first_layer:
397
+ # correct reshape
398
+ h = h.view(h.size(0), h.shape[1]*2, 4, -1)
399
+ h = h.permute(0, 3, 2, 1)
400
+
401
+ else:
402
+ h = h.view(h.size(0), -1, self.bottom_width, self.bottom_height)
403
+
404
+
405
+ #for index, blocklist in enumerate(self.blocks):
406
+ # Second inner loop in case block has multiple layers
407
+ # for block in blocklist:
408
+ # h = block(h, ys[index])
409
+
410
+ #Apply batchnorm-relu-conv-tanh at output
411
+ # h = torch.tanh(self.output_layer(h))
412
+
413
+ h = self.DEC(h)
414
+ return h
415
+
416
+
417
+
418
+
419
+
420
+
421
+
422
+ # Discriminator architecture, same paradigm as G's above
423
+ def D_arch(ch=64, attention='64', input_nc=3, ksize='333333', dilation='111111'):
424
+ arch = {}
425
+ arch[256] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
426
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
427
+ 'downsample': [True] * 6 + [False],
428
+ 'resolution': [128, 64, 32, 16, 8, 4, 4],
429
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
430
+ for i in range(2, 8)}}
431
+ arch[128] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
432
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
433
+ 'downsample': [True] * 5 + [False],
434
+ 'resolution': [64, 32, 16, 8, 4, 4],
435
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
436
+ for i in range(2, 8)}}
437
+ arch[64] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
438
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
439
+ 'downsample': [True] * 4 + [False],
440
+ 'resolution': [32, 16, 8, 4, 4],
441
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
442
+ for i in range(2, 7)}}
443
+ arch[63] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
444
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
445
+ 'downsample': [True] * 4 + [False],
446
+ 'resolution': [32, 16, 8, 4, 4],
447
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
448
+ for i in range(2, 7)}}
449
+ arch[32] = {'in_channels': [input_nc] + [item * ch for item in [4, 4, 4]],
450
+ 'out_channels': [item * ch for item in [4, 4, 4, 4]],
451
+ 'downsample': [True, True, False, False],
452
+ 'resolution': [16, 16, 16, 16],
453
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
454
+ for i in range(2, 6)}}
455
+ arch[129] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
456
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
457
+ 'downsample': [True] * 6 + [False],
458
+ 'resolution': [128, 64, 32, 16, 8, 4, 4],
459
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
460
+ for i in range(2, 8)}}
461
+ arch[33] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
462
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
463
+ 'downsample': [True] * 5 + [False],
464
+ 'resolution': [64, 32, 16, 8, 4, 4],
465
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
466
+ for i in range(2, 10)}}
467
+ arch[31] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
468
+ 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
469
+ 'downsample': [True] * 5 + [False],
470
+ 'resolution': [64, 32, 16, 8, 4, 4],
471
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
472
+ for i in range(2, 10)}}
473
+ arch[16] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
474
+ 'out_channels': [item * ch for item in [1, 8, 16, 16]],
475
+ 'downsample': [True] * 3 + [False],
476
+ 'resolution': [16, 8, 4, 4],
477
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
478
+ for i in range(2, 5)}}
479
+
480
+ arch[17] = {'in_channels': [input_nc] + [ch * item for item in [1, 4]],
481
+ 'out_channels': [item * ch for item in [1, 4, 8]],
482
+ 'downsample': [True] * 3,
483
+ 'resolution': [16, 8, 4],
484
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
485
+ for i in range(2, 5)}}
486
+
487
+
488
+ arch[20] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
489
+ 'out_channels': [item * ch for item in [1, 8, 16, 16]],
490
+ 'downsample': [True] * 3 + [False],
491
+ 'resolution': [16, 8, 4, 4],
492
+ 'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
493
+ for i in range(2, 5)}}
494
+ return arch
495
+
496
+
497
+ class Discriminator(nn.Module):
498
+
499
+ def __init__(self, D_ch=64, D_wide=True, resolution=resolution,
500
+ D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE,
501
+ num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
502
+ SN_eps=1e-8, output_dim=1, D_mixed_precision=False, D_fp16=False,
503
+ D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs):
504
+
505
+ super(Discriminator, self).__init__()
506
+ self.name = 'D'
507
+ # gpu_ids
508
+ self.gpu_ids = gpu_ids
509
+ # one_hot representation
510
+ self.one_hot = one_hot
511
+ # Width multiplier
512
+ self.ch = D_ch
513
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
514
+ self.D_wide = D_wide
515
+ # Resolution
516
+ self.resolution = resolution
517
+ # Kernel size
518
+ self.kernel_size = D_kernel_size
519
+ # Attention?
520
+ self.attention = D_attn
521
+ # Number of classes
522
+ self.n_classes = n_classes
523
+ # Activation
524
+ self.activation = D_activation
525
+ # Initialization style
526
+ self.init = D_init
527
+ # Parameterization style
528
+ self.D_param = D_param
529
+ # Epsilon for Spectral Norm?
530
+ self.SN_eps = SN_eps
531
+ # Fp16?
532
+ self.fp16 = D_fp16
533
+ # Architecture
534
+ self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
535
+
536
+ # Which convs, batchnorms, and linear layers to use
537
+ # No option to turn off SN in D right now
538
+ if self.D_param == 'SN':
539
+ self.which_conv = functools.partial(layers.SNConv2d,
540
+ kernel_size=3, padding=1,
541
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
542
+ eps=self.SN_eps)
543
+ self.which_linear = functools.partial(layers.SNLinear,
544
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
545
+ eps=self.SN_eps)
546
+ self.which_embedding = functools.partial(layers.SNEmbedding,
547
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
548
+ eps=self.SN_eps)
549
+ if bn_linear=='SN':
550
+ self.which_embedding = functools.partial(layers.SNLinear,
551
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
552
+ eps=self.SN_eps)
553
+ else:
554
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
555
+ self.which_linear = nn.Linear
556
+ # We use a non-spectral-normed embedding here regardless;
557
+ # For some reason applying SN to G's embedding seems to randomly cripple G
558
+ self.which_embedding = nn.Embedding
559
+ if one_hot:
560
+ self.which_embedding = functools.partial(layers.SNLinear,
561
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
562
+ eps=self.SN_eps)
563
+ # Prepare model
564
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
565
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
566
+ self.blocks = []
567
+ for index in range(len(self.arch['out_channels'])):
568
+ self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
569
+ out_channels=self.arch['out_channels'][index],
570
+ which_conv=self.which_conv,
571
+ wide=self.D_wide,
572
+ activation=self.activation,
573
+ preactivation=(index > 0),
574
+ downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
575
+ # If attention on this block, attach it to the end
576
+ if self.arch['attention'][self.arch['resolution'][index]]:
577
+ print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
578
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
579
+ self.which_conv)]
580
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
581
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
582
+ # Linear output layer. The output dimension is typically 1, but may be
583
+ # larger if we're e.g. turning this into a VAE with an inference output
584
+ self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
585
+ # Embedding for projection discrimination
586
+ self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
587
+
588
+ # Initialize weights
589
+ if not skip_init:
590
+ self = init_weights(self, D_init)
591
+
592
+ def forward(self, x, y=None, **kwargs):
593
+ # Stick x into h for cleaner for loops without flow control
594
+ h = x
595
+ # Loop over blocks
596
+ for index, blocklist in enumerate(self.blocks):
597
+ for block in blocklist:
598
+ h = block(h)
599
+ # Apply global sum pooling as in SN-GAN
600
+ h = torch.sum(self.activation(h), [2, 3])
601
+ # Get initial class-unconditional output
602
+ out = self.linear(h)
603
+ # Get projection of final featureset onto class vectors and add to evidence
604
+ if y is not None:
605
+ out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
606
+ return out
607
+
608
+ def return_features(self, x, y=None):
609
+ # Stick x into h for cleaner for loops without flow control
610
+ h = x
611
+ block_output = []
612
+ # Loop over blocks
613
+ for index, blocklist in enumerate(self.blocks):
614
+ for block in blocklist:
615
+ h = block(h)
616
+ block_output.append(h)
617
+ # Apply global sum pooling as in SN-GAN
618
+ # h = torch.sum(self.activation(h), [2, 3])
619
+ return block_output
620
+
621
+
622
+
623
+
624
+ class WDiscriminator(nn.Module):
625
+
626
+ def __init__(self, D_ch=64, D_wide=True, resolution=resolution,
627
+ D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE,
628
+ num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
629
+ SN_eps=1e-8, output_dim=NUM_WRITERS, D_mixed_precision=False, D_fp16=False,
630
+ D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs):
631
+ super(WDiscriminator, self).__init__()
632
+ self.name = 'D'
633
+ # gpu_ids
634
+ self.gpu_ids = gpu_ids
635
+ # one_hot representation
636
+ self.one_hot = one_hot
637
+ # Width multiplier
638
+ self.ch = D_ch
639
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
640
+ self.D_wide = D_wide
641
+ # Resolution
642
+ self.resolution = resolution
643
+ # Kernel size
644
+ self.kernel_size = D_kernel_size
645
+ # Attention?
646
+ self.attention = D_attn
647
+ # Number of classes
648
+ self.n_classes = n_classes
649
+ # Activation
650
+ self.activation = D_activation
651
+ # Initialization style
652
+ self.init = D_init
653
+ # Parameterization style
654
+ self.D_param = D_param
655
+ # Epsilon for Spectral Norm?
656
+ self.SN_eps = SN_eps
657
+ # Fp16?
658
+ self.fp16 = D_fp16
659
+ # Architecture
660
+ self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
661
+
662
+ # Which convs, batchnorms, and linear layers to use
663
+ # No option to turn off SN in D right now
664
+ if self.D_param == 'SN':
665
+ self.which_conv = functools.partial(layers.SNConv2d,
666
+ kernel_size=3, padding=1,
667
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
668
+ eps=self.SN_eps)
669
+ self.which_linear = functools.partial(layers.SNLinear,
670
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
671
+ eps=self.SN_eps)
672
+ self.which_embedding = functools.partial(layers.SNEmbedding,
673
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
674
+ eps=self.SN_eps)
675
+ if bn_linear=='SN':
676
+ self.which_embedding = functools.partial(layers.SNLinear,
677
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
678
+ eps=self.SN_eps)
679
+ else:
680
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
681
+ self.which_linear = nn.Linear
682
+ # We use a non-spectral-normed embedding here regardless;
683
+ # For some reason applying SN to G's embedding seems to randomly cripple G
684
+ self.which_embedding = nn.Embedding
685
+ if one_hot:
686
+ self.which_embedding = functools.partial(layers.SNLinear,
687
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
688
+ eps=self.SN_eps)
689
+ # Prepare model
690
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
691
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
692
+ self.blocks = []
693
+ for index in range(len(self.arch['out_channels'])):
694
+ self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
695
+ out_channels=self.arch['out_channels'][index],
696
+ which_conv=self.which_conv,
697
+ wide=self.D_wide,
698
+ activation=self.activation,
699
+ preactivation=(index > 0),
700
+ downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
701
+ # If attention on this block, attach it to the end
702
+ if self.arch['attention'][self.arch['resolution'][index]]:
703
+ print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
704
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
705
+ self.which_conv)]
706
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
707
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
708
+ # Linear output layer. The output dimension is typically 1, but may be
709
+ # larger if we're e.g. turning this into a VAE with an inference output
710
+ self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
711
+ # Embedding for projection discrimination
712
+ self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
713
+ self.cross_entropy = nn.CrossEntropyLoss()
714
+ # Initialize weights
715
+ if not skip_init:
716
+ self = init_weights(self, D_init)
717
+
718
+ def forward(self, x, y=None, **kwargs):
719
+ # Stick x into h for cleaner for loops without flow control
720
+ h = x
721
+ # Loop over blocks
722
+ for index, blocklist in enumerate(self.blocks):
723
+ for block in blocklist:
724
+ h = block(h)
725
+ # Apply global sum pooling as in SN-GAN
726
+ h = torch.sum(self.activation(h), [2, 3])
727
+ # Get initial class-unconditional output
728
+ out = self.linear(h)
729
+ # Get projection of final featureset onto class vectors and add to evidence
730
+ #if y is not None:
731
+ #out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
732
+
733
+ loss = self.cross_entropy(out, y.long())
734
+
735
+ return loss
736
+
737
+ def return_features(self, x, y=None):
738
+ # Stick x into h for cleaner for loops without flow control
739
+ h = x
740
+ block_output = []
741
+ # Loop over blocks
742
+ for index, blocklist in enumerate(self.blocks):
743
+ for block in blocklist:
744
+ h = block(h)
745
+ block_output.append(h)
746
+ # Apply global sum pooling as in SN-GAN
747
+ # h = torch.sum(self.activation(h), [2, 3])
748
+ return block_output
749
+
750
+ class Encoder(Discriminator):
751
+ def __init__(self, opt, output_dim, **kwargs):
752
+ super(Encoder, self).__init__(**vars(opt))
753
+ self.output_layer = nn.Sequential(self.activation,
754
+ nn.Conv2d(self.arch['out_channels'][-1], output_dim, kernel_size=(4,2), padding=0, stride=2))
755
+
756
+ def forward(self, x):
757
+ # Stick x into h for cleaner for loops without flow control
758
+ h = x
759
+ # Loop over blocks
760
+ for index, blocklist in enumerate(self.blocks):
761
+ for block in blocklist:
762
+ h = block(h)
763
+ out = self.output_layer(h)
764
+ return out
765
+
766
+ class BiDiscriminator(nn.Module):
767
+ def __init__(self, opt):
768
+ super(BiDiscriminator, self).__init__()
769
+ self.infer_img = Encoder(opt, output_dim=opt.nimg_features)
770
+ # self.infer_z = nn.Sequential(
771
+ # nn.Conv2d(opt.dim_z, 512, 1, stride=1, bias=False),
772
+ # nn.LeakyReLU(inplace=True),
773
+ # nn.Dropout2d(p=self.dropout),
774
+ # nn.Conv2d(512, opt.nz_features, 1, stride=1, bias=False),
775
+ # nn.LeakyReLU(inplace=True),
776
+ # nn.Dropout2d(p=self.dropout)
777
+ # )
778
+ self.infer_joint = nn.Sequential(
779
+ nn.Conv2d(opt.dim_z+opt.nimg_features, 1024, 1, stride=1, bias=True),
780
+ nn.ReLU(inplace=True),
781
+
782
+ nn.Conv2d(1024, 1024, 1, stride=1, bias=True),
783
+ nn.ReLU(inplace=True)
784
+ )
785
+ self.final = nn.Conv2d(1024, 1, 1, stride=1, bias=True)
786
+
787
+ def forward(self, x, z, **kwargs):
788
+ output_x = self.infer_img(x)
789
+ # output_z = self.infer_z(z)
790
+ if len(z.shape)==2:
791
+ z = z.unsqueeze(2).unsqueeze(2).repeat((1,1,1,output_x.shape[3]))
792
+ output_features = self.infer_joint(torch.cat([output_x, z], dim=1))
793
+ output = self.final(output_features)
794
+ return output
795
+
796
+ # Parallelized G_D to minimize cross-gpu communication
797
+ # Without this, Generator outputs would get all-gathered and then rebroadcast.
798
+ class G_D(nn.Module):
799
+ def __init__(self, G, D):
800
+ super(G_D, self).__init__()
801
+ self.G = G
802
+ self.D = D
803
+
804
+ def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
805
+ split_D=False):
806
+ # If training G, enable grad tape
807
+ with torch.set_grad_enabled(train_G):
808
+ # Get Generator output given noise
809
+ G_z = self.G(z, self.G.shared(gy))
810
+ # Cast as necessary
811
+ if self.G.fp16 and not self.D.fp16:
812
+ G_z = G_z.float()
813
+ if self.D.fp16 and not self.G.fp16:
814
+ G_z = G_z.half()
815
+ # Split_D means to run D once with real data and once with fake,
816
+ # rather than concatenating along the batch dimension.
817
+ if split_D:
818
+ D_fake = self.D(G_z, gy)
819
+ if x is not None:
820
+ D_real = self.D(x, dy)
821
+ return D_fake, D_real
822
+ else:
823
+ if return_G_z:
824
+ return D_fake, G_z
825
+ else:
826
+ return D_fake
827
+ # If real data is provided, concatenate it with the Generator's output
828
+ # along the batch dimension for improved efficiency.
829
+ else:
830
+ D_input = torch.cat([G_z, x], 0) if x is not None else G_z
831
+ D_class = torch.cat([gy, dy], 0) if dy is not None else gy
832
+ # Get Discriminator output
833
+ D_out = self.D(D_input, D_class)
834
+ if x is not None:
835
+ return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
836
+ else:
837
+ if return_G_z:
838
+ return D_out, G_z
839
+ else:
840
+ return D_out
841
+
util/models/OCR_network.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from util.util import to_device
3
+ from torch.nn import init
4
+ import os
5
+ import torch
6
+ from .networks import *
7
+ from params import *
8
+
9
+ class BidirectionalLSTM(nn.Module):
10
+
11
+ def __init__(self, nIn, nHidden, nOut):
12
+ super(BidirectionalLSTM, self).__init__()
13
+
14
+ self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
15
+ self.embedding = nn.Linear(nHidden * 2, nOut)
16
+
17
+
18
+ def forward(self, input):
19
+ recurrent, _ = self.rnn(input)
20
+ T, b, h = recurrent.size()
21
+ t_rec = recurrent.view(T * b, h)
22
+
23
+ output = self.embedding(t_rec) # [T * b, nOut]
24
+ output = output.view(T, b, -1)
25
+
26
+ return output
27
+
28
+
29
+ class CRNN(nn.Module):
30
+
31
+ def __init__(self, leakyRelu=False):
32
+ super(CRNN, self).__init__()
33
+ self.name = 'OCR'
34
+ #assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16'
35
+
36
+ ks = [3, 3, 3, 3, 3, 3, 2]
37
+ ps = [1, 1, 1, 1, 1, 1, 0]
38
+ ss = [1, 1, 1, 1, 1, 1, 1]
39
+ nm = [64, 128, 256, 256, 512, 512, 512]
40
+
41
+ cnn = nn.Sequential()
42
+ nh = 256
43
+ dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero
44
+
45
+ def convRelu(i, batchNormalization=False):
46
+ nIn = 1 if i == 0 else nm[i - 1]
47
+ nOut = nm[i]
48
+ cnn.add_module('conv{0}'.format(i),
49
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
50
+ if batchNormalization:
51
+ cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
52
+ if leakyRelu:
53
+ cnn.add_module('relu{0}'.format(i),
54
+ nn.LeakyReLU(0.2, inplace=True))
55
+ else:
56
+ cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
57
+
58
+ convRelu(0)
59
+ cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
60
+ convRelu(1)
61
+ cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
62
+ convRelu(2, True)
63
+ convRelu(3)
64
+ cnn.add_module('pooling{0}'.format(2),
65
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
66
+ convRelu(4, True)
67
+ if resolution==63:
68
+ cnn.add_module('pooling{0}'.format(3),
69
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
70
+ convRelu(5)
71
+ cnn.add_module('pooling{0}'.format(4),
72
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
73
+ convRelu(6, True) # 512x1x16
74
+
75
+ self.cnn = cnn
76
+ self.use_rnn = False
77
+ if self.use_rnn:
78
+ self.rnn = nn.Sequential(
79
+ BidirectionalLSTM(512, nh, nh),
80
+ BidirectionalLSTM(nh, nh, ))
81
+ else:
82
+ self.linear = nn.Linear(512, VOCAB_SIZE)
83
+
84
+ # replace all nan/inf in gradients to zero
85
+ if dealwith_lossnone:
86
+ self.register_backward_hook(self.backward_hook)
87
+
88
+ self.device = torch.device('cuda:{}'.format(0))
89
+ self.init = 'N02'
90
+ # Initialize weights
91
+
92
+ self = init_weights(self, self.init)
93
+
94
+ def forward(self, input):
95
+ # conv features
96
+ conv = self.cnn(input)
97
+ b, c, h, w = conv.size()
98
+ if h!=1:
99
+ print('a')
100
+ assert h == 1, "the height of conv must be 1"
101
+ conv = conv.squeeze(2)
102
+ conv = conv.permute(2, 0, 1) # [w, b, c]
103
+
104
+ if self.use_rnn:
105
+ # rnn features
106
+ output = self.rnn(conv)
107
+ else:
108
+ output = self.linear(conv)
109
+ return output
110
+
111
+ def backward_hook(self, module, grad_input, grad_output):
112
+ for g in grad_input:
113
+ g[g != g] = 0 # replace all nan/inf in gradients to zero
114
+
115
+
116
+ class OCRLabelConverter(object):
117
+ """Convert between str and label.
118
+
119
+ NOTE:
120
+ Insert `blank` to the alphabet for CTC.
121
+
122
+ Args:
123
+ alphabet (str): set of the possible characters.
124
+ ignore_case (bool, default=True): whether or not to ignore all of the case.
125
+ """
126
+
127
+ def __init__(self, alphabet, ignore_case=False):
128
+ self._ignore_case = ignore_case
129
+ if self._ignore_case:
130
+ alphabet = alphabet.lower()
131
+ self.alphabet = alphabet + '-' # for `-1` index
132
+
133
+ self.dict = {}
134
+ for i, char in enumerate(alphabet):
135
+ # NOTE: 0 is reserved for 'blank' required by wrap_ctc
136
+ self.dict[char] = i + 1
137
+
138
+ def encode(self, text):
139
+ """Support batch or single str.
140
+
141
+ Args:
142
+ text (str or list of str): texts to convert.
143
+
144
+ Returns:
145
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
146
+ torch.IntTensor [n]: length of each text.
147
+ """
148
+ '''
149
+ if isinstance(text, str):
150
+ text = [
151
+ self.dict[char.lower() if self._ignore_case else char]
152
+ for char in text
153
+ ]
154
+ length = [len(text)]
155
+ elif isinstance(text, collections.Iterable):
156
+ length = [len(s) for s in text]
157
+ text = ''.join(text)
158
+ text, _ = self.encode(text)
159
+ return (torch.IntTensor(text), torch.IntTensor(length))
160
+ '''
161
+ length = []
162
+ result = []
163
+ for item in text:
164
+ item = item.decode('utf-8', 'strict')
165
+ length.append(len(item))
166
+ for char in item:
167
+ index = self.dict[char]
168
+ result.append(index)
169
+
170
+ text = result
171
+ return (torch.IntTensor(text), torch.IntTensor(length))
172
+
173
+ def decode(self, t, length, raw=False):
174
+ """Decode encoded texts back into strs.
175
+
176
+ Args:
177
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
178
+ torch.IntTensor [n]: length of each text.
179
+
180
+ Raises:
181
+ AssertionError: when the texts and its length does not match.
182
+
183
+ Returns:
184
+ text (str or list of str): texts to convert.
185
+ """
186
+ if length.numel() == 1:
187
+ length = length[0]
188
+ assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
189
+ length)
190
+ if raw:
191
+ return ''.join([self.alphabet[i - 1] for i in t])
192
+ else:
193
+ char_list = []
194
+ for i in range(length):
195
+ if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
196
+ char_list.append(self.alphabet[t[i] - 1])
197
+ return ''.join(char_list)
198
+ else:
199
+ # batch mode
200
+ assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
201
+ t.numel(), length.sum())
202
+ texts = []
203
+ index = 0
204
+ for i in range(length.numel()):
205
+ l = length[i]
206
+ texts.append(
207
+ self.decode(
208
+ t[index:index + l], torch.IntTensor([l]), raw=raw))
209
+ index += l
210
+ return texts
211
+
212
+
213
+ class strLabelConverter(object):
214
+ """Convert between str and label.
215
+ NOTE:
216
+ Insert `blank` to the alphabet for CTC.
217
+ Args:
218
+ alphabet (str): set of the possible characters.
219
+ ignore_case (bool, default=True): whether or not to ignore all of the case.
220
+ """
221
+
222
+ def __init__(self, alphabet, ignore_case=False):
223
+ self._ignore_case = ignore_case
224
+ if self._ignore_case:
225
+ alphabet = alphabet.lower()
226
+ self.alphabet = alphabet + '-' # for `-1` index
227
+
228
+ self.dict = {}
229
+ for i, char in enumerate(alphabet):
230
+ # NOTE: 0 is reserved for 'blank' required by wrap_ctc
231
+ self.dict[char] = i + 1
232
+
233
+ def encode(self, text):
234
+ """Support batch or single str.
235
+ Args:
236
+ text (str or list of str): texts to convert.
237
+ Returns:
238
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
239
+ torch.IntTensor [n]: length of each text.
240
+ """
241
+ '''
242
+ if isinstance(text, str):
243
+ text = [
244
+ self.dict[char.lower() if self._ignore_case else char]
245
+ for char in text
246
+ ]
247
+ length = [len(text)]
248
+ elif isinstance(text, collections.Iterable):
249
+ length = [len(s) for s in text]
250
+ text = ''.join(text)
251
+ text, _ = self.encode(text)
252
+ return (torch.IntTensor(text), torch.IntTensor(length))
253
+ '''
254
+ length = []
255
+ result = []
256
+ results = []
257
+ for item in text:
258
+ item = item.decode('utf-8', 'strict')
259
+ length.append(len(item))
260
+ for char in item:
261
+ index = self.dict[char]
262
+ result.append(index)
263
+ results.append(result)
264
+ result = []
265
+
266
+ return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length))
267
+
268
+ def decode(self, t, length, raw=False):
269
+ """Decode encoded texts back into strs.
270
+ Args:
271
+ torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
272
+ torch.IntTensor [n]: length of each text.
273
+ Raises:
274
+ AssertionError: when the texts and its length does not match.
275
+ Returns:
276
+ text (str or list of str): texts to convert.
277
+ """
278
+ if length.numel() == 1:
279
+ length = length[0]
280
+ assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
281
+ length)
282
+ if raw:
283
+ return ''.join([self.alphabet[i - 1] for i in t])
284
+ else:
285
+ char_list = []
286
+ for i in range(length):
287
+ if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
288
+ char_list.append(self.alphabet[t[i] - 1])
289
+ return ''.join(char_list)
290
+ else:
291
+ # batch mode
292
+ assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
293
+ t.numel(), length.sum())
294
+ texts = []
295
+ index = 0
296
+ for i in range(length.numel()):
297
+ l = length[i]
298
+ texts.append(
299
+ self.decode(
300
+ t[index:index + l], torch.IntTensor([l]), raw=raw))
301
+ index += l
302
+ return texts
303
+
304
+
util/models/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ """
19
+
20
+ import importlib
21
+
22
+
23
+ def find_model_using_name(model_name):
24
+ """Import the module "models/[model_name]_model.py".
25
+
26
+ In the file, the class called DatasetNameModel() will
27
+ be instantiated. It has to be a subclass of BaseModel,
28
+ and it is case-insensitive.
29
+ """
30
+ model_filename = "models." + model_name + "_model"
31
+ modellib = importlib.import_module(model_filename)
32
+ model = None
33
+ target_model_name = model_name.replace('_', '') + 'model'
34
+ for name, cls in modellib.__dict__.items():
35
+ if name.lower() == target_model_name.lower() \
36
+ and issubclass(cls, BaseModel):
37
+ model = cls
38
+
39
+ if model is None:
40
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
41
+ exit(0)
42
+
43
+ return model
44
+
45
+
46
+ def get_option_setter(model_name):
47
+ """Return the static method <modify_commandline_options> of the model class."""
48
+ model_class = find_model_using_name(model_name)
49
+ return model_class.modify_commandline_options
50
+
51
+
52
+ def create_model(opt):
53
+ """Create a model given the option.
54
+
55
+ This function warps the class CustomDatasetDataLoader.
56
+ This is the main interface between this package and 'train.py'/'test.py'
57
+
58
+ Example:
59
+ >>> from models import create_model
60
+ >>> model = create_model(opt)
61
+ """
62
+ model = find_model_using_name(opt.model)
63
+ instance = model(opt)
64
+ print("model [%s] was created" % type(instance).__name__)
65
+ return instance
util/models/__pycache__/BigGAN_layers.cpython-36.pyc ADDED
Binary file (12.8 kB). View file
 
util/models/__pycache__/BigGAN_networks.cpython-36.pyc ADDED
Binary file (30 kB). View file
 
util/models/__pycache__/OCR_network.cpython-36.pyc ADDED
Binary file (8.62 kB). View file
 
util/models/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (3.1 kB). View file
 
util/models/__pycache__/blocks.cpython-36.pyc ADDED
Binary file (6 kB). View file
 
util/models/__pycache__/inception.cpython-36.pyc ADDED
Binary file (10.7 kB). View file
 
util/models/__pycache__/model.cpython-36.pyc ADDED
Binary file (31.1 kB). View file
 
util/models/__pycache__/model_.cpython-36.pyc ADDED
Binary file (28.8 kB). View file
 
util/models/__pycache__/networks.cpython-36.pyc ADDED
Binary file (4.08 kB). View file
 
util/models/__pycache__/transformer.cpython-36.pyc ADDED
Binary file (8.92 kB). View file
 
util/models/blocks.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class ResBlocks(nn.Module):
7
+ def __init__(self, num_blocks, dim, norm, activation, pad_type):
8
+ super(ResBlocks, self).__init__()
9
+ self.model = []
10
+ for i in range(num_blocks):
11
+ self.model += [ResBlock(dim,
12
+ norm=norm,
13
+ activation=activation,
14
+ pad_type=pad_type)]
15
+ self.model = nn.Sequential(*self.model)
16
+
17
+ def forward(self, x):
18
+ return self.model(x)
19
+
20
+
21
+ class ResBlock(nn.Module):
22
+ def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
23
+ super(ResBlock, self).__init__()
24
+ model = []
25
+ model += [Conv2dBlock(dim, dim, 3, 1, 1,
26
+ norm=norm,
27
+ activation=activation,
28
+ pad_type=pad_type)]
29
+ model += [Conv2dBlock(dim, dim, 3, 1, 1,
30
+ norm=norm,
31
+ activation='none',
32
+ pad_type=pad_type)]
33
+ self.model = nn.Sequential(*model)
34
+
35
+ def forward(self, x):
36
+ residual = x
37
+ out = self.model(x)
38
+ out += residual
39
+ return out
40
+
41
+
42
+ class ActFirstResBlock(nn.Module):
43
+ def __init__(self, fin, fout, fhid=None,
44
+ activation='lrelu', norm='none'):
45
+ super().__init__()
46
+ self.learned_shortcut = (fin != fout)
47
+ self.fin = fin
48
+ self.fout = fout
49
+ self.fhid = min(fin, fout) if fhid is None else fhid
50
+ self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1,
51
+ padding=1, pad_type='reflect', norm=norm,
52
+ activation=activation, activation_first=True)
53
+ self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1,
54
+ padding=1, pad_type='reflect', norm=norm,
55
+ activation=activation, activation_first=True)
56
+ if self.learned_shortcut:
57
+ self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1,
58
+ activation='none', use_bias=False)
59
+
60
+ def forward(self, x):
61
+ x_s = self.conv_s(x) if self.learned_shortcut else x
62
+ dx = self.conv_0(x)
63
+ dx = self.conv_1(dx)
64
+ out = x_s + dx
65
+ return out
66
+
67
+
68
+ class LinearBlock(nn.Module):
69
+ def __init__(self, in_dim, out_dim, norm='none', activation='relu'):
70
+ super(LinearBlock, self).__init__()
71
+ use_bias = True
72
+ self.fc = nn.Linear(in_dim, out_dim, bias=use_bias)
73
+
74
+ # initialize normalization
75
+ norm_dim = out_dim
76
+ if norm == 'bn':
77
+ self.norm = nn.BatchNorm1d(norm_dim)
78
+ elif norm == 'in':
79
+ self.norm = nn.InstanceNorm1d(norm_dim)
80
+ elif norm == 'none':
81
+ self.norm = None
82
+ else:
83
+ assert 0, "Unsupported normalization: {}".format(norm)
84
+
85
+ # initialize activation
86
+ if activation == 'relu':
87
+ self.activation = nn.ReLU(inplace=False)
88
+ elif activation == 'lrelu':
89
+ self.activation = nn.LeakyReLU(0.2, inplace=False)
90
+ elif activation == 'tanh':
91
+ self.activation = nn.Tanh()
92
+ elif activation == 'none':
93
+ self.activation = None
94
+ else:
95
+ assert 0, "Unsupported activation: {}".format(activation)
96
+
97
+ def forward(self, x):
98
+ out = self.fc(x)
99
+ if self.norm:
100
+ out = self.norm(out)
101
+ if self.activation:
102
+ out = self.activation(out)
103
+ return out
104
+
105
+
106
+ class Conv2dBlock(nn.Module):
107
+ def __init__(self, in_dim, out_dim, ks, st, padding=0,
108
+ norm='none', activation='relu', pad_type='zero',
109
+ use_bias=True, activation_first=False):
110
+ super(Conv2dBlock, self).__init__()
111
+ self.use_bias = use_bias
112
+ self.activation_first = activation_first
113
+ # initialize padding
114
+ if pad_type == 'reflect':
115
+ self.pad = nn.ReflectionPad2d(padding)
116
+ elif pad_type == 'replicate':
117
+ self.pad = nn.ReplicationPad2d(padding)
118
+ elif pad_type == 'zero':
119
+ self.pad = nn.ZeroPad2d(padding)
120
+ else:
121
+ assert 0, "Unsupported padding type: {}".format(pad_type)
122
+
123
+ # initialize normalization
124
+ norm_dim = out_dim
125
+ if norm == 'bn':
126
+ self.norm = nn.BatchNorm2d(norm_dim)
127
+ elif norm == 'in':
128
+ self.norm = nn.InstanceNorm2d(norm_dim)
129
+ elif norm == 'adain':
130
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
131
+ elif norm == 'none':
132
+ self.norm = None
133
+ else:
134
+ assert 0, "Unsupported normalization: {}".format(norm)
135
+
136
+ # initialize activation
137
+ if activation == 'relu':
138
+ self.activation = nn.ReLU(inplace=False)
139
+ elif activation == 'lrelu':
140
+ self.activation = nn.LeakyReLU(0.2, inplace=False)
141
+ elif activation == 'tanh':
142
+ self.activation = nn.Tanh()
143
+ elif activation == 'none':
144
+ self.activation = None
145
+ else:
146
+ assert 0, "Unsupported activation: {}".format(activation)
147
+
148
+ self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias)
149
+
150
+ def forward(self, x):
151
+ if self.activation_first:
152
+ if self.activation:
153
+ x = self.activation(x)
154
+ x = self.conv(self.pad(x))
155
+ if self.norm:
156
+ x = self.norm(x)
157
+ else:
158
+ x = self.conv(self.pad(x))
159
+ if self.norm:
160
+ x = self.norm(x)
161
+ if self.activation:
162
+ x = self.activation(x)
163
+ return x
164
+
165
+
166
+ class AdaptiveInstanceNorm2d(nn.Module):
167
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
168
+ super(AdaptiveInstanceNorm2d, self).__init__()
169
+ self.num_features = num_features
170
+ self.eps = eps
171
+ self.momentum = momentum
172
+ self.weight = None
173
+ self.bias = None
174
+ self.register_buffer('running_mean', torch.zeros(num_features))
175
+ self.register_buffer('running_var', torch.ones(num_features))
176
+
177
+ def forward(self, x):
178
+ assert self.weight is not None and \
179
+ self.bias is not None, "Please assign AdaIN weight first"
180
+ b, c = x.size(0), x.size(1)
181
+ running_mean = self.running_mean.repeat(b)
182
+ running_var = self.running_var.repeat(b)
183
+ x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
184
+ out = F.batch_norm(
185
+ x_reshaped, running_mean, running_var, self.weight, self.bias,
186
+ True, self.momentum, self.eps)
187
+ return out.view(b, c, *x.size()[2:])
188
+
189
+ def __repr__(self):
190
+ return self.__class__.__name__ + '(' + str(self.num_features) + ')'
util/models/inception.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+ import numpy as np
6
+
7
+ from itertools import cycle
8
+ from scipy import linalg
9
+
10
+
11
+ try:
12
+ from torchvision.models.utils import load_state_dict_from_url
13
+ except ImportError:
14
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
15
+
16
+ # Inception weights ported to Pytorch from
17
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
18
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
19
+
20
+
21
+ class InceptionV3(nn.Module):
22
+ """Pretrained InceptionV3 network returning feature maps"""
23
+
24
+ # Index of default block of inception to return,
25
+ # corresponds to output of final average pooling
26
+ DEFAULT_BLOCK_INDEX = 3
27
+
28
+ # Maps feature dimensionality to their output blocks indices
29
+ BLOCK_INDEX_BY_DIM = {
30
+ 64: 0, # First max pooling features
31
+ 192: 1, # Second max pooling featurs
32
+ 768: 2, # Pre-aux classifier features
33
+ 2048: 3 # Final average pooling features
34
+ }
35
+
36
+ def __init__(self,
37
+ output_blocks=[DEFAULT_BLOCK_INDEX],
38
+ resize_input=True,
39
+ normalize_input=True,
40
+ requires_grad=False,
41
+ use_fid_inception=True):
42
+ """Build pretrained InceptionV3
43
+ Parameters
44
+ ----------
45
+ output_blocks : list of int
46
+ Indices of blocks to return features of. Possible values are:
47
+ - 0: corresponds to output of first max pooling
48
+ - 1: corresponds to output of second max pooling
49
+ - 2: corresponds to output which is fed to aux classifier
50
+ - 3: corresponds to output of final average pooling
51
+ resize_input : bool
52
+ If true, bilinearly resizes input to width and height 299 before
53
+ feeding input to model. As the network without fully connected
54
+ layers is fully convolutional, it should be able to handle inputs
55
+ of arbitrary size, so resizing might not be strictly needed
56
+ normalize_input : bool
57
+ If true, scales the input from range (0, 1) to the range the
58
+ pretrained Inception network expects, namely (-1, 1)
59
+ requires_grad : bool
60
+ If true, parameters of the model require gradients. Possibly useful
61
+ for finetuning the network
62
+ use_fid_inception : bool
63
+ If true, uses the pretrained Inception model used in Tensorflow's
64
+ FID implementation. If false, uses the pretrained Inception model
65
+ available in torchvision. The FID Inception model has different
66
+ weights and a slightly different structure from torchvision's
67
+ Inception model. If you want to compute FID scores, you are
68
+ strongly advised to set this parameter to true to get comparable
69
+ results.
70
+ """
71
+ super(InceptionV3, self).__init__()
72
+
73
+ self.resize_input = resize_input
74
+ self.normalize_input = normalize_input
75
+ self.output_blocks = sorted(output_blocks)
76
+ self.last_needed_block = max(output_blocks)
77
+
78
+ assert self.last_needed_block <= 3, \
79
+ 'Last possible output block index is 3'
80
+
81
+ self.blocks = nn.ModuleList()
82
+
83
+ if use_fid_inception:
84
+ inception = fid_inception_v3()
85
+ else:
86
+ inception = models.inception_v3(pretrained=True)
87
+
88
+ # Block 0: input to maxpool1
89
+ block0 = [
90
+ inception.Conv2d_1a_3x3,
91
+ inception.Conv2d_2a_3x3,
92
+ inception.Conv2d_2b_3x3,
93
+ nn.MaxPool2d(kernel_size=3, stride=2)
94
+ ]
95
+ self.blocks.append(nn.Sequential(*block0))
96
+
97
+ # Block 1: maxpool1 to maxpool2
98
+ if self.last_needed_block >= 1:
99
+ block1 = [
100
+ inception.Conv2d_3b_1x1,
101
+ inception.Conv2d_4a_3x3,
102
+ nn.MaxPool2d(kernel_size=3, stride=2)
103
+ ]
104
+ self.blocks.append(nn.Sequential(*block1))
105
+
106
+ # Block 2: maxpool2 to aux classifier
107
+ if self.last_needed_block >= 2:
108
+ block2 = [
109
+ inception.Mixed_5b,
110
+ inception.Mixed_5c,
111
+ inception.Mixed_5d,
112
+ inception.Mixed_6a,
113
+ inception.Mixed_6b,
114
+ inception.Mixed_6c,
115
+ inception.Mixed_6d,
116
+ inception.Mixed_6e,
117
+ ]
118
+ self.blocks.append(nn.Sequential(*block2))
119
+
120
+ # Block 3: aux classifier to final avgpool
121
+ if self.last_needed_block >= 3:
122
+ block3 = [
123
+ inception.Mixed_7a,
124
+ inception.Mixed_7b,
125
+ inception.Mixed_7c,
126
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
127
+ ]
128
+ self.blocks.append(nn.Sequential(*block3))
129
+
130
+ for param in self.parameters():
131
+ param.requires_grad = requires_grad
132
+
133
+ def forward(self, inp):
134
+ """Get Inception feature maps
135
+ Parameters
136
+ ----------
137
+ inp : torch.autograd.Variable
138
+ Input tensor of shape Bx3xHxW. Values are expected to be in
139
+ range (0, 1)
140
+ Returns
141
+ -------
142
+ List of torch.autograd.Variable, corresponding to the selected output
143
+ block, sorted ascending by index
144
+ """
145
+ outp = []
146
+ x = inp
147
+
148
+ if self.resize_input:
149
+ x = F.interpolate(x,
150
+ size=(299, 299),
151
+ mode='bilinear',
152
+ align_corners=False)
153
+
154
+ if self.normalize_input:
155
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
156
+
157
+ for idx, block in enumerate(self.blocks):
158
+ x = block(x)
159
+ if idx in self.output_blocks:
160
+ outp.append(x)
161
+
162
+ if idx == self.last_needed_block:
163
+ break
164
+
165
+ return outp
166
+
167
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
168
+ """Numpy implementation of the Frechet Distance.
169
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
170
+ and X_2 ~ N(mu_2, C_2) is
171
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
172
+ Stable version by Dougal J. Sutherland.
173
+ Params:
174
+ -- mu1 : Numpy array containing the activations of a layer of the
175
+ inception net (like returned by the function 'get_predictions')
176
+ for generated samples.
177
+ -- mu2 : The sample mean over activations, precalculated on an
178
+ representative data set.
179
+ -- sigma1: The covariance matrix over activations for generated samples.
180
+ -- sigma2: The covariance matrix over activations, precalculated on an
181
+ representative data set.
182
+ Returns:
183
+ -- : The Frechet Distance.
184
+ """
185
+
186
+ mu1 = np.atleast_1d(mu1)
187
+ mu2 = np.atleast_1d(mu2)
188
+
189
+ sigma1 = np.atleast_2d(sigma1)
190
+ sigma2 = np.atleast_2d(sigma2)
191
+
192
+ assert mu1.shape == mu2.shape, \
193
+ 'Training and test mean vectors have different lengths'
194
+ assert sigma1.shape == sigma2.shape, \
195
+ 'Training and test covariances have different dimensions'
196
+
197
+ diff = mu1 - mu2
198
+
199
+ # Product might be almost singular
200
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
201
+ if not np.isfinite(covmean).all():
202
+ msg = ('fid calculation produces singular product; '
203
+ 'adding %s to diagonal of cov estimates') % eps
204
+ print(msg)
205
+ offset = np.eye(sigma1.shape[0]) * eps
206
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
207
+
208
+ # Numerical error might give slight imaginary component
209
+ if np.iscomplexobj(covmean):
210
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
211
+ m = np.max(np.abs(covmean.imag))
212
+ raise ValueError('Imaginary component {}'.format(m))
213
+ covmean = covmean.real
214
+
215
+ tr_covmean = np.trace(covmean)
216
+
217
+ return (diff.dot(diff) + np.trace(sigma1) +
218
+ np.trace(sigma2) - 2 * tr_covmean)
219
+
220
+
221
+ def fid_inception_v3():
222
+ """Build pretrained Inception model for FID computation
223
+ The Inception model for FID computation uses a different set of weights
224
+ and has a slightly different structure than torchvision's Inception.
225
+ This method first constructs torchvision's Inception and then patches the
226
+ necessary parts that are different in the FID Inception model.
227
+ """
228
+ inception = models.inception_v3(num_classes=1008,
229
+ aux_logits=False,
230
+ pretrained=False)
231
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
232
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
233
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
234
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
235
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
236
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
237
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
238
+ inception.Mixed_7b = FIDInceptionE_1(1280)
239
+ inception.Mixed_7c = FIDInceptionE_2(2048)
240
+
241
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
242
+ inception.load_state_dict(state_dict)
243
+ return inception
244
+
245
+
246
+ class FIDInceptionA(models.inception.InceptionA):
247
+ """InceptionA block patched for FID computation"""
248
+ def __init__(self, in_channels, pool_features):
249
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
250
+
251
+ def forward(self, x):
252
+ branch1x1 = self.branch1x1(x)
253
+
254
+ branch5x5 = self.branch5x5_1(x)
255
+ branch5x5 = self.branch5x5_2(branch5x5)
256
+
257
+ branch3x3dbl = self.branch3x3dbl_1(x)
258
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
259
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
260
+
261
+ # Patch: Tensorflow's average pool does not use the padded zero's in
262
+ # its average calculation
263
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
264
+ count_include_pad=False)
265
+ branch_pool = self.branch_pool(branch_pool)
266
+
267
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
268
+ return torch.cat(outputs, 1)
269
+
270
+
271
+ class FIDInceptionC(models.inception.InceptionC):
272
+ """InceptionC block patched for FID computation"""
273
+ def __init__(self, in_channels, channels_7x7):
274
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
275
+
276
+ def forward(self, x):
277
+ branch1x1 = self.branch1x1(x)
278
+
279
+ branch7x7 = self.branch7x7_1(x)
280
+ branch7x7 = self.branch7x7_2(branch7x7)
281
+ branch7x7 = self.branch7x7_3(branch7x7)
282
+
283
+ branch7x7dbl = self.branch7x7dbl_1(x)
284
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
285
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
286
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
287
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
288
+
289
+ # Patch: Tensorflow's average pool does not use the padded zero's in
290
+ # its average calculation
291
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
292
+ count_include_pad=False)
293
+ branch_pool = self.branch_pool(branch_pool)
294
+
295
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
296
+ return torch.cat(outputs, 1)
297
+
298
+
299
+ class FIDInceptionE_1(models.inception.InceptionE):
300
+ """First InceptionE block patched for FID computation"""
301
+ def __init__(self, in_channels):
302
+ super(FIDInceptionE_1, self).__init__(in_channels)
303
+
304
+ def forward(self, x):
305
+ branch1x1 = self.branch1x1(x)
306
+
307
+ branch3x3 = self.branch3x3_1(x)
308
+ branch3x3 = [
309
+ self.branch3x3_2a(branch3x3),
310
+ self.branch3x3_2b(branch3x3),
311
+ ]
312
+ branch3x3 = torch.cat(branch3x3, 1)
313
+
314
+ branch3x3dbl = self.branch3x3dbl_1(x)
315
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
316
+ branch3x3dbl = [
317
+ self.branch3x3dbl_3a(branch3x3dbl),
318
+ self.branch3x3dbl_3b(branch3x3dbl),
319
+ ]
320
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
321
+
322
+ # Patch: Tensorflow's average pool does not use the padded zero's in
323
+ # its average calculation
324
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
325
+ count_include_pad=False)
326
+ branch_pool = self.branch_pool(branch_pool)
327
+
328
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
329
+ return torch.cat(outputs, 1)
330
+
331
+
332
+ class FIDInceptionE_2(models.inception.InceptionE):
333
+ """Second InceptionE block patched for FID computation"""
334
+ def __init__(self, in_channels):
335
+ super(FIDInceptionE_2, self).__init__(in_channels)
336
+
337
+ def forward(self, x):
338
+ branch1x1 = self.branch1x1(x)
339
+
340
+ branch3x3 = self.branch3x3_1(x)
341
+ branch3x3 = [
342
+ self.branch3x3_2a(branch3x3),
343
+ self.branch3x3_2b(branch3x3),
344
+ ]
345
+ branch3x3 = torch.cat(branch3x3, 1)
346
+
347
+ branch3x3dbl = self.branch3x3dbl_1(x)
348
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
349
+ branch3x3dbl = [
350
+ self.branch3x3dbl_3a(branch3x3dbl),
351
+ self.branch3x3dbl_3b(branch3x3dbl),
352
+ ]
353
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
354
+
355
+ # Patch: The FID Inception model uses max pooling instead of average
356
+ # pooling. This is likely an error in this specific Inception
357
+ # implementation, as other Inception models use average pooling here
358
+ # (which matches the description in the paper).
359
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
360
+ branch_pool = self.branch_pool(branch_pool)
361
+
362
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
363
+ return torch.cat(outputs, 1)
util/models/model.py ADDED
@@ -0,0 +1,1389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ from .OCR_network import *
4
+ from torch.nn import CTCLoss, MSELoss, L1Loss
5
+ from torch.nn.utils import clip_grad_norm_
6
+ import random
7
+ import unicodedata
8
+ import sys
9
+ import torchvision.models as models
10
+ from models.transformer import *
11
+ from .BigGAN_networks import *
12
+ from params import *
13
+ from .OCR_network import *
14
+ from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
15
+ from util.util import toggle_grad, loss_hinge_dis, loss_hinge_gen, ortho, default_ortho, toggle_grad, prepare_z_y, \
16
+ make_one_hot, to_device, multiple_replace, random_word
17
+ from models.inception import InceptionV3, calculate_frechet_distance
18
+ import cv2
19
+
20
+ class FCNDecoder(nn.Module):
21
+ def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
22
+ super(FCNDecoder, self).__init__()
23
+
24
+ self.model = []
25
+ self.model += [ResBlocks(n_res, dim, res_norm,
26
+ activ, pad_type=pad_type)]
27
+ for i in range(ups):
28
+ self.model += [nn.Upsample(scale_factor=2),
29
+ Conv2dBlock(dim, dim // 2, 5, 1, 2,
30
+ norm='in',
31
+ activation=activ,
32
+ pad_type=pad_type)]
33
+ dim //= 2
34
+ self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
35
+ norm='none',
36
+ activation='tanh',
37
+ pad_type=pad_type)]
38
+ self.model = nn.Sequential(*self.model)
39
+
40
+ def forward(self, x):
41
+ y = self.model(x)
42
+
43
+ return y
44
+
45
+
46
+
47
+ class Generator(nn.Module):
48
+
49
+ def __init__(self):
50
+ super(Generator, self).__init__()
51
+
52
+ INP_CHANNEL = NUM_EXAMPLES
53
+ if IS_SEQ: INP_CHANNEL = 1
54
+
55
+
56
+ encoder_layer = TransformerEncoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
57
+ TN_DROPOUT, "relu", True)
58
+ encoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) if True else None
59
+ self.encoder = TransformerEncoder(encoder_layer, TN_ENC_LAYERS, encoder_norm)
60
+
61
+ decoder_layer = TransformerDecoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
62
+ TN_DROPOUT, "relu", True)
63
+ decoder_norm = nn.LayerNorm(TN_HIDDEN_DIM)
64
+ self.decoder = TransformerDecoder(decoder_layer, TN_DEC_LAYERS, decoder_norm,
65
+ return_intermediate=True)
66
+
67
+ self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
68
+
69
+ self.query_embed = nn.Embedding(VOCAB_SIZE, TN_HIDDEN_DIM)
70
+
71
+
72
+ self.linear_q = nn.Linear(TN_DIM_FEEDFORWARD*2, TN_DIM_FEEDFORWARD*8)
73
+
74
+ self.DEC = FCNDecoder(res_norm = 'in')
75
+
76
+
77
+ self._muE = nn.Linear(512,512)
78
+ self._logvarE = nn.Linear(512,512)
79
+
80
+ self._muD = nn.Linear(512,512)
81
+ self._logvarD = nn.Linear(512,512)
82
+
83
+
84
+ self.l1loss = nn.L1Loss()
85
+
86
+ self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
87
+
88
+
89
+
90
+
91
+
92
+
93
+ def reparameterize(self, mu, logvar):
94
+
95
+ mu = torch.unbind(mu , 1)
96
+ logvar = torch.unbind(logvar , 1)
97
+
98
+ outs = []
99
+
100
+ for m,l in zip(mu, logvar):
101
+
102
+ sigma = torch.exp(l)
103
+ eps = torch.cuda.FloatTensor(l.size()[0],1).normal_(0,1)
104
+ eps = eps.expand(sigma.size())
105
+
106
+ out = m + sigma*eps
107
+
108
+ outs.append(out)
109
+
110
+
111
+ return torch.stack(outs, 1)
112
+
113
+
114
+ def Eval(self, ST, QRS):
115
+
116
+ if IS_SEQ:
117
+ B, N, R, C = ST.shape
118
+ FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
119
+ FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
120
+ else:
121
+ FEAT_ST = self.Feat_Encoder(ST)
122
+
123
+
124
+ FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
125
+
126
+ memory = self.encoder(FEAT_ST_ENC)
127
+
128
+ if IS_KLD:
129
+
130
+ Ex = memory.permute(1,0,2)
131
+
132
+ memory_mu = self._muE(Ex)
133
+ memory_logvar = self._logvarE(Ex)
134
+
135
+ memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
136
+
137
+
138
+ OUT_IMGS = []
139
+
140
+ for i in range(QRS.shape[1]):
141
+
142
+ QR = QRS[:, i, :]
143
+
144
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
145
+
146
+ tgt = torch.zeros_like(QR_EMB)
147
+
148
+ hs = self.decoder(tgt, memory, query_pos=QR_EMB)
149
+
150
+ if IS_KLD:
151
+
152
+ Dx = hs[0].permute(1,0,2)
153
+
154
+ hs_mu = self._muD(Dx)
155
+ hs_logvar = self._logvarD(Dx)
156
+
157
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
158
+
159
+
160
+ h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
161
+ if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
162
+
163
+ h = self.linear_q(h)
164
+ h = h.contiguous()
165
+
166
+ h = h.view(h.size(0), h.shape[1]*2, 4, -1)
167
+ h = h.permute(0, 3, 2, 1)
168
+
169
+ h = self.DEC(h)
170
+
171
+
172
+ OUT_IMGS.append(h.detach())
173
+
174
+
175
+
176
+ return OUT_IMGS
177
+
178
+
179
+
180
+
181
+
182
+
183
+ def forward(self, ST, QR, QRs = None, mode = 'train'):
184
+
185
+ #Attention Visualization Init
186
+
187
+
188
+ enc_attn_weights, dec_attn_weights = [], []
189
+
190
+ self.hooks = [
191
+
192
+ self.encoder.layers[-1].self_attn.register_forward_hook(
193
+ lambda self, input, output: enc_attn_weights.append(output[1])
194
+ ),
195
+ self.decoder.layers[-1].multihead_attn.register_forward_hook(
196
+ lambda self, input, output: dec_attn_weights.append(output[1])
197
+ ),
198
+ ]
199
+
200
+
201
+ #Attention Visualization Init
202
+
203
+ if IS_SEQ:
204
+ B, N, R, C = ST.shape
205
+ FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
206
+ FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
207
+ else:
208
+ FEAT_ST = self.Feat_Encoder(ST)
209
+
210
+
211
+ FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
212
+
213
+ memory = self.encoder(FEAT_ST_ENC)
214
+
215
+ if IS_KLD:
216
+
217
+ Ex = memory.permute(1,0,2)
218
+
219
+ memory_mu = self._muE(Ex)
220
+ memory_logvar = self._logvarE(Ex)
221
+
222
+ memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
223
+
224
+
225
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
226
+
227
+ tgt = torch.zeros_like(QR_EMB)
228
+
229
+ hs = self.decoder(tgt, memory, query_pos=QR_EMB)
230
+
231
+ if IS_KLD:
232
+
233
+ Dx = hs[0].permute(1,0,2)
234
+
235
+ hs_mu = self._muD(Dx)
236
+ hs_logvar = self._logvarD(Dx)
237
+
238
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
239
+
240
+ OUT_Feats1_mu = [hs_mu]
241
+ OUT_Feats1_logvar = [hs_logvar]
242
+
243
+
244
+ OUT_Feats1 = [hs]
245
+
246
+
247
+ h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
248
+
249
+ if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
250
+
251
+ h = self.linear_q(h)
252
+ h = h.contiguous()
253
+
254
+ h = h.view(h.size(0), h.shape[1]*2, 4, -1)
255
+ h = h.permute(0, 3, 2, 1)
256
+
257
+ h = self.DEC(h)
258
+
259
+ self.dec_attn_weights = dec_attn_weights[-1].detach()
260
+ self.enc_attn_weights = enc_attn_weights[-1].detach()
261
+
262
+
263
+
264
+ for hook in self.hooks:
265
+ hook.remove()
266
+
267
+ if mode == 'test' or (not IS_CYCLE and not IS_KLD):
268
+
269
+ return h
270
+
271
+
272
+ OUT_IMGS = [h]
273
+
274
+ for QR in QRs:
275
+
276
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
277
+
278
+ tgt = torch.zeros_like(QR_EMB)
279
+
280
+ hs = self.decoder(tgt, memory, query_pos=QR_EMB)
281
+
282
+
283
+ if IS_KLD:
284
+
285
+ Dx = hs[0].permute(1,0,2)
286
+
287
+ hs_mu = self._muD(Dx)
288
+ hs_logvar = self._logvarD(Dx)
289
+
290
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
291
+
292
+ OUT_Feats1_mu.append(hs_mu)
293
+ OUT_Feats1_logvar.append(hs_logvar)
294
+
295
+
296
+ OUT_Feats1.append(hs)
297
+
298
+
299
+ h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
300
+ if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
301
+
302
+ h = self.linear_q(h)
303
+ h = h.contiguous()
304
+
305
+ h = h.view(h.size(0), h.shape[1]*2, 4, -1)
306
+ h = h.permute(0, 3, 2, 1)
307
+
308
+ h = self.DEC(h)
309
+
310
+ OUT_IMGS.append(h)
311
+
312
+
313
+ if (not IS_CYCLE) and IS_KLD:
314
+
315
+ OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]
316
+
317
+ OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
318
+
319
+
320
+ KLD = (0.5 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
321
+ + (0.5 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))
322
+
323
+
324
+
325
+ def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
326
+ return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
327
+ torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
328
+
329
+
330
+ lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
331
+
332
+
333
+ lda1 = torch.stack(lda1).mean()
334
+
335
+
336
+
337
+ return OUT_IMGS[0], lda1, KLD
338
+
339
+
340
+ with torch.no_grad():
341
+
342
+ if IS_SEQ:
343
+
344
+ FEAT_ST_T = torch.cat([self.Feat_Encoder(IM) for IM in OUT_IMGS], -1)
345
+
346
+ else:
347
+
348
+ max_width_ = max([i_.shape[-1] for i_ in OUT_IMGS])
349
+
350
+ FEAT_ST_T = self.Feat_Encoder(torch.cat([torch.cat([i_, torch.ones((i_.shape[0], i_.shape[1],i_.shape[2], max_width_-i_.shape[3])).to(DEVICE)], -1) for i_ in OUT_IMGS], 1))
351
+
352
+ FEAT_ST_ENC_T = FEAT_ST_T.flatten(2).permute(2,0,1)
353
+
354
+ memory_T = self.encoder(FEAT_ST_ENC_T)
355
+
356
+ if IS_KLD:
357
+
358
+ Ex = memory_T.permute(1,0,2)
359
+
360
+ memory_T_mu = self._muE(Ex)
361
+ memory_T_logvar = self._logvarE(Ex)
362
+
363
+ memory_T = self.reparameterize(memory_T_mu, memory_T_logvar).permute(1,0,2)
364
+
365
+
366
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
367
+
368
+ tgt = torch.zeros_like(QR_EMB)
369
+
370
+ hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
371
+
372
+ if IS_KLD:
373
+
374
+ Dx = hs[0].permute(1,0,2)
375
+
376
+ hs_mu = self._muD(Dx)
377
+ hs_logvar = self._logvarD(Dx)
378
+
379
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
380
+
381
+ OUT_Feats2_mu = [hs_mu]
382
+ OUT_Feats2_logvar = [hs_logvar]
383
+
384
+
385
+ OUT_Feats2 = [hs]
386
+
387
+
388
+
389
+ for QR in QRs:
390
+
391
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
392
+
393
+ tgt = torch.zeros_like(QR_EMB)
394
+
395
+ hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
396
+
397
+ if IS_KLD:
398
+
399
+ Dx = hs[0].permute(1,0,2)
400
+
401
+ hs_mu = self._muD(Dx)
402
+ hs_logvar = self._logvarD(Dx)
403
+
404
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
405
+
406
+ OUT_Feats2_mu.append(hs_mu)
407
+ OUT_Feats2_logvar.append(hs_logvar)
408
+
409
+
410
+ OUT_Feats2.append(hs)
411
+
412
+
413
+
414
+
415
+ Lcycle1 = np.sum([self.l1loss(memory[m_i], memory_T[m_j]) for m_i in range(memory.shape[0]) for m_j in range(memory_T.shape[0])])/(memory.shape[0]*memory_T.shape[0])
416
+ OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]; OUT_Feats2 = torch.cat(OUT_Feats2, 1)[0]
417
+
418
+ Lcycle2 = np.sum([self.l1loss(OUT_Feats1[f_i], OUT_Feats2[f_j]) for f_i in range(OUT_Feats1.shape[0]) for f_j in range(OUT_Feats2.shape[0])])/(OUT_Feats1.shape[0]*OUT_Feats2.shape[0])
419
+
420
+ if IS_KLD:
421
+
422
+ OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
423
+ OUT_Feats2_mu = torch.cat(OUT_Feats2_mu, 1); OUT_Feats2_logvar = torch.cat(OUT_Feats2_logvar, 1);
424
+
425
+ KLD = (0.25 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
426
+ + (0.25 * torch.mean(1 + memory_T_logvar - memory_T_mu.pow(2) - memory_T_logvar.exp()))\
427
+ + (0.25 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))\
428
+ + (0.25 * torch.mean(1 + OUT_Feats2_logvar - OUT_Feats2_mu.pow(2) - OUT_Feats2_logvar.exp()))
429
+
430
+
431
+ def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
432
+ return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
433
+ torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
434
+
435
+
436
+ lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
437
+ lda2 = [_get_lda(memory_T_mu[:,idi,:], OUT_Feats2_mu[:,idj,:], memory_T_logvar[:,idi,:], OUT_Feats2_logvar[:,idj,:]) for idi in range(memory_T.shape[0]) for idj in range(OUT_Feats2.shape[0])]
438
+
439
+ lda1 = torch.stack(lda1).mean()
440
+ lda2 = torch.stack(lda2).mean()
441
+
442
+
443
+ return OUT_IMGS[0], Lcycle1, Lcycle2, lda1, lda2, KLD
444
+
445
+
446
+ return OUT_IMGS[0], Lcycle1, Lcycle2
447
+
448
+
449
+
450
+ class TRGAN(nn.Module):
451
+
452
+ def __init__(self):
453
+ super(TRGAN, self).__init__()
454
+
455
+
456
+ self.epsilon = 1e-7
457
+ self.netG = Generator().to(DEVICE)
458
+ self.netD = nn.DataParallel(Discriminator()).to(DEVICE)
459
+ self.netW = nn.DataParallel(WDiscriminator()).to(DEVICE)
460
+ self.netconverter = strLabelConverter(ALPHABET)
461
+ self.netOCR = CRNN().to(DEVICE)
462
+ self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
463
+
464
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
465
+ self.inception = InceptionV3([block_idx]).to(DEVICE)
466
+
467
+
468
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
469
+ lr=G_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
470
+ self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
471
+ lr=OCR_LR, betas=(0.0, 0.999), weight_decay=0,
472
+ eps=1e-8)
473
+
474
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
475
+ lr=D_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
476
+
477
+
478
+ self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
479
+ lr=W_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
480
+
481
+
482
+ self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
483
+
484
+
485
+ self.optimizer_G.zero_grad()
486
+ self.optimizer_OCR.zero_grad()
487
+ self.optimizer_D.zero_grad()
488
+ self.optimizer_wl.zero_grad()
489
+
490
+ self.loss_G = 0
491
+ self.loss_D = 0
492
+ self.loss_Dfake = 0
493
+ self.loss_Dreal = 0
494
+ self.loss_OCR_fake = 0
495
+ self.loss_OCR_real = 0
496
+ self.loss_w_fake = 0
497
+ self.loss_w_real = 0
498
+ self.Lcycle1 = 0
499
+ self.Lcycle2 = 0
500
+ self.lda1 = 0
501
+ self.lda2 = 0
502
+ self.KLD = 0
503
+
504
+
505
+ with open('../Lexicon/english_words.txt', 'rb') as f:
506
+ self.lex = f.read().splitlines()
507
+ lex=[]
508
+ for word in self.lex:
509
+ try:
510
+ word=word.decode("utf-8")
511
+ except:
512
+ continue
513
+ if len(word)<20:
514
+ lex.append(word)
515
+ self.lex = lex
516
+
517
+
518
+ f = open('mytext.txt', 'r')
519
+
520
+ self.text = [j.encode() for j in sum([i.split(' ') for i in f.readlines()], [])][:NUM_EXAMPLES]
521
+ self.eval_text_encode, self.eval_len_text = self.netconverter.encode(self.text)
522
+ self.eval_text_encode = self.eval_text_encode.to(DEVICE).repeat(batch_size, 1, 1)
523
+
524
+
525
+ def _generate_page(self):
526
+
527
+ self.fakes = self.netG.Eval(self.sdata, self.eval_text_encode)
528
+
529
+ word_t = []
530
+ word_l = []
531
+
532
+ gap = np.ones([32,16])
533
+
534
+ line_wids = []
535
+
536
+
537
+ for idx, fake_ in enumerate(self.fakes):
538
+
539
+ word_t.append((fake_[0,0,:,:self.eval_len_text[idx]*resolution].cpu().numpy()+1)/2)
540
+
541
+ word_t.append(gap)
542
+
543
+ if len(word_t) == 16 or idx == len(self.fakes) - 1:
544
+
545
+ line_ = np.concatenate(word_t, -1)
546
+
547
+ word_l.append(line_)
548
+ line_wids.append(line_.shape[1])
549
+
550
+ word_t = []
551
+
552
+
553
+ gap_h = np.ones([16,max(line_wids)])
554
+
555
+ page_= []
556
+
557
+ for l in word_l:
558
+
559
+ pad_ = np.ones([32,max(line_wids) - l.shape[1]])
560
+
561
+ page_.append(np.concatenate([l, pad_], 1))
562
+ page_.append(gap_h)
563
+
564
+
565
+
566
+ page1 = np.concatenate(page_, 0)
567
+
568
+
569
+ word_t = []
570
+ word_l = []
571
+
572
+ gap = np.ones([32,16])
573
+
574
+ line_wids = []
575
+
576
+ sdata_ = [i.unsqueeze(1) for i in torch.unbind(self.sdata, 1)]
577
+
578
+ for idx, st in enumerate((sdata_)):
579
+
580
+ word_t.append((st[0,0,:,:int(self.input['swids'].cpu().numpy()[0][idx])
581
+ ].cpu().numpy()+1)/2)
582
+
583
+ word_t.append(gap)
584
+
585
+ if len(word_t) == 16 or idx == len(self.fakes) - 1:
586
+
587
+ line_ = np.concatenate(word_t, -1)
588
+
589
+ word_l.append(line_)
590
+ line_wids.append(line_.shape[1])
591
+
592
+ word_t = []
593
+
594
+
595
+ gap_h = np.ones([16,max(line_wids)])
596
+
597
+ page_= []
598
+
599
+ for l in word_l:
600
+
601
+ pad_ = np.ones([32,max(line_wids) - l.shape[1]])
602
+
603
+ page_.append(np.concatenate([l, pad_], 1))
604
+ page_.append(gap_h)
605
+
606
+
607
+
608
+ page2 = np.concatenate(page_, 0)
609
+
610
+ merge_w_size = max(page1.shape[0], page2.shape[0])
611
+
612
+ if page1.shape[0] != merge_w_size:
613
+
614
+ page1 = np.concatenate([page1, np.ones([merge_w_size-page1.shape[0], page1.shape[1]])], 0)
615
+
616
+ if page2.shape[0] != merge_w_size:
617
+
618
+ page2 = np.concatenate([page2, np.ones([merge_w_size-page2.shape[0], page2.shape[1]])], 0)
619
+
620
+
621
+ page = np.concatenate([page2, page1], 1)
622
+
623
+
624
+ return page
625
+
626
+
627
+
628
+
629
+
630
+
631
+
632
+
633
+
634
+
635
+
636
+
637
+
638
+
639
+
640
+
641
+
642
+
643
+
644
+ #FEAT1 = self.inception(torch.cat(self.fakes, 0).repeat(1,3,1,1))[0].detach().view(batch_size, len(self.fakes), -1).cpu().numpy()
645
+ #FEAT2 = self.inception(self.sdata.view(batch_size*NUM_EXAMPLES, 1, 32, -1).repeat(1,3,1,1))[0].detach().view(batch_size, NUM_EXAMPLES, -1 ).cpu().numpy()
646
+ #muvars1 = [{'mu':np.mean(FEAT1[i], axis=0), 'sigma' : np.cov(FEAT1[i], rowvar=False)} for i in range(FEAT1.shape[0])]
647
+ #muvars2 = [{'mu':np.mean(FEAT2[i], axis=0), 'sigma' : np.cov(FEAT2[i], rowvar=False)} for i in range(FEAT2.shape[0])]
648
+
649
+
650
+
651
+
652
+
653
+
654
+ def get_current_losses(self):
655
+
656
+ losses = {}
657
+
658
+ losses['G'] = self.loss_G
659
+ losses['D'] = self.loss_D
660
+ losses['Dfake'] = self.loss_Dfake
661
+ losses['Dreal'] = self.loss_Dreal
662
+ losses['OCR_fake'] = self.loss_OCR_fake
663
+ losses['OCR_real'] = self.loss_OCR_real
664
+ losses['w_fake'] = self.loss_w_fake
665
+ losses['w_real'] = self.loss_w_real
666
+ losses['cycle1'] = self.Lcycle1
667
+ losses['cycle2'] = self.Lcycle2
668
+ losses['lda1'] = self.lda1
669
+ losses['lda2'] = self.lda2
670
+ losses['KLD'] = self.KLD
671
+
672
+ return losses
673
+
674
+ def visualize_images(self):
675
+
676
+ imgs = {}
677
+
678
+
679
+ imgs['fake-1']=self.netG(self.sdata[0:1], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
680
+ imgs['fake-2']=self.netG(self.sdata[0:1], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
681
+ imgs['fake-3']=self.netG(self.sdata[0:1], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
682
+
683
+
684
+ imgs['res-1'] = torch.cat([self.sdata[0, 0],self.sdata[0, 1],self.sdata[0, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
685
+
686
+
687
+ imgs['fake-1']=self.netG(self.sdata[1:2], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
688
+ imgs['fake-2']=self.netG(self.sdata[1:2], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
689
+ imgs['fake-3']=self.netG(self.sdata[1:2], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
690
+
691
+
692
+ imgs['res-2'] = torch.cat([self.sdata[1, 0],self.sdata[1, 1],self.sdata[1, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
693
+
694
+
695
+ imgs['fake-1']=self.netG(self.sdata[2:3], self.text_encode_fake[0].unsqueeze(0) , mode = 'test' )[0, 0].detach()
696
+ imgs['fake-2']=self.netG(self.sdata[2:3], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
697
+ imgs['fake-3']=self.netG(self.sdata[2:3], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
698
+
699
+
700
+ imgs['res-3'] = torch.cat([self.sdata[2, 0],self.sdata[2, 1],self.sdata[2, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
701
+
702
+
703
+
704
+
705
+ return imgs
706
+
707
+
708
+ def load_networks(self, epoch):
709
+ BaseModel.load_networks(self, epoch)
710
+ if self.opt.single_writer:
711
+ load_filename = '%s_z.pkl' % (epoch)
712
+ load_path = os.path.join(self.save_dir, load_filename)
713
+ self.z = torch.load(load_path)
714
+
715
+ def _set_input(self, input):
716
+ self.input = input
717
+
718
+ def set_requires_grad(self, nets, requires_grad=False):
719
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
720
+ Parameters:
721
+ nets (network list) -- a list of networks
722
+ requires_grad (bool) -- whether the networks require gradients or not
723
+ """
724
+ if not isinstance(nets, list):
725
+ nets = [nets]
726
+ for net in nets:
727
+ if net is not None:
728
+ for param in net.parameters():
729
+ param.requires_grad = requires_grad
730
+
731
+ def forward(self):
732
+
733
+
734
+ self.real = self.input['img'].to(DEVICE)
735
+ self.label = self.input['label']
736
+ self.sdata = self.input['simg'].to(DEVICE)
737
+ self.ST_LEN = self.input['swids']
738
+ self.text_encode, self.len_text = self.netconverter.encode(self.label)
739
+ self.one_hot_real = make_one_hot(self.text_encode, self.len_text, VOCAB_SIZE).to(DEVICE).detach()
740
+ self.text_encode = self.text_encode.to(DEVICE).detach()
741
+ self.len_text = self.len_text.detach()
742
+
743
+ self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
744
+ self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words)
745
+ self.text_encode_fake = self.text_encode_fake.to(DEVICE)
746
+ self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, VOCAB_SIZE).to(DEVICE)
747
+
748
+
749
+ self.text_encode_fake_js = []
750
+
751
+ for _ in range(NUM_WORDS - 1):
752
+
753
+ self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
754
+ self.text_encode_fake_j, self.len_text_fake_j = self.netconverter.encode(self.words_j)
755
+ self.text_encode_fake_j = self.text_encode_fake_j.to(DEVICE)
756
+ self.text_encode_fake_js.append(self.text_encode_fake_j)
757
+
758
+
759
+ if IS_CYCLE and IS_KLD:
760
+
761
+ self.fake, self.Lcycle1, self.Lcycle2, self.lda1, self.lda2, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
762
+
763
+ elif IS_CYCLE and (not IS_KLD):
764
+
765
+ self.fake, self.Lcycle1, self.Lcycle2 = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
766
+
767
+ elif (not IS_CYCLE) and IS_KLD:
768
+
769
+ self.fake, self.lda1, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
770
+
771
+ else:
772
+
773
+ self.fake = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
774
+
775
+
776
+
777
+ def visualize_attention(self):
778
+
779
+ def _norm_scores(arr):
780
+ return (arr - min(arr))/(max(arr) - min(arr))
781
+
782
+ simgs = self.sdata[0].detach().cpu().numpy()
783
+ fake = self.fake[0,0].detach().cpu().numpy()
784
+ slen = self.ST_LEN[0].detach().cpu().numpy()
785
+ selfatt = self.netG.enc_attn_weights[0].detach().cpu().numpy()
786
+ selfatt = np.stack([_norm_scores(i) for i in selfatt], 1)
787
+ fake_lab = self.words[0].decode()
788
+
789
+ decatt = self.netG.dec_attn_weights[0].detach().cpu().numpy()
790
+ decatt = np.stack([_norm_scores(i) for i in decatt], 0)
791
+
792
+ STdict = {}
793
+ FAKEdict = {}
794
+ count = 0
795
+
796
+ for sim_, sle_ in zip(simgs,slen):
797
+
798
+ for pi in range(sim_.shape[1]//sim_.shape[0]):
799
+
800
+ STdict[count] = {'patch':sim_[:, pi*32:(pi+1)*32], 'ischar': sle_>=pi*32, 'encoder_attention_score': selfatt[count], 'decoder_attention_score': decatt[:,count]}
801
+ count = count + 1
802
+
803
+
804
+ for pi in range(fake.shape[1]//resolution):
805
+
806
+ FAKEdict[pi] = {'patch': fake[:, pi*resolution:(pi+1)*resolution]}
807
+
808
+ show_ims = []
809
+
810
+ for idx in range(len(fake_lab)):
811
+
812
+ viz_pats = []
813
+ viz_lin = []
814
+
815
+ for i in STdict.keys():
816
+
817
+ if STdict[i]['ischar']:
818
+
819
+ viz_pats.append(cv2.addWeighted(STdict[i]['patch'], 0.5, np.ones_like(STdict[i]['patch'])*STdict[i]['decoder_attention_score'][idx], 0.5, 0))
820
+
821
+ if len(viz_pats) >= 20:
822
+
823
+ viz_lin.append(np.concatenate(viz_pats, -1))
824
+
825
+ viz_pats = []
826
+
827
+
828
+
829
+
830
+ src = np.concatenate(viz_lin[:-2], 0)*255
831
+
832
+ viz_gts = []
833
+
834
+ for i in range(len(fake_lab)):
835
+
836
+
837
+
838
+ #if i == idx:
839
+
840
+ #bordersize = 5
841
+
842
+ #FAKEdict[i]['patch'] = cv2.addWeighted(FAKEdict[i]['patch'] , 0.5, np.ones_like(FAKEdict[i]['patch'] ), 0.5, 0)
843
+
844
+
845
+
846
+
847
+
848
+
849
+ img = np.zeros((54,16))
850
+ font = cv2.FONT_HERSHEY_SIMPLEX
851
+ text = fake_lab[i]
852
+
853
+ # get boundary of this text
854
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
855
+
856
+ # get coords based on boundary
857
+ textX = (img.shape[1] - textsize[0]) // 2
858
+ textY = (img.shape[0] + textsize[1]) // 2
859
+
860
+ # add text centered on image
861
+ cv2.putText(img, text, (textX, textY ), font, 1, (255, 255, 255), 2)
862
+
863
+ img = (255 - img)/255
864
+
865
+ if i == idx:
866
+
867
+ img = (1 - img)
868
+
869
+ viz_gts.append(img)
870
+
871
+
872
+
873
+ tgt = np.concatenate([fake[:,:len(fake_lab)*16],np.concatenate(viz_gts, -1)], 0)
874
+ pad_ = np.ones((tgt.shape[0], (src.shape[1]-tgt.shape[1])//2))
875
+ tgt = np.concatenate([pad_, tgt, pad_], -1)*255
876
+ final = np.concatenate([src, tgt], 0)
877
+
878
+
879
+ show_ims.append(final)
880
+
881
+ return show_ims
882
+
883
+
884
+ def backward_D_OCR(self):
885
+
886
+ pred_real = self.netD(self.real.detach())
887
+
888
+ pred_fake = self.netD(**{'x': self.fake.detach()})
889
+
890
+
891
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
892
+
893
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
894
+
895
+ self.pred_real_OCR = self.netOCR(self.real.detach())
896
+ preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * batch_size).detach()
897
+ loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
898
+ self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
899
+
900
+ loss_total = self.loss_D + self.loss_OCR_real
901
+
902
+ # backward
903
+ loss_total.backward()
904
+ for param in self.netOCR.parameters():
905
+ param.grad[param.grad!=param.grad]=0
906
+ param.grad[torch.isnan(param.grad)]=0
907
+ param.grad[torch.isinf(param.grad)]=0
908
+
909
+
910
+
911
+ return loss_total
912
+
913
+ def backward_D_WL(self):
914
+ # Real
915
+ pred_real = self.netD(self.real.detach())
916
+
917
+ pred_fake = self.netD(**{'x': self.fake.detach()})
918
+
919
+
920
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
921
+
922
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
923
+
924
+
925
+ self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(DEVICE)).mean()
926
+ # total loss
927
+ loss_total = self.loss_D + self.loss_w_real
928
+
929
+ # backward
930
+ loss_total.backward()
931
+
932
+
933
+ return loss_total
934
+
935
+ def optimize_D_WL(self):
936
+ self.forward()
937
+ self.set_requires_grad([self.netD], True)
938
+ self.set_requires_grad([self.netOCR], False)
939
+ self.set_requires_grad([self.netW], True)
940
+
941
+ self.optimizer_D.zero_grad()
942
+ self.optimizer_wl.zero_grad()
943
+
944
+ self.backward_D_WL()
945
+
946
+
947
+
948
+
949
+ def backward_D_OCR_WL(self):
950
+ # Real
951
+ if self.real_z_mean is None:
952
+ pred_real = self.netD(self.real.detach())
953
+ else:
954
+ pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
955
+ # Fake
956
+ try:
957
+ pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
958
+ except:
959
+ print('a')
960
+ # Combined loss
961
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
962
+
963
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
964
+ # OCR loss on real data
965
+ self.pred_real_OCR = self.netOCR(self.real.detach())
966
+ preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
967
+ loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
968
+ self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
969
+ # total loss
970
+ self.loss_w_real = self.netW(self.real.detach(), self.wcl)
971
+ loss_total = self.loss_D + self.loss_OCR_real + self.loss_w_real
972
+
973
+ # backward
974
+ loss_total.backward()
975
+ for param in self.netOCR.parameters():
976
+ param.grad[param.grad!=param.grad]=0
977
+ param.grad[torch.isnan(param.grad)]=0
978
+ param.grad[torch.isinf(param.grad)]=0
979
+
980
+
981
+
982
+ return loss_total
983
+
984
+ def optimize_D_WL_step(self):
985
+ self.optimizer_D.step()
986
+ self.optimizer_wl.step()
987
+ self.optimizer_D.zero_grad()
988
+ self.optimizer_wl.zero_grad()
989
+
990
+ def backward_OCR(self):
991
+ # OCR loss on real data
992
+ self.pred_real_OCR = self.netOCR(self.real.detach())
993
+ preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
994
+ loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
995
+ self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
996
+
997
+ # backward
998
+ self.loss_OCR_real.backward()
999
+ for param in self.netOCR.parameters():
1000
+ param.grad[param.grad!=param.grad]=0
1001
+ param.grad[torch.isnan(param.grad)]=0
1002
+ param.grad[torch.isinf(param.grad)]=0
1003
+
1004
+ return self.loss_OCR_real
1005
+
1006
+
1007
+ def backward_D(self):
1008
+ # Real
1009
+ if self.real_z_mean is None:
1010
+ pred_real = self.netD(self.real.detach())
1011
+ else:
1012
+ pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
1013
+ pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
1014
+ # Combined loss
1015
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
1016
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
1017
+ # backward
1018
+ self.loss_D.backward()
1019
+
1020
+
1021
+ return self.loss_D
1022
+
1023
+
1024
+ def backward_G_only(self):
1025
+
1026
+ self.gb_alpha = 0.7
1027
+ #self.Lcycle1 = self.Lcycle1.mean()
1028
+ #self.Lcycle2 = self.Lcycle2.mean()
1029
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
1030
+
1031
+
1032
+ pred_fake_OCR = self.netOCR(self.fake)
1033
+ preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * batch_size).detach()
1034
+ loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
1035
+ self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
1036
+
1037
+ self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
1038
+
1039
+ self.loss_T = self.loss_G + self.loss_OCR_fake
1040
+
1041
+
1042
+
1043
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
1044
+
1045
+
1046
+ self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
1047
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
1048
+ self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
1049
+
1050
+
1051
+ self.loss_T.backward(retain_graph=True)
1052
+
1053
+
1054
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
1055
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
1056
+
1057
+
1058
+ a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
1059
+
1060
+
1061
+ if a is None:
1062
+ print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
1063
+ if a>1000 or a<0.0001:
1064
+ print(a)
1065
+
1066
+
1067
+ self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
1068
+
1069
+ self.loss_T = self.loss_G + self.loss_OCR_fake
1070
+
1071
+
1072
+ self.loss_T.backward(retain_graph=True)
1073
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
1074
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
1075
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
1076
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
1077
+
1078
+ with torch.no_grad():
1079
+ self.loss_T.backward()
1080
+
1081
+ if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G):
1082
+ print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
1083
+ sys.exit()
1084
+
1085
+ def backward_G_WL(self):
1086
+
1087
+ self.gb_alpha = 0.7
1088
+ #self.Lcycle1 = self.Lcycle1.mean()
1089
+ #self.Lcycle2 = self.Lcycle2.mean()
1090
+
1091
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
1092
+
1093
+ self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(DEVICE)).mean()
1094
+
1095
+ self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
1096
+
1097
+ self.loss_T = self.loss_G + self.loss_w_fake
1098
+
1099
+
1100
+
1101
+
1102
+ #grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, retain_graph=True)[0]
1103
+
1104
+
1105
+ #self.loss_grad_fake_WL = 10**6*torch.mean(grad_fake_WL**2)
1106
+ #grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
1107
+ #self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
1108
+
1109
+
1110
+
1111
+ self.loss_T.backward(retain_graph=True)
1112
+
1113
+
1114
+ grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
1115
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
1116
+
1117
+
1118
+ a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_WL))
1119
+
1120
+
1121
+
1122
+ if a is None:
1123
+ print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_WL))
1124
+ if a>1000 or a<0.0001:
1125
+ print(a)
1126
+
1127
+ self.loss_w_fake = a.detach() * self.loss_w_fake
1128
+
1129
+ self.loss_T = self.loss_G + self.loss_w_fake
1130
+
1131
+ self.loss_T.backward(retain_graph=True)
1132
+ grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
1133
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
1134
+ self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
1135
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
1136
+
1137
+ with torch.no_grad():
1138
+ self.loss_T.backward()
1139
+
1140
+ def backward_G(self):
1141
+ self.opt.gb_alpha = 0.7
1142
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss)
1143
+ # OCR loss on real data
1144
+
1145
+ pred_fake_OCR = self.netOCR(self.fake)
1146
+ preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach()
1147
+ loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
1148
+ self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
1149
+
1150
+
1151
+ self.loss_w_fake = self.netW(self.fake, self.wcl)
1152
+ #self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
1153
+ # total loss
1154
+
1155
+ # l1 = self.params[0]*self.loss_G
1156
+ # l2 = self.params[0]*self.loss_OCR_fake
1157
+ #l3 = self.params[0]*self.loss_w_fake
1158
+ self.loss_G_ = 10*self.loss_G + self.loss_w_fake
1159
+ self.loss_T = self.loss_G_ + self.loss_OCR_fake
1160
+
1161
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
1162
+
1163
+
1164
+ self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
1165
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
1166
+ self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
1167
+
1168
+ if not False:
1169
+
1170
+ self.loss_T.backward(retain_graph=True)
1171
+
1172
+
1173
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
1174
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
1175
+ #grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
1176
+
1177
+
1178
+ a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
1179
+
1180
+
1181
+ #a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
1182
+
1183
+ if a is None:
1184
+ print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
1185
+ if a>1000 or a<0.0001:
1186
+ print(a)
1187
+ b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
1188
+ torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))*
1189
+ torch.mean(grad_fake_OCR))
1190
+ # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
1191
+ self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
1192
+ #self.loss_w_fake = a0.detach() * self.loss_w_fake
1193
+
1194
+ self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G_ + self.loss_OCR_fake# + self.loss_w_fake
1195
+ self.loss_T.backward(retain_graph=True)
1196
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
1197
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
1198
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
1199
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
1200
+ with torch.no_grad():
1201
+ self.loss_T.backward()
1202
+ else:
1203
+ self.loss_T.backward()
1204
+
1205
+ if self.opt.clip_grad > 0:
1206
+ clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
1207
+ if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
1208
+ print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
1209
+ sys.exit()
1210
+
1211
+
1212
+
1213
+ def optimize_D_OCR(self):
1214
+ self.forward()
1215
+ self.set_requires_grad([self.netD], True)
1216
+ self.set_requires_grad([self.netOCR], True)
1217
+ self.optimizer_D.zero_grad()
1218
+ #if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1219
+ self.optimizer_OCR.zero_grad()
1220
+ self.backward_D_OCR()
1221
+
1222
+ def optimize_OCR(self):
1223
+ self.forward()
1224
+ self.set_requires_grad([self.netD], False)
1225
+ self.set_requires_grad([self.netOCR], True)
1226
+ if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1227
+ self.optimizer_OCR.zero_grad()
1228
+ self.backward_OCR()
1229
+
1230
+ def optimize_D(self):
1231
+ self.forward()
1232
+ self.set_requires_grad([self.netD], True)
1233
+ self.backward_D()
1234
+
1235
+ def optimize_D_OCR_step(self):
1236
+ self.optimizer_D.step()
1237
+
1238
+ self.optimizer_OCR.step()
1239
+ self.optimizer_D.zero_grad()
1240
+ self.optimizer_OCR.zero_grad()
1241
+
1242
+
1243
+ def optimize_D_OCR_WL(self):
1244
+ self.forward()
1245
+ self.set_requires_grad([self.netD], True)
1246
+ self.set_requires_grad([self.netOCR], True)
1247
+ self.set_requires_grad([self.netW], True)
1248
+ self.optimizer_D.zero_grad()
1249
+ self.optimizer_wl.zero_grad()
1250
+ if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1251
+ self.optimizer_OCR.zero_grad()
1252
+ self.backward_D_OCR_WL()
1253
+
1254
+ def optimize_D_OCR_WL_step(self):
1255
+ self.optimizer_D.step()
1256
+ if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1257
+ self.optimizer_OCR.step()
1258
+ self.optimizer_wl.step()
1259
+ self.optimizer_D.zero_grad()
1260
+ self.optimizer_OCR.zero_grad()
1261
+ self.optimizer_wl.zero_grad()
1262
+
1263
+ def optimize_D_step(self):
1264
+ self.optimizer_D.step()
1265
+ if any(torch.isnan(self.netD.infer_img.blocks[0][0].conv1.bias)):
1266
+ print('D is nan')
1267
+ sys.exit()
1268
+ self.optimizer_D.zero_grad()
1269
+
1270
+ def optimize_G(self):
1271
+ self.forward()
1272
+ self.set_requires_grad([self.netD], False)
1273
+ self.set_requires_grad([self.netOCR], False)
1274
+ self.set_requires_grad([self.netW], False)
1275
+ self.backward_G()
1276
+
1277
+ def optimize_G_WL(self):
1278
+ self.forward()
1279
+ self.set_requires_grad([self.netD], False)
1280
+ self.set_requires_grad([self.netOCR], False)
1281
+ self.set_requires_grad([self.netW], False)
1282
+ self.backward_G_WL()
1283
+
1284
+
1285
+ def optimize_G_only(self):
1286
+ self.forward()
1287
+ self.set_requires_grad([self.netD], False)
1288
+ self.set_requires_grad([self.netOCR], False)
1289
+ self.set_requires_grad([self.netW], False)
1290
+ self.backward_G_only()
1291
+
1292
+
1293
+ def optimize_G_step(self):
1294
+
1295
+ self.optimizer_G.step()
1296
+ self.optimizer_G.zero_grad()
1297
+
1298
+ def optimize_ocr(self):
1299
+ self.set_requires_grad([self.netOCR], True)
1300
+ # OCR loss on real data
1301
+ pred_real_OCR = self.netOCR(self.real)
1302
+ preds_size =torch.IntTensor([pred_real_OCR.size(0)] * self.opt.batch_size).detach()
1303
+ self.loss_OCR_real = self.OCR_criterion(pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
1304
+ self.loss_OCR_real.backward()
1305
+ self.optimizer_OCR.step()
1306
+
1307
+ def optimize_z(self):
1308
+ self.set_requires_grad([self.z], True)
1309
+
1310
+
1311
+ def optimize_parameters(self):
1312
+ self.forward()
1313
+ self.set_requires_grad([self.netD], False)
1314
+ self.optimizer_G.zero_grad()
1315
+ self.backward_G()
1316
+ self.optimizer_G.step()
1317
+
1318
+ self.set_requires_grad([self.netD], True)
1319
+ self.optimizer_D.zero_grad()
1320
+ self.backward_D()
1321
+ self.optimizer_D.step()
1322
+
1323
+ def test(self):
1324
+ self.visual_names = ['fake']
1325
+ self.netG.eval()
1326
+ with torch.no_grad():
1327
+ self.forward()
1328
+
1329
+ def train_GD(self):
1330
+ self.netG.train()
1331
+ self.netD.train()
1332
+ self.optimizer_G.zero_grad()
1333
+ self.optimizer_D.zero_grad()
1334
+ # How many chunks to split x and y into?
1335
+ x = torch.split(self.real, self.opt.batch_size)
1336
+ y = torch.split(self.label, self.opt.batch_size)
1337
+ counter = 0
1338
+
1339
+ # Optionally toggle D and G's "require_grad"
1340
+ if self.opt.toggle_grads:
1341
+ toggle_grad(self.netD, True)
1342
+ toggle_grad(self.netG, False)
1343
+
1344
+ for step_index in range(self.opt.num_critic_train):
1345
+ self.optimizer_D.zero_grad()
1346
+ with torch.set_grad_enabled(False):
1347
+ self.forward()
1348
+ D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake
1349
+ D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter]
1350
+ # Get Discriminator output
1351
+ D_out = self.netD(D_input, D_class)
1352
+ if x is not None:
1353
+ pred_fake, pred_real = torch.split(D_out, [self.fake.shape[0], x[counter].shape[0]]) # D_fake, D_real
1354
+ else:
1355
+ pred_fake = D_out
1356
+ # Combined loss
1357
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
1358
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
1359
+ self.loss_D.backward()
1360
+ counter += 1
1361
+ self.optimizer_D.step()
1362
+
1363
+ # Optionally toggle D and G's "require_grad"
1364
+ if self.opt.toggle_grads:
1365
+ toggle_grad(self.netD, False)
1366
+ toggle_grad(self.netG, True)
1367
+ # Zero G's gradients by default before training G, for safety
1368
+ self.optimizer_G.zero_grad()
1369
+ self.forward()
1370
+ self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss)
1371
+ self.loss_G.backward()
1372
+ self.optimizer_G.step()
1373
+
1374
+
1375
+
1376
+
1377
+
1378
+
1379
+
1380
+
1381
+
1382
+
1383
+
1384
+
1385
+
1386
+
1387
+
1388
+
1389
+
util/models/model_.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ from .OCR_network import *
4
+ from torch.nn import CTCLoss, MSELoss, L1Loss
5
+ from torch.nn.utils import clip_grad_norm_
6
+ import random
7
+ import unicodedata
8
+ import sys
9
+ import torchvision.models as models
10
+ from models.transformer import *
11
+ from .BigGAN_networks import *
12
+ from params import *
13
+ from .OCR_network import *
14
+ from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
15
+ from util.util import toggle_grad, loss_hinge_dis, loss_hinge_gen, ortho, default_ortho, toggle_grad, prepare_z_y, \
16
+ make_one_hot, to_device, multiple_replace, random_word
17
+ from models.inception import InceptionV3, calculate_frechet_distance
18
+
19
+ class FCNDecoder(nn.Module):
20
+ def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
21
+ super(FCNDecoder, self).__init__()
22
+
23
+ self.model = []
24
+ self.model += [ResBlocks(n_res, dim, res_norm,
25
+ activ, pad_type=pad_type)]
26
+ for i in range(ups):
27
+ self.model += [nn.Upsample(scale_factor=2),
28
+ Conv2dBlock(dim, dim // 2, 5, 1, 2,
29
+ norm='in',
30
+ activation=activ,
31
+ pad_type=pad_type)]
32
+ dim //= 2
33
+ self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
34
+ norm='none',
35
+ activation='tanh',
36
+ pad_type=pad_type)]
37
+ self.model = nn.Sequential(*self.model)
38
+
39
+ def forward(self, x):
40
+ y = self.model(x)
41
+
42
+ return y
43
+
44
+
45
+
46
+ class Generator(nn.Module):
47
+
48
+ def __init__(self):
49
+ super(Generator, self).__init__()
50
+
51
+ INP_CHANNEL = NUM_EXAMPLES
52
+ if IS_SEQ: INP_CHANNEL = 1
53
+
54
+
55
+ encoder_layer = TransformerEncoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
56
+ TN_DROPOUT, "relu", True)
57
+ encoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) if True else None
58
+ self.encoder = TransformerEncoder(encoder_layer, TN_ENC_LAYERS, encoder_norm)
59
+
60
+ decoder_layer = TransformerDecoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
61
+ TN_DROPOUT, "relu", True)
62
+ decoder_norm = nn.LayerNorm(TN_HIDDEN_DIM)
63
+ self.decoder = TransformerDecoder(decoder_layer, TN_DEC_LAYERS, decoder_norm,
64
+ return_intermediate=True)
65
+
66
+ self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
67
+
68
+ self.query_embed = nn.Embedding(VOCAB_SIZE, TN_HIDDEN_DIM)
69
+
70
+
71
+ self.linear_q = nn.Linear(TN_DIM_FEEDFORWARD*2, TN_DIM_FEEDFORWARD*8)
72
+
73
+ self.DEC = FCNDecoder(res_norm = 'in')
74
+
75
+
76
+ self._muE = nn.Linear(512,512)
77
+ self._logvarE = nn.Linear(512,512)
78
+
79
+ self._muD = nn.Linear(512,512)
80
+ self._logvarD = nn.Linear(512,512)
81
+
82
+
83
+ self.l1loss = nn.L1Loss()
84
+
85
+ self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
86
+
87
+
88
+
89
+ def reparameterize(self, mu, logvar):
90
+
91
+ mu = torch.unbind(mu , 1)
92
+ logvar = torch.unbind(logvar , 1)
93
+
94
+ outs = []
95
+
96
+ for m,l in zip(mu, logvar):
97
+
98
+ sigma = torch.exp(l)
99
+ eps = torch.cuda.FloatTensor(l.size()[0],1).normal_(0,1)
100
+ eps = eps.expand(sigma.size())
101
+
102
+ out = m + sigma*eps
103
+
104
+ outs.append(out)
105
+
106
+
107
+ return torch.stack(outs, 1)
108
+
109
+
110
+ def Eval(self, ST, QRS):
111
+
112
+ if IS_SEQ:
113
+ B, N, R, C = ST.shape
114
+ FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
115
+ FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
116
+ else:
117
+ FEAT_ST = self.Feat_Encoder(ST)
118
+
119
+
120
+ FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
121
+
122
+ memory = self.encoder(FEAT_ST_ENC)
123
+
124
+ if IS_KLD:
125
+
126
+ Ex = memory.permute(1,0,2)
127
+
128
+ memory_mu = self._muE(Ex)
129
+ memory_logvar = self._logvarE(Ex)
130
+
131
+ memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
132
+
133
+
134
+ OUT_IMGS = []
135
+
136
+ for i in range(QRS.shape[1]):
137
+
138
+ QR = QRS[:, i, :]
139
+
140
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
141
+
142
+ tgt = torch.zeros_like(QR_EMB)
143
+
144
+ hs = self.decoder(tgt, memory, query_pos=QR_EMB)
145
+
146
+ if IS_KLD:
147
+
148
+ Dx = hs[0].permute(1,0,2)
149
+
150
+ hs_mu = self._muD(Dx)
151
+ hs_logvar = self._logvarD(Dx)
152
+
153
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
154
+
155
+
156
+ h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
157
+ if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
158
+
159
+ h = self.linear_q(h)
160
+ h = h.contiguous()
161
+
162
+ h = h.view(h.size(0), h.shape[1]*2, 4, -1)
163
+ h = h.permute(0, 3, 2, 1)
164
+
165
+ h = self.DEC(h)
166
+
167
+
168
+ OUT_IMGS.append(h.detach())
169
+
170
+
171
+
172
+ return OUT_IMGS
173
+
174
+
175
+
176
+
177
+
178
+
179
+ def forward(self, ST, QR, QRs = None, mode = 'train'):
180
+
181
+ if IS_SEQ:
182
+ B, N, R, C = ST.shape
183
+ FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
184
+ FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
185
+ else:
186
+ FEAT_ST = self.Feat_Encoder(ST)
187
+
188
+
189
+ FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
190
+
191
+ memory = self.encoder(FEAT_ST_ENC)
192
+
193
+ if IS_KLD:
194
+
195
+ Ex = memory.permute(1,0,2)
196
+
197
+ memory_mu = self._muE(Ex)
198
+ memory_logvar = self._logvarE(Ex)
199
+
200
+ memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
201
+
202
+
203
+ QR_EMB = self.query_embed.weight.repeat(batch_size,1,1).permute(1,0,2)
204
+
205
+ tgt = torch.zeros_like(QR_EMB)
206
+
207
+ hs = self.decoder(tgt, memory, query_pos=QR_EMB)
208
+
209
+
210
+
211
+ if IS_KLD:
212
+
213
+ Dx = hs[0].permute(1,0,2)
214
+
215
+ hs_mu = self._muD(Dx)
216
+ hs_logvar = self._logvarD(Dx)
217
+
218
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
219
+
220
+ OUT_Feats1_mu = [hs_mu]
221
+ OUT_Feats1_logvar = [hs_logvar]
222
+
223
+
224
+ OUT_Feats1 = [hs]
225
+
226
+
227
+ h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
228
+
229
+ if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
230
+
231
+ h = self.linear_q(h)
232
+
233
+ h = h.contiguous()
234
+
235
+ h = [torch.stack([h[i][QR[i]] for i in range(batch_size)], 0) for QR in QRs]
236
+
237
+ h_list = []
238
+
239
+ for h_ in h:
240
+
241
+ h_ = h_.view(h_.size(0), h_.shape[1]*2, 4, -1)
242
+ h_ = h_.permute(0, 3, 2, 1)
243
+
244
+ #h_ = self.DEC(h_)
245
+
246
+ h_list.append(h_)
247
+
248
+ if mode == 'test' or (not IS_CYCLE and not IS_KLD):
249
+
250
+ return h
251
+
252
+
253
+ OUT_IMGS = [h]
254
+
255
+ for QR in QRs:
256
+
257
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
258
+
259
+ tgt = torch.zeros_like(QR_EMB)
260
+
261
+ hs = self.decoder(tgt, memory, query_pos=QR_EMB)
262
+
263
+
264
+ if IS_KLD:
265
+
266
+ Dx = hs[0].permute(1,0,2)
267
+
268
+ hs_mu = self._muD(Dx)
269
+ hs_logvar = self._logvarD(Dx)
270
+
271
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
272
+
273
+ OUT_Feats1_mu.append(hs_mu)
274
+ OUT_Feats1_logvar.append(hs_logvar)
275
+
276
+
277
+ OUT_Feats1.append(hs)
278
+
279
+
280
+ h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
281
+ if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
282
+
283
+ h = self.linear_q(h)
284
+ h = h.contiguous()
285
+
286
+ h = h.view(h.size(0), h.shape[1]*2, 4, -1)
287
+ h = h.permute(0, 3, 2, 1)
288
+
289
+ h = self.DEC(h)
290
+
291
+ OUT_IMGS.append(h)
292
+
293
+
294
+ if (not IS_CYCLE) and IS_KLD:
295
+
296
+ OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]
297
+
298
+ OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
299
+
300
+
301
+ KLD = (0.5 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
302
+ + (0.5 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))
303
+
304
+
305
+
306
+ def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
307
+ return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
308
+ torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
309
+
310
+
311
+ lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
312
+
313
+
314
+ lda1 = torch.stack(lda1).mean()
315
+
316
+
317
+
318
+ return OUT_IMGS[0], lda1, KLD
319
+
320
+
321
+ with torch.no_grad():
322
+
323
+ if IS_SEQ:
324
+
325
+ FEAT_ST_T = torch.cat([self.Feat_Encoder(IM) for IM in OUT_IMGS], -1)
326
+
327
+ else:
328
+
329
+ max_width_ = max([i_.shape[-1] for i_ in OUT_IMGS])
330
+
331
+ FEAT_ST_T = self.Feat_Encoder(torch.cat([torch.cat([i_, torch.ones((i_.shape[0], i_.shape[1],i_.shape[2], max_width_-i_.shape[3])).to(DEVICE)], -1) for i_ in OUT_IMGS], 1))
332
+
333
+ FEAT_ST_ENC_T = FEAT_ST_T.flatten(2).permute(2,0,1)
334
+
335
+ memory_T = self.encoder(FEAT_ST_ENC_T)
336
+
337
+ if IS_KLD:
338
+
339
+ Ex = memory_T.permute(1,0,2)
340
+
341
+ memory_T_mu = self._muE(Ex)
342
+ memory_T_logvar = self._logvarE(Ex)
343
+
344
+ memory_T = self.reparameterize(memory_T_mu, memory_T_logvar).permute(1,0,2)
345
+
346
+
347
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
348
+
349
+ tgt = torch.zeros_like(QR_EMB)
350
+
351
+ hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
352
+
353
+ if IS_KLD:
354
+
355
+ Dx = hs[0].permute(1,0,2)
356
+
357
+ hs_mu = self._muD(Dx)
358
+ hs_logvar = self._logvarD(Dx)
359
+
360
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
361
+
362
+ OUT_Feats2_mu = [hs_mu]
363
+ OUT_Feats2_logvar = [hs_logvar]
364
+
365
+
366
+ OUT_Feats2 = [hs]
367
+
368
+
369
+
370
+ for QR in QRs:
371
+
372
+ QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
373
+
374
+ tgt = torch.zeros_like(QR_EMB)
375
+
376
+ hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
377
+
378
+ if IS_KLD:
379
+
380
+ Dx = hs[0].permute(1,0,2)
381
+
382
+ hs_mu = self._muD(Dx)
383
+ hs_logvar = self._logvarD(Dx)
384
+
385
+ hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
386
+
387
+ OUT_Feats2_mu.append(hs_mu)
388
+ OUT_Feats2_logvar.append(hs_logvar)
389
+
390
+
391
+ OUT_Feats2.append(hs)
392
+
393
+
394
+
395
+
396
+ Lcycle1 = np.sum([self.l1loss(memory[m_i], memory_T[m_j]) for m_i in range(memory.shape[0]) for m_j in range(memory_T.shape[0])])/(memory.shape[0]*memory_T.shape[0])
397
+ OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]; OUT_Feats2 = torch.cat(OUT_Feats2, 1)[0]
398
+
399
+ Lcycle2 = np.sum([self.l1loss(OUT_Feats1[f_i], OUT_Feats2[f_j]) for f_i in range(OUT_Feats1.shape[0]) for f_j in range(OUT_Feats2.shape[0])])/(OUT_Feats1.shape[0]*OUT_Feats2.shape[0])
400
+
401
+ if IS_KLD:
402
+
403
+ OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
404
+ OUT_Feats2_mu = torch.cat(OUT_Feats2_mu, 1); OUT_Feats2_logvar = torch.cat(OUT_Feats2_logvar, 1);
405
+
406
+ KLD = (0.25 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
407
+ + (0.25 * torch.mean(1 + memory_T_logvar - memory_T_mu.pow(2) - memory_T_logvar.exp()))\
408
+ + (0.25 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))\
409
+ + (0.25 * torch.mean(1 + OUT_Feats2_logvar - OUT_Feats2_mu.pow(2) - OUT_Feats2_logvar.exp()))
410
+
411
+
412
+ def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
413
+ return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
414
+ torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
415
+
416
+
417
+ lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
418
+ lda2 = [_get_lda(memory_T_mu[:,idi,:], OUT_Feats2_mu[:,idj,:], memory_T_logvar[:,idi,:], OUT_Feats2_logvar[:,idj,:]) for idi in range(memory_T.shape[0]) for idj in range(OUT_Feats2.shape[0])]
419
+
420
+ lda1 = torch.stack(lda1).mean()
421
+ lda2 = torch.stack(lda2).mean()
422
+
423
+
424
+ return OUT_IMGS[0], Lcycle1, Lcycle2, lda1, lda2, KLD
425
+
426
+
427
+ return OUT_IMGS[0], Lcycle1, Lcycle2
428
+
429
+
430
+
431
+ class TRGAN(nn.Module):
432
+
433
+ def __init__(self):
434
+ super(TRGAN, self).__init__()
435
+
436
+
437
+ self.epsilon = 1e-7
438
+ self.netG = Generator().to(DEVICE)
439
+ self.netD = nn.DataParallel(Discriminator()).to(DEVICE)
440
+ self.netW = nn.DataParallel(WDiscriminator()).to(DEVICE)
441
+ self.netconverter = strLabelConverter(ALPHABET)
442
+ self.netOCR = CRNN().to(DEVICE)
443
+ self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
444
+
445
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
446
+ self.inception = InceptionV3([block_idx]).to(DEVICE)
447
+
448
+
449
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
450
+ lr=G_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
451
+ self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
452
+ lr=OCR_LR, betas=(0.0, 0.999), weight_decay=0,
453
+ eps=1e-8)
454
+
455
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
456
+ lr=D_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
457
+
458
+
459
+ self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
460
+ lr=W_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
461
+
462
+
463
+ self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
464
+
465
+
466
+ self.optimizer_G.zero_grad()
467
+ self.optimizer_OCR.zero_grad()
468
+ self.optimizer_D.zero_grad()
469
+ self.optimizer_wl.zero_grad()
470
+
471
+ self.loss_G = 0
472
+ self.loss_D = 0
473
+ self.loss_Dfake = 0
474
+ self.loss_Dreal = 0
475
+ self.loss_OCR_fake = 0
476
+ self.loss_OCR_real = 0
477
+ self.loss_w_fake = 0
478
+ self.loss_w_real = 0
479
+ self.Lcycle1 = 0
480
+ self.Lcycle2 = 0
481
+ self.lda1 = 0
482
+ self.lda2 = 0
483
+ self.KLD = 0
484
+
485
+
486
+ with open('../Lexicon/english_words.txt', 'rb') as f:
487
+ self.lex = f.read().splitlines()
488
+ lex=[]
489
+ for word in self.lex:
490
+ try:
491
+ word=word.decode("utf-8")
492
+ except:
493
+ continue
494
+ if len(word)<20:
495
+ lex.append(word)
496
+ self.lex = lex
497
+
498
+
499
+ f = open('mytext.txt', 'r')
500
+
501
+ self.text = [j.encode() for j in sum([i.split(' ') for i in f.readlines()], [])][:NUM_EXAMPLES]
502
+ self.eval_text_encode, self.eval_len_text = self.netconverter.encode(self.text)
503
+ self.eval_text_encode = self.eval_text_encode.to(DEVICE).repeat(batch_size, 1, 1)
504
+
505
+
506
+ def _generate_page(self):
507
+
508
+ self.fakes = self.netG.Eval(self.sdata, self.eval_text_encode)
509
+
510
+ word_t = []
511
+ word_l = []
512
+
513
+ gap = np.ones([32,16])
514
+
515
+ line_wids = []
516
+
517
+
518
+ for idx, fake_ in enumerate(self.fakes):
519
+
520
+ word_t.append((fake_[0,0,:,:self.eval_len_text[idx]*resolution].cpu().numpy()+1)/2)
521
+
522
+ word_t.append(gap)
523
+
524
+ if len(word_t) == 16 or idx == len(self.fakes) - 1:
525
+
526
+ line_ = np.concatenate(word_t, -1)
527
+
528
+ word_l.append(line_)
529
+ line_wids.append(line_.shape[1])
530
+
531
+ word_t = []
532
+
533
+
534
+ gap_h = np.ones([16,max(line_wids)])
535
+
536
+ page_= []
537
+
538
+ for l in word_l:
539
+
540
+ pad_ = np.ones([32,max(line_wids) - l.shape[1]])
541
+
542
+ page_.append(np.concatenate([l, pad_], 1))
543
+ page_.append(gap_h)
544
+
545
+
546
+
547
+ page1 = np.concatenate(page_, 0)
548
+
549
+
550
+ word_t = []
551
+ word_l = []
552
+
553
+ gap = np.ones([32,16])
554
+
555
+ line_wids = []
556
+
557
+ sdata_ = [i.unsqueeze(1) for i in torch.unbind(self.sdata, 1)]
558
+
559
+ for idx, st in enumerate((sdata_)):
560
+
561
+ word_t.append((st[0,0,:,:int(self.input['swids'].cpu().numpy()[0][idx])
562
+ ].cpu().numpy()+1)/2)
563
+
564
+ word_t.append(gap)
565
+
566
+ if len(word_t) == 16 or idx == len(self.fakes) - 1:
567
+
568
+ line_ = np.concatenate(word_t, -1)
569
+
570
+ word_l.append(line_)
571
+ line_wids.append(line_.shape[1])
572
+
573
+ word_t = []
574
+
575
+
576
+ gap_h = np.ones([16,max(line_wids)])
577
+
578
+ page_= []
579
+
580
+ for l in word_l:
581
+
582
+ pad_ = np.ones([32,max(line_wids) - l.shape[1]])
583
+
584
+ page_.append(np.concatenate([l, pad_], 1))
585
+ page_.append(gap_h)
586
+
587
+
588
+
589
+ page2 = np.concatenate(page_, 0)
590
+
591
+ merge_w_size = max(page1.shape[0], page2.shape[0])
592
+
593
+ if page1.shape[0] != merge_w_size:
594
+
595
+ page1 = np.concatenate([page1, np.ones([merge_w_size-page1.shape[0], page1.shape[1]])], 0)
596
+
597
+ if page2.shape[0] != merge_w_size:
598
+
599
+ page2 = np.concatenate([page2, np.ones([merge_w_size-page2.shape[0], page2.shape[1]])], 0)
600
+
601
+
602
+ page = np.concatenate([page2, page1], 1)
603
+
604
+
605
+ return page
606
+
607
+
608
+
609
+
610
+
611
+
612
+
613
+
614
+
615
+
616
+
617
+
618
+
619
+
620
+
621
+
622
+
623
+
624
+
625
+ #FEAT1 = self.inception(torch.cat(self.fakes, 0).repeat(1,3,1,1))[0].detach().view(batch_size, len(self.fakes), -1).cpu().numpy()
626
+ #FEAT2 = self.inception(self.sdata.view(batch_size*NUM_EXAMPLES, 1, 32, -1).repeat(1,3,1,1))[0].detach().view(batch_size, NUM_EXAMPLES, -1 ).cpu().numpy()
627
+ #muvars1 = [{'mu':np.mean(FEAT1[i], axis=0), 'sigma' : np.cov(FEAT1[i], rowvar=False)} for i in range(FEAT1.shape[0])]
628
+ #muvars2 = [{'mu':np.mean(FEAT2[i], axis=0), 'sigma' : np.cov(FEAT2[i], rowvar=False)} for i in range(FEAT2.shape[0])]
629
+
630
+
631
+
632
+
633
+
634
+
635
+ def get_current_losses(self):
636
+
637
+ losses = {}
638
+
639
+ losses['G'] = self.loss_G
640
+ losses['D'] = self.loss_D
641
+ losses['Dfake'] = self.loss_Dfake
642
+ losses['Dreal'] = self.loss_Dreal
643
+ losses['OCR_fake'] = self.loss_OCR_fake
644
+ losses['OCR_real'] = self.loss_OCR_real
645
+ losses['w_fake'] = self.loss_w_fake
646
+ losses['w_real'] = self.loss_w_real
647
+ losses['cycle1'] = self.Lcycle1
648
+ losses['cycle2'] = self.Lcycle2
649
+ losses['lda1'] = self.lda1
650
+ losses['lda2'] = self.lda2
651
+ losses['KLD'] = self.KLD
652
+
653
+ return losses
654
+
655
+ def visualize_images(self):
656
+
657
+ imgs = {}
658
+
659
+
660
+ imgs['fake-1']=self.netG(self.sdata[0:1], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
661
+ imgs['fake-2']=self.netG(self.sdata[0:1], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
662
+ imgs['fake-3']=self.netG(self.sdata[0:1], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
663
+
664
+
665
+ imgs['res-1'] = torch.cat([self.sdata[0, 0],self.sdata[0, 1],self.sdata[0, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
666
+
667
+
668
+ imgs['fake-1']=self.netG(self.sdata[1:2], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
669
+ imgs['fake-2']=self.netG(self.sdata[1:2], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
670
+ imgs['fake-3']=self.netG(self.sdata[1:2], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
671
+
672
+
673
+ imgs['res-2'] = torch.cat([self.sdata[1, 0],self.sdata[1, 1],self.sdata[1, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
674
+
675
+
676
+ imgs['fake-1']=self.netG(self.sdata[2:3], self.text_encode_fake[0].unsqueeze(0) , mode = 'test' )[0, 0].detach()
677
+ imgs['fake-2']=self.netG(self.sdata[2:3], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
678
+ imgs['fake-3']=self.netG(self.sdata[2:3], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
679
+
680
+
681
+ imgs['res-3'] = torch.cat([self.sdata[2, 0],self.sdata[2, 1],self.sdata[2, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
682
+
683
+
684
+
685
+
686
+ return imgs
687
+
688
+
689
+ def load_networks(self, epoch):
690
+ BaseModel.load_networks(self, epoch)
691
+ if self.opt.single_writer:
692
+ load_filename = '%s_z.pkl' % (epoch)
693
+ load_path = os.path.join(self.save_dir, load_filename)
694
+ self.z = torch.load(load_path)
695
+
696
+ def _set_input(self, input):
697
+ self.input = input
698
+
699
+ def set_requires_grad(self, nets, requires_grad=False):
700
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
701
+ Parameters:
702
+ nets (network list) -- a list of networks
703
+ requires_grad (bool) -- whether the networks require gradients or not
704
+ """
705
+ if not isinstance(nets, list):
706
+ nets = [nets]
707
+ for net in nets:
708
+ if net is not None:
709
+ for param in net.parameters():
710
+ param.requires_grad = requires_grad
711
+
712
+ def forward(self):
713
+
714
+
715
+ self.real = self.input['img'].to(DEVICE)
716
+ self.label = self.input['label']
717
+ self.sdata = self.input['simg'].to(DEVICE)
718
+ self.ST_LEN = self.input['swids']
719
+ self.text_encode, self.len_text = self.netconverter.encode(self.label)
720
+ self.one_hot_real = make_one_hot(self.text_encode, self.len_text, VOCAB_SIZE).to(DEVICE).detach()
721
+ self.text_encode = self.text_encode.to(DEVICE).detach()
722
+ self.len_text = self.len_text.detach()
723
+
724
+ self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
725
+ self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words)
726
+ self.text_encode_fake = self.text_encode_fake.to(DEVICE)
727
+ self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, VOCAB_SIZE).to(DEVICE)
728
+
729
+
730
+ self.text_encode_fake_js = []
731
+
732
+ for _ in range(NUM_WORDS - 1):
733
+
734
+ self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
735
+ self.text_encode_fake_j, self.len_text_fake_j = self.netconverter.encode(self.words_j)
736
+ self.text_encode_fake_j = self.text_encode_fake_j.to(DEVICE)
737
+ self.text_encode_fake_js.append(self.text_encode_fake_j)
738
+
739
+
740
+ if IS_CYCLE and IS_KLD:
741
+
742
+ self.fake, self.Lcycle1, self.Lcycle2, self.lda1, self.lda2, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
743
+
744
+ elif IS_CYCLE and (not IS_KLD):
745
+
746
+ self.fake, self.Lcycle1, self.Lcycle2 = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
747
+
748
+ elif (not IS_CYCLE) and IS_KLD:
749
+
750
+ self.fake, self.lda1, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
751
+
752
+ else:
753
+
754
+ self.fake = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
755
+
756
+
757
+
758
+
759
+ def backward_D_OCR(self):
760
+
761
+ pred_real = self.netD(self.real.detach())
762
+
763
+ pred_fake = self.netD(**{'x': self.fake.detach()})
764
+
765
+
766
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
767
+
768
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
769
+
770
+ self.pred_real_OCR = self.netOCR(self.real.detach())
771
+ preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * batch_size).detach()
772
+ loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
773
+ self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
774
+
775
+ loss_total = self.loss_D + self.loss_OCR_real
776
+
777
+ # backward
778
+ loss_total.backward()
779
+ for param in self.netOCR.parameters():
780
+ param.grad[param.grad!=param.grad]=0
781
+ param.grad[torch.isnan(param.grad)]=0
782
+ param.grad[torch.isinf(param.grad)]=0
783
+
784
+
785
+
786
+ return loss_total
787
+
788
+ def backward_D_WL(self):
789
+ # Real
790
+ pred_real = self.netD(self.real.detach())
791
+
792
+ pred_fake = self.netD(**{'x': self.fake.detach()})
793
+
794
+
795
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
796
+
797
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
798
+
799
+
800
+ self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(DEVICE)).mean()
801
+ # total loss
802
+ loss_total = self.loss_D + self.loss_w_real
803
+
804
+ # backward
805
+ loss_total.backward()
806
+
807
+
808
+ return loss_total
809
+
810
+ def optimize_D_WL(self):
811
+ self.forward()
812
+ self.set_requires_grad([self.netD], True)
813
+ self.set_requires_grad([self.netOCR], False)
814
+ self.set_requires_grad([self.netW], True)
815
+
816
+ self.optimizer_D.zero_grad()
817
+ self.optimizer_wl.zero_grad()
818
+
819
+ self.backward_D_WL()
820
+
821
+
822
+
823
+
824
+ def backward_D_OCR_WL(self):
825
+ # Real
826
+ if self.real_z_mean is None:
827
+ pred_real = self.netD(self.real.detach())
828
+ else:
829
+ pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
830
+ # Fake
831
+ try:
832
+ pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
833
+ except:
834
+ print('a')
835
+ # Combined loss
836
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
837
+
838
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
839
+ # OCR loss on real data
840
+ self.pred_real_OCR = self.netOCR(self.real.detach())
841
+ preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
842
+ loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
843
+ self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
844
+ # total loss
845
+ self.loss_w_real = self.netW(self.real.detach(), self.wcl)
846
+ loss_total = self.loss_D + self.loss_OCR_real + self.loss_w_real
847
+
848
+ # backward
849
+ loss_total.backward()
850
+ for param in self.netOCR.parameters():
851
+ param.grad[param.grad!=param.grad]=0
852
+ param.grad[torch.isnan(param.grad)]=0
853
+ param.grad[torch.isinf(param.grad)]=0
854
+
855
+
856
+
857
+ return loss_total
858
+
859
+ def optimize_D_WL_step(self):
860
+ self.optimizer_D.step()
861
+ self.optimizer_wl.step()
862
+ self.optimizer_D.zero_grad()
863
+ self.optimizer_wl.zero_grad()
864
+
865
+ def backward_OCR(self):
866
+ # OCR loss on real data
867
+ self.pred_real_OCR = self.netOCR(self.real.detach())
868
+ preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
869
+ loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
870
+ self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
871
+
872
+ # backward
873
+ self.loss_OCR_real.backward()
874
+ for param in self.netOCR.parameters():
875
+ param.grad[param.grad!=param.grad]=0
876
+ param.grad[torch.isnan(param.grad)]=0
877
+ param.grad[torch.isinf(param.grad)]=0
878
+
879
+ return self.loss_OCR_real
880
+
881
+
882
+ def backward_D(self):
883
+ # Real
884
+ if self.real_z_mean is None:
885
+ pred_real = self.netD(self.real.detach())
886
+ else:
887
+ pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
888
+ pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
889
+ # Combined loss
890
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
891
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
892
+ # backward
893
+ self.loss_D.backward()
894
+
895
+
896
+ return self.loss_D
897
+
898
+
899
+ def backward_G_only(self):
900
+
901
+ self.gb_alpha = 0.7
902
+ #self.Lcycle1 = self.Lcycle1.mean()
903
+ #self.Lcycle2 = self.Lcycle2.mean()
904
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
905
+
906
+
907
+ pred_fake_OCR = self.netOCR(self.fake)
908
+ preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * batch_size).detach()
909
+ loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
910
+ self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
911
+
912
+ self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
913
+
914
+ self.loss_T = self.loss_G + self.loss_OCR_fake
915
+
916
+
917
+
918
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
919
+
920
+
921
+ self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
922
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
923
+ self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
924
+
925
+
926
+ self.loss_T.backward(retain_graph=True)
927
+
928
+
929
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
930
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
931
+
932
+
933
+ a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
934
+
935
+
936
+ if a is None:
937
+ print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
938
+ if a>1000 or a<0.0001:
939
+ print(a)
940
+
941
+
942
+ self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
943
+
944
+ self.loss_T = self.loss_G + self.loss_OCR_fake
945
+
946
+
947
+ self.loss_T.backward(retain_graph=True)
948
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
949
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
950
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
951
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
952
+
953
+ with torch.no_grad():
954
+ self.loss_T.backward()
955
+
956
+ if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G):
957
+ print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
958
+ sys.exit()
959
+
960
+ def backward_G_WL(self):
961
+
962
+ self.gb_alpha = 0.7
963
+ #self.Lcycle1 = self.Lcycle1.mean()
964
+ #self.Lcycle2 = self.Lcycle2.mean()
965
+
966
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
967
+
968
+ self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(DEVICE)).mean()
969
+
970
+ self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
971
+
972
+ self.loss_T = self.loss_G + self.loss_w_fake
973
+
974
+
975
+
976
+
977
+ #grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, retain_graph=True)[0]
978
+
979
+
980
+ #self.loss_grad_fake_WL = 10**6*torch.mean(grad_fake_WL**2)
981
+ #grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
982
+ #self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
983
+
984
+
985
+
986
+ self.loss_T.backward(retain_graph=True)
987
+
988
+
989
+ grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
990
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
991
+
992
+
993
+ a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_WL))
994
+
995
+
996
+
997
+ if a is None:
998
+ print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_WL))
999
+ if a>1000 or a<0.0001:
1000
+ print(a)
1001
+
1002
+ self.loss_w_fake = a.detach() * self.loss_w_fake
1003
+
1004
+ self.loss_T = self.loss_G + self.loss_w_fake
1005
+
1006
+ self.loss_T.backward(retain_graph=True)
1007
+ grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
1008
+ grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
1009
+ self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
1010
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
1011
+
1012
+ with torch.no_grad():
1013
+ self.loss_T.backward()
1014
+
1015
+ def backward_G(self):
1016
+ self.opt.gb_alpha = 0.7
1017
+ self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss)
1018
+ # OCR loss on real data
1019
+
1020
+ pred_fake_OCR = self.netOCR(self.fake)
1021
+ preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach()
1022
+ loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
1023
+ self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
1024
+
1025
+
1026
+ self.loss_w_fake = self.netW(self.fake, self.wcl)
1027
+ #self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
1028
+ # total loss
1029
+
1030
+ # l1 = self.params[0]*self.loss_G
1031
+ # l2 = self.params[0]*self.loss_OCR_fake
1032
+ #l3 = self.params[0]*self.loss_w_fake
1033
+ self.loss_G_ = 10*self.loss_G + self.loss_w_fake
1034
+ self.loss_T = self.loss_G_ + self.loss_OCR_fake
1035
+
1036
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
1037
+
1038
+
1039
+ self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
1040
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
1041
+ self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
1042
+
1043
+ if not False:
1044
+
1045
+ self.loss_T.backward(retain_graph=True)
1046
+
1047
+
1048
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
1049
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
1050
+ #grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
1051
+
1052
+
1053
+ a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
1054
+
1055
+
1056
+ #a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
1057
+
1058
+ if a is None:
1059
+ print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
1060
+ if a>1000 or a<0.0001:
1061
+ print(a)
1062
+ b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
1063
+ torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))*
1064
+ torch.mean(grad_fake_OCR))
1065
+ # self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
1066
+ self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
1067
+ #self.loss_w_fake = a0.detach() * self.loss_w_fake
1068
+
1069
+ self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G_ + self.loss_OCR_fake# + self.loss_w_fake
1070
+ self.loss_T.backward(retain_graph=True)
1071
+ grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
1072
+ grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
1073
+ self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
1074
+ self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
1075
+ with torch.no_grad():
1076
+ self.loss_T.backward()
1077
+ else:
1078
+ self.loss_T.backward()
1079
+
1080
+ if self.opt.clip_grad > 0:
1081
+ clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
1082
+ if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
1083
+ print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
1084
+ sys.exit()
1085
+
1086
+
1087
+
1088
+ def optimize_D_OCR(self):
1089
+ self.forward()
1090
+ self.set_requires_grad([self.netD], True)
1091
+ self.set_requires_grad([self.netOCR], True)
1092
+ self.optimizer_D.zero_grad()
1093
+ #if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1094
+ self.optimizer_OCR.zero_grad()
1095
+ self.backward_D_OCR()
1096
+
1097
+ def optimize_OCR(self):
1098
+ self.forward()
1099
+ self.set_requires_grad([self.netD], False)
1100
+ self.set_requires_grad([self.netOCR], True)
1101
+ if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1102
+ self.optimizer_OCR.zero_grad()
1103
+ self.backward_OCR()
1104
+
1105
+ def optimize_D(self):
1106
+ self.forward()
1107
+ self.set_requires_grad([self.netD], True)
1108
+ self.backward_D()
1109
+
1110
+ def optimize_D_OCR_step(self):
1111
+ self.optimizer_D.step()
1112
+
1113
+ self.optimizer_OCR.step()
1114
+ self.optimizer_D.zero_grad()
1115
+ self.optimizer_OCR.zero_grad()
1116
+
1117
+
1118
+ def optimize_D_OCR_WL(self):
1119
+ self.forward()
1120
+ self.set_requires_grad([self.netD], True)
1121
+ self.set_requires_grad([self.netOCR], True)
1122
+ self.set_requires_grad([self.netW], True)
1123
+ self.optimizer_D.zero_grad()
1124
+ self.optimizer_wl.zero_grad()
1125
+ if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1126
+ self.optimizer_OCR.zero_grad()
1127
+ self.backward_D_OCR_WL()
1128
+
1129
+ def optimize_D_OCR_WL_step(self):
1130
+ self.optimizer_D.step()
1131
+ if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
1132
+ self.optimizer_OCR.step()
1133
+ self.optimizer_wl.step()
1134
+ self.optimizer_D.zero_grad()
1135
+ self.optimizer_OCR.zero_grad()
1136
+ self.optimizer_wl.zero_grad()
1137
+
1138
+ def optimize_D_step(self):
1139
+ self.optimizer_D.step()
1140
+ if any(torch.isnan(self.netD.infer_img.blocks[0][0].conv1.bias)):
1141
+ print('D is nan')
1142
+ sys.exit()
1143
+ self.optimizer_D.zero_grad()
1144
+
1145
+ def optimize_G(self):
1146
+ self.forward()
1147
+ self.set_requires_grad([self.netD], False)
1148
+ self.set_requires_grad([self.netOCR], False)
1149
+ self.set_requires_grad([self.netW], False)
1150
+ self.backward_G()
1151
+
1152
+ def optimize_G_WL(self):
1153
+ self.forward()
1154
+ self.set_requires_grad([self.netD], False)
1155
+ self.set_requires_grad([self.netOCR], False)
1156
+ self.set_requires_grad([self.netW], False)
1157
+ self.backward_G_WL()
1158
+
1159
+
1160
+ def optimize_G_only(self):
1161
+ self.forward()
1162
+ self.set_requires_grad([self.netD], False)
1163
+ self.set_requires_grad([self.netOCR], False)
1164
+ self.set_requires_grad([self.netW], False)
1165
+ self.backward_G_only()
1166
+
1167
+
1168
+ def optimize_G_step(self):
1169
+
1170
+ self.optimizer_G.step()
1171
+ self.optimizer_G.zero_grad()
1172
+
1173
+ def optimize_ocr(self):
1174
+ self.set_requires_grad([self.netOCR], True)
1175
+ # OCR loss on real data
1176
+ pred_real_OCR = self.netOCR(self.real)
1177
+ preds_size =torch.IntTensor([pred_real_OCR.size(0)] * self.opt.batch_size).detach()
1178
+ self.loss_OCR_real = self.OCR_criterion(pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
1179
+ self.loss_OCR_real.backward()
1180
+ self.optimizer_OCR.step()
1181
+
1182
+ def optimize_z(self):
1183
+ self.set_requires_grad([self.z], True)
1184
+
1185
+
1186
+ def optimize_parameters(self):
1187
+ self.forward()
1188
+ self.set_requires_grad([self.netD], False)
1189
+ self.optimizer_G.zero_grad()
1190
+ self.backward_G()
1191
+ self.optimizer_G.step()
1192
+
1193
+ self.set_requires_grad([self.netD], True)
1194
+ self.optimizer_D.zero_grad()
1195
+ self.backward_D()
1196
+ self.optimizer_D.step()
1197
+
1198
+ def test(self):
1199
+ self.visual_names = ['fake']
1200
+ self.netG.eval()
1201
+ with torch.no_grad():
1202
+ self.forward()
1203
+
1204
+ def train_GD(self):
1205
+ self.netG.train()
1206
+ self.netD.train()
1207
+ self.optimizer_G.zero_grad()
1208
+ self.optimizer_D.zero_grad()
1209
+ # How many chunks to split x and y into?
1210
+ x = torch.split(self.real, self.opt.batch_size)
1211
+ y = torch.split(self.label, self.opt.batch_size)
1212
+ counter = 0
1213
+
1214
+ # Optionally toggle D and G's "require_grad"
1215
+ if self.opt.toggle_grads:
1216
+ toggle_grad(self.netD, True)
1217
+ toggle_grad(self.netG, False)
1218
+
1219
+ for step_index in range(self.opt.num_critic_train):
1220
+ self.optimizer_D.zero_grad()
1221
+ with torch.set_grad_enabled(False):
1222
+ self.forward()
1223
+ D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake
1224
+ D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter]
1225
+ # Get Discriminator output
1226
+ D_out = self.netD(D_input, D_class)
1227
+ if x is not None:
1228
+ pred_fake, pred_real = torch.split(D_out, [self.fake.shape[0], x[counter].shape[0]]) # D_fake, D_real
1229
+ else:
1230
+ pred_fake = D_out
1231
+ # Combined loss
1232
+ self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
1233
+ self.loss_D = self.loss_Dreal + self.loss_Dfake
1234
+ self.loss_D.backward()
1235
+ counter += 1
1236
+ self.optimizer_D.step()
1237
+
1238
+ # Optionally toggle D and G's "require_grad"
1239
+ if self.opt.toggle_grads:
1240
+ toggle_grad(self.netD, False)
1241
+ toggle_grad(self.netG, True)
1242
+ # Zero G's gradients by default before training G, for safety
1243
+ self.optimizer_G.zero_grad()
1244
+ self.forward()
1245
+ self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss)
1246
+ self.loss_G.backward()
1247
+ self.optimizer_G.step()
1248
+
1249
+
1250
+
1251
+
1252
+
1253
+
1254
+
1255
+
1256
+
1257
+
1258
+
1259
+
1260
+
1261
+
1262
+
1263
+
1264
+
util/models/networks.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ from util.util import to_device, load_network
7
+
8
+ ###############################################################################
9
+ # Helper Functions
10
+ ###############################################################################
11
+
12
+
13
+ def init_weights(net, init_type='normal', init_gain=0.02):
14
+ """Initialize network weights.
15
+
16
+ Parameters:
17
+ net (network) -- network to be initialized
18
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
19
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
20
+
21
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
22
+ work better for some applications. Feel free to try yourself.
23
+ """
24
+ def init_func(m): # define the initialization function
25
+ classname = m.__class__.__name__
26
+ if (isinstance(m, nn.Conv2d)
27
+ or isinstance(m, nn.Linear)
28
+ or isinstance(m, nn.Embedding)):
29
+ # if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
+ if init_type == 'N02':
31
+ init.normal_(m.weight.data, 0.0, init_gain)
32
+ elif init_type in ['glorot', 'xavier']:
33
+ init.xavier_normal_(m.weight.data, gain=init_gain)
34
+ elif init_type == 'kaiming':
35
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
+ elif init_type == 'ortho':
37
+ init.orthogonal_(m.weight.data, gain=init_gain)
38
+ else:
39
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
40
+ # if hasattr(m, 'bias') and m.bias is not None:
41
+ # init.constant_(m.bias.data, 0.0)
42
+ # elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
43
+ # init.normal_(m.weight.data, 1.0, init_gain)
44
+ # init.constant_(m.bias.data, 0.0)
45
+ if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']:
46
+ print('initialize network with %s' % init_type)
47
+ net.apply(init_func) # apply the initialization function <init_func>
48
+ else:
49
+ print('loading the model from %s' % init_type)
50
+ net = load_network(net, init_type, 'latest')
51
+ return net
52
+
53
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
54
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
55
+ Parameters:
56
+ net (network) -- the network to be initialized
57
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
58
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
59
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
60
+
61
+ Return an initialized network.
62
+ """
63
+ if len(gpu_ids) > 0:
64
+ assert(torch.cuda.is_available())
65
+ net.to(gpu_ids[0])
66
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
67
+ init_weights(net, init_type, init_gain=init_gain)
68
+ return net
69
+
70
+
71
+ def get_scheduler(optimizer, opt):
72
+ """Return a learning rate scheduler
73
+
74
+ Parameters:
75
+ optimizer -- the optimizer of the network
76
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
77
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
78
+
79
+ For 'linear', we keep the same learning rate for the first <opt.niter> epochs
80
+ and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
81
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
82
+ See https://pytorch.org/docs/stable/optim.html for more details.
83
+ """
84
+ if opt.lr_policy == 'linear':
85
+ def lambda_rule(epoch):
86
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
87
+ return lr_l
88
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
89
+ elif opt.lr_policy == 'step':
90
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
91
+ elif opt.lr_policy == 'plateau':
92
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
93
+ elif opt.lr_policy == 'cosine':
94
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
95
+ else:
96
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
97
+ return scheduler
98
+
util/models/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (380 Bytes). View file
 
util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (384 Bytes). View file
 
util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc ADDED
Binary file (13.1 kB). View file
 
util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc ADDED
Binary file (13 kB). View file
 
util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc ADDED
Binary file (4.77 kB). View file
 
util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc ADDED
Binary file (4.77 kB). View file
 
util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc ADDED
Binary file (3.44 kB). View file
 
util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc ADDED
Binary file (3.45 kB). View file
 
util/models/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+ # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input, gain=None, bias=None):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ out = F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+ if gain is not None:
55
+ out = out + gain
56
+ if bias is not None:
57
+ out = out + bias
58
+ return out
59
+
60
+ # Resize the input to (B, C, -1).
61
+ input_shape = input.size()
62
+ # print(input_shape)
63
+ input = input.view(input.size(0), input.size(1), -1)
64
+
65
+ # Compute the sum and square-sum.
66
+ sum_size = input.size(0) * input.size(2)
67
+ input_sum = _sum_ft(input)
68
+ input_ssum = _sum_ft(input ** 2)
69
+ # Reduce-and-broadcast the statistics.
70
+ # print('it begins')
71
+ if self._parallel_id == 0:
72
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
73
+ else:
74
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
75
+ # if self._parallel_id == 0:
76
+ # # print('here')
77
+ # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
78
+ # else:
79
+ # # print('there')
80
+ # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
81
+
82
+ # print('how2')
83
+ # num = sum_size
84
+ # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
85
+ # Fix the graph
86
+ # sum = (sum.detach() - input_sum.detach()) + input_sum
87
+ # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
88
+
89
+ # mean = sum / num
90
+ # var = ssum / num - mean ** 2
91
+ # # var = (ssum - mean * sum) / num
92
+ # inv_std = torch.rsqrt(var + self.eps)
93
+
94
+ # Compute the output.
95
+ if gain is not None:
96
+ # print('gaining')
97
+ # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
98
+ # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
99
+ # output = input * scale - shift
100
+ output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
101
+ elif self.affine:
102
+ # MJY:: Fuse the multiplication for speed.
103
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
104
+ else:
105
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
106
+
107
+ # Reshape it.
108
+ return output.view(input_shape)
109
+
110
+ def __data_parallel_replicate__(self, ctx, copy_id):
111
+ self._is_parallel = True
112
+ self._parallel_id = copy_id
113
+
114
+ # parallel_id == 0 means master device.
115
+ if self._parallel_id == 0:
116
+ ctx.sync_master = self._sync_master
117
+ else:
118
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
119
+
120
+ def _data_parallel_master(self, intermediates):
121
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
122
+
123
+ # Always using same "device order" makes the ReduceAdd operation faster.
124
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
125
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
126
+
127
+ to_reduce = [i[1][:2] for i in intermediates]
128
+ to_reduce = [j for i in to_reduce for j in i] # flatten
129
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
130
+
131
+ sum_size = sum([i[1].sum_size for i in intermediates])
132
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
133
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
134
+
135
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
136
+ # print('a')
137
+ # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
138
+ # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
139
+ # print('b')
140
+ outputs = []
141
+ for i, rec in enumerate(intermediates):
142
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
143
+ # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
144
+
145
+ return outputs
146
+
147
+ def _compute_mean_std(self, sum_, ssum, size):
148
+ """Compute the mean and standard-deviation with sum and square-sum. This method
149
+ also maintains the moving average on the master device."""
150
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
151
+ mean = sum_ / size
152
+ sumvar = ssum - sum_ * mean
153
+ unbias_var = sumvar / (size - 1)
154
+ bias_var = sumvar / size
155
+
156
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
157
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
158
+ return mean, torch.rsqrt(bias_var + self.eps)
159
+ # return mean, bias_var.clamp(self.eps) ** -0.5
160
+
161
+
162
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
163
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
164
+ mini-batch.
165
+
166
+ .. math::
167
+
168
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
169
+
170
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
171
+ standard-deviation are reduced across all devices during training.
172
+
173
+ For example, when one uses `nn.DataParallel` to wrap the network during
174
+ training, PyTorch's implementation normalize the tensor on each device using
175
+ the statistics only on that device, which accelerated the computation and
176
+ is also easy to implement, but the statistics might be inaccurate.
177
+ Instead, in this synchronized version, the statistics will be computed
178
+ over all training samples distributed on multiple devices.
179
+
180
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
181
+ as the built-in PyTorch implementation.
182
+
183
+ The mean and standard-deviation are calculated per-dimension over
184
+ the mini-batches and gamma and beta are learnable parameter vectors
185
+ of size C (where C is the input size).
186
+
187
+ During training, this layer keeps a running estimate of its computed mean
188
+ and variance. The running sum is kept with a default momentum of 0.1.
189
+
190
+ During evaluation, this running mean/variance is used for normalization.
191
+
192
+ Because the BatchNorm is done over the `C` dimension, computing statistics
193
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
194
+
195
+ Args:
196
+ num_features: num_features from an expected input of size
197
+ `batch_size x num_features [x width]`
198
+ eps: a value added to the denominator for numerical stability.
199
+ Default: 1e-5
200
+ momentum: the value used for the running_mean and running_var
201
+ computation. Default: 0.1
202
+ affine: a boolean value that when set to ``True``, gives the layer learnable
203
+ affine parameters. Default: ``True``
204
+
205
+ Shape:
206
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
207
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
208
+
209
+ Examples:
210
+ >>> # With Learnable Parameters
211
+ >>> m = SynchronizedBatchNorm1d(100)
212
+ >>> # Without Learnable Parameters
213
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
214
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
215
+ >>> output = m(input)
216
+ """
217
+
218
+ def _check_input_dim(self, input):
219
+ if input.dim() != 2 and input.dim() != 3:
220
+ raise ValueError('expected 2D or 3D input (got {}D input)'
221
+ .format(input.dim()))
222
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
223
+
224
+
225
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
226
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
227
+ of 3d inputs
228
+
229
+ .. math::
230
+
231
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
232
+
233
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
234
+ standard-deviation are reduced across all devices during training.
235
+
236
+ For example, when one uses `nn.DataParallel` to wrap the network during
237
+ training, PyTorch's implementation normalize the tensor on each device using
238
+ the statistics only on that device, which accelerated the computation and
239
+ is also easy to implement, but the statistics might be inaccurate.
240
+ Instead, in this synchronized version, the statistics will be computed
241
+ over all training samples distributed on multiple devices.
242
+
243
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
244
+ as the built-in PyTorch implementation.
245
+
246
+ The mean and standard-deviation are calculated per-dimension over
247
+ the mini-batches and gamma and beta are learnable parameter vectors
248
+ of size C (where C is the input size).
249
+
250
+ During training, this layer keeps a running estimate of its computed mean
251
+ and variance. The running sum is kept with a default momentum of 0.1.
252
+
253
+ During evaluation, this running mean/variance is used for normalization.
254
+
255
+ Because the BatchNorm is done over the `C` dimension, computing statistics
256
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
257
+
258
+ Args:
259
+ num_features: num_features from an expected input of
260
+ size batch_size x num_features x height x width
261
+ eps: a value added to the denominator for numerical stability.
262
+ Default: 1e-5
263
+ momentum: the value used for the running_mean and running_var
264
+ computation. Default: 0.1
265
+ affine: a boolean value that when set to ``True``, gives the layer learnable
266
+ affine parameters. Default: ``True``
267
+
268
+ Shape:
269
+ - Input: :math:`(N, C, H, W)`
270
+ - Output: :math:`(N, C, H, W)` (same shape as input)
271
+
272
+ Examples:
273
+ >>> # With Learnable Parameters
274
+ >>> m = SynchronizedBatchNorm2d(100)
275
+ >>> # Without Learnable Parameters
276
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
277
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
278
+ >>> output = m(input)
279
+ """
280
+
281
+ def _check_input_dim(self, input):
282
+ if input.dim() != 4:
283
+ raise ValueError('expected 4D input (got {}D input)'
284
+ .format(input.dim()))
285
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
286
+
287
+
288
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
289
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
290
+ of 4d inputs
291
+
292
+ .. math::
293
+
294
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
295
+
296
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
297
+ standard-deviation are reduced across all devices during training.
298
+
299
+ For example, when one uses `nn.DataParallel` to wrap the network during
300
+ training, PyTorch's implementation normalize the tensor on each device using
301
+ the statistics only on that device, which accelerated the computation and
302
+ is also easy to implement, but the statistics might be inaccurate.
303
+ Instead, in this synchronized version, the statistics will be computed
304
+ over all training samples distributed on multiple devices.
305
+
306
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
307
+ as the built-in PyTorch implementation.
308
+
309
+ The mean and standard-deviation are calculated per-dimension over
310
+ the mini-batches and gamma and beta are learnable parameter vectors
311
+ of size C (where C is the input size).
312
+
313
+ During training, this layer keeps a running estimate of its computed mean
314
+ and variance. The running sum is kept with a default momentum of 0.1.
315
+
316
+ During evaluation, this running mean/variance is used for normalization.
317
+
318
+ Because the BatchNorm is done over the `C` dimension, computing statistics
319
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
320
+ or Spatio-temporal BatchNorm
321
+
322
+ Args:
323
+ num_features: num_features from an expected input of
324
+ size batch_size x num_features x depth x height x width
325
+ eps: a value added to the denominator for numerical stability.
326
+ Default: 1e-5
327
+ momentum: the value used for the running_mean and running_var
328
+ computation. Default: 0.1
329
+ affine: a boolean value that when set to ``True``, gives the layer learnable
330
+ affine parameters. Default: ``True``
331
+
332
+ Shape:
333
+ - Input: :math:`(N, C, D, H, W)`
334
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
335
+
336
+ Examples:
337
+ >>> # With Learnable Parameters
338
+ >>> m = SynchronizedBatchNorm3d(100)
339
+ >>> # Without Learnable Parameters
340
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
341
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
342
+ >>> output = m(input)
343
+ """
344
+
345
+ def _check_input_dim(self, input):
346
+ if input.dim() != 5:
347
+ raise ValueError('expected 5D input (got {}D input)'
348
+ .format(input.dim()))
349
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
util/models/sync_batchnorm/batchnorm_reimpl.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : batchnorm_reimpl.py
4
+ # Author : acgtyrant
5
+ # Date : 11/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+ __all__ = ['BatchNormReimpl']
16
+
17
+
18
+ class BatchNorm2dReimpl(nn.Module):
19
+ """
20
+ A re-implementation of batch normalization, used for testing the numerical
21
+ stability.
22
+
23
+ Author: acgtyrant
24
+ See also:
25
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26
+ """
27
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
28
+ super().__init__()
29
+
30
+ self.num_features = num_features
31
+ self.eps = eps
32
+ self.momentum = momentum
33
+ self.weight = nn.Parameter(torch.empty(num_features))
34
+ self.bias = nn.Parameter(torch.empty(num_features))
35
+ self.register_buffer('running_mean', torch.zeros(num_features))
36
+ self.register_buffer('running_var', torch.ones(num_features))
37
+ self.reset_parameters()
38
+
39
+ def reset_running_stats(self):
40
+ self.running_mean.zero_()
41
+ self.running_var.fill_(1)
42
+
43
+ def reset_parameters(self):
44
+ self.reset_running_stats()
45
+ init.uniform_(self.weight)
46
+ init.zeros_(self.bias)
47
+
48
+ def forward(self, input_):
49
+ batchsize, channels, height, width = input_.size()
50
+ numel = batchsize * height * width
51
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52
+ sum_ = input_.sum(1)
53
+ sum_of_square = input_.pow(2).sum(1)
54
+ mean = sum_ / numel
55
+ sumvar = sum_of_square - sum_ * mean
56
+
57
+ self.running_mean = (
58
+ (1 - self.momentum) * self.running_mean
59
+ + self.momentum * mean.detach()
60
+ )
61
+ unbias_var = sumvar / (numel - 1)
62
+ self.running_var = (
63
+ (1 - self.momentum) * self.running_var
64
+ + self.momentum * unbias_var.detach()
65
+ )
66
+
67
+ bias_var = sumvar / numel
68
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
69
+ output = (
70
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72
+
73
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74
+
util/models/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)