Spaces:
Running
Running
ankankbhunia
commited on
Upload 58 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- util/__init__.py +1 -0
- util/__init__.pyc +0 -0
- util/__pycache__/__init__.cpython-36.pyc +0 -0
- util/__pycache__/__init__.cpython-37.pyc +0 -0
- util/__pycache__/__init__.cpython-38.pyc +0 -0
- util/__pycache__/__init__.cpython-39.pyc +0 -0
- util/__pycache__/html.cpython-36.pyc +0 -0
- util/__pycache__/html.cpython-37.pyc +0 -0
- util/__pycache__/misc.cpython-36.pyc +0 -0
- util/__pycache__/misc.cpython-37.pyc +0 -0
- util/__pycache__/params.cpython-37.pyc +0 -0
- util/__pycache__/util.cpython-36.pyc +0 -0
- util/__pycache__/util.cpython-37.pyc +0 -0
- util/__pycache__/util.cpython-38.pyc +0 -0
- util/__pycache__/util.cpython-39.pyc +0 -0
- util/__pycache__/visualizer.cpython-36.pyc +0 -0
- util/__pycache__/visualizer.cpython-37.pyc +0 -0
- util/html.py +86 -0
- util/misc.py +465 -0
- util/models/BigGAN_layers.py +469 -0
- util/models/BigGAN_networks.py +841 -0
- util/models/OCR_network.py +304 -0
- util/models/__init__.py +65 -0
- util/models/__pycache__/BigGAN_layers.cpython-36.pyc +0 -0
- util/models/__pycache__/BigGAN_networks.cpython-36.pyc +0 -0
- util/models/__pycache__/OCR_network.cpython-36.pyc +0 -0
- util/models/__pycache__/__init__.cpython-36.pyc +0 -0
- util/models/__pycache__/blocks.cpython-36.pyc +0 -0
- util/models/__pycache__/inception.cpython-36.pyc +0 -0
- util/models/__pycache__/model.cpython-36.pyc +0 -0
- util/models/__pycache__/model_.cpython-36.pyc +0 -0
- util/models/__pycache__/networks.cpython-36.pyc +0 -0
- util/models/__pycache__/transformer.cpython-36.pyc +0 -0
- util/models/blocks.py +190 -0
- util/models/inception.py +363 -0
- util/models/model.py +1389 -0
- util/models/model_.py +1264 -0
- util/models/networks.py +98 -0
- util/models/sync_batchnorm/__init__.py +12 -0
- util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc +0 -0
- util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc +0 -0
- util/models/sync_batchnorm/batchnorm.py +349 -0
- util/models/sync_batchnorm/batchnorm_reimpl.py +74 -0
- util/models/sync_batchnorm/comm.py +137 -0
util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
util/__init__.pyc
ADDED
Binary file (265 Bytes). View file
|
|
util/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (215 Bytes). View file
|
|
util/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (261 Bytes). View file
|
|
util/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (259 Bytes). View file
|
|
util/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (252 Bytes). View file
|
|
util/__pycache__/html.cpython-36.pyc
ADDED
Binary file (3.56 kB). View file
|
|
util/__pycache__/html.cpython-37.pyc
ADDED
Binary file (3.57 kB). View file
|
|
util/__pycache__/misc.cpython-36.pyc
ADDED
Binary file (14.4 kB). View file
|
|
util/__pycache__/misc.cpython-37.pyc
ADDED
Binary file (14.3 kB). View file
|
|
util/__pycache__/params.cpython-37.pyc
ADDED
Binary file (1.15 kB). View file
|
|
util/__pycache__/util.cpython-36.pyc
ADDED
Binary file (9.98 kB). View file
|
|
util/__pycache__/util.cpython-37.pyc
ADDED
Binary file (10 kB). View file
|
|
util/__pycache__/util.cpython-38.pyc
ADDED
Binary file (10.1 kB). View file
|
|
util/__pycache__/util.cpython-39.pyc
ADDED
Binary file (10.1 kB). View file
|
|
util/__pycache__/visualizer.cpython-36.pyc
ADDED
Binary file (8.55 kB). View file
|
|
util/__pycache__/visualizer.cpython-37.pyc
ADDED
Binary file (8.55 kB). View file
|
|
util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
util/misc.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Misc functions, including distributed helpers.
|
4 |
+
|
5 |
+
Mostly copy-paste from torchvision references.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import time
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
import datetime
|
12 |
+
import pickle
|
13 |
+
from typing import Optional, List
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.distributed as dist
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
20 |
+
import torchvision
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
class SmoothedValue(object):
|
25 |
+
"""Track a series of values and provide access to smoothed values over a
|
26 |
+
window or the global series average.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, window_size=20, fmt=None):
|
30 |
+
if fmt is None:
|
31 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
32 |
+
self.deque = deque(maxlen=window_size)
|
33 |
+
self.total = 0.0
|
34 |
+
self.count = 0
|
35 |
+
self.fmt = fmt
|
36 |
+
|
37 |
+
def update(self, value, n=1):
|
38 |
+
self.deque.append(value)
|
39 |
+
self.count += n
|
40 |
+
self.total += value * n
|
41 |
+
|
42 |
+
def synchronize_between_processes(self):
|
43 |
+
"""
|
44 |
+
Warning: does not synchronize the deque!
|
45 |
+
"""
|
46 |
+
if not is_dist_avail_and_initialized():
|
47 |
+
return
|
48 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
49 |
+
dist.barrier()
|
50 |
+
dist.all_reduce(t)
|
51 |
+
t = t.tolist()
|
52 |
+
self.count = int(t[0])
|
53 |
+
self.total = t[1]
|
54 |
+
|
55 |
+
@property
|
56 |
+
def median(self):
|
57 |
+
d = torch.tensor(list(self.deque))
|
58 |
+
return d.median().item()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def avg(self):
|
62 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
63 |
+
return d.mean().item()
|
64 |
+
|
65 |
+
@property
|
66 |
+
def global_avg(self):
|
67 |
+
return self.total / self.count
|
68 |
+
|
69 |
+
@property
|
70 |
+
def max(self):
|
71 |
+
return max(self.deque)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def value(self):
|
75 |
+
return self.deque[-1]
|
76 |
+
|
77 |
+
def __str__(self):
|
78 |
+
return self.fmt.format(
|
79 |
+
median=self.median,
|
80 |
+
avg=self.avg,
|
81 |
+
global_avg=self.global_avg,
|
82 |
+
max=self.max,
|
83 |
+
value=self.value)
|
84 |
+
|
85 |
+
|
86 |
+
def all_gather(data):
|
87 |
+
"""
|
88 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
89 |
+
Args:
|
90 |
+
data: any picklable object
|
91 |
+
Returns:
|
92 |
+
list[data]: list of data gathered from each rank
|
93 |
+
"""
|
94 |
+
world_size = get_world_size()
|
95 |
+
if world_size == 1:
|
96 |
+
return [data]
|
97 |
+
|
98 |
+
# serialized to a Tensor
|
99 |
+
buffer = pickle.dumps(data)
|
100 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
101 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
102 |
+
|
103 |
+
# obtain Tensor size of each rank
|
104 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
105 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
106 |
+
dist.all_gather(size_list, local_size)
|
107 |
+
size_list = [int(size.item()) for size in size_list]
|
108 |
+
max_size = max(size_list)
|
109 |
+
|
110 |
+
# receiving Tensor from all ranks
|
111 |
+
# we pad the tensor because torch all_gather does not support
|
112 |
+
# gathering tensors of different shapes
|
113 |
+
tensor_list = []
|
114 |
+
for _ in size_list:
|
115 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
116 |
+
if local_size != max_size:
|
117 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
118 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
119 |
+
dist.all_gather(tensor_list, tensor)
|
120 |
+
|
121 |
+
data_list = []
|
122 |
+
for size, tensor in zip(size_list, tensor_list):
|
123 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
124 |
+
data_list.append(pickle.loads(buffer))
|
125 |
+
|
126 |
+
return data_list
|
127 |
+
|
128 |
+
|
129 |
+
def reduce_dict(input_dict, average=True):
|
130 |
+
"""
|
131 |
+
Args:
|
132 |
+
input_dict (dict): all the values will be reduced
|
133 |
+
average (bool): whether to do average or sum
|
134 |
+
Reduce the values in the dictionary from all processes so that all processes
|
135 |
+
have the averaged results. Returns a dict with the same fields as
|
136 |
+
input_dict, after reduction.
|
137 |
+
"""
|
138 |
+
world_size = get_world_size()
|
139 |
+
if world_size < 2:
|
140 |
+
return input_dict
|
141 |
+
with torch.no_grad():
|
142 |
+
names = []
|
143 |
+
values = []
|
144 |
+
# sort the keys so that they are consistent across processes
|
145 |
+
for k in sorted(input_dict.keys()):
|
146 |
+
names.append(k)
|
147 |
+
values.append(input_dict[k])
|
148 |
+
values = torch.stack(values, dim=0)
|
149 |
+
dist.all_reduce(values)
|
150 |
+
if average:
|
151 |
+
values /= world_size
|
152 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
153 |
+
return reduced_dict
|
154 |
+
|
155 |
+
|
156 |
+
class MetricLogger(object):
|
157 |
+
def __init__(self, delimiter="\t"):
|
158 |
+
self.meters = defaultdict(SmoothedValue)
|
159 |
+
self.delimiter = delimiter
|
160 |
+
|
161 |
+
def update(self, **kwargs):
|
162 |
+
for k, v in kwargs.items():
|
163 |
+
if isinstance(v, torch.Tensor):
|
164 |
+
v = v.item()
|
165 |
+
assert isinstance(v, (float, int))
|
166 |
+
self.meters[k].update(v)
|
167 |
+
|
168 |
+
def __getattr__(self, attr):
|
169 |
+
if attr in self.meters:
|
170 |
+
return self.meters[attr]
|
171 |
+
if attr in self.__dict__:
|
172 |
+
return self.__dict__[attr]
|
173 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
174 |
+
type(self).__name__, attr))
|
175 |
+
|
176 |
+
def __str__(self):
|
177 |
+
loss_str = []
|
178 |
+
for name, meter in self.meters.items():
|
179 |
+
loss_str.append(
|
180 |
+
"{}: {}".format(name, str(meter))
|
181 |
+
)
|
182 |
+
return self.delimiter.join(loss_str)
|
183 |
+
|
184 |
+
def synchronize_between_processes(self):
|
185 |
+
for meter in self.meters.values():
|
186 |
+
meter.synchronize_between_processes()
|
187 |
+
|
188 |
+
def add_meter(self, name, meter):
|
189 |
+
self.meters[name] = meter
|
190 |
+
|
191 |
+
def log_every(self, iterable, print_freq, header=None):
|
192 |
+
i = 0
|
193 |
+
if not header:
|
194 |
+
header = ''
|
195 |
+
start_time = time.time()
|
196 |
+
end = time.time()
|
197 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
198 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
199 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
200 |
+
if torch.cuda.is_available():
|
201 |
+
log_msg = self.delimiter.join([
|
202 |
+
header,
|
203 |
+
'[{0' + space_fmt + '}/{1}]',
|
204 |
+
'eta: {eta}',
|
205 |
+
'{meters}',
|
206 |
+
'time: {time}',
|
207 |
+
'data: {data}',
|
208 |
+
'max mem: {memory:.0f}'
|
209 |
+
])
|
210 |
+
else:
|
211 |
+
log_msg = self.delimiter.join([
|
212 |
+
header,
|
213 |
+
'[{0' + space_fmt + '}/{1}]',
|
214 |
+
'eta: {eta}',
|
215 |
+
'{meters}',
|
216 |
+
'time: {time}',
|
217 |
+
'data: {data}'
|
218 |
+
])
|
219 |
+
MB = 1024.0 * 1024.0
|
220 |
+
for obj in iterable:
|
221 |
+
data_time.update(time.time() - end)
|
222 |
+
yield obj
|
223 |
+
iter_time.update(time.time() - end)
|
224 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
225 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
226 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
227 |
+
if torch.cuda.is_available():
|
228 |
+
print(log_msg.format(
|
229 |
+
i, len(iterable), eta=eta_string,
|
230 |
+
meters=str(self),
|
231 |
+
time=str(iter_time), data=str(data_time),
|
232 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
233 |
+
else:
|
234 |
+
print(log_msg.format(
|
235 |
+
i, len(iterable), eta=eta_string,
|
236 |
+
meters=str(self),
|
237 |
+
time=str(iter_time), data=str(data_time)))
|
238 |
+
i += 1
|
239 |
+
end = time.time()
|
240 |
+
total_time = time.time() - start_time
|
241 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
242 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
243 |
+
header, total_time_str, total_time / len(iterable)))
|
244 |
+
|
245 |
+
|
246 |
+
def get_sha():
|
247 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
248 |
+
|
249 |
+
def _run(command):
|
250 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
251 |
+
sha = 'N/A'
|
252 |
+
diff = "clean"
|
253 |
+
branch = 'N/A'
|
254 |
+
try:
|
255 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
256 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
257 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
258 |
+
diff = "has uncommited changes" if diff else "clean"
|
259 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
260 |
+
except Exception:
|
261 |
+
pass
|
262 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
263 |
+
return message
|
264 |
+
|
265 |
+
|
266 |
+
def collate_fn(batch):
|
267 |
+
batch = list(zip(*batch))
|
268 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
269 |
+
return tuple(batch)
|
270 |
+
|
271 |
+
|
272 |
+
def _max_by_axis(the_list):
|
273 |
+
# type: (List[List[int]]) -> List[int]
|
274 |
+
maxes = the_list[0]
|
275 |
+
for sublist in the_list[1:]:
|
276 |
+
for index, item in enumerate(sublist):
|
277 |
+
maxes[index] = max(maxes[index], item)
|
278 |
+
return maxes
|
279 |
+
|
280 |
+
|
281 |
+
class NestedTensor(object):
|
282 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
283 |
+
self.tensors = tensors
|
284 |
+
self.mask = mask
|
285 |
+
|
286 |
+
def to(self, device):
|
287 |
+
# type: (Device) -> NestedTensor # noqa
|
288 |
+
cast_tensor = self.tensors.to(device)
|
289 |
+
mask = self.mask
|
290 |
+
if mask is not None:
|
291 |
+
assert mask is not None
|
292 |
+
cast_mask = mask.to(device)
|
293 |
+
else:
|
294 |
+
cast_mask = None
|
295 |
+
return NestedTensor(cast_tensor, cast_mask)
|
296 |
+
|
297 |
+
def decompose(self):
|
298 |
+
return self.tensors, self.mask
|
299 |
+
|
300 |
+
def __repr__(self):
|
301 |
+
return str(self.tensors)
|
302 |
+
|
303 |
+
|
304 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
305 |
+
# TODO make this more general
|
306 |
+
if tensor_list[0].ndim == 3:
|
307 |
+
if torchvision._is_tracing():
|
308 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
309 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
310 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
311 |
+
|
312 |
+
# TODO make it support different-sized images
|
313 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
314 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
315 |
+
batch_shape = [len(tensor_list)] + max_size
|
316 |
+
b, c, h, w = batch_shape
|
317 |
+
dtype = tensor_list[0].dtype
|
318 |
+
device = tensor_list[0].device
|
319 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
320 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
321 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
322 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
323 |
+
m[: img.shape[1], :img.shape[2]] = False
|
324 |
+
else:
|
325 |
+
raise ValueError('not supported')
|
326 |
+
return NestedTensor(tensor, mask)
|
327 |
+
|
328 |
+
|
329 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
330 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
331 |
+
@torch.jit.unused
|
332 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
333 |
+
max_size = []
|
334 |
+
for i in range(tensor_list[0].dim()):
|
335 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
336 |
+
max_size.append(max_size_i)
|
337 |
+
max_size = tuple(max_size)
|
338 |
+
|
339 |
+
# work around for
|
340 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
341 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
342 |
+
# which is not yet supported in onnx
|
343 |
+
padded_imgs = []
|
344 |
+
padded_masks = []
|
345 |
+
for img in tensor_list:
|
346 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
347 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
348 |
+
padded_imgs.append(padded_img)
|
349 |
+
|
350 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
351 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
352 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
353 |
+
|
354 |
+
tensor = torch.stack(padded_imgs)
|
355 |
+
mask = torch.stack(padded_masks)
|
356 |
+
|
357 |
+
return NestedTensor(tensor, mask=mask)
|
358 |
+
|
359 |
+
|
360 |
+
def setup_for_distributed(is_master):
|
361 |
+
"""
|
362 |
+
This function disables printing when not in master process
|
363 |
+
"""
|
364 |
+
import builtins as __builtin__
|
365 |
+
builtin_print = __builtin__.print
|
366 |
+
|
367 |
+
def print(*args, **kwargs):
|
368 |
+
force = kwargs.pop('force', False)
|
369 |
+
if is_master or force:
|
370 |
+
builtin_print(*args, **kwargs)
|
371 |
+
|
372 |
+
__builtin__.print = print
|
373 |
+
|
374 |
+
|
375 |
+
def is_dist_avail_and_initialized():
|
376 |
+
if not dist.is_available():
|
377 |
+
return False
|
378 |
+
if not dist.is_initialized():
|
379 |
+
return False
|
380 |
+
return True
|
381 |
+
|
382 |
+
|
383 |
+
def get_world_size():
|
384 |
+
if not is_dist_avail_and_initialized():
|
385 |
+
return 1
|
386 |
+
return dist.get_world_size()
|
387 |
+
|
388 |
+
|
389 |
+
def get_rank():
|
390 |
+
if not is_dist_avail_and_initialized():
|
391 |
+
return 0
|
392 |
+
return dist.get_rank()
|
393 |
+
|
394 |
+
|
395 |
+
def is_main_process():
|
396 |
+
return get_rank() == 0
|
397 |
+
|
398 |
+
|
399 |
+
def save_on_master(*args, **kwargs):
|
400 |
+
if is_main_process():
|
401 |
+
torch.save(*args, **kwargs)
|
402 |
+
|
403 |
+
|
404 |
+
def init_distributed_mode(args):
|
405 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
406 |
+
args.rank = int(os.environ["RANK"])
|
407 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
408 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
409 |
+
elif 'SLURM_PROCID' in os.environ:
|
410 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
411 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
412 |
+
else:
|
413 |
+
print('Not using distributed mode')
|
414 |
+
args.distributed = False
|
415 |
+
return
|
416 |
+
|
417 |
+
args.distributed = True
|
418 |
+
|
419 |
+
torch.cuda.set_device(args.gpu)
|
420 |
+
args.dist_backend = 'nccl'
|
421 |
+
print('| distributed init (rank {}): {}'.format(
|
422 |
+
args.rank, args.dist_url), flush=True)
|
423 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
424 |
+
world_size=args.world_size, rank=args.rank)
|
425 |
+
torch.distributed.barrier()
|
426 |
+
setup_for_distributed(args.rank == 0)
|
427 |
+
|
428 |
+
|
429 |
+
@torch.no_grad()
|
430 |
+
def accuracy(output, target, topk=(1,)):
|
431 |
+
"""Computes the precision@k for the specified values of k"""
|
432 |
+
if target.numel() == 0:
|
433 |
+
return [torch.zeros([], device=output.device)]
|
434 |
+
maxk = max(topk)
|
435 |
+
batch_size = target.size(0)
|
436 |
+
|
437 |
+
_, pred = output.topk(maxk, 1, True, True)
|
438 |
+
pred = pred.t()
|
439 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
440 |
+
|
441 |
+
res = []
|
442 |
+
for k in topk:
|
443 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
444 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
445 |
+
return res
|
446 |
+
|
447 |
+
|
448 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
449 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
450 |
+
"""
|
451 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
452 |
+
This will eventually be supported natively by PyTorch, and this
|
453 |
+
class can go away.
|
454 |
+
"""
|
455 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
456 |
+
if input.numel() > 0:
|
457 |
+
return torch.nn.functional.interpolate(
|
458 |
+
input, size, scale_factor, mode, align_corners
|
459 |
+
)
|
460 |
+
|
461 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
462 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
463 |
+
return _new_empty_tensor(input, output_shape)
|
464 |
+
else:
|
465 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
util/models/BigGAN_layers.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Layers
|
2 |
+
This file contains various layers for the BigGAN models.
|
3 |
+
'''
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import init
|
8 |
+
import torch.optim as optim
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn import Parameter as P
|
11 |
+
|
12 |
+
from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
|
13 |
+
|
14 |
+
# Projection of x onto y
|
15 |
+
def proj(x, y):
|
16 |
+
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
|
17 |
+
|
18 |
+
|
19 |
+
# Orthogonalize x wrt list of vectors ys
|
20 |
+
def gram_schmidt(x, ys):
|
21 |
+
for y in ys:
|
22 |
+
x = x - proj(x, y)
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
# Apply num_itrs steps of the power method to estimate top N singular values.
|
27 |
+
def power_iteration(W, u_, update=True, eps=1e-12):
|
28 |
+
# Lists holding singular vectors and values
|
29 |
+
us, vs, svs = [], [], []
|
30 |
+
for i, u in enumerate(u_):
|
31 |
+
# Run one step of the power iteration
|
32 |
+
with torch.no_grad():
|
33 |
+
v = torch.matmul(u, W)
|
34 |
+
# Run Gram-Schmidt to subtract components of all other singular vectors
|
35 |
+
v = F.normalize(gram_schmidt(v, vs), eps=eps)
|
36 |
+
# Add to the list
|
37 |
+
vs += [v]
|
38 |
+
# Update the other singular vector
|
39 |
+
u = torch.matmul(v, W.t())
|
40 |
+
# Run Gram-Schmidt to subtract components of all other singular vectors
|
41 |
+
u = F.normalize(gram_schmidt(u, us), eps=eps)
|
42 |
+
# Add to the list
|
43 |
+
us += [u]
|
44 |
+
if update:
|
45 |
+
u_[i][:] = u
|
46 |
+
# Compute this singular value and add it to the list
|
47 |
+
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
|
48 |
+
# svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
|
49 |
+
return svs, us, vs
|
50 |
+
|
51 |
+
|
52 |
+
# Convenience passthrough function
|
53 |
+
class identity(nn.Module):
|
54 |
+
def forward(self, input):
|
55 |
+
return input
|
56 |
+
|
57 |
+
|
58 |
+
# Spectral normalization base class
|
59 |
+
class SN(object):
|
60 |
+
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
|
61 |
+
# Number of power iterations per step
|
62 |
+
self.num_itrs = num_itrs
|
63 |
+
# Number of singular values
|
64 |
+
self.num_svs = num_svs
|
65 |
+
# Transposed?
|
66 |
+
self.transpose = transpose
|
67 |
+
# Epsilon value for avoiding divide-by-0
|
68 |
+
self.eps = eps
|
69 |
+
# Register a singular vector for each sv
|
70 |
+
for i in range(self.num_svs):
|
71 |
+
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
|
72 |
+
self.register_buffer('sv%d' % i, torch.ones(1))
|
73 |
+
|
74 |
+
# Singular vectors (u side)
|
75 |
+
@property
|
76 |
+
def u(self):
|
77 |
+
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
|
78 |
+
|
79 |
+
# Singular values;
|
80 |
+
# note that these buffers are just for logging and are not used in training.
|
81 |
+
@property
|
82 |
+
def sv(self):
|
83 |
+
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
|
84 |
+
|
85 |
+
# Compute the spectrally-normalized weight
|
86 |
+
def W_(self):
|
87 |
+
W_mat = self.weight.view(self.weight.size(0), -1)
|
88 |
+
if self.transpose:
|
89 |
+
W_mat = W_mat.t()
|
90 |
+
# Apply num_itrs power iterations
|
91 |
+
for _ in range(self.num_itrs):
|
92 |
+
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
|
93 |
+
# Update the svs
|
94 |
+
if self.training:
|
95 |
+
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
|
96 |
+
for i, sv in enumerate(svs):
|
97 |
+
self.sv[i][:] = sv
|
98 |
+
return self.weight / svs[0]
|
99 |
+
|
100 |
+
|
101 |
+
# 2D Conv layer with spectral norm
|
102 |
+
class SNConv2d(nn.Conv2d, SN):
|
103 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
104 |
+
padding=0, dilation=1, groups=1, bias=True,
|
105 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
106 |
+
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
|
107 |
+
padding, dilation, groups, bias)
|
108 |
+
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
return F.conv2d(x, self.W_(), self.bias, self.stride,
|
112 |
+
self.padding, self.dilation, self.groups)
|
113 |
+
|
114 |
+
|
115 |
+
# Linear layer with spectral norm
|
116 |
+
class SNLinear(nn.Linear, SN):
|
117 |
+
def __init__(self, in_features, out_features, bias=True,
|
118 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
119 |
+
nn.Linear.__init__(self, in_features, out_features, bias)
|
120 |
+
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
return F.linear(x, self.W_(), self.bias)
|
124 |
+
|
125 |
+
|
126 |
+
# Embedding layer with spectral norm
|
127 |
+
# We use num_embeddings as the dim instead of embedding_dim here
|
128 |
+
# for convenience sake
|
129 |
+
class SNEmbedding(nn.Embedding, SN):
|
130 |
+
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
131 |
+
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
132 |
+
sparse=False, _weight=None,
|
133 |
+
num_svs=1, num_itrs=1, eps=1e-12):
|
134 |
+
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
|
135 |
+
max_norm, norm_type, scale_grad_by_freq,
|
136 |
+
sparse, _weight)
|
137 |
+
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
return F.embedding(x, self.W_())
|
141 |
+
|
142 |
+
|
143 |
+
# A non-local block as used in SA-GAN
|
144 |
+
# Note that the implementation as described in the paper is largely incorrect;
|
145 |
+
# refer to the released code for the actual implementation.
|
146 |
+
class Attention(nn.Module):
|
147 |
+
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
|
148 |
+
super(Attention, self).__init__()
|
149 |
+
# Channel multiplier
|
150 |
+
self.ch = ch
|
151 |
+
self.which_conv = which_conv
|
152 |
+
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
153 |
+
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
|
154 |
+
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
|
155 |
+
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
|
156 |
+
# Learnable gain parameter
|
157 |
+
self.gamma = P(torch.tensor(0.), requires_grad=True)
|
158 |
+
|
159 |
+
def forward(self, x, y=None):
|
160 |
+
# Apply convs
|
161 |
+
theta = self.theta(x)
|
162 |
+
phi = F.max_pool2d(self.phi(x), [2, 2])
|
163 |
+
g = F.max_pool2d(self.g(x), [2, 2])
|
164 |
+
# Perform reshapes
|
165 |
+
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
|
166 |
+
try:
|
167 |
+
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
|
168 |
+
except:
|
169 |
+
print(phi.shape)
|
170 |
+
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
|
171 |
+
# Matmul and softmax to get attention maps
|
172 |
+
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
|
173 |
+
# Attention map times g path
|
174 |
+
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
|
175 |
+
return self.gamma * o + x
|
176 |
+
|
177 |
+
|
178 |
+
# Fused batchnorm op
|
179 |
+
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
|
180 |
+
# Apply scale and shift--if gain and bias are provided, fuse them here
|
181 |
+
# Prepare scale
|
182 |
+
scale = torch.rsqrt(var + eps)
|
183 |
+
# If a gain is provided, use it
|
184 |
+
if gain is not None:
|
185 |
+
scale = scale * gain
|
186 |
+
# Prepare shift
|
187 |
+
shift = mean * scale
|
188 |
+
# If bias is provided, use it
|
189 |
+
if bias is not None:
|
190 |
+
shift = shift - bias
|
191 |
+
return x * scale - shift
|
192 |
+
# return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
|
193 |
+
|
194 |
+
|
195 |
+
# Manual BN
|
196 |
+
# Calculate means and variances using mean-of-squares minus mean-squared
|
197 |
+
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
|
198 |
+
# Cast x to float32 if necessary
|
199 |
+
float_x = x.float()
|
200 |
+
# Calculate expected value of x (m) and expected value of x**2 (m2)
|
201 |
+
# Mean of x
|
202 |
+
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
|
203 |
+
# Mean of x squared
|
204 |
+
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
|
205 |
+
# Calculate variance as mean of squared minus mean squared.
|
206 |
+
var = (m2 - m ** 2)
|
207 |
+
# Cast back to float 16 if necessary
|
208 |
+
var = var.type(x.type())
|
209 |
+
m = m.type(x.type())
|
210 |
+
# Return mean and variance for updating stored mean/var if requested
|
211 |
+
if return_mean_var:
|
212 |
+
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
|
213 |
+
else:
|
214 |
+
return fused_bn(x, m, var, gain, bias, eps)
|
215 |
+
|
216 |
+
|
217 |
+
# My batchnorm, supports standing stats
|
218 |
+
class myBN(nn.Module):
|
219 |
+
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
|
220 |
+
super(myBN, self).__init__()
|
221 |
+
# momentum for updating running stats
|
222 |
+
self.momentum = momentum
|
223 |
+
# epsilon to avoid dividing by 0
|
224 |
+
self.eps = eps
|
225 |
+
# Momentum
|
226 |
+
self.momentum = momentum
|
227 |
+
# Register buffers
|
228 |
+
self.register_buffer('stored_mean', torch.zeros(num_channels))
|
229 |
+
self.register_buffer('stored_var', torch.ones(num_channels))
|
230 |
+
self.register_buffer('accumulation_counter', torch.zeros(1))
|
231 |
+
# Accumulate running means and vars
|
232 |
+
self.accumulate_standing = False
|
233 |
+
|
234 |
+
# reset standing stats
|
235 |
+
def reset_stats(self):
|
236 |
+
self.stored_mean[:] = 0
|
237 |
+
self.stored_var[:] = 0
|
238 |
+
self.accumulation_counter[:] = 0
|
239 |
+
|
240 |
+
def forward(self, x, gain, bias):
|
241 |
+
if self.training:
|
242 |
+
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
|
243 |
+
# If accumulating standing stats, increment them
|
244 |
+
if self.accumulate_standing:
|
245 |
+
self.stored_mean[:] = self.stored_mean + mean.data
|
246 |
+
self.stored_var[:] = self.stored_var + var.data
|
247 |
+
self.accumulation_counter += 1.0
|
248 |
+
# If not accumulating standing stats, take running averages
|
249 |
+
else:
|
250 |
+
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
|
251 |
+
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
|
252 |
+
return out
|
253 |
+
# If not in training mode, use the stored statistics
|
254 |
+
else:
|
255 |
+
mean = self.stored_mean.view(1, -1, 1, 1)
|
256 |
+
var = self.stored_var.view(1, -1, 1, 1)
|
257 |
+
# If using standing stats, divide them by the accumulation counter
|
258 |
+
if self.accumulate_standing:
|
259 |
+
mean = mean / self.accumulation_counter
|
260 |
+
var = var / self.accumulation_counter
|
261 |
+
return fused_bn(x, mean, var, gain, bias, self.eps)
|
262 |
+
|
263 |
+
|
264 |
+
# Simple function to handle groupnorm norm stylization
|
265 |
+
def groupnorm(x, norm_style):
|
266 |
+
# If number of channels specified in norm_style:
|
267 |
+
if 'ch' in norm_style:
|
268 |
+
ch = int(norm_style.split('_')[-1])
|
269 |
+
groups = max(int(x.shape[1]) // ch, 1)
|
270 |
+
# If number of groups specified in norm style
|
271 |
+
elif 'grp' in norm_style:
|
272 |
+
groups = int(norm_style.split('_')[-1])
|
273 |
+
# If neither, default to groups = 16
|
274 |
+
else:
|
275 |
+
groups = 16
|
276 |
+
return F.group_norm(x, groups)
|
277 |
+
|
278 |
+
|
279 |
+
# Class-conditional bn
|
280 |
+
# output size is the number of channels, input size is for the linear layers
|
281 |
+
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
|
282 |
+
# Suggestions welcome! (By which I mean, refactor this and make a pull request
|
283 |
+
# if you want to make this more readable/usable).
|
284 |
+
class ccbn(nn.Module):
|
285 |
+
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
|
286 |
+
cross_replica=False, mybn=False, norm_style='bn', ):
|
287 |
+
super(ccbn, self).__init__()
|
288 |
+
self.output_size, self.input_size = output_size, input_size
|
289 |
+
# Prepare gain and bias layers
|
290 |
+
self.gain = which_linear(input_size, output_size)
|
291 |
+
self.bias = which_linear(input_size, output_size)
|
292 |
+
# epsilon to avoid dividing by 0
|
293 |
+
self.eps = eps
|
294 |
+
# Momentum
|
295 |
+
self.momentum = momentum
|
296 |
+
# Use cross-replica batchnorm?
|
297 |
+
self.cross_replica = cross_replica
|
298 |
+
# Use my batchnorm?
|
299 |
+
self.mybn = mybn
|
300 |
+
# Norm style?
|
301 |
+
self.norm_style = norm_style
|
302 |
+
|
303 |
+
if self.cross_replica:
|
304 |
+
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
305 |
+
elif self.mybn:
|
306 |
+
self.bn = myBN(output_size, self.eps, self.momentum)
|
307 |
+
elif self.norm_style in ['bn', 'in']:
|
308 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
309 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
310 |
+
|
311 |
+
def forward(self, x, y):
|
312 |
+
# Calculate class-conditional gains and biases
|
313 |
+
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
314 |
+
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
315 |
+
# If using my batchnorm
|
316 |
+
if self.mybn or self.cross_replica:
|
317 |
+
return self.bn(x, gain=gain, bias=bias)
|
318 |
+
# else:
|
319 |
+
else:
|
320 |
+
if self.norm_style == 'bn':
|
321 |
+
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
322 |
+
self.training, 0.1, self.eps)
|
323 |
+
elif self.norm_style == 'in':
|
324 |
+
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
|
325 |
+
self.training, 0.1, self.eps)
|
326 |
+
elif self.norm_style == 'gn':
|
327 |
+
out = groupnorm(x, self.normstyle)
|
328 |
+
elif self.norm_style == 'nonorm':
|
329 |
+
out = x
|
330 |
+
return out * gain + bias
|
331 |
+
|
332 |
+
def extra_repr(self):
|
333 |
+
s = 'out: {output_size}, in: {input_size},'
|
334 |
+
s += ' cross_replica={cross_replica}'
|
335 |
+
return s.format(**self.__dict__)
|
336 |
+
|
337 |
+
|
338 |
+
# Normal, non-class-conditional BN
|
339 |
+
class bn(nn.Module):
|
340 |
+
def __init__(self, output_size, eps=1e-5, momentum=0.1,
|
341 |
+
cross_replica=False, mybn=False):
|
342 |
+
super(bn, self).__init__()
|
343 |
+
self.output_size = output_size
|
344 |
+
# Prepare gain and bias layers
|
345 |
+
self.gain = P(torch.ones(output_size), requires_grad=True)
|
346 |
+
self.bias = P(torch.zeros(output_size), requires_grad=True)
|
347 |
+
# epsilon to avoid dividing by 0
|
348 |
+
self.eps = eps
|
349 |
+
# Momentum
|
350 |
+
self.momentum = momentum
|
351 |
+
# Use cross-replica batchnorm?
|
352 |
+
self.cross_replica = cross_replica
|
353 |
+
# Use my batchnorm?
|
354 |
+
self.mybn = mybn
|
355 |
+
|
356 |
+
if self.cross_replica:
|
357 |
+
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
|
358 |
+
elif mybn:
|
359 |
+
self.bn = myBN(output_size, self.eps, self.momentum)
|
360 |
+
# Register buffers if neither of the above
|
361 |
+
else:
|
362 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
363 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
364 |
+
|
365 |
+
def forward(self, x, y=None):
|
366 |
+
if self.cross_replica or self.mybn:
|
367 |
+
gain = self.gain.view(1, -1, 1, 1)
|
368 |
+
bias = self.bias.view(1, -1, 1, 1)
|
369 |
+
return self.bn(x, gain=gain, bias=bias)
|
370 |
+
else:
|
371 |
+
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
|
372 |
+
self.bias, self.training, self.momentum, self.eps)
|
373 |
+
|
374 |
+
|
375 |
+
# Generator blocks
|
376 |
+
# Note that this class assumes the kernel size and padding (and any other
|
377 |
+
# settings) have been selected in the main generator module and passed in
|
378 |
+
# through the which_conv arg. Similar rules apply with which_bn (the input
|
379 |
+
# size [which is actually the number of channels of the conditional info] must
|
380 |
+
# be preselected)
|
381 |
+
class GBlock(nn.Module):
|
382 |
+
def __init__(self, in_channels, out_channels,
|
383 |
+
which_conv1=nn.Conv2d, which_conv2=nn.Conv2d, which_bn=bn, activation=None,
|
384 |
+
upsample=None):
|
385 |
+
super(GBlock, self).__init__()
|
386 |
+
|
387 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
388 |
+
self.which_conv1, self.which_conv2, self.which_bn = which_conv1, which_conv2, which_bn
|
389 |
+
self.activation = activation
|
390 |
+
self.upsample = upsample
|
391 |
+
# Conv layers
|
392 |
+
self.conv1 = self.which_conv1(self.in_channels, self.out_channels)
|
393 |
+
self.conv2 = self.which_conv2(self.out_channels, self.out_channels)
|
394 |
+
self.learnable_sc = in_channels != out_channels or upsample
|
395 |
+
if self.learnable_sc:
|
396 |
+
self.conv_sc = self.which_conv1(in_channels, out_channels,
|
397 |
+
kernel_size=1, padding=0)
|
398 |
+
# Batchnorm layers
|
399 |
+
self.bn1 = self.which_bn(in_channels)
|
400 |
+
self.bn2 = self.which_bn(out_channels)
|
401 |
+
# upsample layers
|
402 |
+
self.upsample = upsample
|
403 |
+
|
404 |
+
def forward(self, x, y):
|
405 |
+
h = self.activation(self.bn1(x, y))
|
406 |
+
# h = self.activation(x)
|
407 |
+
# h=x
|
408 |
+
if self.upsample:
|
409 |
+
h = self.upsample(h)
|
410 |
+
x = self.upsample(x)
|
411 |
+
h = self.conv1(h)
|
412 |
+
h = self.activation(self.bn2(h, y))
|
413 |
+
# h = self.activation(h)
|
414 |
+
h = self.conv2(h)
|
415 |
+
if self.learnable_sc:
|
416 |
+
x = self.conv_sc(x)
|
417 |
+
return h + x
|
418 |
+
|
419 |
+
|
420 |
+
# Residual block for the discriminator
|
421 |
+
class DBlock(nn.Module):
|
422 |
+
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
|
423 |
+
preactivation=False, activation=None, downsample=None, ):
|
424 |
+
super(DBlock, self).__init__()
|
425 |
+
self.in_channels, self.out_channels = in_channels, out_channels
|
426 |
+
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
|
427 |
+
self.hidden_channels = self.out_channels if wide else self.in_channels
|
428 |
+
self.which_conv = which_conv
|
429 |
+
self.preactivation = preactivation
|
430 |
+
self.activation = activation
|
431 |
+
self.downsample = downsample
|
432 |
+
|
433 |
+
# Conv layers
|
434 |
+
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
|
435 |
+
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
|
436 |
+
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
|
437 |
+
if self.learnable_sc:
|
438 |
+
self.conv_sc = self.which_conv(in_channels, out_channels,
|
439 |
+
kernel_size=1, padding=0)
|
440 |
+
|
441 |
+
def shortcut(self, x):
|
442 |
+
if self.preactivation:
|
443 |
+
if self.learnable_sc:
|
444 |
+
x = self.conv_sc(x)
|
445 |
+
if self.downsample:
|
446 |
+
x = self.downsample(x)
|
447 |
+
else:
|
448 |
+
if self.downsample:
|
449 |
+
x = self.downsample(x)
|
450 |
+
if self.learnable_sc:
|
451 |
+
x = self.conv_sc(x)
|
452 |
+
return x
|
453 |
+
|
454 |
+
def forward(self, x):
|
455 |
+
if self.preactivation:
|
456 |
+
# h = self.activation(x) # NOT TODAY SATAN
|
457 |
+
# Andy's note: This line *must* be an out-of-place ReLU or it
|
458 |
+
# will negatively affect the shortcut connection.
|
459 |
+
h = F.relu(x)
|
460 |
+
else:
|
461 |
+
h = x
|
462 |
+
h = self.conv1(h)
|
463 |
+
h = self.conv2(self.activation(h))
|
464 |
+
if self.downsample:
|
465 |
+
h = self.downsample(h)
|
466 |
+
|
467 |
+
return h + self.shortcut(x)
|
468 |
+
|
469 |
+
# dogball
|
util/models/BigGAN_networks.py
ADDED
@@ -0,0 +1,841 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import functools
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import init
|
11 |
+
import torch.optim as optim
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn import Parameter as P
|
14 |
+
from .transformer import Transformer
|
15 |
+
from . import BigGAN_layers as layers
|
16 |
+
from .sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
|
17 |
+
from util.util import to_device, load_network
|
18 |
+
from .networks import init_weights
|
19 |
+
from params import *
|
20 |
+
# Attention is passed in in the format '32_64' to mean applying an attention
|
21 |
+
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
|
22 |
+
|
23 |
+
from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
|
24 |
+
|
25 |
+
class Decoder(nn.Module):
|
26 |
+
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
|
27 |
+
super(Decoder, self).__init__()
|
28 |
+
|
29 |
+
self.model = []
|
30 |
+
self.model += [ResBlocks(n_res, dim, res_norm,
|
31 |
+
activ, pad_type=pad_type)]
|
32 |
+
for i in range(ups):
|
33 |
+
self.model += [nn.Upsample(scale_factor=2),
|
34 |
+
Conv2dBlock(dim, dim // 2, 5, 1, 2,
|
35 |
+
norm='in',
|
36 |
+
activation=activ,
|
37 |
+
pad_type=pad_type)]
|
38 |
+
dim //= 2
|
39 |
+
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
|
40 |
+
norm='none',
|
41 |
+
activation='tanh',
|
42 |
+
pad_type=pad_type)]
|
43 |
+
self.model = nn.Sequential(*self.model)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
y = self.model(x)
|
47 |
+
|
48 |
+
return y
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
|
53 |
+
arch = {}
|
54 |
+
arch[512] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
|
55 |
+
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
|
56 |
+
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
|
57 |
+
'resolution': [8, 16, 32, 64, 128, 256, 512],
|
58 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
59 |
+
for i in range(3, 10)}}
|
60 |
+
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
|
61 |
+
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
|
62 |
+
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
|
63 |
+
'resolution': [8, 16, 32, 64, 128, 256],
|
64 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
65 |
+
for i in range(3, 9)}}
|
66 |
+
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
|
67 |
+
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
|
68 |
+
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],
|
69 |
+
'resolution': [8, 16, 32, 64, 128],
|
70 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
71 |
+
for i in range(3, 8)}}
|
72 |
+
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
|
73 |
+
'out_channels': [ch * item for item in [16, 8, 4, 2]],
|
74 |
+
'upsample': [(2, 2), (2, 2), (2, 2), (2, 2)],
|
75 |
+
'resolution': [8, 16, 32, 64],
|
76 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
77 |
+
for i in range(3, 7)}}
|
78 |
+
|
79 |
+
arch[63] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
|
80 |
+
'out_channels': [ch * item for item in [16, 8, 4, 2]],
|
81 |
+
'upsample': [(2, 2), (2, 2), (2, 2), (2,1)],
|
82 |
+
'resolution': [8, 16, 32, 64],
|
83 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
84 |
+
for i in range(3, 7)},
|
85 |
+
'kernel1': [3, 3, 3, 3],
|
86 |
+
'kernel2': [3, 3, 1, 1]
|
87 |
+
}
|
88 |
+
|
89 |
+
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
|
90 |
+
'out_channels': [ch * item for item in [4, 4, 4]],
|
91 |
+
'upsample': [(2, 2), (2, 2), (2, 2)],
|
92 |
+
'resolution': [8, 16, 32],
|
93 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
94 |
+
for i in range(3, 6)}}
|
95 |
+
|
96 |
+
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
|
97 |
+
'out_channels': [ch * item for item in [4, 4, 4]],
|
98 |
+
'upsample': [(2, 2), (2, 2), (2, 2)],
|
99 |
+
'resolution': [8, 16, 32],
|
100 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
101 |
+
for i in range(3, 6)},
|
102 |
+
'kernel1': [3, 3, 3],
|
103 |
+
'kernel2': [3, 3, 1]
|
104 |
+
}
|
105 |
+
|
106 |
+
arch[129] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
|
107 |
+
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
|
108 |
+
'upsample': [(2,2), (2,2), (2,2), (2,2), (2,2), (1,2), (1,2)],
|
109 |
+
'resolution': [8, 16, 32, 64, 128, 256, 512],
|
110 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
111 |
+
for i in range(3, 10)}}
|
112 |
+
|
113 |
+
arch[33] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
|
114 |
+
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
|
115 |
+
'upsample': [(2,2), (2,2), (2,2), (1,2), (1,2)],
|
116 |
+
'resolution': [8, 16, 32, 64, 128],
|
117 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
118 |
+
for i in range(3, 8)}}
|
119 |
+
|
120 |
+
arch[31] = {'in_channels': [ch * item for item in [16, 16, 4, 2]],
|
121 |
+
'out_channels': [ch * item for item in [16, 4, 2, 1]],
|
122 |
+
'upsample': [(2,2), (2,2), (2,2), (1,2)],
|
123 |
+
'resolution': [8, 16, 32, 64],
|
124 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
125 |
+
for i in range(3, 7)},
|
126 |
+
'kernel1':[3, 3, 3, 3],
|
127 |
+
'kernel2': [3, 1, 1, 1]}
|
128 |
+
|
129 |
+
arch[16] = {'in_channels': [ch * item for item in [8, 4, 2]],
|
130 |
+
'out_channels': [ch * item for item in [4, 2, 1]],
|
131 |
+
'upsample': [(2,2), (2,2), (2,1)],
|
132 |
+
'resolution': [8, 16, 16],
|
133 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
134 |
+
for i in range(3, 6)},
|
135 |
+
'kernel1':[3, 3, 3],
|
136 |
+
'kernel2': [3, 3, 1]}
|
137 |
+
|
138 |
+
arch[17] = {'in_channels': [ch * item for item in [8, 4, 2]],
|
139 |
+
'out_channels': [ch * item for item in [4, 2, 1]],
|
140 |
+
'upsample': [(2,2), (2,2), (2,1)],
|
141 |
+
'resolution': [8, 16, 16],
|
142 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
143 |
+
for i in range(3, 6)},
|
144 |
+
'kernel1':[3, 3, 3],
|
145 |
+
'kernel2': [3, 3, 1]}
|
146 |
+
|
147 |
+
arch[20] = {'in_channels': [ch * item for item in [8, 4, 2]],
|
148 |
+
'out_channels': [ch * item for item in [4, 2, 1]],
|
149 |
+
'upsample': [(2,2), (2,2), (2,1)],
|
150 |
+
'resolution': [8, 16, 16],
|
151 |
+
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
|
152 |
+
for i in range(3, 6)},
|
153 |
+
'kernel1':[3, 3, 3],
|
154 |
+
'kernel2': [3, 1, 1]}
|
155 |
+
|
156 |
+
return arch
|
157 |
+
|
158 |
+
|
159 |
+
class Generator(nn.Module):
|
160 |
+
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, bottom_height=4,resolution=128,
|
161 |
+
G_kernel_size=3, G_attn='64', n_classes=1000,
|
162 |
+
num_G_SVs=1, num_G_SV_itrs=1,
|
163 |
+
G_shared=True, shared_dim=0, no_hier=False,
|
164 |
+
cross_replica=False, mybn=False,
|
165 |
+
G_activation=nn.ReLU(inplace=False),
|
166 |
+
BN_eps=1e-5, SN_eps=1e-12, G_fp16=False,
|
167 |
+
G_init='ortho', skip_init=False,
|
168 |
+
G_param='SN', norm_style='bn',gpu_ids=[], bn_linear='embed', input_nc=3,
|
169 |
+
one_hot=False, first_layer=False, one_hot_k=1,
|
170 |
+
**kwargs):
|
171 |
+
super(Generator, self).__init__()
|
172 |
+
self.name = 'G'
|
173 |
+
# Use class only in first layer
|
174 |
+
self.first_layer = first_layer
|
175 |
+
# gpu-ids
|
176 |
+
self.gpu_ids = gpu_ids
|
177 |
+
# Use one hot vector representation for input class
|
178 |
+
self.one_hot = one_hot
|
179 |
+
# Use one hot k vector representation for input class if k is larger than 0. If it's 0, simly use the class number and not a k-hot encoding.
|
180 |
+
self.one_hot_k = one_hot_k
|
181 |
+
# Channel width mulitplier
|
182 |
+
self.ch = G_ch
|
183 |
+
# Dimensionality of the latent space
|
184 |
+
self.dim_z = dim_z
|
185 |
+
# The initial width dimensions
|
186 |
+
self.bottom_width = bottom_width
|
187 |
+
# The initial height dimension
|
188 |
+
self.bottom_height = bottom_height
|
189 |
+
# Resolution of the output
|
190 |
+
self.resolution = resolution
|
191 |
+
# Kernel size?
|
192 |
+
self.kernel_size = G_kernel_size
|
193 |
+
# Attention?
|
194 |
+
self.attention = G_attn
|
195 |
+
# number of classes, for use in categorical conditional generation
|
196 |
+
self.n_classes = n_classes
|
197 |
+
# Use shared embeddings?
|
198 |
+
self.G_shared = G_shared
|
199 |
+
# Dimensionality of the shared embedding? Unused if not using G_shared
|
200 |
+
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
|
201 |
+
# Hierarchical latent space?
|
202 |
+
self.hier = not no_hier
|
203 |
+
# Cross replica batchnorm?
|
204 |
+
self.cross_replica = cross_replica
|
205 |
+
# Use my batchnorm?
|
206 |
+
self.mybn = mybn
|
207 |
+
# nonlinearity for residual blocks
|
208 |
+
self.activation = G_activation
|
209 |
+
# Initialization style
|
210 |
+
self.init = G_init
|
211 |
+
# Parameterization style
|
212 |
+
self.G_param = G_param
|
213 |
+
# Normalization style
|
214 |
+
self.norm_style = norm_style
|
215 |
+
# Epsilon for BatchNorm?
|
216 |
+
self.BN_eps = BN_eps
|
217 |
+
# Epsilon for Spectral Norm?
|
218 |
+
self.SN_eps = SN_eps
|
219 |
+
# fp16?
|
220 |
+
self.fp16 = G_fp16
|
221 |
+
# Architecture dict
|
222 |
+
self.arch = G_arch(self.ch, self.attention)[resolution]
|
223 |
+
self.bn_linear = bn_linear
|
224 |
+
|
225 |
+
#self.transformer = Transformer(d_model = 2560)
|
226 |
+
#self.input_proj = nn.Conv2d(512, 2560, kernel_size=1)
|
227 |
+
self.linear_q = nn.Linear(512,2048*2)
|
228 |
+
|
229 |
+
self.DETR = build()
|
230 |
+
self.DEC = Decoder(res_norm = 'in')
|
231 |
+
# If using hierarchical latents, adjust z
|
232 |
+
if self.hier:
|
233 |
+
# Number of places z slots into
|
234 |
+
self.num_slots = len(self.arch['in_channels']) + 1
|
235 |
+
self.z_chunk_size = (self.dim_z // self.num_slots)
|
236 |
+
# Recalculate latent dimensionality for even splitting into chunks
|
237 |
+
self.dim_z = self.z_chunk_size * self.num_slots
|
238 |
+
else:
|
239 |
+
self.num_slots = 1
|
240 |
+
self.z_chunk_size = 0
|
241 |
+
|
242 |
+
# Which convs, batchnorms, and linear layers to use
|
243 |
+
if self.G_param == 'SN':
|
244 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
245 |
+
kernel_size=3, padding=1,
|
246 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
247 |
+
eps=self.SN_eps)
|
248 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
249 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
250 |
+
eps=self.SN_eps)
|
251 |
+
else:
|
252 |
+
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
253 |
+
self.which_linear = nn.Linear
|
254 |
+
|
255 |
+
# We use a non-spectral-normed embedding here regardless;
|
256 |
+
# For some reason applying SN to G's embedding seems to randomly cripple G
|
257 |
+
if one_hot:
|
258 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
259 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
260 |
+
eps=self.SN_eps)
|
261 |
+
else:
|
262 |
+
self.which_embedding = nn.Embedding
|
263 |
+
|
264 |
+
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
|
265 |
+
else self.which_embedding)
|
266 |
+
if self.bn_linear=='SN':
|
267 |
+
bn_linear = functools.partial(self.which_linear, bias=False)
|
268 |
+
if self.G_shared:
|
269 |
+
input_size = self.shared_dim + self.z_chunk_size
|
270 |
+
elif self.hier:
|
271 |
+
if self.first_layer:
|
272 |
+
input_size = self.z_chunk_size
|
273 |
+
else:
|
274 |
+
input_size = self.n_classes + self.z_chunk_size
|
275 |
+
self.which_bn = functools.partial(layers.ccbn,
|
276 |
+
which_linear=bn_linear,
|
277 |
+
cross_replica=self.cross_replica,
|
278 |
+
mybn=self.mybn,
|
279 |
+
input_size=input_size,
|
280 |
+
norm_style=self.norm_style,
|
281 |
+
eps=self.BN_eps)
|
282 |
+
else:
|
283 |
+
input_size = self.n_classes
|
284 |
+
self.which_bn = functools.partial(layers.bn,
|
285 |
+
cross_replica=self.cross_replica,
|
286 |
+
mybn=self.mybn,
|
287 |
+
eps=self.BN_eps)
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
# Prepare model
|
293 |
+
# If not using shared embeddings, self.shared is just a passthrough
|
294 |
+
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
|
295 |
+
else layers.identity())
|
296 |
+
# First linear layer
|
297 |
+
# The parameters for the first linear layer depend on the different input variations.
|
298 |
+
if self.first_layer:
|
299 |
+
if self.one_hot:
|
300 |
+
self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes,
|
301 |
+
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
302 |
+
else:
|
303 |
+
self.linear = self.which_linear(self.dim_z // self.num_slots + 1,
|
304 |
+
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
305 |
+
if self.one_hot_k==1:
|
306 |
+
self.linear = self.which_linear((self.dim_z // self.num_slots) * self.n_classes,
|
307 |
+
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
308 |
+
if self.one_hot_k>1:
|
309 |
+
self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes*self.one_hot_k,
|
310 |
+
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
311 |
+
|
312 |
+
|
313 |
+
else:
|
314 |
+
self.linear = self.which_linear(self.dim_z // self.num_slots,
|
315 |
+
self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height))
|
316 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
317 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
318 |
+
# while the inner loop is over a given block
|
319 |
+
self.blocks = []
|
320 |
+
for index in range(len(self.arch['out_channels'])):
|
321 |
+
if 'kernel1' in self.arch.keys():
|
322 |
+
padd1 = 1 if self.arch['kernel1'][index]>1 else 0
|
323 |
+
padd2 = 1 if self.arch['kernel2'][index]>1 else 0
|
324 |
+
conv1 = functools.partial(layers.SNConv2d,
|
325 |
+
kernel_size=self.arch['kernel1'][index], padding=padd1,
|
326 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
327 |
+
eps=self.SN_eps)
|
328 |
+
conv2 = functools.partial(layers.SNConv2d,
|
329 |
+
kernel_size=self.arch['kernel2'][index], padding=padd2,
|
330 |
+
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
|
331 |
+
eps=self.SN_eps)
|
332 |
+
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
|
333 |
+
out_channels=self.arch['out_channels'][index],
|
334 |
+
which_conv1=conv1,
|
335 |
+
which_conv2=conv2,
|
336 |
+
which_bn=self.which_bn,
|
337 |
+
activation=self.activation,
|
338 |
+
upsample=(functools.partial(F.interpolate,
|
339 |
+
scale_factor=self.arch['upsample'][index])
|
340 |
+
if index < len(self.arch['upsample']) else None))]]
|
341 |
+
else:
|
342 |
+
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
|
343 |
+
out_channels=self.arch['out_channels'][index],
|
344 |
+
which_conv1=self.which_conv,
|
345 |
+
which_conv2=self.which_conv,
|
346 |
+
which_bn=self.which_bn,
|
347 |
+
activation=self.activation,
|
348 |
+
upsample=(functools.partial(F.interpolate, scale_factor=self.arch['upsample'][index])
|
349 |
+
if index < len(self.arch['upsample']) else None))]]
|
350 |
+
|
351 |
+
# If attention on this block, attach it to the end
|
352 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
353 |
+
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
|
354 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
|
355 |
+
|
356 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
357 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
358 |
+
|
359 |
+
# output layer: batchnorm-relu-conv.
|
360 |
+
# Consider using a non-spectral conv here
|
361 |
+
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
|
362 |
+
cross_replica=self.cross_replica,
|
363 |
+
mybn=self.mybn),
|
364 |
+
self.activation,
|
365 |
+
self.which_conv(self.arch['out_channels'][-1], input_nc))
|
366 |
+
|
367 |
+
# Initialize weights. Optionally skip init for testing.
|
368 |
+
if not skip_init:
|
369 |
+
self = init_weights(self, G_init)
|
370 |
+
|
371 |
+
# Note on this forward function: we pass in a y vector which has
|
372 |
+
# already been passed through G.shared to enable easy class-wise
|
373 |
+
# interpolation later. If we passed in the one-hot and then ran it through
|
374 |
+
# G.shared in this forward function, it would be harder to handle.
|
375 |
+
def forward(self, x, y_ind, y):
|
376 |
+
# If hierarchical, concatenate zs and ys
|
377 |
+
|
378 |
+
|
379 |
+
h_all = self.DETR(x, y_ind)
|
380 |
+
#h_all = torch.stack([h_all, h_all, h_all])
|
381 |
+
|
382 |
+
#h_all_bs = torch.unbind(h_all[-1], 0)
|
383 |
+
#y_bs = torch.unbind(y_ind, 0)
|
384 |
+
|
385 |
+
#h = torch.stack([h_i[y_j] for h_i,y_j in zip(h_all_bs, y_bs)], 0)
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
h = self.linear_q(h_all)
|
391 |
+
|
392 |
+
|
393 |
+
h = h.contiguous()
|
394 |
+
# Reshape - when y is not a single class value but rather an array of classes, the reshape is needed to create
|
395 |
+
# a separate vertical patch for each input.
|
396 |
+
if self.first_layer:
|
397 |
+
# correct reshape
|
398 |
+
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
399 |
+
h = h.permute(0, 3, 2, 1)
|
400 |
+
|
401 |
+
else:
|
402 |
+
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_height)
|
403 |
+
|
404 |
+
|
405 |
+
#for index, blocklist in enumerate(self.blocks):
|
406 |
+
# Second inner loop in case block has multiple layers
|
407 |
+
# for block in blocklist:
|
408 |
+
# h = block(h, ys[index])
|
409 |
+
|
410 |
+
#Apply batchnorm-relu-conv-tanh at output
|
411 |
+
# h = torch.tanh(self.output_layer(h))
|
412 |
+
|
413 |
+
h = self.DEC(h)
|
414 |
+
return h
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
# Discriminator architecture, same paradigm as G's above
|
423 |
+
def D_arch(ch=64, attention='64', input_nc=3, ksize='333333', dilation='111111'):
|
424 |
+
arch = {}
|
425 |
+
arch[256] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
|
426 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
427 |
+
'downsample': [True] * 6 + [False],
|
428 |
+
'resolution': [128, 64, 32, 16, 8, 4, 4],
|
429 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
430 |
+
for i in range(2, 8)}}
|
431 |
+
arch[128] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
432 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
433 |
+
'downsample': [True] * 5 + [False],
|
434 |
+
'resolution': [64, 32, 16, 8, 4, 4],
|
435 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
436 |
+
for i in range(2, 8)}}
|
437 |
+
arch[64] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
|
438 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
|
439 |
+
'downsample': [True] * 4 + [False],
|
440 |
+
'resolution': [32, 16, 8, 4, 4],
|
441 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
442 |
+
for i in range(2, 7)}}
|
443 |
+
arch[63] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8]],
|
444 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
|
445 |
+
'downsample': [True] * 4 + [False],
|
446 |
+
'resolution': [32, 16, 8, 4, 4],
|
447 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
448 |
+
for i in range(2, 7)}}
|
449 |
+
arch[32] = {'in_channels': [input_nc] + [item * ch for item in [4, 4, 4]],
|
450 |
+
'out_channels': [item * ch for item in [4, 4, 4, 4]],
|
451 |
+
'downsample': [True, True, False, False],
|
452 |
+
'resolution': [16, 16, 16, 16],
|
453 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
454 |
+
for i in range(2, 6)}}
|
455 |
+
arch[129] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
|
456 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
|
457 |
+
'downsample': [True] * 6 + [False],
|
458 |
+
'resolution': [128, 64, 32, 16, 8, 4, 4],
|
459 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
460 |
+
for i in range(2, 8)}}
|
461 |
+
arch[33] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
462 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
463 |
+
'downsample': [True] * 5 + [False],
|
464 |
+
'resolution': [64, 32, 16, 8, 4, 4],
|
465 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
466 |
+
for i in range(2, 10)}}
|
467 |
+
arch[31] = {'in_channels': [input_nc] + [ch * item for item in [1, 2, 4, 8, 16]],
|
468 |
+
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
|
469 |
+
'downsample': [True] * 5 + [False],
|
470 |
+
'resolution': [64, 32, 16, 8, 4, 4],
|
471 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
472 |
+
for i in range(2, 10)}}
|
473 |
+
arch[16] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
|
474 |
+
'out_channels': [item * ch for item in [1, 8, 16, 16]],
|
475 |
+
'downsample': [True] * 3 + [False],
|
476 |
+
'resolution': [16, 8, 4, 4],
|
477 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
478 |
+
for i in range(2, 5)}}
|
479 |
+
|
480 |
+
arch[17] = {'in_channels': [input_nc] + [ch * item for item in [1, 4]],
|
481 |
+
'out_channels': [item * ch for item in [1, 4, 8]],
|
482 |
+
'downsample': [True] * 3,
|
483 |
+
'resolution': [16, 8, 4],
|
484 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
485 |
+
for i in range(2, 5)}}
|
486 |
+
|
487 |
+
|
488 |
+
arch[20] = {'in_channels': [input_nc] + [ch * item for item in [1, 8, 16]],
|
489 |
+
'out_channels': [item * ch for item in [1, 8, 16, 16]],
|
490 |
+
'downsample': [True] * 3 + [False],
|
491 |
+
'resolution': [16, 8, 4, 4],
|
492 |
+
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
|
493 |
+
for i in range(2, 5)}}
|
494 |
+
return arch
|
495 |
+
|
496 |
+
|
497 |
+
class Discriminator(nn.Module):
|
498 |
+
|
499 |
+
def __init__(self, D_ch=64, D_wide=True, resolution=resolution,
|
500 |
+
D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE,
|
501 |
+
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
502 |
+
SN_eps=1e-8, output_dim=1, D_mixed_precision=False, D_fp16=False,
|
503 |
+
D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs):
|
504 |
+
|
505 |
+
super(Discriminator, self).__init__()
|
506 |
+
self.name = 'D'
|
507 |
+
# gpu_ids
|
508 |
+
self.gpu_ids = gpu_ids
|
509 |
+
# one_hot representation
|
510 |
+
self.one_hot = one_hot
|
511 |
+
# Width multiplier
|
512 |
+
self.ch = D_ch
|
513 |
+
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
514 |
+
self.D_wide = D_wide
|
515 |
+
# Resolution
|
516 |
+
self.resolution = resolution
|
517 |
+
# Kernel size
|
518 |
+
self.kernel_size = D_kernel_size
|
519 |
+
# Attention?
|
520 |
+
self.attention = D_attn
|
521 |
+
# Number of classes
|
522 |
+
self.n_classes = n_classes
|
523 |
+
# Activation
|
524 |
+
self.activation = D_activation
|
525 |
+
# Initialization style
|
526 |
+
self.init = D_init
|
527 |
+
# Parameterization style
|
528 |
+
self.D_param = D_param
|
529 |
+
# Epsilon for Spectral Norm?
|
530 |
+
self.SN_eps = SN_eps
|
531 |
+
# Fp16?
|
532 |
+
self.fp16 = D_fp16
|
533 |
+
# Architecture
|
534 |
+
self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
|
535 |
+
|
536 |
+
# Which convs, batchnorms, and linear layers to use
|
537 |
+
# No option to turn off SN in D right now
|
538 |
+
if self.D_param == 'SN':
|
539 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
540 |
+
kernel_size=3, padding=1,
|
541 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
542 |
+
eps=self.SN_eps)
|
543 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
544 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
545 |
+
eps=self.SN_eps)
|
546 |
+
self.which_embedding = functools.partial(layers.SNEmbedding,
|
547 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
548 |
+
eps=self.SN_eps)
|
549 |
+
if bn_linear=='SN':
|
550 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
551 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
552 |
+
eps=self.SN_eps)
|
553 |
+
else:
|
554 |
+
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
555 |
+
self.which_linear = nn.Linear
|
556 |
+
# We use a non-spectral-normed embedding here regardless;
|
557 |
+
# For some reason applying SN to G's embedding seems to randomly cripple G
|
558 |
+
self.which_embedding = nn.Embedding
|
559 |
+
if one_hot:
|
560 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
561 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
562 |
+
eps=self.SN_eps)
|
563 |
+
# Prepare model
|
564 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
565 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
566 |
+
self.blocks = []
|
567 |
+
for index in range(len(self.arch['out_channels'])):
|
568 |
+
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
569 |
+
out_channels=self.arch['out_channels'][index],
|
570 |
+
which_conv=self.which_conv,
|
571 |
+
wide=self.D_wide,
|
572 |
+
activation=self.activation,
|
573 |
+
preactivation=(index > 0),
|
574 |
+
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
575 |
+
# If attention on this block, attach it to the end
|
576 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
577 |
+
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
578 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
579 |
+
self.which_conv)]
|
580 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
581 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
582 |
+
# Linear output layer. The output dimension is typically 1, but may be
|
583 |
+
# larger if we're e.g. turning this into a VAE with an inference output
|
584 |
+
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
585 |
+
# Embedding for projection discrimination
|
586 |
+
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
|
587 |
+
|
588 |
+
# Initialize weights
|
589 |
+
if not skip_init:
|
590 |
+
self = init_weights(self, D_init)
|
591 |
+
|
592 |
+
def forward(self, x, y=None, **kwargs):
|
593 |
+
# Stick x into h for cleaner for loops without flow control
|
594 |
+
h = x
|
595 |
+
# Loop over blocks
|
596 |
+
for index, blocklist in enumerate(self.blocks):
|
597 |
+
for block in blocklist:
|
598 |
+
h = block(h)
|
599 |
+
# Apply global sum pooling as in SN-GAN
|
600 |
+
h = torch.sum(self.activation(h), [2, 3])
|
601 |
+
# Get initial class-unconditional output
|
602 |
+
out = self.linear(h)
|
603 |
+
# Get projection of final featureset onto class vectors and add to evidence
|
604 |
+
if y is not None:
|
605 |
+
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
|
606 |
+
return out
|
607 |
+
|
608 |
+
def return_features(self, x, y=None):
|
609 |
+
# Stick x into h for cleaner for loops without flow control
|
610 |
+
h = x
|
611 |
+
block_output = []
|
612 |
+
# Loop over blocks
|
613 |
+
for index, blocklist in enumerate(self.blocks):
|
614 |
+
for block in blocklist:
|
615 |
+
h = block(h)
|
616 |
+
block_output.append(h)
|
617 |
+
# Apply global sum pooling as in SN-GAN
|
618 |
+
# h = torch.sum(self.activation(h), [2, 3])
|
619 |
+
return block_output
|
620 |
+
|
621 |
+
|
622 |
+
|
623 |
+
|
624 |
+
class WDiscriminator(nn.Module):
|
625 |
+
|
626 |
+
def __init__(self, D_ch=64, D_wide=True, resolution=resolution,
|
627 |
+
D_kernel_size=3, D_attn='64', n_classes=VOCAB_SIZE,
|
628 |
+
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
|
629 |
+
SN_eps=1e-8, output_dim=NUM_WRITERS, D_mixed_precision=False, D_fp16=False,
|
630 |
+
D_init='N02', skip_init=False, D_param='SN', gpu_ids=[0],bn_linear='SN', input_nc=1, one_hot=False, **kwargs):
|
631 |
+
super(WDiscriminator, self).__init__()
|
632 |
+
self.name = 'D'
|
633 |
+
# gpu_ids
|
634 |
+
self.gpu_ids = gpu_ids
|
635 |
+
# one_hot representation
|
636 |
+
self.one_hot = one_hot
|
637 |
+
# Width multiplier
|
638 |
+
self.ch = D_ch
|
639 |
+
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
|
640 |
+
self.D_wide = D_wide
|
641 |
+
# Resolution
|
642 |
+
self.resolution = resolution
|
643 |
+
# Kernel size
|
644 |
+
self.kernel_size = D_kernel_size
|
645 |
+
# Attention?
|
646 |
+
self.attention = D_attn
|
647 |
+
# Number of classes
|
648 |
+
self.n_classes = n_classes
|
649 |
+
# Activation
|
650 |
+
self.activation = D_activation
|
651 |
+
# Initialization style
|
652 |
+
self.init = D_init
|
653 |
+
# Parameterization style
|
654 |
+
self.D_param = D_param
|
655 |
+
# Epsilon for Spectral Norm?
|
656 |
+
self.SN_eps = SN_eps
|
657 |
+
# Fp16?
|
658 |
+
self.fp16 = D_fp16
|
659 |
+
# Architecture
|
660 |
+
self.arch = D_arch(self.ch, self.attention, input_nc)[resolution]
|
661 |
+
|
662 |
+
# Which convs, batchnorms, and linear layers to use
|
663 |
+
# No option to turn off SN in D right now
|
664 |
+
if self.D_param == 'SN':
|
665 |
+
self.which_conv = functools.partial(layers.SNConv2d,
|
666 |
+
kernel_size=3, padding=1,
|
667 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
668 |
+
eps=self.SN_eps)
|
669 |
+
self.which_linear = functools.partial(layers.SNLinear,
|
670 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
671 |
+
eps=self.SN_eps)
|
672 |
+
self.which_embedding = functools.partial(layers.SNEmbedding,
|
673 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
674 |
+
eps=self.SN_eps)
|
675 |
+
if bn_linear=='SN':
|
676 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
677 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
678 |
+
eps=self.SN_eps)
|
679 |
+
else:
|
680 |
+
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
|
681 |
+
self.which_linear = nn.Linear
|
682 |
+
# We use a non-spectral-normed embedding here regardless;
|
683 |
+
# For some reason applying SN to G's embedding seems to randomly cripple G
|
684 |
+
self.which_embedding = nn.Embedding
|
685 |
+
if one_hot:
|
686 |
+
self.which_embedding = functools.partial(layers.SNLinear,
|
687 |
+
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
|
688 |
+
eps=self.SN_eps)
|
689 |
+
# Prepare model
|
690 |
+
# self.blocks is a doubly-nested list of modules, the outer loop intended
|
691 |
+
# to be over blocks at a given resolution (resblocks and/or self-attention)
|
692 |
+
self.blocks = []
|
693 |
+
for index in range(len(self.arch['out_channels'])):
|
694 |
+
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
|
695 |
+
out_channels=self.arch['out_channels'][index],
|
696 |
+
which_conv=self.which_conv,
|
697 |
+
wide=self.D_wide,
|
698 |
+
activation=self.activation,
|
699 |
+
preactivation=(index > 0),
|
700 |
+
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
|
701 |
+
# If attention on this block, attach it to the end
|
702 |
+
if self.arch['attention'][self.arch['resolution'][index]]:
|
703 |
+
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
|
704 |
+
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
|
705 |
+
self.which_conv)]
|
706 |
+
# Turn self.blocks into a ModuleList so that it's all properly registered.
|
707 |
+
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
|
708 |
+
# Linear output layer. The output dimension is typically 1, but may be
|
709 |
+
# larger if we're e.g. turning this into a VAE with an inference output
|
710 |
+
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
|
711 |
+
# Embedding for projection discrimination
|
712 |
+
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
|
713 |
+
self.cross_entropy = nn.CrossEntropyLoss()
|
714 |
+
# Initialize weights
|
715 |
+
if not skip_init:
|
716 |
+
self = init_weights(self, D_init)
|
717 |
+
|
718 |
+
def forward(self, x, y=None, **kwargs):
|
719 |
+
# Stick x into h for cleaner for loops without flow control
|
720 |
+
h = x
|
721 |
+
# Loop over blocks
|
722 |
+
for index, blocklist in enumerate(self.blocks):
|
723 |
+
for block in blocklist:
|
724 |
+
h = block(h)
|
725 |
+
# Apply global sum pooling as in SN-GAN
|
726 |
+
h = torch.sum(self.activation(h), [2, 3])
|
727 |
+
# Get initial class-unconditional output
|
728 |
+
out = self.linear(h)
|
729 |
+
# Get projection of final featureset onto class vectors and add to evidence
|
730 |
+
#if y is not None:
|
731 |
+
#out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
|
732 |
+
|
733 |
+
loss = self.cross_entropy(out, y.long())
|
734 |
+
|
735 |
+
return loss
|
736 |
+
|
737 |
+
def return_features(self, x, y=None):
|
738 |
+
# Stick x into h for cleaner for loops without flow control
|
739 |
+
h = x
|
740 |
+
block_output = []
|
741 |
+
# Loop over blocks
|
742 |
+
for index, blocklist in enumerate(self.blocks):
|
743 |
+
for block in blocklist:
|
744 |
+
h = block(h)
|
745 |
+
block_output.append(h)
|
746 |
+
# Apply global sum pooling as in SN-GAN
|
747 |
+
# h = torch.sum(self.activation(h), [2, 3])
|
748 |
+
return block_output
|
749 |
+
|
750 |
+
class Encoder(Discriminator):
|
751 |
+
def __init__(self, opt, output_dim, **kwargs):
|
752 |
+
super(Encoder, self).__init__(**vars(opt))
|
753 |
+
self.output_layer = nn.Sequential(self.activation,
|
754 |
+
nn.Conv2d(self.arch['out_channels'][-1], output_dim, kernel_size=(4,2), padding=0, stride=2))
|
755 |
+
|
756 |
+
def forward(self, x):
|
757 |
+
# Stick x into h for cleaner for loops without flow control
|
758 |
+
h = x
|
759 |
+
# Loop over blocks
|
760 |
+
for index, blocklist in enumerate(self.blocks):
|
761 |
+
for block in blocklist:
|
762 |
+
h = block(h)
|
763 |
+
out = self.output_layer(h)
|
764 |
+
return out
|
765 |
+
|
766 |
+
class BiDiscriminator(nn.Module):
|
767 |
+
def __init__(self, opt):
|
768 |
+
super(BiDiscriminator, self).__init__()
|
769 |
+
self.infer_img = Encoder(opt, output_dim=opt.nimg_features)
|
770 |
+
# self.infer_z = nn.Sequential(
|
771 |
+
# nn.Conv2d(opt.dim_z, 512, 1, stride=1, bias=False),
|
772 |
+
# nn.LeakyReLU(inplace=True),
|
773 |
+
# nn.Dropout2d(p=self.dropout),
|
774 |
+
# nn.Conv2d(512, opt.nz_features, 1, stride=1, bias=False),
|
775 |
+
# nn.LeakyReLU(inplace=True),
|
776 |
+
# nn.Dropout2d(p=self.dropout)
|
777 |
+
# )
|
778 |
+
self.infer_joint = nn.Sequential(
|
779 |
+
nn.Conv2d(opt.dim_z+opt.nimg_features, 1024, 1, stride=1, bias=True),
|
780 |
+
nn.ReLU(inplace=True),
|
781 |
+
|
782 |
+
nn.Conv2d(1024, 1024, 1, stride=1, bias=True),
|
783 |
+
nn.ReLU(inplace=True)
|
784 |
+
)
|
785 |
+
self.final = nn.Conv2d(1024, 1, 1, stride=1, bias=True)
|
786 |
+
|
787 |
+
def forward(self, x, z, **kwargs):
|
788 |
+
output_x = self.infer_img(x)
|
789 |
+
# output_z = self.infer_z(z)
|
790 |
+
if len(z.shape)==2:
|
791 |
+
z = z.unsqueeze(2).unsqueeze(2).repeat((1,1,1,output_x.shape[3]))
|
792 |
+
output_features = self.infer_joint(torch.cat([output_x, z], dim=1))
|
793 |
+
output = self.final(output_features)
|
794 |
+
return output
|
795 |
+
|
796 |
+
# Parallelized G_D to minimize cross-gpu communication
|
797 |
+
# Without this, Generator outputs would get all-gathered and then rebroadcast.
|
798 |
+
class G_D(nn.Module):
|
799 |
+
def __init__(self, G, D):
|
800 |
+
super(G_D, self).__init__()
|
801 |
+
self.G = G
|
802 |
+
self.D = D
|
803 |
+
|
804 |
+
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
|
805 |
+
split_D=False):
|
806 |
+
# If training G, enable grad tape
|
807 |
+
with torch.set_grad_enabled(train_G):
|
808 |
+
# Get Generator output given noise
|
809 |
+
G_z = self.G(z, self.G.shared(gy))
|
810 |
+
# Cast as necessary
|
811 |
+
if self.G.fp16 and not self.D.fp16:
|
812 |
+
G_z = G_z.float()
|
813 |
+
if self.D.fp16 and not self.G.fp16:
|
814 |
+
G_z = G_z.half()
|
815 |
+
# Split_D means to run D once with real data and once with fake,
|
816 |
+
# rather than concatenating along the batch dimension.
|
817 |
+
if split_D:
|
818 |
+
D_fake = self.D(G_z, gy)
|
819 |
+
if x is not None:
|
820 |
+
D_real = self.D(x, dy)
|
821 |
+
return D_fake, D_real
|
822 |
+
else:
|
823 |
+
if return_G_z:
|
824 |
+
return D_fake, G_z
|
825 |
+
else:
|
826 |
+
return D_fake
|
827 |
+
# If real data is provided, concatenate it with the Generator's output
|
828 |
+
# along the batch dimension for improved efficiency.
|
829 |
+
else:
|
830 |
+
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
|
831 |
+
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
|
832 |
+
# Get Discriminator output
|
833 |
+
D_out = self.D(D_input, D_class)
|
834 |
+
if x is not None:
|
835 |
+
return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
|
836 |
+
else:
|
837 |
+
if return_G_z:
|
838 |
+
return D_out, G_z
|
839 |
+
else:
|
840 |
+
return D_out
|
841 |
+
|
util/models/OCR_network.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from util.util import to_device
|
3 |
+
from torch.nn import init
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from .networks import *
|
7 |
+
from params import *
|
8 |
+
|
9 |
+
class BidirectionalLSTM(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, nIn, nHidden, nOut):
|
12 |
+
super(BidirectionalLSTM, self).__init__()
|
13 |
+
|
14 |
+
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
15 |
+
self.embedding = nn.Linear(nHidden * 2, nOut)
|
16 |
+
|
17 |
+
|
18 |
+
def forward(self, input):
|
19 |
+
recurrent, _ = self.rnn(input)
|
20 |
+
T, b, h = recurrent.size()
|
21 |
+
t_rec = recurrent.view(T * b, h)
|
22 |
+
|
23 |
+
output = self.embedding(t_rec) # [T * b, nOut]
|
24 |
+
output = output.view(T, b, -1)
|
25 |
+
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
class CRNN(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, leakyRelu=False):
|
32 |
+
super(CRNN, self).__init__()
|
33 |
+
self.name = 'OCR'
|
34 |
+
#assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16'
|
35 |
+
|
36 |
+
ks = [3, 3, 3, 3, 3, 3, 2]
|
37 |
+
ps = [1, 1, 1, 1, 1, 1, 0]
|
38 |
+
ss = [1, 1, 1, 1, 1, 1, 1]
|
39 |
+
nm = [64, 128, 256, 256, 512, 512, 512]
|
40 |
+
|
41 |
+
cnn = nn.Sequential()
|
42 |
+
nh = 256
|
43 |
+
dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero
|
44 |
+
|
45 |
+
def convRelu(i, batchNormalization=False):
|
46 |
+
nIn = 1 if i == 0 else nm[i - 1]
|
47 |
+
nOut = nm[i]
|
48 |
+
cnn.add_module('conv{0}'.format(i),
|
49 |
+
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
50 |
+
if batchNormalization:
|
51 |
+
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
52 |
+
if leakyRelu:
|
53 |
+
cnn.add_module('relu{0}'.format(i),
|
54 |
+
nn.LeakyReLU(0.2, inplace=True))
|
55 |
+
else:
|
56 |
+
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
57 |
+
|
58 |
+
convRelu(0)
|
59 |
+
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
60 |
+
convRelu(1)
|
61 |
+
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
62 |
+
convRelu(2, True)
|
63 |
+
convRelu(3)
|
64 |
+
cnn.add_module('pooling{0}'.format(2),
|
65 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
66 |
+
convRelu(4, True)
|
67 |
+
if resolution==63:
|
68 |
+
cnn.add_module('pooling{0}'.format(3),
|
69 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
70 |
+
convRelu(5)
|
71 |
+
cnn.add_module('pooling{0}'.format(4),
|
72 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
73 |
+
convRelu(6, True) # 512x1x16
|
74 |
+
|
75 |
+
self.cnn = cnn
|
76 |
+
self.use_rnn = False
|
77 |
+
if self.use_rnn:
|
78 |
+
self.rnn = nn.Sequential(
|
79 |
+
BidirectionalLSTM(512, nh, nh),
|
80 |
+
BidirectionalLSTM(nh, nh, ))
|
81 |
+
else:
|
82 |
+
self.linear = nn.Linear(512, VOCAB_SIZE)
|
83 |
+
|
84 |
+
# replace all nan/inf in gradients to zero
|
85 |
+
if dealwith_lossnone:
|
86 |
+
self.register_backward_hook(self.backward_hook)
|
87 |
+
|
88 |
+
self.device = torch.device('cuda:{}'.format(0))
|
89 |
+
self.init = 'N02'
|
90 |
+
# Initialize weights
|
91 |
+
|
92 |
+
self = init_weights(self, self.init)
|
93 |
+
|
94 |
+
def forward(self, input):
|
95 |
+
# conv features
|
96 |
+
conv = self.cnn(input)
|
97 |
+
b, c, h, w = conv.size()
|
98 |
+
if h!=1:
|
99 |
+
print('a')
|
100 |
+
assert h == 1, "the height of conv must be 1"
|
101 |
+
conv = conv.squeeze(2)
|
102 |
+
conv = conv.permute(2, 0, 1) # [w, b, c]
|
103 |
+
|
104 |
+
if self.use_rnn:
|
105 |
+
# rnn features
|
106 |
+
output = self.rnn(conv)
|
107 |
+
else:
|
108 |
+
output = self.linear(conv)
|
109 |
+
return output
|
110 |
+
|
111 |
+
def backward_hook(self, module, grad_input, grad_output):
|
112 |
+
for g in grad_input:
|
113 |
+
g[g != g] = 0 # replace all nan/inf in gradients to zero
|
114 |
+
|
115 |
+
|
116 |
+
class OCRLabelConverter(object):
|
117 |
+
"""Convert between str and label.
|
118 |
+
|
119 |
+
NOTE:
|
120 |
+
Insert `blank` to the alphabet for CTC.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
alphabet (str): set of the possible characters.
|
124 |
+
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self, alphabet, ignore_case=False):
|
128 |
+
self._ignore_case = ignore_case
|
129 |
+
if self._ignore_case:
|
130 |
+
alphabet = alphabet.lower()
|
131 |
+
self.alphabet = alphabet + '-' # for `-1` index
|
132 |
+
|
133 |
+
self.dict = {}
|
134 |
+
for i, char in enumerate(alphabet):
|
135 |
+
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
|
136 |
+
self.dict[char] = i + 1
|
137 |
+
|
138 |
+
def encode(self, text):
|
139 |
+
"""Support batch or single str.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
text (str or list of str): texts to convert.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
146 |
+
torch.IntTensor [n]: length of each text.
|
147 |
+
"""
|
148 |
+
'''
|
149 |
+
if isinstance(text, str):
|
150 |
+
text = [
|
151 |
+
self.dict[char.lower() if self._ignore_case else char]
|
152 |
+
for char in text
|
153 |
+
]
|
154 |
+
length = [len(text)]
|
155 |
+
elif isinstance(text, collections.Iterable):
|
156 |
+
length = [len(s) for s in text]
|
157 |
+
text = ''.join(text)
|
158 |
+
text, _ = self.encode(text)
|
159 |
+
return (torch.IntTensor(text), torch.IntTensor(length))
|
160 |
+
'''
|
161 |
+
length = []
|
162 |
+
result = []
|
163 |
+
for item in text:
|
164 |
+
item = item.decode('utf-8', 'strict')
|
165 |
+
length.append(len(item))
|
166 |
+
for char in item:
|
167 |
+
index = self.dict[char]
|
168 |
+
result.append(index)
|
169 |
+
|
170 |
+
text = result
|
171 |
+
return (torch.IntTensor(text), torch.IntTensor(length))
|
172 |
+
|
173 |
+
def decode(self, t, length, raw=False):
|
174 |
+
"""Decode encoded texts back into strs.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
178 |
+
torch.IntTensor [n]: length of each text.
|
179 |
+
|
180 |
+
Raises:
|
181 |
+
AssertionError: when the texts and its length does not match.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
text (str or list of str): texts to convert.
|
185 |
+
"""
|
186 |
+
if length.numel() == 1:
|
187 |
+
length = length[0]
|
188 |
+
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
|
189 |
+
length)
|
190 |
+
if raw:
|
191 |
+
return ''.join([self.alphabet[i - 1] for i in t])
|
192 |
+
else:
|
193 |
+
char_list = []
|
194 |
+
for i in range(length):
|
195 |
+
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
196 |
+
char_list.append(self.alphabet[t[i] - 1])
|
197 |
+
return ''.join(char_list)
|
198 |
+
else:
|
199 |
+
# batch mode
|
200 |
+
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
|
201 |
+
t.numel(), length.sum())
|
202 |
+
texts = []
|
203 |
+
index = 0
|
204 |
+
for i in range(length.numel()):
|
205 |
+
l = length[i]
|
206 |
+
texts.append(
|
207 |
+
self.decode(
|
208 |
+
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
209 |
+
index += l
|
210 |
+
return texts
|
211 |
+
|
212 |
+
|
213 |
+
class strLabelConverter(object):
|
214 |
+
"""Convert between str and label.
|
215 |
+
NOTE:
|
216 |
+
Insert `blank` to the alphabet for CTC.
|
217 |
+
Args:
|
218 |
+
alphabet (str): set of the possible characters.
|
219 |
+
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(self, alphabet, ignore_case=False):
|
223 |
+
self._ignore_case = ignore_case
|
224 |
+
if self._ignore_case:
|
225 |
+
alphabet = alphabet.lower()
|
226 |
+
self.alphabet = alphabet + '-' # for `-1` index
|
227 |
+
|
228 |
+
self.dict = {}
|
229 |
+
for i, char in enumerate(alphabet):
|
230 |
+
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
|
231 |
+
self.dict[char] = i + 1
|
232 |
+
|
233 |
+
def encode(self, text):
|
234 |
+
"""Support batch or single str.
|
235 |
+
Args:
|
236 |
+
text (str or list of str): texts to convert.
|
237 |
+
Returns:
|
238 |
+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
239 |
+
torch.IntTensor [n]: length of each text.
|
240 |
+
"""
|
241 |
+
'''
|
242 |
+
if isinstance(text, str):
|
243 |
+
text = [
|
244 |
+
self.dict[char.lower() if self._ignore_case else char]
|
245 |
+
for char in text
|
246 |
+
]
|
247 |
+
length = [len(text)]
|
248 |
+
elif isinstance(text, collections.Iterable):
|
249 |
+
length = [len(s) for s in text]
|
250 |
+
text = ''.join(text)
|
251 |
+
text, _ = self.encode(text)
|
252 |
+
return (torch.IntTensor(text), torch.IntTensor(length))
|
253 |
+
'''
|
254 |
+
length = []
|
255 |
+
result = []
|
256 |
+
results = []
|
257 |
+
for item in text:
|
258 |
+
item = item.decode('utf-8', 'strict')
|
259 |
+
length.append(len(item))
|
260 |
+
for char in item:
|
261 |
+
index = self.dict[char]
|
262 |
+
result.append(index)
|
263 |
+
results.append(result)
|
264 |
+
result = []
|
265 |
+
|
266 |
+
return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length))
|
267 |
+
|
268 |
+
def decode(self, t, length, raw=False):
|
269 |
+
"""Decode encoded texts back into strs.
|
270 |
+
Args:
|
271 |
+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
272 |
+
torch.IntTensor [n]: length of each text.
|
273 |
+
Raises:
|
274 |
+
AssertionError: when the texts and its length does not match.
|
275 |
+
Returns:
|
276 |
+
text (str or list of str): texts to convert.
|
277 |
+
"""
|
278 |
+
if length.numel() == 1:
|
279 |
+
length = length[0]
|
280 |
+
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
|
281 |
+
length)
|
282 |
+
if raw:
|
283 |
+
return ''.join([self.alphabet[i - 1] for i in t])
|
284 |
+
else:
|
285 |
+
char_list = []
|
286 |
+
for i in range(length):
|
287 |
+
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
288 |
+
char_list.append(self.alphabet[t[i] - 1])
|
289 |
+
return ''.join(char_list)
|
290 |
+
else:
|
291 |
+
# batch mode
|
292 |
+
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
|
293 |
+
t.numel(), length.sum())
|
294 |
+
texts = []
|
295 |
+
index = 0
|
296 |
+
for i in range(length.numel()):
|
297 |
+
l = length[i]
|
298 |
+
texts.append(
|
299 |
+
self.decode(
|
300 |
+
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
301 |
+
index += l
|
302 |
+
return texts
|
303 |
+
|
304 |
+
|
util/models/__init__.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
"""
|
19 |
+
|
20 |
+
import importlib
|
21 |
+
|
22 |
+
|
23 |
+
def find_model_using_name(model_name):
|
24 |
+
"""Import the module "models/[model_name]_model.py".
|
25 |
+
|
26 |
+
In the file, the class called DatasetNameModel() will
|
27 |
+
be instantiated. It has to be a subclass of BaseModel,
|
28 |
+
and it is case-insensitive.
|
29 |
+
"""
|
30 |
+
model_filename = "models." + model_name + "_model"
|
31 |
+
modellib = importlib.import_module(model_filename)
|
32 |
+
model = None
|
33 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
34 |
+
for name, cls in modellib.__dict__.items():
|
35 |
+
if name.lower() == target_model_name.lower() \
|
36 |
+
and issubclass(cls, BaseModel):
|
37 |
+
model = cls
|
38 |
+
|
39 |
+
if model is None:
|
40 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
41 |
+
exit(0)
|
42 |
+
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def get_option_setter(model_name):
|
47 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
48 |
+
model_class = find_model_using_name(model_name)
|
49 |
+
return model_class.modify_commandline_options
|
50 |
+
|
51 |
+
|
52 |
+
def create_model(opt):
|
53 |
+
"""Create a model given the option.
|
54 |
+
|
55 |
+
This function warps the class CustomDatasetDataLoader.
|
56 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
57 |
+
|
58 |
+
Example:
|
59 |
+
>>> from models import create_model
|
60 |
+
>>> model = create_model(opt)
|
61 |
+
"""
|
62 |
+
model = find_model_using_name(opt.model)
|
63 |
+
instance = model(opt)
|
64 |
+
print("model [%s] was created" % type(instance).__name__)
|
65 |
+
return instance
|
util/models/__pycache__/BigGAN_layers.cpython-36.pyc
ADDED
Binary file (12.8 kB). View file
|
|
util/models/__pycache__/BigGAN_networks.cpython-36.pyc
ADDED
Binary file (30 kB). View file
|
|
util/models/__pycache__/OCR_network.cpython-36.pyc
ADDED
Binary file (8.62 kB). View file
|
|
util/models/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (3.1 kB). View file
|
|
util/models/__pycache__/blocks.cpython-36.pyc
ADDED
Binary file (6 kB). View file
|
|
util/models/__pycache__/inception.cpython-36.pyc
ADDED
Binary file (10.7 kB). View file
|
|
util/models/__pycache__/model.cpython-36.pyc
ADDED
Binary file (31.1 kB). View file
|
|
util/models/__pycache__/model_.cpython-36.pyc
ADDED
Binary file (28.8 kB). View file
|
|
util/models/__pycache__/networks.cpython-36.pyc
ADDED
Binary file (4.08 kB). View file
|
|
util/models/__pycache__/transformer.cpython-36.pyc
ADDED
Binary file (8.92 kB). View file
|
|
util/models/blocks.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class ResBlocks(nn.Module):
|
7 |
+
def __init__(self, num_blocks, dim, norm, activation, pad_type):
|
8 |
+
super(ResBlocks, self).__init__()
|
9 |
+
self.model = []
|
10 |
+
for i in range(num_blocks):
|
11 |
+
self.model += [ResBlock(dim,
|
12 |
+
norm=norm,
|
13 |
+
activation=activation,
|
14 |
+
pad_type=pad_type)]
|
15 |
+
self.model = nn.Sequential(*self.model)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.model(x)
|
19 |
+
|
20 |
+
|
21 |
+
class ResBlock(nn.Module):
|
22 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
23 |
+
super(ResBlock, self).__init__()
|
24 |
+
model = []
|
25 |
+
model += [Conv2dBlock(dim, dim, 3, 1, 1,
|
26 |
+
norm=norm,
|
27 |
+
activation=activation,
|
28 |
+
pad_type=pad_type)]
|
29 |
+
model += [Conv2dBlock(dim, dim, 3, 1, 1,
|
30 |
+
norm=norm,
|
31 |
+
activation='none',
|
32 |
+
pad_type=pad_type)]
|
33 |
+
self.model = nn.Sequential(*model)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
residual = x
|
37 |
+
out = self.model(x)
|
38 |
+
out += residual
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class ActFirstResBlock(nn.Module):
|
43 |
+
def __init__(self, fin, fout, fhid=None,
|
44 |
+
activation='lrelu', norm='none'):
|
45 |
+
super().__init__()
|
46 |
+
self.learned_shortcut = (fin != fout)
|
47 |
+
self.fin = fin
|
48 |
+
self.fout = fout
|
49 |
+
self.fhid = min(fin, fout) if fhid is None else fhid
|
50 |
+
self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1,
|
51 |
+
padding=1, pad_type='reflect', norm=norm,
|
52 |
+
activation=activation, activation_first=True)
|
53 |
+
self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1,
|
54 |
+
padding=1, pad_type='reflect', norm=norm,
|
55 |
+
activation=activation, activation_first=True)
|
56 |
+
if self.learned_shortcut:
|
57 |
+
self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1,
|
58 |
+
activation='none', use_bias=False)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x_s = self.conv_s(x) if self.learned_shortcut else x
|
62 |
+
dx = self.conv_0(x)
|
63 |
+
dx = self.conv_1(dx)
|
64 |
+
out = x_s + dx
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class LinearBlock(nn.Module):
|
69 |
+
def __init__(self, in_dim, out_dim, norm='none', activation='relu'):
|
70 |
+
super(LinearBlock, self).__init__()
|
71 |
+
use_bias = True
|
72 |
+
self.fc = nn.Linear(in_dim, out_dim, bias=use_bias)
|
73 |
+
|
74 |
+
# initialize normalization
|
75 |
+
norm_dim = out_dim
|
76 |
+
if norm == 'bn':
|
77 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
78 |
+
elif norm == 'in':
|
79 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
80 |
+
elif norm == 'none':
|
81 |
+
self.norm = None
|
82 |
+
else:
|
83 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
84 |
+
|
85 |
+
# initialize activation
|
86 |
+
if activation == 'relu':
|
87 |
+
self.activation = nn.ReLU(inplace=False)
|
88 |
+
elif activation == 'lrelu':
|
89 |
+
self.activation = nn.LeakyReLU(0.2, inplace=False)
|
90 |
+
elif activation == 'tanh':
|
91 |
+
self.activation = nn.Tanh()
|
92 |
+
elif activation == 'none':
|
93 |
+
self.activation = None
|
94 |
+
else:
|
95 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
out = self.fc(x)
|
99 |
+
if self.norm:
|
100 |
+
out = self.norm(out)
|
101 |
+
if self.activation:
|
102 |
+
out = self.activation(out)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class Conv2dBlock(nn.Module):
|
107 |
+
def __init__(self, in_dim, out_dim, ks, st, padding=0,
|
108 |
+
norm='none', activation='relu', pad_type='zero',
|
109 |
+
use_bias=True, activation_first=False):
|
110 |
+
super(Conv2dBlock, self).__init__()
|
111 |
+
self.use_bias = use_bias
|
112 |
+
self.activation_first = activation_first
|
113 |
+
# initialize padding
|
114 |
+
if pad_type == 'reflect':
|
115 |
+
self.pad = nn.ReflectionPad2d(padding)
|
116 |
+
elif pad_type == 'replicate':
|
117 |
+
self.pad = nn.ReplicationPad2d(padding)
|
118 |
+
elif pad_type == 'zero':
|
119 |
+
self.pad = nn.ZeroPad2d(padding)
|
120 |
+
else:
|
121 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
122 |
+
|
123 |
+
# initialize normalization
|
124 |
+
norm_dim = out_dim
|
125 |
+
if norm == 'bn':
|
126 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
127 |
+
elif norm == 'in':
|
128 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
129 |
+
elif norm == 'adain':
|
130 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
131 |
+
elif norm == 'none':
|
132 |
+
self.norm = None
|
133 |
+
else:
|
134 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
135 |
+
|
136 |
+
# initialize activation
|
137 |
+
if activation == 'relu':
|
138 |
+
self.activation = nn.ReLU(inplace=False)
|
139 |
+
elif activation == 'lrelu':
|
140 |
+
self.activation = nn.LeakyReLU(0.2, inplace=False)
|
141 |
+
elif activation == 'tanh':
|
142 |
+
self.activation = nn.Tanh()
|
143 |
+
elif activation == 'none':
|
144 |
+
self.activation = None
|
145 |
+
else:
|
146 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
147 |
+
|
148 |
+
self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
if self.activation_first:
|
152 |
+
if self.activation:
|
153 |
+
x = self.activation(x)
|
154 |
+
x = self.conv(self.pad(x))
|
155 |
+
if self.norm:
|
156 |
+
x = self.norm(x)
|
157 |
+
else:
|
158 |
+
x = self.conv(self.pad(x))
|
159 |
+
if self.norm:
|
160 |
+
x = self.norm(x)
|
161 |
+
if self.activation:
|
162 |
+
x = self.activation(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class AdaptiveInstanceNorm2d(nn.Module):
|
167 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
168 |
+
super(AdaptiveInstanceNorm2d, self).__init__()
|
169 |
+
self.num_features = num_features
|
170 |
+
self.eps = eps
|
171 |
+
self.momentum = momentum
|
172 |
+
self.weight = None
|
173 |
+
self.bias = None
|
174 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
175 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
assert self.weight is not None and \
|
179 |
+
self.bias is not None, "Please assign AdaIN weight first"
|
180 |
+
b, c = x.size(0), x.size(1)
|
181 |
+
running_mean = self.running_mean.repeat(b)
|
182 |
+
running_var = self.running_var.repeat(b)
|
183 |
+
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
184 |
+
out = F.batch_norm(
|
185 |
+
x_reshaped, running_mean, running_var, self.weight, self.bias,
|
186 |
+
True, self.momentum, self.eps)
|
187 |
+
return out.view(b, c, *x.size()[2:])
|
188 |
+
|
189 |
+
def __repr__(self):
|
190 |
+
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|
util/models/inception.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import models
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from itertools import cycle
|
8 |
+
from scipy import linalg
|
9 |
+
|
10 |
+
|
11 |
+
try:
|
12 |
+
from torchvision.models.utils import load_state_dict_from_url
|
13 |
+
except ImportError:
|
14 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
15 |
+
|
16 |
+
# Inception weights ported to Pytorch from
|
17 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
18 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
19 |
+
|
20 |
+
|
21 |
+
class InceptionV3(nn.Module):
|
22 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
23 |
+
|
24 |
+
# Index of default block of inception to return,
|
25 |
+
# corresponds to output of final average pooling
|
26 |
+
DEFAULT_BLOCK_INDEX = 3
|
27 |
+
|
28 |
+
# Maps feature dimensionality to their output blocks indices
|
29 |
+
BLOCK_INDEX_BY_DIM = {
|
30 |
+
64: 0, # First max pooling features
|
31 |
+
192: 1, # Second max pooling featurs
|
32 |
+
768: 2, # Pre-aux classifier features
|
33 |
+
2048: 3 # Final average pooling features
|
34 |
+
}
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
38 |
+
resize_input=True,
|
39 |
+
normalize_input=True,
|
40 |
+
requires_grad=False,
|
41 |
+
use_fid_inception=True):
|
42 |
+
"""Build pretrained InceptionV3
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
output_blocks : list of int
|
46 |
+
Indices of blocks to return features of. Possible values are:
|
47 |
+
- 0: corresponds to output of first max pooling
|
48 |
+
- 1: corresponds to output of second max pooling
|
49 |
+
- 2: corresponds to output which is fed to aux classifier
|
50 |
+
- 3: corresponds to output of final average pooling
|
51 |
+
resize_input : bool
|
52 |
+
If true, bilinearly resizes input to width and height 299 before
|
53 |
+
feeding input to model. As the network without fully connected
|
54 |
+
layers is fully convolutional, it should be able to handle inputs
|
55 |
+
of arbitrary size, so resizing might not be strictly needed
|
56 |
+
normalize_input : bool
|
57 |
+
If true, scales the input from range (0, 1) to the range the
|
58 |
+
pretrained Inception network expects, namely (-1, 1)
|
59 |
+
requires_grad : bool
|
60 |
+
If true, parameters of the model require gradients. Possibly useful
|
61 |
+
for finetuning the network
|
62 |
+
use_fid_inception : bool
|
63 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
64 |
+
FID implementation. If false, uses the pretrained Inception model
|
65 |
+
available in torchvision. The FID Inception model has different
|
66 |
+
weights and a slightly different structure from torchvision's
|
67 |
+
Inception model. If you want to compute FID scores, you are
|
68 |
+
strongly advised to set this parameter to true to get comparable
|
69 |
+
results.
|
70 |
+
"""
|
71 |
+
super(InceptionV3, self).__init__()
|
72 |
+
|
73 |
+
self.resize_input = resize_input
|
74 |
+
self.normalize_input = normalize_input
|
75 |
+
self.output_blocks = sorted(output_blocks)
|
76 |
+
self.last_needed_block = max(output_blocks)
|
77 |
+
|
78 |
+
assert self.last_needed_block <= 3, \
|
79 |
+
'Last possible output block index is 3'
|
80 |
+
|
81 |
+
self.blocks = nn.ModuleList()
|
82 |
+
|
83 |
+
if use_fid_inception:
|
84 |
+
inception = fid_inception_v3()
|
85 |
+
else:
|
86 |
+
inception = models.inception_v3(pretrained=True)
|
87 |
+
|
88 |
+
# Block 0: input to maxpool1
|
89 |
+
block0 = [
|
90 |
+
inception.Conv2d_1a_3x3,
|
91 |
+
inception.Conv2d_2a_3x3,
|
92 |
+
inception.Conv2d_2b_3x3,
|
93 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
94 |
+
]
|
95 |
+
self.blocks.append(nn.Sequential(*block0))
|
96 |
+
|
97 |
+
# Block 1: maxpool1 to maxpool2
|
98 |
+
if self.last_needed_block >= 1:
|
99 |
+
block1 = [
|
100 |
+
inception.Conv2d_3b_1x1,
|
101 |
+
inception.Conv2d_4a_3x3,
|
102 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
103 |
+
]
|
104 |
+
self.blocks.append(nn.Sequential(*block1))
|
105 |
+
|
106 |
+
# Block 2: maxpool2 to aux classifier
|
107 |
+
if self.last_needed_block >= 2:
|
108 |
+
block2 = [
|
109 |
+
inception.Mixed_5b,
|
110 |
+
inception.Mixed_5c,
|
111 |
+
inception.Mixed_5d,
|
112 |
+
inception.Mixed_6a,
|
113 |
+
inception.Mixed_6b,
|
114 |
+
inception.Mixed_6c,
|
115 |
+
inception.Mixed_6d,
|
116 |
+
inception.Mixed_6e,
|
117 |
+
]
|
118 |
+
self.blocks.append(nn.Sequential(*block2))
|
119 |
+
|
120 |
+
# Block 3: aux classifier to final avgpool
|
121 |
+
if self.last_needed_block >= 3:
|
122 |
+
block3 = [
|
123 |
+
inception.Mixed_7a,
|
124 |
+
inception.Mixed_7b,
|
125 |
+
inception.Mixed_7c,
|
126 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
127 |
+
]
|
128 |
+
self.blocks.append(nn.Sequential(*block3))
|
129 |
+
|
130 |
+
for param in self.parameters():
|
131 |
+
param.requires_grad = requires_grad
|
132 |
+
|
133 |
+
def forward(self, inp):
|
134 |
+
"""Get Inception feature maps
|
135 |
+
Parameters
|
136 |
+
----------
|
137 |
+
inp : torch.autograd.Variable
|
138 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
139 |
+
range (0, 1)
|
140 |
+
Returns
|
141 |
+
-------
|
142 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
143 |
+
block, sorted ascending by index
|
144 |
+
"""
|
145 |
+
outp = []
|
146 |
+
x = inp
|
147 |
+
|
148 |
+
if self.resize_input:
|
149 |
+
x = F.interpolate(x,
|
150 |
+
size=(299, 299),
|
151 |
+
mode='bilinear',
|
152 |
+
align_corners=False)
|
153 |
+
|
154 |
+
if self.normalize_input:
|
155 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
156 |
+
|
157 |
+
for idx, block in enumerate(self.blocks):
|
158 |
+
x = block(x)
|
159 |
+
if idx in self.output_blocks:
|
160 |
+
outp.append(x)
|
161 |
+
|
162 |
+
if idx == self.last_needed_block:
|
163 |
+
break
|
164 |
+
|
165 |
+
return outp
|
166 |
+
|
167 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
168 |
+
"""Numpy implementation of the Frechet Distance.
|
169 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
170 |
+
and X_2 ~ N(mu_2, C_2) is
|
171 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
172 |
+
Stable version by Dougal J. Sutherland.
|
173 |
+
Params:
|
174 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
175 |
+
inception net (like returned by the function 'get_predictions')
|
176 |
+
for generated samples.
|
177 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
178 |
+
representative data set.
|
179 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
180 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
181 |
+
representative data set.
|
182 |
+
Returns:
|
183 |
+
-- : The Frechet Distance.
|
184 |
+
"""
|
185 |
+
|
186 |
+
mu1 = np.atleast_1d(mu1)
|
187 |
+
mu2 = np.atleast_1d(mu2)
|
188 |
+
|
189 |
+
sigma1 = np.atleast_2d(sigma1)
|
190 |
+
sigma2 = np.atleast_2d(sigma2)
|
191 |
+
|
192 |
+
assert mu1.shape == mu2.shape, \
|
193 |
+
'Training and test mean vectors have different lengths'
|
194 |
+
assert sigma1.shape == sigma2.shape, \
|
195 |
+
'Training and test covariances have different dimensions'
|
196 |
+
|
197 |
+
diff = mu1 - mu2
|
198 |
+
|
199 |
+
# Product might be almost singular
|
200 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
201 |
+
if not np.isfinite(covmean).all():
|
202 |
+
msg = ('fid calculation produces singular product; '
|
203 |
+
'adding %s to diagonal of cov estimates') % eps
|
204 |
+
print(msg)
|
205 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
206 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
207 |
+
|
208 |
+
# Numerical error might give slight imaginary component
|
209 |
+
if np.iscomplexobj(covmean):
|
210 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
211 |
+
m = np.max(np.abs(covmean.imag))
|
212 |
+
raise ValueError('Imaginary component {}'.format(m))
|
213 |
+
covmean = covmean.real
|
214 |
+
|
215 |
+
tr_covmean = np.trace(covmean)
|
216 |
+
|
217 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
218 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
219 |
+
|
220 |
+
|
221 |
+
def fid_inception_v3():
|
222 |
+
"""Build pretrained Inception model for FID computation
|
223 |
+
The Inception model for FID computation uses a different set of weights
|
224 |
+
and has a slightly different structure than torchvision's Inception.
|
225 |
+
This method first constructs torchvision's Inception and then patches the
|
226 |
+
necessary parts that are different in the FID Inception model.
|
227 |
+
"""
|
228 |
+
inception = models.inception_v3(num_classes=1008,
|
229 |
+
aux_logits=False,
|
230 |
+
pretrained=False)
|
231 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
232 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
233 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
234 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
235 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
236 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
237 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
238 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
239 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
240 |
+
|
241 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
242 |
+
inception.load_state_dict(state_dict)
|
243 |
+
return inception
|
244 |
+
|
245 |
+
|
246 |
+
class FIDInceptionA(models.inception.InceptionA):
|
247 |
+
"""InceptionA block patched for FID computation"""
|
248 |
+
def __init__(self, in_channels, pool_features):
|
249 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
branch1x1 = self.branch1x1(x)
|
253 |
+
|
254 |
+
branch5x5 = self.branch5x5_1(x)
|
255 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
256 |
+
|
257 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
258 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
259 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
260 |
+
|
261 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
262 |
+
# its average calculation
|
263 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
264 |
+
count_include_pad=False)
|
265 |
+
branch_pool = self.branch_pool(branch_pool)
|
266 |
+
|
267 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
268 |
+
return torch.cat(outputs, 1)
|
269 |
+
|
270 |
+
|
271 |
+
class FIDInceptionC(models.inception.InceptionC):
|
272 |
+
"""InceptionC block patched for FID computation"""
|
273 |
+
def __init__(self, in_channels, channels_7x7):
|
274 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
branch1x1 = self.branch1x1(x)
|
278 |
+
|
279 |
+
branch7x7 = self.branch7x7_1(x)
|
280 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
281 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
282 |
+
|
283 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
284 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
285 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
286 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
287 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
288 |
+
|
289 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
290 |
+
# its average calculation
|
291 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
292 |
+
count_include_pad=False)
|
293 |
+
branch_pool = self.branch_pool(branch_pool)
|
294 |
+
|
295 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
296 |
+
return torch.cat(outputs, 1)
|
297 |
+
|
298 |
+
|
299 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
300 |
+
"""First InceptionE block patched for FID computation"""
|
301 |
+
def __init__(self, in_channels):
|
302 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
branch1x1 = self.branch1x1(x)
|
306 |
+
|
307 |
+
branch3x3 = self.branch3x3_1(x)
|
308 |
+
branch3x3 = [
|
309 |
+
self.branch3x3_2a(branch3x3),
|
310 |
+
self.branch3x3_2b(branch3x3),
|
311 |
+
]
|
312 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
313 |
+
|
314 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
315 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
316 |
+
branch3x3dbl = [
|
317 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
318 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
319 |
+
]
|
320 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
321 |
+
|
322 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
323 |
+
# its average calculation
|
324 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
325 |
+
count_include_pad=False)
|
326 |
+
branch_pool = self.branch_pool(branch_pool)
|
327 |
+
|
328 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
329 |
+
return torch.cat(outputs, 1)
|
330 |
+
|
331 |
+
|
332 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
333 |
+
"""Second InceptionE block patched for FID computation"""
|
334 |
+
def __init__(self, in_channels):
|
335 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
branch1x1 = self.branch1x1(x)
|
339 |
+
|
340 |
+
branch3x3 = self.branch3x3_1(x)
|
341 |
+
branch3x3 = [
|
342 |
+
self.branch3x3_2a(branch3x3),
|
343 |
+
self.branch3x3_2b(branch3x3),
|
344 |
+
]
|
345 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
346 |
+
|
347 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
348 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
349 |
+
branch3x3dbl = [
|
350 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
351 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
352 |
+
]
|
353 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
354 |
+
|
355 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
356 |
+
# pooling. This is likely an error in this specific Inception
|
357 |
+
# implementation, as other Inception models use average pooling here
|
358 |
+
# (which matches the description in the paper).
|
359 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
360 |
+
branch_pool = self.branch_pool(branch_pool)
|
361 |
+
|
362 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
363 |
+
return torch.cat(outputs, 1)
|
util/models/model.py
ADDED
@@ -0,0 +1,1389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pandas as pd
|
3 |
+
from .OCR_network import *
|
4 |
+
from torch.nn import CTCLoss, MSELoss, L1Loss
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
import random
|
7 |
+
import unicodedata
|
8 |
+
import sys
|
9 |
+
import torchvision.models as models
|
10 |
+
from models.transformer import *
|
11 |
+
from .BigGAN_networks import *
|
12 |
+
from params import *
|
13 |
+
from .OCR_network import *
|
14 |
+
from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
|
15 |
+
from util.util import toggle_grad, loss_hinge_dis, loss_hinge_gen, ortho, default_ortho, toggle_grad, prepare_z_y, \
|
16 |
+
make_one_hot, to_device, multiple_replace, random_word
|
17 |
+
from models.inception import InceptionV3, calculate_frechet_distance
|
18 |
+
import cv2
|
19 |
+
|
20 |
+
class FCNDecoder(nn.Module):
|
21 |
+
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
|
22 |
+
super(FCNDecoder, self).__init__()
|
23 |
+
|
24 |
+
self.model = []
|
25 |
+
self.model += [ResBlocks(n_res, dim, res_norm,
|
26 |
+
activ, pad_type=pad_type)]
|
27 |
+
for i in range(ups):
|
28 |
+
self.model += [nn.Upsample(scale_factor=2),
|
29 |
+
Conv2dBlock(dim, dim // 2, 5, 1, 2,
|
30 |
+
norm='in',
|
31 |
+
activation=activ,
|
32 |
+
pad_type=pad_type)]
|
33 |
+
dim //= 2
|
34 |
+
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
|
35 |
+
norm='none',
|
36 |
+
activation='tanh',
|
37 |
+
pad_type=pad_type)]
|
38 |
+
self.model = nn.Sequential(*self.model)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
y = self.model(x)
|
42 |
+
|
43 |
+
return y
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
class Generator(nn.Module):
|
48 |
+
|
49 |
+
def __init__(self):
|
50 |
+
super(Generator, self).__init__()
|
51 |
+
|
52 |
+
INP_CHANNEL = NUM_EXAMPLES
|
53 |
+
if IS_SEQ: INP_CHANNEL = 1
|
54 |
+
|
55 |
+
|
56 |
+
encoder_layer = TransformerEncoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
|
57 |
+
TN_DROPOUT, "relu", True)
|
58 |
+
encoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) if True else None
|
59 |
+
self.encoder = TransformerEncoder(encoder_layer, TN_ENC_LAYERS, encoder_norm)
|
60 |
+
|
61 |
+
decoder_layer = TransformerDecoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
|
62 |
+
TN_DROPOUT, "relu", True)
|
63 |
+
decoder_norm = nn.LayerNorm(TN_HIDDEN_DIM)
|
64 |
+
self.decoder = TransformerDecoder(decoder_layer, TN_DEC_LAYERS, decoder_norm,
|
65 |
+
return_intermediate=True)
|
66 |
+
|
67 |
+
self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
|
68 |
+
|
69 |
+
self.query_embed = nn.Embedding(VOCAB_SIZE, TN_HIDDEN_DIM)
|
70 |
+
|
71 |
+
|
72 |
+
self.linear_q = nn.Linear(TN_DIM_FEEDFORWARD*2, TN_DIM_FEEDFORWARD*8)
|
73 |
+
|
74 |
+
self.DEC = FCNDecoder(res_norm = 'in')
|
75 |
+
|
76 |
+
|
77 |
+
self._muE = nn.Linear(512,512)
|
78 |
+
self._logvarE = nn.Linear(512,512)
|
79 |
+
|
80 |
+
self._muD = nn.Linear(512,512)
|
81 |
+
self._logvarD = nn.Linear(512,512)
|
82 |
+
|
83 |
+
|
84 |
+
self.l1loss = nn.L1Loss()
|
85 |
+
|
86 |
+
self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
def reparameterize(self, mu, logvar):
|
94 |
+
|
95 |
+
mu = torch.unbind(mu , 1)
|
96 |
+
logvar = torch.unbind(logvar , 1)
|
97 |
+
|
98 |
+
outs = []
|
99 |
+
|
100 |
+
for m,l in zip(mu, logvar):
|
101 |
+
|
102 |
+
sigma = torch.exp(l)
|
103 |
+
eps = torch.cuda.FloatTensor(l.size()[0],1).normal_(0,1)
|
104 |
+
eps = eps.expand(sigma.size())
|
105 |
+
|
106 |
+
out = m + sigma*eps
|
107 |
+
|
108 |
+
outs.append(out)
|
109 |
+
|
110 |
+
|
111 |
+
return torch.stack(outs, 1)
|
112 |
+
|
113 |
+
|
114 |
+
def Eval(self, ST, QRS):
|
115 |
+
|
116 |
+
if IS_SEQ:
|
117 |
+
B, N, R, C = ST.shape
|
118 |
+
FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
|
119 |
+
FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
|
120 |
+
else:
|
121 |
+
FEAT_ST = self.Feat_Encoder(ST)
|
122 |
+
|
123 |
+
|
124 |
+
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
|
125 |
+
|
126 |
+
memory = self.encoder(FEAT_ST_ENC)
|
127 |
+
|
128 |
+
if IS_KLD:
|
129 |
+
|
130 |
+
Ex = memory.permute(1,0,2)
|
131 |
+
|
132 |
+
memory_mu = self._muE(Ex)
|
133 |
+
memory_logvar = self._logvarE(Ex)
|
134 |
+
|
135 |
+
memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
|
136 |
+
|
137 |
+
|
138 |
+
OUT_IMGS = []
|
139 |
+
|
140 |
+
for i in range(QRS.shape[1]):
|
141 |
+
|
142 |
+
QR = QRS[:, i, :]
|
143 |
+
|
144 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
145 |
+
|
146 |
+
tgt = torch.zeros_like(QR_EMB)
|
147 |
+
|
148 |
+
hs = self.decoder(tgt, memory, query_pos=QR_EMB)
|
149 |
+
|
150 |
+
if IS_KLD:
|
151 |
+
|
152 |
+
Dx = hs[0].permute(1,0,2)
|
153 |
+
|
154 |
+
hs_mu = self._muD(Dx)
|
155 |
+
hs_logvar = self._logvarD(Dx)
|
156 |
+
|
157 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
158 |
+
|
159 |
+
|
160 |
+
h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
|
161 |
+
if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
|
162 |
+
|
163 |
+
h = self.linear_q(h)
|
164 |
+
h = h.contiguous()
|
165 |
+
|
166 |
+
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
167 |
+
h = h.permute(0, 3, 2, 1)
|
168 |
+
|
169 |
+
h = self.DEC(h)
|
170 |
+
|
171 |
+
|
172 |
+
OUT_IMGS.append(h.detach())
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
return OUT_IMGS
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
def forward(self, ST, QR, QRs = None, mode = 'train'):
|
184 |
+
|
185 |
+
#Attention Visualization Init
|
186 |
+
|
187 |
+
|
188 |
+
enc_attn_weights, dec_attn_weights = [], []
|
189 |
+
|
190 |
+
self.hooks = [
|
191 |
+
|
192 |
+
self.encoder.layers[-1].self_attn.register_forward_hook(
|
193 |
+
lambda self, input, output: enc_attn_weights.append(output[1])
|
194 |
+
),
|
195 |
+
self.decoder.layers[-1].multihead_attn.register_forward_hook(
|
196 |
+
lambda self, input, output: dec_attn_weights.append(output[1])
|
197 |
+
),
|
198 |
+
]
|
199 |
+
|
200 |
+
|
201 |
+
#Attention Visualization Init
|
202 |
+
|
203 |
+
if IS_SEQ:
|
204 |
+
B, N, R, C = ST.shape
|
205 |
+
FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
|
206 |
+
FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
|
207 |
+
else:
|
208 |
+
FEAT_ST = self.Feat_Encoder(ST)
|
209 |
+
|
210 |
+
|
211 |
+
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
|
212 |
+
|
213 |
+
memory = self.encoder(FEAT_ST_ENC)
|
214 |
+
|
215 |
+
if IS_KLD:
|
216 |
+
|
217 |
+
Ex = memory.permute(1,0,2)
|
218 |
+
|
219 |
+
memory_mu = self._muE(Ex)
|
220 |
+
memory_logvar = self._logvarE(Ex)
|
221 |
+
|
222 |
+
memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
|
223 |
+
|
224 |
+
|
225 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
226 |
+
|
227 |
+
tgt = torch.zeros_like(QR_EMB)
|
228 |
+
|
229 |
+
hs = self.decoder(tgt, memory, query_pos=QR_EMB)
|
230 |
+
|
231 |
+
if IS_KLD:
|
232 |
+
|
233 |
+
Dx = hs[0].permute(1,0,2)
|
234 |
+
|
235 |
+
hs_mu = self._muD(Dx)
|
236 |
+
hs_logvar = self._logvarD(Dx)
|
237 |
+
|
238 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
239 |
+
|
240 |
+
OUT_Feats1_mu = [hs_mu]
|
241 |
+
OUT_Feats1_logvar = [hs_logvar]
|
242 |
+
|
243 |
+
|
244 |
+
OUT_Feats1 = [hs]
|
245 |
+
|
246 |
+
|
247 |
+
h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
|
248 |
+
|
249 |
+
if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
|
250 |
+
|
251 |
+
h = self.linear_q(h)
|
252 |
+
h = h.contiguous()
|
253 |
+
|
254 |
+
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
255 |
+
h = h.permute(0, 3, 2, 1)
|
256 |
+
|
257 |
+
h = self.DEC(h)
|
258 |
+
|
259 |
+
self.dec_attn_weights = dec_attn_weights[-1].detach()
|
260 |
+
self.enc_attn_weights = enc_attn_weights[-1].detach()
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
for hook in self.hooks:
|
265 |
+
hook.remove()
|
266 |
+
|
267 |
+
if mode == 'test' or (not IS_CYCLE and not IS_KLD):
|
268 |
+
|
269 |
+
return h
|
270 |
+
|
271 |
+
|
272 |
+
OUT_IMGS = [h]
|
273 |
+
|
274 |
+
for QR in QRs:
|
275 |
+
|
276 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
277 |
+
|
278 |
+
tgt = torch.zeros_like(QR_EMB)
|
279 |
+
|
280 |
+
hs = self.decoder(tgt, memory, query_pos=QR_EMB)
|
281 |
+
|
282 |
+
|
283 |
+
if IS_KLD:
|
284 |
+
|
285 |
+
Dx = hs[0].permute(1,0,2)
|
286 |
+
|
287 |
+
hs_mu = self._muD(Dx)
|
288 |
+
hs_logvar = self._logvarD(Dx)
|
289 |
+
|
290 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
291 |
+
|
292 |
+
OUT_Feats1_mu.append(hs_mu)
|
293 |
+
OUT_Feats1_logvar.append(hs_logvar)
|
294 |
+
|
295 |
+
|
296 |
+
OUT_Feats1.append(hs)
|
297 |
+
|
298 |
+
|
299 |
+
h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
|
300 |
+
if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
|
301 |
+
|
302 |
+
h = self.linear_q(h)
|
303 |
+
h = h.contiguous()
|
304 |
+
|
305 |
+
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
306 |
+
h = h.permute(0, 3, 2, 1)
|
307 |
+
|
308 |
+
h = self.DEC(h)
|
309 |
+
|
310 |
+
OUT_IMGS.append(h)
|
311 |
+
|
312 |
+
|
313 |
+
if (not IS_CYCLE) and IS_KLD:
|
314 |
+
|
315 |
+
OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]
|
316 |
+
|
317 |
+
OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
|
318 |
+
|
319 |
+
|
320 |
+
KLD = (0.5 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
|
321 |
+
+ (0.5 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
|
326 |
+
return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
|
327 |
+
torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
|
328 |
+
|
329 |
+
|
330 |
+
lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
|
331 |
+
|
332 |
+
|
333 |
+
lda1 = torch.stack(lda1).mean()
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
return OUT_IMGS[0], lda1, KLD
|
338 |
+
|
339 |
+
|
340 |
+
with torch.no_grad():
|
341 |
+
|
342 |
+
if IS_SEQ:
|
343 |
+
|
344 |
+
FEAT_ST_T = torch.cat([self.Feat_Encoder(IM) for IM in OUT_IMGS], -1)
|
345 |
+
|
346 |
+
else:
|
347 |
+
|
348 |
+
max_width_ = max([i_.shape[-1] for i_ in OUT_IMGS])
|
349 |
+
|
350 |
+
FEAT_ST_T = self.Feat_Encoder(torch.cat([torch.cat([i_, torch.ones((i_.shape[0], i_.shape[1],i_.shape[2], max_width_-i_.shape[3])).to(DEVICE)], -1) for i_ in OUT_IMGS], 1))
|
351 |
+
|
352 |
+
FEAT_ST_ENC_T = FEAT_ST_T.flatten(2).permute(2,0,1)
|
353 |
+
|
354 |
+
memory_T = self.encoder(FEAT_ST_ENC_T)
|
355 |
+
|
356 |
+
if IS_KLD:
|
357 |
+
|
358 |
+
Ex = memory_T.permute(1,0,2)
|
359 |
+
|
360 |
+
memory_T_mu = self._muE(Ex)
|
361 |
+
memory_T_logvar = self._logvarE(Ex)
|
362 |
+
|
363 |
+
memory_T = self.reparameterize(memory_T_mu, memory_T_logvar).permute(1,0,2)
|
364 |
+
|
365 |
+
|
366 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
367 |
+
|
368 |
+
tgt = torch.zeros_like(QR_EMB)
|
369 |
+
|
370 |
+
hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
|
371 |
+
|
372 |
+
if IS_KLD:
|
373 |
+
|
374 |
+
Dx = hs[0].permute(1,0,2)
|
375 |
+
|
376 |
+
hs_mu = self._muD(Dx)
|
377 |
+
hs_logvar = self._logvarD(Dx)
|
378 |
+
|
379 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
380 |
+
|
381 |
+
OUT_Feats2_mu = [hs_mu]
|
382 |
+
OUT_Feats2_logvar = [hs_logvar]
|
383 |
+
|
384 |
+
|
385 |
+
OUT_Feats2 = [hs]
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
for QR in QRs:
|
390 |
+
|
391 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
392 |
+
|
393 |
+
tgt = torch.zeros_like(QR_EMB)
|
394 |
+
|
395 |
+
hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
|
396 |
+
|
397 |
+
if IS_KLD:
|
398 |
+
|
399 |
+
Dx = hs[0].permute(1,0,2)
|
400 |
+
|
401 |
+
hs_mu = self._muD(Dx)
|
402 |
+
hs_logvar = self._logvarD(Dx)
|
403 |
+
|
404 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
405 |
+
|
406 |
+
OUT_Feats2_mu.append(hs_mu)
|
407 |
+
OUT_Feats2_logvar.append(hs_logvar)
|
408 |
+
|
409 |
+
|
410 |
+
OUT_Feats2.append(hs)
|
411 |
+
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
+
Lcycle1 = np.sum([self.l1loss(memory[m_i], memory_T[m_j]) for m_i in range(memory.shape[0]) for m_j in range(memory_T.shape[0])])/(memory.shape[0]*memory_T.shape[0])
|
416 |
+
OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]; OUT_Feats2 = torch.cat(OUT_Feats2, 1)[0]
|
417 |
+
|
418 |
+
Lcycle2 = np.sum([self.l1loss(OUT_Feats1[f_i], OUT_Feats2[f_j]) for f_i in range(OUT_Feats1.shape[0]) for f_j in range(OUT_Feats2.shape[0])])/(OUT_Feats1.shape[0]*OUT_Feats2.shape[0])
|
419 |
+
|
420 |
+
if IS_KLD:
|
421 |
+
|
422 |
+
OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
|
423 |
+
OUT_Feats2_mu = torch.cat(OUT_Feats2_mu, 1); OUT_Feats2_logvar = torch.cat(OUT_Feats2_logvar, 1);
|
424 |
+
|
425 |
+
KLD = (0.25 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
|
426 |
+
+ (0.25 * torch.mean(1 + memory_T_logvar - memory_T_mu.pow(2) - memory_T_logvar.exp()))\
|
427 |
+
+ (0.25 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))\
|
428 |
+
+ (0.25 * torch.mean(1 + OUT_Feats2_logvar - OUT_Feats2_mu.pow(2) - OUT_Feats2_logvar.exp()))
|
429 |
+
|
430 |
+
|
431 |
+
def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
|
432 |
+
return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
|
433 |
+
torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
|
434 |
+
|
435 |
+
|
436 |
+
lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
|
437 |
+
lda2 = [_get_lda(memory_T_mu[:,idi,:], OUT_Feats2_mu[:,idj,:], memory_T_logvar[:,idi,:], OUT_Feats2_logvar[:,idj,:]) for idi in range(memory_T.shape[0]) for idj in range(OUT_Feats2.shape[0])]
|
438 |
+
|
439 |
+
lda1 = torch.stack(lda1).mean()
|
440 |
+
lda2 = torch.stack(lda2).mean()
|
441 |
+
|
442 |
+
|
443 |
+
return OUT_IMGS[0], Lcycle1, Lcycle2, lda1, lda2, KLD
|
444 |
+
|
445 |
+
|
446 |
+
return OUT_IMGS[0], Lcycle1, Lcycle2
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
class TRGAN(nn.Module):
|
451 |
+
|
452 |
+
def __init__(self):
|
453 |
+
super(TRGAN, self).__init__()
|
454 |
+
|
455 |
+
|
456 |
+
self.epsilon = 1e-7
|
457 |
+
self.netG = Generator().to(DEVICE)
|
458 |
+
self.netD = nn.DataParallel(Discriminator()).to(DEVICE)
|
459 |
+
self.netW = nn.DataParallel(WDiscriminator()).to(DEVICE)
|
460 |
+
self.netconverter = strLabelConverter(ALPHABET)
|
461 |
+
self.netOCR = CRNN().to(DEVICE)
|
462 |
+
self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
|
463 |
+
|
464 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
465 |
+
self.inception = InceptionV3([block_idx]).to(DEVICE)
|
466 |
+
|
467 |
+
|
468 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
|
469 |
+
lr=G_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
470 |
+
self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
|
471 |
+
lr=OCR_LR, betas=(0.0, 0.999), weight_decay=0,
|
472 |
+
eps=1e-8)
|
473 |
+
|
474 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
|
475 |
+
lr=D_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
476 |
+
|
477 |
+
|
478 |
+
self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
|
479 |
+
lr=W_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
480 |
+
|
481 |
+
|
482 |
+
self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
|
483 |
+
|
484 |
+
|
485 |
+
self.optimizer_G.zero_grad()
|
486 |
+
self.optimizer_OCR.zero_grad()
|
487 |
+
self.optimizer_D.zero_grad()
|
488 |
+
self.optimizer_wl.zero_grad()
|
489 |
+
|
490 |
+
self.loss_G = 0
|
491 |
+
self.loss_D = 0
|
492 |
+
self.loss_Dfake = 0
|
493 |
+
self.loss_Dreal = 0
|
494 |
+
self.loss_OCR_fake = 0
|
495 |
+
self.loss_OCR_real = 0
|
496 |
+
self.loss_w_fake = 0
|
497 |
+
self.loss_w_real = 0
|
498 |
+
self.Lcycle1 = 0
|
499 |
+
self.Lcycle2 = 0
|
500 |
+
self.lda1 = 0
|
501 |
+
self.lda2 = 0
|
502 |
+
self.KLD = 0
|
503 |
+
|
504 |
+
|
505 |
+
with open('../Lexicon/english_words.txt', 'rb') as f:
|
506 |
+
self.lex = f.read().splitlines()
|
507 |
+
lex=[]
|
508 |
+
for word in self.lex:
|
509 |
+
try:
|
510 |
+
word=word.decode("utf-8")
|
511 |
+
except:
|
512 |
+
continue
|
513 |
+
if len(word)<20:
|
514 |
+
lex.append(word)
|
515 |
+
self.lex = lex
|
516 |
+
|
517 |
+
|
518 |
+
f = open('mytext.txt', 'r')
|
519 |
+
|
520 |
+
self.text = [j.encode() for j in sum([i.split(' ') for i in f.readlines()], [])][:NUM_EXAMPLES]
|
521 |
+
self.eval_text_encode, self.eval_len_text = self.netconverter.encode(self.text)
|
522 |
+
self.eval_text_encode = self.eval_text_encode.to(DEVICE).repeat(batch_size, 1, 1)
|
523 |
+
|
524 |
+
|
525 |
+
def _generate_page(self):
|
526 |
+
|
527 |
+
self.fakes = self.netG.Eval(self.sdata, self.eval_text_encode)
|
528 |
+
|
529 |
+
word_t = []
|
530 |
+
word_l = []
|
531 |
+
|
532 |
+
gap = np.ones([32,16])
|
533 |
+
|
534 |
+
line_wids = []
|
535 |
+
|
536 |
+
|
537 |
+
for idx, fake_ in enumerate(self.fakes):
|
538 |
+
|
539 |
+
word_t.append((fake_[0,0,:,:self.eval_len_text[idx]*resolution].cpu().numpy()+1)/2)
|
540 |
+
|
541 |
+
word_t.append(gap)
|
542 |
+
|
543 |
+
if len(word_t) == 16 or idx == len(self.fakes) - 1:
|
544 |
+
|
545 |
+
line_ = np.concatenate(word_t, -1)
|
546 |
+
|
547 |
+
word_l.append(line_)
|
548 |
+
line_wids.append(line_.shape[1])
|
549 |
+
|
550 |
+
word_t = []
|
551 |
+
|
552 |
+
|
553 |
+
gap_h = np.ones([16,max(line_wids)])
|
554 |
+
|
555 |
+
page_= []
|
556 |
+
|
557 |
+
for l in word_l:
|
558 |
+
|
559 |
+
pad_ = np.ones([32,max(line_wids) - l.shape[1]])
|
560 |
+
|
561 |
+
page_.append(np.concatenate([l, pad_], 1))
|
562 |
+
page_.append(gap_h)
|
563 |
+
|
564 |
+
|
565 |
+
|
566 |
+
page1 = np.concatenate(page_, 0)
|
567 |
+
|
568 |
+
|
569 |
+
word_t = []
|
570 |
+
word_l = []
|
571 |
+
|
572 |
+
gap = np.ones([32,16])
|
573 |
+
|
574 |
+
line_wids = []
|
575 |
+
|
576 |
+
sdata_ = [i.unsqueeze(1) for i in torch.unbind(self.sdata, 1)]
|
577 |
+
|
578 |
+
for idx, st in enumerate((sdata_)):
|
579 |
+
|
580 |
+
word_t.append((st[0,0,:,:int(self.input['swids'].cpu().numpy()[0][idx])
|
581 |
+
].cpu().numpy()+1)/2)
|
582 |
+
|
583 |
+
word_t.append(gap)
|
584 |
+
|
585 |
+
if len(word_t) == 16 or idx == len(self.fakes) - 1:
|
586 |
+
|
587 |
+
line_ = np.concatenate(word_t, -1)
|
588 |
+
|
589 |
+
word_l.append(line_)
|
590 |
+
line_wids.append(line_.shape[1])
|
591 |
+
|
592 |
+
word_t = []
|
593 |
+
|
594 |
+
|
595 |
+
gap_h = np.ones([16,max(line_wids)])
|
596 |
+
|
597 |
+
page_= []
|
598 |
+
|
599 |
+
for l in word_l:
|
600 |
+
|
601 |
+
pad_ = np.ones([32,max(line_wids) - l.shape[1]])
|
602 |
+
|
603 |
+
page_.append(np.concatenate([l, pad_], 1))
|
604 |
+
page_.append(gap_h)
|
605 |
+
|
606 |
+
|
607 |
+
|
608 |
+
page2 = np.concatenate(page_, 0)
|
609 |
+
|
610 |
+
merge_w_size = max(page1.shape[0], page2.shape[0])
|
611 |
+
|
612 |
+
if page1.shape[0] != merge_w_size:
|
613 |
+
|
614 |
+
page1 = np.concatenate([page1, np.ones([merge_w_size-page1.shape[0], page1.shape[1]])], 0)
|
615 |
+
|
616 |
+
if page2.shape[0] != merge_w_size:
|
617 |
+
|
618 |
+
page2 = np.concatenate([page2, np.ones([merge_w_size-page2.shape[0], page2.shape[1]])], 0)
|
619 |
+
|
620 |
+
|
621 |
+
page = np.concatenate([page2, page1], 1)
|
622 |
+
|
623 |
+
|
624 |
+
return page
|
625 |
+
|
626 |
+
|
627 |
+
|
628 |
+
|
629 |
+
|
630 |
+
|
631 |
+
|
632 |
+
|
633 |
+
|
634 |
+
|
635 |
+
|
636 |
+
|
637 |
+
|
638 |
+
|
639 |
+
|
640 |
+
|
641 |
+
|
642 |
+
|
643 |
+
|
644 |
+
#FEAT1 = self.inception(torch.cat(self.fakes, 0).repeat(1,3,1,1))[0].detach().view(batch_size, len(self.fakes), -1).cpu().numpy()
|
645 |
+
#FEAT2 = self.inception(self.sdata.view(batch_size*NUM_EXAMPLES, 1, 32, -1).repeat(1,3,1,1))[0].detach().view(batch_size, NUM_EXAMPLES, -1 ).cpu().numpy()
|
646 |
+
#muvars1 = [{'mu':np.mean(FEAT1[i], axis=0), 'sigma' : np.cov(FEAT1[i], rowvar=False)} for i in range(FEAT1.shape[0])]
|
647 |
+
#muvars2 = [{'mu':np.mean(FEAT2[i], axis=0), 'sigma' : np.cov(FEAT2[i], rowvar=False)} for i in range(FEAT2.shape[0])]
|
648 |
+
|
649 |
+
|
650 |
+
|
651 |
+
|
652 |
+
|
653 |
+
|
654 |
+
def get_current_losses(self):
|
655 |
+
|
656 |
+
losses = {}
|
657 |
+
|
658 |
+
losses['G'] = self.loss_G
|
659 |
+
losses['D'] = self.loss_D
|
660 |
+
losses['Dfake'] = self.loss_Dfake
|
661 |
+
losses['Dreal'] = self.loss_Dreal
|
662 |
+
losses['OCR_fake'] = self.loss_OCR_fake
|
663 |
+
losses['OCR_real'] = self.loss_OCR_real
|
664 |
+
losses['w_fake'] = self.loss_w_fake
|
665 |
+
losses['w_real'] = self.loss_w_real
|
666 |
+
losses['cycle1'] = self.Lcycle1
|
667 |
+
losses['cycle2'] = self.Lcycle2
|
668 |
+
losses['lda1'] = self.lda1
|
669 |
+
losses['lda2'] = self.lda2
|
670 |
+
losses['KLD'] = self.KLD
|
671 |
+
|
672 |
+
return losses
|
673 |
+
|
674 |
+
def visualize_images(self):
|
675 |
+
|
676 |
+
imgs = {}
|
677 |
+
|
678 |
+
|
679 |
+
imgs['fake-1']=self.netG(self.sdata[0:1], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
|
680 |
+
imgs['fake-2']=self.netG(self.sdata[0:1], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
681 |
+
imgs['fake-3']=self.netG(self.sdata[0:1], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
682 |
+
|
683 |
+
|
684 |
+
imgs['res-1'] = torch.cat([self.sdata[0, 0],self.sdata[0, 1],self.sdata[0, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
|
685 |
+
|
686 |
+
|
687 |
+
imgs['fake-1']=self.netG(self.sdata[1:2], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
|
688 |
+
imgs['fake-2']=self.netG(self.sdata[1:2], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
689 |
+
imgs['fake-3']=self.netG(self.sdata[1:2], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
690 |
+
|
691 |
+
|
692 |
+
imgs['res-2'] = torch.cat([self.sdata[1, 0],self.sdata[1, 1],self.sdata[1, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
|
693 |
+
|
694 |
+
|
695 |
+
imgs['fake-1']=self.netG(self.sdata[2:3], self.text_encode_fake[0].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
696 |
+
imgs['fake-2']=self.netG(self.sdata[2:3], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
697 |
+
imgs['fake-3']=self.netG(self.sdata[2:3], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
698 |
+
|
699 |
+
|
700 |
+
imgs['res-3'] = torch.cat([self.sdata[2, 0],self.sdata[2, 1],self.sdata[2, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
|
701 |
+
|
702 |
+
|
703 |
+
|
704 |
+
|
705 |
+
return imgs
|
706 |
+
|
707 |
+
|
708 |
+
def load_networks(self, epoch):
|
709 |
+
BaseModel.load_networks(self, epoch)
|
710 |
+
if self.opt.single_writer:
|
711 |
+
load_filename = '%s_z.pkl' % (epoch)
|
712 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
713 |
+
self.z = torch.load(load_path)
|
714 |
+
|
715 |
+
def _set_input(self, input):
|
716 |
+
self.input = input
|
717 |
+
|
718 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
719 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
720 |
+
Parameters:
|
721 |
+
nets (network list) -- a list of networks
|
722 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
723 |
+
"""
|
724 |
+
if not isinstance(nets, list):
|
725 |
+
nets = [nets]
|
726 |
+
for net in nets:
|
727 |
+
if net is not None:
|
728 |
+
for param in net.parameters():
|
729 |
+
param.requires_grad = requires_grad
|
730 |
+
|
731 |
+
def forward(self):
|
732 |
+
|
733 |
+
|
734 |
+
self.real = self.input['img'].to(DEVICE)
|
735 |
+
self.label = self.input['label']
|
736 |
+
self.sdata = self.input['simg'].to(DEVICE)
|
737 |
+
self.ST_LEN = self.input['swids']
|
738 |
+
self.text_encode, self.len_text = self.netconverter.encode(self.label)
|
739 |
+
self.one_hot_real = make_one_hot(self.text_encode, self.len_text, VOCAB_SIZE).to(DEVICE).detach()
|
740 |
+
self.text_encode = self.text_encode.to(DEVICE).detach()
|
741 |
+
self.len_text = self.len_text.detach()
|
742 |
+
|
743 |
+
self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
|
744 |
+
self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words)
|
745 |
+
self.text_encode_fake = self.text_encode_fake.to(DEVICE)
|
746 |
+
self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, VOCAB_SIZE).to(DEVICE)
|
747 |
+
|
748 |
+
|
749 |
+
self.text_encode_fake_js = []
|
750 |
+
|
751 |
+
for _ in range(NUM_WORDS - 1):
|
752 |
+
|
753 |
+
self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
|
754 |
+
self.text_encode_fake_j, self.len_text_fake_j = self.netconverter.encode(self.words_j)
|
755 |
+
self.text_encode_fake_j = self.text_encode_fake_j.to(DEVICE)
|
756 |
+
self.text_encode_fake_js.append(self.text_encode_fake_j)
|
757 |
+
|
758 |
+
|
759 |
+
if IS_CYCLE and IS_KLD:
|
760 |
+
|
761 |
+
self.fake, self.Lcycle1, self.Lcycle2, self.lda1, self.lda2, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
762 |
+
|
763 |
+
elif IS_CYCLE and (not IS_KLD):
|
764 |
+
|
765 |
+
self.fake, self.Lcycle1, self.Lcycle2 = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
766 |
+
|
767 |
+
elif (not IS_CYCLE) and IS_KLD:
|
768 |
+
|
769 |
+
self.fake, self.lda1, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
770 |
+
|
771 |
+
else:
|
772 |
+
|
773 |
+
self.fake = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
774 |
+
|
775 |
+
|
776 |
+
|
777 |
+
def visualize_attention(self):
|
778 |
+
|
779 |
+
def _norm_scores(arr):
|
780 |
+
return (arr - min(arr))/(max(arr) - min(arr))
|
781 |
+
|
782 |
+
simgs = self.sdata[0].detach().cpu().numpy()
|
783 |
+
fake = self.fake[0,0].detach().cpu().numpy()
|
784 |
+
slen = self.ST_LEN[0].detach().cpu().numpy()
|
785 |
+
selfatt = self.netG.enc_attn_weights[0].detach().cpu().numpy()
|
786 |
+
selfatt = np.stack([_norm_scores(i) for i in selfatt], 1)
|
787 |
+
fake_lab = self.words[0].decode()
|
788 |
+
|
789 |
+
decatt = self.netG.dec_attn_weights[0].detach().cpu().numpy()
|
790 |
+
decatt = np.stack([_norm_scores(i) for i in decatt], 0)
|
791 |
+
|
792 |
+
STdict = {}
|
793 |
+
FAKEdict = {}
|
794 |
+
count = 0
|
795 |
+
|
796 |
+
for sim_, sle_ in zip(simgs,slen):
|
797 |
+
|
798 |
+
for pi in range(sim_.shape[1]//sim_.shape[0]):
|
799 |
+
|
800 |
+
STdict[count] = {'patch':sim_[:, pi*32:(pi+1)*32], 'ischar': sle_>=pi*32, 'encoder_attention_score': selfatt[count], 'decoder_attention_score': decatt[:,count]}
|
801 |
+
count = count + 1
|
802 |
+
|
803 |
+
|
804 |
+
for pi in range(fake.shape[1]//resolution):
|
805 |
+
|
806 |
+
FAKEdict[pi] = {'patch': fake[:, pi*resolution:(pi+1)*resolution]}
|
807 |
+
|
808 |
+
show_ims = []
|
809 |
+
|
810 |
+
for idx in range(len(fake_lab)):
|
811 |
+
|
812 |
+
viz_pats = []
|
813 |
+
viz_lin = []
|
814 |
+
|
815 |
+
for i in STdict.keys():
|
816 |
+
|
817 |
+
if STdict[i]['ischar']:
|
818 |
+
|
819 |
+
viz_pats.append(cv2.addWeighted(STdict[i]['patch'], 0.5, np.ones_like(STdict[i]['patch'])*STdict[i]['decoder_attention_score'][idx], 0.5, 0))
|
820 |
+
|
821 |
+
if len(viz_pats) >= 20:
|
822 |
+
|
823 |
+
viz_lin.append(np.concatenate(viz_pats, -1))
|
824 |
+
|
825 |
+
viz_pats = []
|
826 |
+
|
827 |
+
|
828 |
+
|
829 |
+
|
830 |
+
src = np.concatenate(viz_lin[:-2], 0)*255
|
831 |
+
|
832 |
+
viz_gts = []
|
833 |
+
|
834 |
+
for i in range(len(fake_lab)):
|
835 |
+
|
836 |
+
|
837 |
+
|
838 |
+
#if i == idx:
|
839 |
+
|
840 |
+
#bordersize = 5
|
841 |
+
|
842 |
+
#FAKEdict[i]['patch'] = cv2.addWeighted(FAKEdict[i]['patch'] , 0.5, np.ones_like(FAKEdict[i]['patch'] ), 0.5, 0)
|
843 |
+
|
844 |
+
|
845 |
+
|
846 |
+
|
847 |
+
|
848 |
+
|
849 |
+
img = np.zeros((54,16))
|
850 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
851 |
+
text = fake_lab[i]
|
852 |
+
|
853 |
+
# get boundary of this text
|
854 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
855 |
+
|
856 |
+
# get coords based on boundary
|
857 |
+
textX = (img.shape[1] - textsize[0]) // 2
|
858 |
+
textY = (img.shape[0] + textsize[1]) // 2
|
859 |
+
|
860 |
+
# add text centered on image
|
861 |
+
cv2.putText(img, text, (textX, textY ), font, 1, (255, 255, 255), 2)
|
862 |
+
|
863 |
+
img = (255 - img)/255
|
864 |
+
|
865 |
+
if i == idx:
|
866 |
+
|
867 |
+
img = (1 - img)
|
868 |
+
|
869 |
+
viz_gts.append(img)
|
870 |
+
|
871 |
+
|
872 |
+
|
873 |
+
tgt = np.concatenate([fake[:,:len(fake_lab)*16],np.concatenate(viz_gts, -1)], 0)
|
874 |
+
pad_ = np.ones((tgt.shape[0], (src.shape[1]-tgt.shape[1])//2))
|
875 |
+
tgt = np.concatenate([pad_, tgt, pad_], -1)*255
|
876 |
+
final = np.concatenate([src, tgt], 0)
|
877 |
+
|
878 |
+
|
879 |
+
show_ims.append(final)
|
880 |
+
|
881 |
+
return show_ims
|
882 |
+
|
883 |
+
|
884 |
+
def backward_D_OCR(self):
|
885 |
+
|
886 |
+
pred_real = self.netD(self.real.detach())
|
887 |
+
|
888 |
+
pred_fake = self.netD(**{'x': self.fake.detach()})
|
889 |
+
|
890 |
+
|
891 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
|
892 |
+
|
893 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
894 |
+
|
895 |
+
self.pred_real_OCR = self.netOCR(self.real.detach())
|
896 |
+
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * batch_size).detach()
|
897 |
+
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
898 |
+
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
|
899 |
+
|
900 |
+
loss_total = self.loss_D + self.loss_OCR_real
|
901 |
+
|
902 |
+
# backward
|
903 |
+
loss_total.backward()
|
904 |
+
for param in self.netOCR.parameters():
|
905 |
+
param.grad[param.grad!=param.grad]=0
|
906 |
+
param.grad[torch.isnan(param.grad)]=0
|
907 |
+
param.grad[torch.isinf(param.grad)]=0
|
908 |
+
|
909 |
+
|
910 |
+
|
911 |
+
return loss_total
|
912 |
+
|
913 |
+
def backward_D_WL(self):
|
914 |
+
# Real
|
915 |
+
pred_real = self.netD(self.real.detach())
|
916 |
+
|
917 |
+
pred_fake = self.netD(**{'x': self.fake.detach()})
|
918 |
+
|
919 |
+
|
920 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
|
921 |
+
|
922 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
923 |
+
|
924 |
+
|
925 |
+
self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(DEVICE)).mean()
|
926 |
+
# total loss
|
927 |
+
loss_total = self.loss_D + self.loss_w_real
|
928 |
+
|
929 |
+
# backward
|
930 |
+
loss_total.backward()
|
931 |
+
|
932 |
+
|
933 |
+
return loss_total
|
934 |
+
|
935 |
+
def optimize_D_WL(self):
|
936 |
+
self.forward()
|
937 |
+
self.set_requires_grad([self.netD], True)
|
938 |
+
self.set_requires_grad([self.netOCR], False)
|
939 |
+
self.set_requires_grad([self.netW], True)
|
940 |
+
|
941 |
+
self.optimizer_D.zero_grad()
|
942 |
+
self.optimizer_wl.zero_grad()
|
943 |
+
|
944 |
+
self.backward_D_WL()
|
945 |
+
|
946 |
+
|
947 |
+
|
948 |
+
|
949 |
+
def backward_D_OCR_WL(self):
|
950 |
+
# Real
|
951 |
+
if self.real_z_mean is None:
|
952 |
+
pred_real = self.netD(self.real.detach())
|
953 |
+
else:
|
954 |
+
pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
|
955 |
+
# Fake
|
956 |
+
try:
|
957 |
+
pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
|
958 |
+
except:
|
959 |
+
print('a')
|
960 |
+
# Combined loss
|
961 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
|
962 |
+
|
963 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
964 |
+
# OCR loss on real data
|
965 |
+
self.pred_real_OCR = self.netOCR(self.real.detach())
|
966 |
+
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
|
967 |
+
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
968 |
+
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
|
969 |
+
# total loss
|
970 |
+
self.loss_w_real = self.netW(self.real.detach(), self.wcl)
|
971 |
+
loss_total = self.loss_D + self.loss_OCR_real + self.loss_w_real
|
972 |
+
|
973 |
+
# backward
|
974 |
+
loss_total.backward()
|
975 |
+
for param in self.netOCR.parameters():
|
976 |
+
param.grad[param.grad!=param.grad]=0
|
977 |
+
param.grad[torch.isnan(param.grad)]=0
|
978 |
+
param.grad[torch.isinf(param.grad)]=0
|
979 |
+
|
980 |
+
|
981 |
+
|
982 |
+
return loss_total
|
983 |
+
|
984 |
+
def optimize_D_WL_step(self):
|
985 |
+
self.optimizer_D.step()
|
986 |
+
self.optimizer_wl.step()
|
987 |
+
self.optimizer_D.zero_grad()
|
988 |
+
self.optimizer_wl.zero_grad()
|
989 |
+
|
990 |
+
def backward_OCR(self):
|
991 |
+
# OCR loss on real data
|
992 |
+
self.pred_real_OCR = self.netOCR(self.real.detach())
|
993 |
+
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
|
994 |
+
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
995 |
+
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
|
996 |
+
|
997 |
+
# backward
|
998 |
+
self.loss_OCR_real.backward()
|
999 |
+
for param in self.netOCR.parameters():
|
1000 |
+
param.grad[param.grad!=param.grad]=0
|
1001 |
+
param.grad[torch.isnan(param.grad)]=0
|
1002 |
+
param.grad[torch.isinf(param.grad)]=0
|
1003 |
+
|
1004 |
+
return self.loss_OCR_real
|
1005 |
+
|
1006 |
+
|
1007 |
+
def backward_D(self):
|
1008 |
+
# Real
|
1009 |
+
if self.real_z_mean is None:
|
1010 |
+
pred_real = self.netD(self.real.detach())
|
1011 |
+
else:
|
1012 |
+
pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
|
1013 |
+
pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
|
1014 |
+
# Combined loss
|
1015 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
|
1016 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
1017 |
+
# backward
|
1018 |
+
self.loss_D.backward()
|
1019 |
+
|
1020 |
+
|
1021 |
+
return self.loss_D
|
1022 |
+
|
1023 |
+
|
1024 |
+
def backward_G_only(self):
|
1025 |
+
|
1026 |
+
self.gb_alpha = 0.7
|
1027 |
+
#self.Lcycle1 = self.Lcycle1.mean()
|
1028 |
+
#self.Lcycle2 = self.Lcycle2.mean()
|
1029 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
|
1030 |
+
|
1031 |
+
|
1032 |
+
pred_fake_OCR = self.netOCR(self.fake)
|
1033 |
+
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * batch_size).detach()
|
1034 |
+
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
|
1035 |
+
self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
|
1036 |
+
|
1037 |
+
self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
|
1038 |
+
|
1039 |
+
self.loss_T = self.loss_G + self.loss_OCR_fake
|
1040 |
+
|
1041 |
+
|
1042 |
+
|
1043 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
|
1044 |
+
|
1045 |
+
|
1046 |
+
self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
|
1047 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
|
1048 |
+
self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
|
1049 |
+
|
1050 |
+
|
1051 |
+
self.loss_T.backward(retain_graph=True)
|
1052 |
+
|
1053 |
+
|
1054 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
1055 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
1056 |
+
|
1057 |
+
|
1058 |
+
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
|
1059 |
+
|
1060 |
+
|
1061 |
+
if a is None:
|
1062 |
+
print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
|
1063 |
+
if a>1000 or a<0.0001:
|
1064 |
+
print(a)
|
1065 |
+
|
1066 |
+
|
1067 |
+
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
|
1068 |
+
|
1069 |
+
self.loss_T = self.loss_G + self.loss_OCR_fake
|
1070 |
+
|
1071 |
+
|
1072 |
+
self.loss_T.backward(retain_graph=True)
|
1073 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
1074 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
|
1075 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
1076 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
1077 |
+
|
1078 |
+
with torch.no_grad():
|
1079 |
+
self.loss_T.backward()
|
1080 |
+
|
1081 |
+
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G):
|
1082 |
+
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
|
1083 |
+
sys.exit()
|
1084 |
+
|
1085 |
+
def backward_G_WL(self):
|
1086 |
+
|
1087 |
+
self.gb_alpha = 0.7
|
1088 |
+
#self.Lcycle1 = self.Lcycle1.mean()
|
1089 |
+
#self.Lcycle2 = self.Lcycle2.mean()
|
1090 |
+
|
1091 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
|
1092 |
+
|
1093 |
+
self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(DEVICE)).mean()
|
1094 |
+
|
1095 |
+
self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
|
1096 |
+
|
1097 |
+
self.loss_T = self.loss_G + self.loss_w_fake
|
1098 |
+
|
1099 |
+
|
1100 |
+
|
1101 |
+
|
1102 |
+
#grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, retain_graph=True)[0]
|
1103 |
+
|
1104 |
+
|
1105 |
+
#self.loss_grad_fake_WL = 10**6*torch.mean(grad_fake_WL**2)
|
1106 |
+
#grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
|
1107 |
+
#self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
|
1108 |
+
|
1109 |
+
|
1110 |
+
|
1111 |
+
self.loss_T.backward(retain_graph=True)
|
1112 |
+
|
1113 |
+
|
1114 |
+
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
1115 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
1116 |
+
|
1117 |
+
|
1118 |
+
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_WL))
|
1119 |
+
|
1120 |
+
|
1121 |
+
|
1122 |
+
if a is None:
|
1123 |
+
print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_WL))
|
1124 |
+
if a>1000 or a<0.0001:
|
1125 |
+
print(a)
|
1126 |
+
|
1127 |
+
self.loss_w_fake = a.detach() * self.loss_w_fake
|
1128 |
+
|
1129 |
+
self.loss_T = self.loss_G + self.loss_w_fake
|
1130 |
+
|
1131 |
+
self.loss_T.backward(retain_graph=True)
|
1132 |
+
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
1133 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
|
1134 |
+
self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
|
1135 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
1136 |
+
|
1137 |
+
with torch.no_grad():
|
1138 |
+
self.loss_T.backward()
|
1139 |
+
|
1140 |
+
def backward_G(self):
|
1141 |
+
self.opt.gb_alpha = 0.7
|
1142 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss)
|
1143 |
+
# OCR loss on real data
|
1144 |
+
|
1145 |
+
pred_fake_OCR = self.netOCR(self.fake)
|
1146 |
+
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach()
|
1147 |
+
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
|
1148 |
+
self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
|
1149 |
+
|
1150 |
+
|
1151 |
+
self.loss_w_fake = self.netW(self.fake, self.wcl)
|
1152 |
+
#self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
|
1153 |
+
# total loss
|
1154 |
+
|
1155 |
+
# l1 = self.params[0]*self.loss_G
|
1156 |
+
# l2 = self.params[0]*self.loss_OCR_fake
|
1157 |
+
#l3 = self.params[0]*self.loss_w_fake
|
1158 |
+
self.loss_G_ = 10*self.loss_G + self.loss_w_fake
|
1159 |
+
self.loss_T = self.loss_G_ + self.loss_OCR_fake
|
1160 |
+
|
1161 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
|
1162 |
+
|
1163 |
+
|
1164 |
+
self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
|
1165 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
|
1166 |
+
self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
|
1167 |
+
|
1168 |
+
if not False:
|
1169 |
+
|
1170 |
+
self.loss_T.backward(retain_graph=True)
|
1171 |
+
|
1172 |
+
|
1173 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
1174 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
|
1175 |
+
#grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
1176 |
+
|
1177 |
+
|
1178 |
+
a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
|
1179 |
+
|
1180 |
+
|
1181 |
+
#a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
|
1182 |
+
|
1183 |
+
if a is None:
|
1184 |
+
print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
|
1185 |
+
if a>1000 or a<0.0001:
|
1186 |
+
print(a)
|
1187 |
+
b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
|
1188 |
+
torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))*
|
1189 |
+
torch.mean(grad_fake_OCR))
|
1190 |
+
# self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
|
1191 |
+
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
|
1192 |
+
#self.loss_w_fake = a0.detach() * self.loss_w_fake
|
1193 |
+
|
1194 |
+
self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G_ + self.loss_OCR_fake# + self.loss_w_fake
|
1195 |
+
self.loss_T.backward(retain_graph=True)
|
1196 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
1197 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
|
1198 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
1199 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
1200 |
+
with torch.no_grad():
|
1201 |
+
self.loss_T.backward()
|
1202 |
+
else:
|
1203 |
+
self.loss_T.backward()
|
1204 |
+
|
1205 |
+
if self.opt.clip_grad > 0:
|
1206 |
+
clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
|
1207 |
+
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
|
1208 |
+
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
|
1209 |
+
sys.exit()
|
1210 |
+
|
1211 |
+
|
1212 |
+
|
1213 |
+
def optimize_D_OCR(self):
|
1214 |
+
self.forward()
|
1215 |
+
self.set_requires_grad([self.netD], True)
|
1216 |
+
self.set_requires_grad([self.netOCR], True)
|
1217 |
+
self.optimizer_D.zero_grad()
|
1218 |
+
#if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1219 |
+
self.optimizer_OCR.zero_grad()
|
1220 |
+
self.backward_D_OCR()
|
1221 |
+
|
1222 |
+
def optimize_OCR(self):
|
1223 |
+
self.forward()
|
1224 |
+
self.set_requires_grad([self.netD], False)
|
1225 |
+
self.set_requires_grad([self.netOCR], True)
|
1226 |
+
if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1227 |
+
self.optimizer_OCR.zero_grad()
|
1228 |
+
self.backward_OCR()
|
1229 |
+
|
1230 |
+
def optimize_D(self):
|
1231 |
+
self.forward()
|
1232 |
+
self.set_requires_grad([self.netD], True)
|
1233 |
+
self.backward_D()
|
1234 |
+
|
1235 |
+
def optimize_D_OCR_step(self):
|
1236 |
+
self.optimizer_D.step()
|
1237 |
+
|
1238 |
+
self.optimizer_OCR.step()
|
1239 |
+
self.optimizer_D.zero_grad()
|
1240 |
+
self.optimizer_OCR.zero_grad()
|
1241 |
+
|
1242 |
+
|
1243 |
+
def optimize_D_OCR_WL(self):
|
1244 |
+
self.forward()
|
1245 |
+
self.set_requires_grad([self.netD], True)
|
1246 |
+
self.set_requires_grad([self.netOCR], True)
|
1247 |
+
self.set_requires_grad([self.netW], True)
|
1248 |
+
self.optimizer_D.zero_grad()
|
1249 |
+
self.optimizer_wl.zero_grad()
|
1250 |
+
if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1251 |
+
self.optimizer_OCR.zero_grad()
|
1252 |
+
self.backward_D_OCR_WL()
|
1253 |
+
|
1254 |
+
def optimize_D_OCR_WL_step(self):
|
1255 |
+
self.optimizer_D.step()
|
1256 |
+
if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1257 |
+
self.optimizer_OCR.step()
|
1258 |
+
self.optimizer_wl.step()
|
1259 |
+
self.optimizer_D.zero_grad()
|
1260 |
+
self.optimizer_OCR.zero_grad()
|
1261 |
+
self.optimizer_wl.zero_grad()
|
1262 |
+
|
1263 |
+
def optimize_D_step(self):
|
1264 |
+
self.optimizer_D.step()
|
1265 |
+
if any(torch.isnan(self.netD.infer_img.blocks[0][0].conv1.bias)):
|
1266 |
+
print('D is nan')
|
1267 |
+
sys.exit()
|
1268 |
+
self.optimizer_D.zero_grad()
|
1269 |
+
|
1270 |
+
def optimize_G(self):
|
1271 |
+
self.forward()
|
1272 |
+
self.set_requires_grad([self.netD], False)
|
1273 |
+
self.set_requires_grad([self.netOCR], False)
|
1274 |
+
self.set_requires_grad([self.netW], False)
|
1275 |
+
self.backward_G()
|
1276 |
+
|
1277 |
+
def optimize_G_WL(self):
|
1278 |
+
self.forward()
|
1279 |
+
self.set_requires_grad([self.netD], False)
|
1280 |
+
self.set_requires_grad([self.netOCR], False)
|
1281 |
+
self.set_requires_grad([self.netW], False)
|
1282 |
+
self.backward_G_WL()
|
1283 |
+
|
1284 |
+
|
1285 |
+
def optimize_G_only(self):
|
1286 |
+
self.forward()
|
1287 |
+
self.set_requires_grad([self.netD], False)
|
1288 |
+
self.set_requires_grad([self.netOCR], False)
|
1289 |
+
self.set_requires_grad([self.netW], False)
|
1290 |
+
self.backward_G_only()
|
1291 |
+
|
1292 |
+
|
1293 |
+
def optimize_G_step(self):
|
1294 |
+
|
1295 |
+
self.optimizer_G.step()
|
1296 |
+
self.optimizer_G.zero_grad()
|
1297 |
+
|
1298 |
+
def optimize_ocr(self):
|
1299 |
+
self.set_requires_grad([self.netOCR], True)
|
1300 |
+
# OCR loss on real data
|
1301 |
+
pred_real_OCR = self.netOCR(self.real)
|
1302 |
+
preds_size =torch.IntTensor([pred_real_OCR.size(0)] * self.opt.batch_size).detach()
|
1303 |
+
self.loss_OCR_real = self.OCR_criterion(pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
1304 |
+
self.loss_OCR_real.backward()
|
1305 |
+
self.optimizer_OCR.step()
|
1306 |
+
|
1307 |
+
def optimize_z(self):
|
1308 |
+
self.set_requires_grad([self.z], True)
|
1309 |
+
|
1310 |
+
|
1311 |
+
def optimize_parameters(self):
|
1312 |
+
self.forward()
|
1313 |
+
self.set_requires_grad([self.netD], False)
|
1314 |
+
self.optimizer_G.zero_grad()
|
1315 |
+
self.backward_G()
|
1316 |
+
self.optimizer_G.step()
|
1317 |
+
|
1318 |
+
self.set_requires_grad([self.netD], True)
|
1319 |
+
self.optimizer_D.zero_grad()
|
1320 |
+
self.backward_D()
|
1321 |
+
self.optimizer_D.step()
|
1322 |
+
|
1323 |
+
def test(self):
|
1324 |
+
self.visual_names = ['fake']
|
1325 |
+
self.netG.eval()
|
1326 |
+
with torch.no_grad():
|
1327 |
+
self.forward()
|
1328 |
+
|
1329 |
+
def train_GD(self):
|
1330 |
+
self.netG.train()
|
1331 |
+
self.netD.train()
|
1332 |
+
self.optimizer_G.zero_grad()
|
1333 |
+
self.optimizer_D.zero_grad()
|
1334 |
+
# How many chunks to split x and y into?
|
1335 |
+
x = torch.split(self.real, self.opt.batch_size)
|
1336 |
+
y = torch.split(self.label, self.opt.batch_size)
|
1337 |
+
counter = 0
|
1338 |
+
|
1339 |
+
# Optionally toggle D and G's "require_grad"
|
1340 |
+
if self.opt.toggle_grads:
|
1341 |
+
toggle_grad(self.netD, True)
|
1342 |
+
toggle_grad(self.netG, False)
|
1343 |
+
|
1344 |
+
for step_index in range(self.opt.num_critic_train):
|
1345 |
+
self.optimizer_D.zero_grad()
|
1346 |
+
with torch.set_grad_enabled(False):
|
1347 |
+
self.forward()
|
1348 |
+
D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake
|
1349 |
+
D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter]
|
1350 |
+
# Get Discriminator output
|
1351 |
+
D_out = self.netD(D_input, D_class)
|
1352 |
+
if x is not None:
|
1353 |
+
pred_fake, pred_real = torch.split(D_out, [self.fake.shape[0], x[counter].shape[0]]) # D_fake, D_real
|
1354 |
+
else:
|
1355 |
+
pred_fake = D_out
|
1356 |
+
# Combined loss
|
1357 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
|
1358 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
1359 |
+
self.loss_D.backward()
|
1360 |
+
counter += 1
|
1361 |
+
self.optimizer_D.step()
|
1362 |
+
|
1363 |
+
# Optionally toggle D and G's "require_grad"
|
1364 |
+
if self.opt.toggle_grads:
|
1365 |
+
toggle_grad(self.netD, False)
|
1366 |
+
toggle_grad(self.netG, True)
|
1367 |
+
# Zero G's gradients by default before training G, for safety
|
1368 |
+
self.optimizer_G.zero_grad()
|
1369 |
+
self.forward()
|
1370 |
+
self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss)
|
1371 |
+
self.loss_G.backward()
|
1372 |
+
self.optimizer_G.step()
|
1373 |
+
|
1374 |
+
|
1375 |
+
|
1376 |
+
|
1377 |
+
|
1378 |
+
|
1379 |
+
|
1380 |
+
|
1381 |
+
|
1382 |
+
|
1383 |
+
|
1384 |
+
|
1385 |
+
|
1386 |
+
|
1387 |
+
|
1388 |
+
|
1389 |
+
|
util/models/model_.py
ADDED
@@ -0,0 +1,1264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pandas as pd
|
3 |
+
from .OCR_network import *
|
4 |
+
from torch.nn import CTCLoss, MSELoss, L1Loss
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
import random
|
7 |
+
import unicodedata
|
8 |
+
import sys
|
9 |
+
import torchvision.models as models
|
10 |
+
from models.transformer import *
|
11 |
+
from .BigGAN_networks import *
|
12 |
+
from params import *
|
13 |
+
from .OCR_network import *
|
14 |
+
from models.blocks import LinearBlock, Conv2dBlock, ResBlocks, ActFirstResBlock
|
15 |
+
from util.util import toggle_grad, loss_hinge_dis, loss_hinge_gen, ortho, default_ortho, toggle_grad, prepare_z_y, \
|
16 |
+
make_one_hot, to_device, multiple_replace, random_word
|
17 |
+
from models.inception import InceptionV3, calculate_frechet_distance
|
18 |
+
|
19 |
+
class FCNDecoder(nn.Module):
|
20 |
+
def __init__(self, ups=3, n_res=2, dim=512, out_dim=1, res_norm='adain', activ='relu', pad_type='reflect'):
|
21 |
+
super(FCNDecoder, self).__init__()
|
22 |
+
|
23 |
+
self.model = []
|
24 |
+
self.model += [ResBlocks(n_res, dim, res_norm,
|
25 |
+
activ, pad_type=pad_type)]
|
26 |
+
for i in range(ups):
|
27 |
+
self.model += [nn.Upsample(scale_factor=2),
|
28 |
+
Conv2dBlock(dim, dim // 2, 5, 1, 2,
|
29 |
+
norm='in',
|
30 |
+
activation=activ,
|
31 |
+
pad_type=pad_type)]
|
32 |
+
dim //= 2
|
33 |
+
self.model += [Conv2dBlock(dim, out_dim, 7, 1, 3,
|
34 |
+
norm='none',
|
35 |
+
activation='tanh',
|
36 |
+
pad_type=pad_type)]
|
37 |
+
self.model = nn.Sequential(*self.model)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
y = self.model(x)
|
41 |
+
|
42 |
+
return y
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class Generator(nn.Module):
|
47 |
+
|
48 |
+
def __init__(self):
|
49 |
+
super(Generator, self).__init__()
|
50 |
+
|
51 |
+
INP_CHANNEL = NUM_EXAMPLES
|
52 |
+
if IS_SEQ: INP_CHANNEL = 1
|
53 |
+
|
54 |
+
|
55 |
+
encoder_layer = TransformerEncoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
|
56 |
+
TN_DROPOUT, "relu", True)
|
57 |
+
encoder_norm = nn.LayerNorm(TN_HIDDEN_DIM) if True else None
|
58 |
+
self.encoder = TransformerEncoder(encoder_layer, TN_ENC_LAYERS, encoder_norm)
|
59 |
+
|
60 |
+
decoder_layer = TransformerDecoderLayer(TN_HIDDEN_DIM, TN_NHEADS, TN_DIM_FEEDFORWARD,
|
61 |
+
TN_DROPOUT, "relu", True)
|
62 |
+
decoder_norm = nn.LayerNorm(TN_HIDDEN_DIM)
|
63 |
+
self.decoder = TransformerDecoder(decoder_layer, TN_DEC_LAYERS, decoder_norm,
|
64 |
+
return_intermediate=True)
|
65 |
+
|
66 |
+
self.Feat_Encoder = nn.Sequential(*([nn.Conv2d(INP_CHANNEL, 64, kernel_size=7, stride=2, padding=3, bias=False)] +list(models.resnet18(pretrained=True).children())[1:-2]))
|
67 |
+
|
68 |
+
self.query_embed = nn.Embedding(VOCAB_SIZE, TN_HIDDEN_DIM)
|
69 |
+
|
70 |
+
|
71 |
+
self.linear_q = nn.Linear(TN_DIM_FEEDFORWARD*2, TN_DIM_FEEDFORWARD*8)
|
72 |
+
|
73 |
+
self.DEC = FCNDecoder(res_norm = 'in')
|
74 |
+
|
75 |
+
|
76 |
+
self._muE = nn.Linear(512,512)
|
77 |
+
self._logvarE = nn.Linear(512,512)
|
78 |
+
|
79 |
+
self._muD = nn.Linear(512,512)
|
80 |
+
self._logvarD = nn.Linear(512,512)
|
81 |
+
|
82 |
+
|
83 |
+
self.l1loss = nn.L1Loss()
|
84 |
+
|
85 |
+
self.noise = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def reparameterize(self, mu, logvar):
|
90 |
+
|
91 |
+
mu = torch.unbind(mu , 1)
|
92 |
+
logvar = torch.unbind(logvar , 1)
|
93 |
+
|
94 |
+
outs = []
|
95 |
+
|
96 |
+
for m,l in zip(mu, logvar):
|
97 |
+
|
98 |
+
sigma = torch.exp(l)
|
99 |
+
eps = torch.cuda.FloatTensor(l.size()[0],1).normal_(0,1)
|
100 |
+
eps = eps.expand(sigma.size())
|
101 |
+
|
102 |
+
out = m + sigma*eps
|
103 |
+
|
104 |
+
outs.append(out)
|
105 |
+
|
106 |
+
|
107 |
+
return torch.stack(outs, 1)
|
108 |
+
|
109 |
+
|
110 |
+
def Eval(self, ST, QRS):
|
111 |
+
|
112 |
+
if IS_SEQ:
|
113 |
+
B, N, R, C = ST.shape
|
114 |
+
FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
|
115 |
+
FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
|
116 |
+
else:
|
117 |
+
FEAT_ST = self.Feat_Encoder(ST)
|
118 |
+
|
119 |
+
|
120 |
+
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
|
121 |
+
|
122 |
+
memory = self.encoder(FEAT_ST_ENC)
|
123 |
+
|
124 |
+
if IS_KLD:
|
125 |
+
|
126 |
+
Ex = memory.permute(1,0,2)
|
127 |
+
|
128 |
+
memory_mu = self._muE(Ex)
|
129 |
+
memory_logvar = self._logvarE(Ex)
|
130 |
+
|
131 |
+
memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
|
132 |
+
|
133 |
+
|
134 |
+
OUT_IMGS = []
|
135 |
+
|
136 |
+
for i in range(QRS.shape[1]):
|
137 |
+
|
138 |
+
QR = QRS[:, i, :]
|
139 |
+
|
140 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
141 |
+
|
142 |
+
tgt = torch.zeros_like(QR_EMB)
|
143 |
+
|
144 |
+
hs = self.decoder(tgt, memory, query_pos=QR_EMB)
|
145 |
+
|
146 |
+
if IS_KLD:
|
147 |
+
|
148 |
+
Dx = hs[0].permute(1,0,2)
|
149 |
+
|
150 |
+
hs_mu = self._muD(Dx)
|
151 |
+
hs_logvar = self._logvarD(Dx)
|
152 |
+
|
153 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
154 |
+
|
155 |
+
|
156 |
+
h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
|
157 |
+
if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
|
158 |
+
|
159 |
+
h = self.linear_q(h)
|
160 |
+
h = h.contiguous()
|
161 |
+
|
162 |
+
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
163 |
+
h = h.permute(0, 3, 2, 1)
|
164 |
+
|
165 |
+
h = self.DEC(h)
|
166 |
+
|
167 |
+
|
168 |
+
OUT_IMGS.append(h.detach())
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
return OUT_IMGS
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
def forward(self, ST, QR, QRs = None, mode = 'train'):
|
180 |
+
|
181 |
+
if IS_SEQ:
|
182 |
+
B, N, R, C = ST.shape
|
183 |
+
FEAT_ST = self.Feat_Encoder(ST.view(B*N, 1, R, C))
|
184 |
+
FEAT_ST = FEAT_ST.view(B, 512, 1, -1)
|
185 |
+
else:
|
186 |
+
FEAT_ST = self.Feat_Encoder(ST)
|
187 |
+
|
188 |
+
|
189 |
+
FEAT_ST_ENC = FEAT_ST.flatten(2).permute(2,0,1)
|
190 |
+
|
191 |
+
memory = self.encoder(FEAT_ST_ENC)
|
192 |
+
|
193 |
+
if IS_KLD:
|
194 |
+
|
195 |
+
Ex = memory.permute(1,0,2)
|
196 |
+
|
197 |
+
memory_mu = self._muE(Ex)
|
198 |
+
memory_logvar = self._logvarE(Ex)
|
199 |
+
|
200 |
+
memory = self.reparameterize(memory_mu, memory_logvar).permute(1,0,2)
|
201 |
+
|
202 |
+
|
203 |
+
QR_EMB = self.query_embed.weight.repeat(batch_size,1,1).permute(1,0,2)
|
204 |
+
|
205 |
+
tgt = torch.zeros_like(QR_EMB)
|
206 |
+
|
207 |
+
hs = self.decoder(tgt, memory, query_pos=QR_EMB)
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
if IS_KLD:
|
212 |
+
|
213 |
+
Dx = hs[0].permute(1,0,2)
|
214 |
+
|
215 |
+
hs_mu = self._muD(Dx)
|
216 |
+
hs_logvar = self._logvarD(Dx)
|
217 |
+
|
218 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
219 |
+
|
220 |
+
OUT_Feats1_mu = [hs_mu]
|
221 |
+
OUT_Feats1_logvar = [hs_logvar]
|
222 |
+
|
223 |
+
|
224 |
+
OUT_Feats1 = [hs]
|
225 |
+
|
226 |
+
|
227 |
+
h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
|
228 |
+
|
229 |
+
if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
|
230 |
+
|
231 |
+
h = self.linear_q(h)
|
232 |
+
|
233 |
+
h = h.contiguous()
|
234 |
+
|
235 |
+
h = [torch.stack([h[i][QR[i]] for i in range(batch_size)], 0) for QR in QRs]
|
236 |
+
|
237 |
+
h_list = []
|
238 |
+
|
239 |
+
for h_ in h:
|
240 |
+
|
241 |
+
h_ = h_.view(h_.size(0), h_.shape[1]*2, 4, -1)
|
242 |
+
h_ = h_.permute(0, 3, 2, 1)
|
243 |
+
|
244 |
+
#h_ = self.DEC(h_)
|
245 |
+
|
246 |
+
h_list.append(h_)
|
247 |
+
|
248 |
+
if mode == 'test' or (not IS_CYCLE and not IS_KLD):
|
249 |
+
|
250 |
+
return h
|
251 |
+
|
252 |
+
|
253 |
+
OUT_IMGS = [h]
|
254 |
+
|
255 |
+
for QR in QRs:
|
256 |
+
|
257 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
258 |
+
|
259 |
+
tgt = torch.zeros_like(QR_EMB)
|
260 |
+
|
261 |
+
hs = self.decoder(tgt, memory, query_pos=QR_EMB)
|
262 |
+
|
263 |
+
|
264 |
+
if IS_KLD:
|
265 |
+
|
266 |
+
Dx = hs[0].permute(1,0,2)
|
267 |
+
|
268 |
+
hs_mu = self._muD(Dx)
|
269 |
+
hs_logvar = self._logvarD(Dx)
|
270 |
+
|
271 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
272 |
+
|
273 |
+
OUT_Feats1_mu.append(hs_mu)
|
274 |
+
OUT_Feats1_logvar.append(hs_logvar)
|
275 |
+
|
276 |
+
|
277 |
+
OUT_Feats1.append(hs)
|
278 |
+
|
279 |
+
|
280 |
+
h = torch.cat([hs.transpose(1, 2)[-1], QR_EMB.permute(1,0,2)], -1)
|
281 |
+
if ADD_NOISE: h = h + self.noise.sample(h.size()).squeeze(-1).to(DEVICE)
|
282 |
+
|
283 |
+
h = self.linear_q(h)
|
284 |
+
h = h.contiguous()
|
285 |
+
|
286 |
+
h = h.view(h.size(0), h.shape[1]*2, 4, -1)
|
287 |
+
h = h.permute(0, 3, 2, 1)
|
288 |
+
|
289 |
+
h = self.DEC(h)
|
290 |
+
|
291 |
+
OUT_IMGS.append(h)
|
292 |
+
|
293 |
+
|
294 |
+
if (not IS_CYCLE) and IS_KLD:
|
295 |
+
|
296 |
+
OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]
|
297 |
+
|
298 |
+
OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
|
299 |
+
|
300 |
+
|
301 |
+
KLD = (0.5 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
|
302 |
+
+ (0.5 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
|
307 |
+
return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
|
308 |
+
torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
|
309 |
+
|
310 |
+
|
311 |
+
lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
|
312 |
+
|
313 |
+
|
314 |
+
lda1 = torch.stack(lda1).mean()
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
return OUT_IMGS[0], lda1, KLD
|
319 |
+
|
320 |
+
|
321 |
+
with torch.no_grad():
|
322 |
+
|
323 |
+
if IS_SEQ:
|
324 |
+
|
325 |
+
FEAT_ST_T = torch.cat([self.Feat_Encoder(IM) for IM in OUT_IMGS], -1)
|
326 |
+
|
327 |
+
else:
|
328 |
+
|
329 |
+
max_width_ = max([i_.shape[-1] for i_ in OUT_IMGS])
|
330 |
+
|
331 |
+
FEAT_ST_T = self.Feat_Encoder(torch.cat([torch.cat([i_, torch.ones((i_.shape[0], i_.shape[1],i_.shape[2], max_width_-i_.shape[3])).to(DEVICE)], -1) for i_ in OUT_IMGS], 1))
|
332 |
+
|
333 |
+
FEAT_ST_ENC_T = FEAT_ST_T.flatten(2).permute(2,0,1)
|
334 |
+
|
335 |
+
memory_T = self.encoder(FEAT_ST_ENC_T)
|
336 |
+
|
337 |
+
if IS_KLD:
|
338 |
+
|
339 |
+
Ex = memory_T.permute(1,0,2)
|
340 |
+
|
341 |
+
memory_T_mu = self._muE(Ex)
|
342 |
+
memory_T_logvar = self._logvarE(Ex)
|
343 |
+
|
344 |
+
memory_T = self.reparameterize(memory_T_mu, memory_T_logvar).permute(1,0,2)
|
345 |
+
|
346 |
+
|
347 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
348 |
+
|
349 |
+
tgt = torch.zeros_like(QR_EMB)
|
350 |
+
|
351 |
+
hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
|
352 |
+
|
353 |
+
if IS_KLD:
|
354 |
+
|
355 |
+
Dx = hs[0].permute(1,0,2)
|
356 |
+
|
357 |
+
hs_mu = self._muD(Dx)
|
358 |
+
hs_logvar = self._logvarD(Dx)
|
359 |
+
|
360 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
361 |
+
|
362 |
+
OUT_Feats2_mu = [hs_mu]
|
363 |
+
OUT_Feats2_logvar = [hs_logvar]
|
364 |
+
|
365 |
+
|
366 |
+
OUT_Feats2 = [hs]
|
367 |
+
|
368 |
+
|
369 |
+
|
370 |
+
for QR in QRs:
|
371 |
+
|
372 |
+
QR_EMB = self.query_embed.weight[QR].permute(1,0,2)
|
373 |
+
|
374 |
+
tgt = torch.zeros_like(QR_EMB)
|
375 |
+
|
376 |
+
hs = self.decoder(tgt, memory_T, query_pos=QR_EMB)
|
377 |
+
|
378 |
+
if IS_KLD:
|
379 |
+
|
380 |
+
Dx = hs[0].permute(1,0,2)
|
381 |
+
|
382 |
+
hs_mu = self._muD(Dx)
|
383 |
+
hs_logvar = self._logvarD(Dx)
|
384 |
+
|
385 |
+
hs = self.reparameterize(hs_mu, hs_logvar).permute(1,0,2).unsqueeze(0)
|
386 |
+
|
387 |
+
OUT_Feats2_mu.append(hs_mu)
|
388 |
+
OUT_Feats2_logvar.append(hs_logvar)
|
389 |
+
|
390 |
+
|
391 |
+
OUT_Feats2.append(hs)
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
Lcycle1 = np.sum([self.l1loss(memory[m_i], memory_T[m_j]) for m_i in range(memory.shape[0]) for m_j in range(memory_T.shape[0])])/(memory.shape[0]*memory_T.shape[0])
|
397 |
+
OUT_Feats1 = torch.cat(OUT_Feats1, 1)[0]; OUT_Feats2 = torch.cat(OUT_Feats2, 1)[0]
|
398 |
+
|
399 |
+
Lcycle2 = np.sum([self.l1loss(OUT_Feats1[f_i], OUT_Feats2[f_j]) for f_i in range(OUT_Feats1.shape[0]) for f_j in range(OUT_Feats2.shape[0])])/(OUT_Feats1.shape[0]*OUT_Feats2.shape[0])
|
400 |
+
|
401 |
+
if IS_KLD:
|
402 |
+
|
403 |
+
OUT_Feats1_mu = torch.cat(OUT_Feats1_mu, 1); OUT_Feats1_logvar = torch.cat(OUT_Feats1_logvar, 1);
|
404 |
+
OUT_Feats2_mu = torch.cat(OUT_Feats2_mu, 1); OUT_Feats2_logvar = torch.cat(OUT_Feats2_logvar, 1);
|
405 |
+
|
406 |
+
KLD = (0.25 * torch.mean(1 + memory_logvar - memory_mu.pow(2) - memory_logvar.exp())) \
|
407 |
+
+ (0.25 * torch.mean(1 + memory_T_logvar - memory_T_mu.pow(2) - memory_T_logvar.exp()))\
|
408 |
+
+ (0.25 * torch.mean(1 + OUT_Feats1_logvar - OUT_Feats1_mu.pow(2) - OUT_Feats1_logvar.exp()))\
|
409 |
+
+ (0.25 * torch.mean(1 + OUT_Feats2_logvar - OUT_Feats2_mu.pow(2) - OUT_Feats2_logvar.exp()))
|
410 |
+
|
411 |
+
|
412 |
+
def _get_lda(Ex_mu, Dx_mu, Ex_logvar, Dx_logvar):
|
413 |
+
return torch.sqrt(torch.sum((Ex_mu - Dx_mu) ** 2, dim=1) + \
|
414 |
+
torch.sum((torch.sqrt(Ex_logvar.exp()) - torch.sqrt(Dx_logvar.exp())) ** 2, dim=1)).sum()
|
415 |
+
|
416 |
+
|
417 |
+
lda1 = [_get_lda(memory_mu[:,idi,:], OUT_Feats1_mu[:,idj,:], memory_logvar[:,idi,:], OUT_Feats1_logvar[:,idj,:]) for idi in range(memory.shape[0]) for idj in range(OUT_Feats1.shape[0])]
|
418 |
+
lda2 = [_get_lda(memory_T_mu[:,idi,:], OUT_Feats2_mu[:,idj,:], memory_T_logvar[:,idi,:], OUT_Feats2_logvar[:,idj,:]) for idi in range(memory_T.shape[0]) for idj in range(OUT_Feats2.shape[0])]
|
419 |
+
|
420 |
+
lda1 = torch.stack(lda1).mean()
|
421 |
+
lda2 = torch.stack(lda2).mean()
|
422 |
+
|
423 |
+
|
424 |
+
return OUT_IMGS[0], Lcycle1, Lcycle2, lda1, lda2, KLD
|
425 |
+
|
426 |
+
|
427 |
+
return OUT_IMGS[0], Lcycle1, Lcycle2
|
428 |
+
|
429 |
+
|
430 |
+
|
431 |
+
class TRGAN(nn.Module):
|
432 |
+
|
433 |
+
def __init__(self):
|
434 |
+
super(TRGAN, self).__init__()
|
435 |
+
|
436 |
+
|
437 |
+
self.epsilon = 1e-7
|
438 |
+
self.netG = Generator().to(DEVICE)
|
439 |
+
self.netD = nn.DataParallel(Discriminator()).to(DEVICE)
|
440 |
+
self.netW = nn.DataParallel(WDiscriminator()).to(DEVICE)
|
441 |
+
self.netconverter = strLabelConverter(ALPHABET)
|
442 |
+
self.netOCR = CRNN().to(DEVICE)
|
443 |
+
self.OCR_criterion = CTCLoss(zero_infinity=True, reduction='none')
|
444 |
+
|
445 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
446 |
+
self.inception = InceptionV3([block_idx]).to(DEVICE)
|
447 |
+
|
448 |
+
|
449 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
|
450 |
+
lr=G_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
451 |
+
self.optimizer_OCR = torch.optim.Adam(self.netOCR.parameters(),
|
452 |
+
lr=OCR_LR, betas=(0.0, 0.999), weight_decay=0,
|
453 |
+
eps=1e-8)
|
454 |
+
|
455 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
|
456 |
+
lr=D_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
457 |
+
|
458 |
+
|
459 |
+
self.optimizer_wl = torch.optim.Adam(self.netW.parameters(),
|
460 |
+
lr=W_LR, betas=(0.0, 0.999), weight_decay=0, eps=1e-8)
|
461 |
+
|
462 |
+
|
463 |
+
self.optimizers = [self.optimizer_G, self.optimizer_OCR, self.optimizer_D, self.optimizer_wl]
|
464 |
+
|
465 |
+
|
466 |
+
self.optimizer_G.zero_grad()
|
467 |
+
self.optimizer_OCR.zero_grad()
|
468 |
+
self.optimizer_D.zero_grad()
|
469 |
+
self.optimizer_wl.zero_grad()
|
470 |
+
|
471 |
+
self.loss_G = 0
|
472 |
+
self.loss_D = 0
|
473 |
+
self.loss_Dfake = 0
|
474 |
+
self.loss_Dreal = 0
|
475 |
+
self.loss_OCR_fake = 0
|
476 |
+
self.loss_OCR_real = 0
|
477 |
+
self.loss_w_fake = 0
|
478 |
+
self.loss_w_real = 0
|
479 |
+
self.Lcycle1 = 0
|
480 |
+
self.Lcycle2 = 0
|
481 |
+
self.lda1 = 0
|
482 |
+
self.lda2 = 0
|
483 |
+
self.KLD = 0
|
484 |
+
|
485 |
+
|
486 |
+
with open('../Lexicon/english_words.txt', 'rb') as f:
|
487 |
+
self.lex = f.read().splitlines()
|
488 |
+
lex=[]
|
489 |
+
for word in self.lex:
|
490 |
+
try:
|
491 |
+
word=word.decode("utf-8")
|
492 |
+
except:
|
493 |
+
continue
|
494 |
+
if len(word)<20:
|
495 |
+
lex.append(word)
|
496 |
+
self.lex = lex
|
497 |
+
|
498 |
+
|
499 |
+
f = open('mytext.txt', 'r')
|
500 |
+
|
501 |
+
self.text = [j.encode() for j in sum([i.split(' ') for i in f.readlines()], [])][:NUM_EXAMPLES]
|
502 |
+
self.eval_text_encode, self.eval_len_text = self.netconverter.encode(self.text)
|
503 |
+
self.eval_text_encode = self.eval_text_encode.to(DEVICE).repeat(batch_size, 1, 1)
|
504 |
+
|
505 |
+
|
506 |
+
def _generate_page(self):
|
507 |
+
|
508 |
+
self.fakes = self.netG.Eval(self.sdata, self.eval_text_encode)
|
509 |
+
|
510 |
+
word_t = []
|
511 |
+
word_l = []
|
512 |
+
|
513 |
+
gap = np.ones([32,16])
|
514 |
+
|
515 |
+
line_wids = []
|
516 |
+
|
517 |
+
|
518 |
+
for idx, fake_ in enumerate(self.fakes):
|
519 |
+
|
520 |
+
word_t.append((fake_[0,0,:,:self.eval_len_text[idx]*resolution].cpu().numpy()+1)/2)
|
521 |
+
|
522 |
+
word_t.append(gap)
|
523 |
+
|
524 |
+
if len(word_t) == 16 or idx == len(self.fakes) - 1:
|
525 |
+
|
526 |
+
line_ = np.concatenate(word_t, -1)
|
527 |
+
|
528 |
+
word_l.append(line_)
|
529 |
+
line_wids.append(line_.shape[1])
|
530 |
+
|
531 |
+
word_t = []
|
532 |
+
|
533 |
+
|
534 |
+
gap_h = np.ones([16,max(line_wids)])
|
535 |
+
|
536 |
+
page_= []
|
537 |
+
|
538 |
+
for l in word_l:
|
539 |
+
|
540 |
+
pad_ = np.ones([32,max(line_wids) - l.shape[1]])
|
541 |
+
|
542 |
+
page_.append(np.concatenate([l, pad_], 1))
|
543 |
+
page_.append(gap_h)
|
544 |
+
|
545 |
+
|
546 |
+
|
547 |
+
page1 = np.concatenate(page_, 0)
|
548 |
+
|
549 |
+
|
550 |
+
word_t = []
|
551 |
+
word_l = []
|
552 |
+
|
553 |
+
gap = np.ones([32,16])
|
554 |
+
|
555 |
+
line_wids = []
|
556 |
+
|
557 |
+
sdata_ = [i.unsqueeze(1) for i in torch.unbind(self.sdata, 1)]
|
558 |
+
|
559 |
+
for idx, st in enumerate((sdata_)):
|
560 |
+
|
561 |
+
word_t.append((st[0,0,:,:int(self.input['swids'].cpu().numpy()[0][idx])
|
562 |
+
].cpu().numpy()+1)/2)
|
563 |
+
|
564 |
+
word_t.append(gap)
|
565 |
+
|
566 |
+
if len(word_t) == 16 or idx == len(self.fakes) - 1:
|
567 |
+
|
568 |
+
line_ = np.concatenate(word_t, -1)
|
569 |
+
|
570 |
+
word_l.append(line_)
|
571 |
+
line_wids.append(line_.shape[1])
|
572 |
+
|
573 |
+
word_t = []
|
574 |
+
|
575 |
+
|
576 |
+
gap_h = np.ones([16,max(line_wids)])
|
577 |
+
|
578 |
+
page_= []
|
579 |
+
|
580 |
+
for l in word_l:
|
581 |
+
|
582 |
+
pad_ = np.ones([32,max(line_wids) - l.shape[1]])
|
583 |
+
|
584 |
+
page_.append(np.concatenate([l, pad_], 1))
|
585 |
+
page_.append(gap_h)
|
586 |
+
|
587 |
+
|
588 |
+
|
589 |
+
page2 = np.concatenate(page_, 0)
|
590 |
+
|
591 |
+
merge_w_size = max(page1.shape[0], page2.shape[0])
|
592 |
+
|
593 |
+
if page1.shape[0] != merge_w_size:
|
594 |
+
|
595 |
+
page1 = np.concatenate([page1, np.ones([merge_w_size-page1.shape[0], page1.shape[1]])], 0)
|
596 |
+
|
597 |
+
if page2.shape[0] != merge_w_size:
|
598 |
+
|
599 |
+
page2 = np.concatenate([page2, np.ones([merge_w_size-page2.shape[0], page2.shape[1]])], 0)
|
600 |
+
|
601 |
+
|
602 |
+
page = np.concatenate([page2, page1], 1)
|
603 |
+
|
604 |
+
|
605 |
+
return page
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
|
610 |
+
|
611 |
+
|
612 |
+
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
|
617 |
+
|
618 |
+
|
619 |
+
|
620 |
+
|
621 |
+
|
622 |
+
|
623 |
+
|
624 |
+
|
625 |
+
#FEAT1 = self.inception(torch.cat(self.fakes, 0).repeat(1,3,1,1))[0].detach().view(batch_size, len(self.fakes), -1).cpu().numpy()
|
626 |
+
#FEAT2 = self.inception(self.sdata.view(batch_size*NUM_EXAMPLES, 1, 32, -1).repeat(1,3,1,1))[0].detach().view(batch_size, NUM_EXAMPLES, -1 ).cpu().numpy()
|
627 |
+
#muvars1 = [{'mu':np.mean(FEAT1[i], axis=0), 'sigma' : np.cov(FEAT1[i], rowvar=False)} for i in range(FEAT1.shape[0])]
|
628 |
+
#muvars2 = [{'mu':np.mean(FEAT2[i], axis=0), 'sigma' : np.cov(FEAT2[i], rowvar=False)} for i in range(FEAT2.shape[0])]
|
629 |
+
|
630 |
+
|
631 |
+
|
632 |
+
|
633 |
+
|
634 |
+
|
635 |
+
def get_current_losses(self):
|
636 |
+
|
637 |
+
losses = {}
|
638 |
+
|
639 |
+
losses['G'] = self.loss_G
|
640 |
+
losses['D'] = self.loss_D
|
641 |
+
losses['Dfake'] = self.loss_Dfake
|
642 |
+
losses['Dreal'] = self.loss_Dreal
|
643 |
+
losses['OCR_fake'] = self.loss_OCR_fake
|
644 |
+
losses['OCR_real'] = self.loss_OCR_real
|
645 |
+
losses['w_fake'] = self.loss_w_fake
|
646 |
+
losses['w_real'] = self.loss_w_real
|
647 |
+
losses['cycle1'] = self.Lcycle1
|
648 |
+
losses['cycle2'] = self.Lcycle2
|
649 |
+
losses['lda1'] = self.lda1
|
650 |
+
losses['lda2'] = self.lda2
|
651 |
+
losses['KLD'] = self.KLD
|
652 |
+
|
653 |
+
return losses
|
654 |
+
|
655 |
+
def visualize_images(self):
|
656 |
+
|
657 |
+
imgs = {}
|
658 |
+
|
659 |
+
|
660 |
+
imgs['fake-1']=self.netG(self.sdata[0:1], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
|
661 |
+
imgs['fake-2']=self.netG(self.sdata[0:1], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
662 |
+
imgs['fake-3']=self.netG(self.sdata[0:1], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
663 |
+
|
664 |
+
|
665 |
+
imgs['res-1'] = torch.cat([self.sdata[0, 0],self.sdata[0, 1],self.sdata[0, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
|
666 |
+
|
667 |
+
|
668 |
+
imgs['fake-1']=self.netG(self.sdata[1:2], self.text_encode_fake[0].unsqueeze(0), mode = 'test' )[0, 0].detach()
|
669 |
+
imgs['fake-2']=self.netG(self.sdata[1:2], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
670 |
+
imgs['fake-3']=self.netG(self.sdata[1:2], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
671 |
+
|
672 |
+
|
673 |
+
imgs['res-2'] = torch.cat([self.sdata[1, 0],self.sdata[1, 1],self.sdata[1, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
|
674 |
+
|
675 |
+
|
676 |
+
imgs['fake-1']=self.netG(self.sdata[2:3], self.text_encode_fake[0].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
677 |
+
imgs['fake-2']=self.netG(self.sdata[2:3], self.text_encode_fake[1].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
678 |
+
imgs['fake-3']=self.netG(self.sdata[2:3], self.text_encode_fake[2].unsqueeze(0) , mode = 'test' )[0, 0].detach()
|
679 |
+
|
680 |
+
|
681 |
+
imgs['res-3'] = torch.cat([self.sdata[2, 0],self.sdata[2, 1],self.sdata[2, 2], imgs['fake-1'], imgs['fake-2'], imgs['fake-3']], -1)
|
682 |
+
|
683 |
+
|
684 |
+
|
685 |
+
|
686 |
+
return imgs
|
687 |
+
|
688 |
+
|
689 |
+
def load_networks(self, epoch):
|
690 |
+
BaseModel.load_networks(self, epoch)
|
691 |
+
if self.opt.single_writer:
|
692 |
+
load_filename = '%s_z.pkl' % (epoch)
|
693 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
694 |
+
self.z = torch.load(load_path)
|
695 |
+
|
696 |
+
def _set_input(self, input):
|
697 |
+
self.input = input
|
698 |
+
|
699 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
700 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
701 |
+
Parameters:
|
702 |
+
nets (network list) -- a list of networks
|
703 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
704 |
+
"""
|
705 |
+
if not isinstance(nets, list):
|
706 |
+
nets = [nets]
|
707 |
+
for net in nets:
|
708 |
+
if net is not None:
|
709 |
+
for param in net.parameters():
|
710 |
+
param.requires_grad = requires_grad
|
711 |
+
|
712 |
+
def forward(self):
|
713 |
+
|
714 |
+
|
715 |
+
self.real = self.input['img'].to(DEVICE)
|
716 |
+
self.label = self.input['label']
|
717 |
+
self.sdata = self.input['simg'].to(DEVICE)
|
718 |
+
self.ST_LEN = self.input['swids']
|
719 |
+
self.text_encode, self.len_text = self.netconverter.encode(self.label)
|
720 |
+
self.one_hot_real = make_one_hot(self.text_encode, self.len_text, VOCAB_SIZE).to(DEVICE).detach()
|
721 |
+
self.text_encode = self.text_encode.to(DEVICE).detach()
|
722 |
+
self.len_text = self.len_text.detach()
|
723 |
+
|
724 |
+
self.words = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
|
725 |
+
self.text_encode_fake, self.len_text_fake = self.netconverter.encode(self.words)
|
726 |
+
self.text_encode_fake = self.text_encode_fake.to(DEVICE)
|
727 |
+
self.one_hot_fake = make_one_hot(self.text_encode_fake, self.len_text_fake, VOCAB_SIZE).to(DEVICE)
|
728 |
+
|
729 |
+
|
730 |
+
self.text_encode_fake_js = []
|
731 |
+
|
732 |
+
for _ in range(NUM_WORDS - 1):
|
733 |
+
|
734 |
+
self.words_j = [word.encode('utf-8') for word in np.random.choice(self.lex, batch_size)]
|
735 |
+
self.text_encode_fake_j, self.len_text_fake_j = self.netconverter.encode(self.words_j)
|
736 |
+
self.text_encode_fake_j = self.text_encode_fake_j.to(DEVICE)
|
737 |
+
self.text_encode_fake_js.append(self.text_encode_fake_j)
|
738 |
+
|
739 |
+
|
740 |
+
if IS_CYCLE and IS_KLD:
|
741 |
+
|
742 |
+
self.fake, self.Lcycle1, self.Lcycle2, self.lda1, self.lda2, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
743 |
+
|
744 |
+
elif IS_CYCLE and (not IS_KLD):
|
745 |
+
|
746 |
+
self.fake, self.Lcycle1, self.Lcycle2 = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
747 |
+
|
748 |
+
elif (not IS_CYCLE) and IS_KLD:
|
749 |
+
|
750 |
+
self.fake, self.lda1, self.KLD = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
751 |
+
|
752 |
+
else:
|
753 |
+
|
754 |
+
self.fake = self.netG(self.sdata, self.text_encode_fake, self.text_encode_fake_js)
|
755 |
+
|
756 |
+
|
757 |
+
|
758 |
+
|
759 |
+
def backward_D_OCR(self):
|
760 |
+
|
761 |
+
pred_real = self.netD(self.real.detach())
|
762 |
+
|
763 |
+
pred_fake = self.netD(**{'x': self.fake.detach()})
|
764 |
+
|
765 |
+
|
766 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
|
767 |
+
|
768 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
769 |
+
|
770 |
+
self.pred_real_OCR = self.netOCR(self.real.detach())
|
771 |
+
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * batch_size).detach()
|
772 |
+
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
773 |
+
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
|
774 |
+
|
775 |
+
loss_total = self.loss_D + self.loss_OCR_real
|
776 |
+
|
777 |
+
# backward
|
778 |
+
loss_total.backward()
|
779 |
+
for param in self.netOCR.parameters():
|
780 |
+
param.grad[param.grad!=param.grad]=0
|
781 |
+
param.grad[torch.isnan(param.grad)]=0
|
782 |
+
param.grad[torch.isinf(param.grad)]=0
|
783 |
+
|
784 |
+
|
785 |
+
|
786 |
+
return loss_total
|
787 |
+
|
788 |
+
def backward_D_WL(self):
|
789 |
+
# Real
|
790 |
+
pred_real = self.netD(self.real.detach())
|
791 |
+
|
792 |
+
pred_fake = self.netD(**{'x': self.fake.detach()})
|
793 |
+
|
794 |
+
|
795 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), True)
|
796 |
+
|
797 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
798 |
+
|
799 |
+
|
800 |
+
self.loss_w_real = self.netW(self.real.detach(), self.input['wcl'].to(DEVICE)).mean()
|
801 |
+
# total loss
|
802 |
+
loss_total = self.loss_D + self.loss_w_real
|
803 |
+
|
804 |
+
# backward
|
805 |
+
loss_total.backward()
|
806 |
+
|
807 |
+
|
808 |
+
return loss_total
|
809 |
+
|
810 |
+
def optimize_D_WL(self):
|
811 |
+
self.forward()
|
812 |
+
self.set_requires_grad([self.netD], True)
|
813 |
+
self.set_requires_grad([self.netOCR], False)
|
814 |
+
self.set_requires_grad([self.netW], True)
|
815 |
+
|
816 |
+
self.optimizer_D.zero_grad()
|
817 |
+
self.optimizer_wl.zero_grad()
|
818 |
+
|
819 |
+
self.backward_D_WL()
|
820 |
+
|
821 |
+
|
822 |
+
|
823 |
+
|
824 |
+
def backward_D_OCR_WL(self):
|
825 |
+
# Real
|
826 |
+
if self.real_z_mean is None:
|
827 |
+
pred_real = self.netD(self.real.detach())
|
828 |
+
else:
|
829 |
+
pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
|
830 |
+
# Fake
|
831 |
+
try:
|
832 |
+
pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
|
833 |
+
except:
|
834 |
+
print('a')
|
835 |
+
# Combined loss
|
836 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
|
837 |
+
|
838 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
839 |
+
# OCR loss on real data
|
840 |
+
self.pred_real_OCR = self.netOCR(self.real.detach())
|
841 |
+
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
|
842 |
+
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
843 |
+
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
|
844 |
+
# total loss
|
845 |
+
self.loss_w_real = self.netW(self.real.detach(), self.wcl)
|
846 |
+
loss_total = self.loss_D + self.loss_OCR_real + self.loss_w_real
|
847 |
+
|
848 |
+
# backward
|
849 |
+
loss_total.backward()
|
850 |
+
for param in self.netOCR.parameters():
|
851 |
+
param.grad[param.grad!=param.grad]=0
|
852 |
+
param.grad[torch.isnan(param.grad)]=0
|
853 |
+
param.grad[torch.isinf(param.grad)]=0
|
854 |
+
|
855 |
+
|
856 |
+
|
857 |
+
return loss_total
|
858 |
+
|
859 |
+
def optimize_D_WL_step(self):
|
860 |
+
self.optimizer_D.step()
|
861 |
+
self.optimizer_wl.step()
|
862 |
+
self.optimizer_D.zero_grad()
|
863 |
+
self.optimizer_wl.zero_grad()
|
864 |
+
|
865 |
+
def backward_OCR(self):
|
866 |
+
# OCR loss on real data
|
867 |
+
self.pred_real_OCR = self.netOCR(self.real.detach())
|
868 |
+
preds_size = torch.IntTensor([self.pred_real_OCR.size(0)] * self.opt.batch_size).detach()
|
869 |
+
loss_OCR_real = self.OCR_criterion(self.pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
870 |
+
self.loss_OCR_real = torch.mean(loss_OCR_real[~torch.isnan(loss_OCR_real)])
|
871 |
+
|
872 |
+
# backward
|
873 |
+
self.loss_OCR_real.backward()
|
874 |
+
for param in self.netOCR.parameters():
|
875 |
+
param.grad[param.grad!=param.grad]=0
|
876 |
+
param.grad[torch.isnan(param.grad)]=0
|
877 |
+
param.grad[torch.isinf(param.grad)]=0
|
878 |
+
|
879 |
+
return self.loss_OCR_real
|
880 |
+
|
881 |
+
|
882 |
+
def backward_D(self):
|
883 |
+
# Real
|
884 |
+
if self.real_z_mean is None:
|
885 |
+
pred_real = self.netD(self.real.detach())
|
886 |
+
else:
|
887 |
+
pred_real = self.netD(**{'x': self.real.detach(), 'z': self.real_z_mean.detach()})
|
888 |
+
pred_fake = self.netD(**{'x': self.fake.detach(), 'z': self.z.detach()})
|
889 |
+
# Combined loss
|
890 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
|
891 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
892 |
+
# backward
|
893 |
+
self.loss_D.backward()
|
894 |
+
|
895 |
+
|
896 |
+
return self.loss_D
|
897 |
+
|
898 |
+
|
899 |
+
def backward_G_only(self):
|
900 |
+
|
901 |
+
self.gb_alpha = 0.7
|
902 |
+
#self.Lcycle1 = self.Lcycle1.mean()
|
903 |
+
#self.Lcycle2 = self.Lcycle2.mean()
|
904 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
|
905 |
+
|
906 |
+
|
907 |
+
pred_fake_OCR = self.netOCR(self.fake)
|
908 |
+
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * batch_size).detach()
|
909 |
+
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
|
910 |
+
self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
|
911 |
+
|
912 |
+
self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
|
913 |
+
|
914 |
+
self.loss_T = self.loss_G + self.loss_OCR_fake
|
915 |
+
|
916 |
+
|
917 |
+
|
918 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
|
919 |
+
|
920 |
+
|
921 |
+
self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
|
922 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
|
923 |
+
self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
|
924 |
+
|
925 |
+
|
926 |
+
self.loss_T.backward(retain_graph=True)
|
927 |
+
|
928 |
+
|
929 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
930 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
931 |
+
|
932 |
+
|
933 |
+
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
|
934 |
+
|
935 |
+
|
936 |
+
if a is None:
|
937 |
+
print(self.loss_OCR_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
|
938 |
+
if a>1000 or a<0.0001:
|
939 |
+
print(a)
|
940 |
+
|
941 |
+
|
942 |
+
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
|
943 |
+
|
944 |
+
self.loss_T = self.loss_G + self.loss_OCR_fake
|
945 |
+
|
946 |
+
|
947 |
+
self.loss_T.backward(retain_graph=True)
|
948 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
949 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
|
950 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
951 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
952 |
+
|
953 |
+
with torch.no_grad():
|
954 |
+
self.loss_T.backward()
|
955 |
+
|
956 |
+
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G):
|
957 |
+
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
|
958 |
+
sys.exit()
|
959 |
+
|
960 |
+
def backward_G_WL(self):
|
961 |
+
|
962 |
+
self.gb_alpha = 0.7
|
963 |
+
#self.Lcycle1 = self.Lcycle1.mean()
|
964 |
+
#self.Lcycle2 = self.Lcycle2.mean()
|
965 |
+
|
966 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake}), self.len_text_fake.detach(), True).mean()
|
967 |
+
|
968 |
+
self.loss_w_fake = self.netW(self.fake, self.input['wcl'].to(DEVICE)).mean()
|
969 |
+
|
970 |
+
self.loss_G = self.loss_G + self.Lcycle1 + self.Lcycle2 + self.lda1 + self.lda2 - self.KLD
|
971 |
+
|
972 |
+
self.loss_T = self.loss_G + self.loss_w_fake
|
973 |
+
|
974 |
+
|
975 |
+
|
976 |
+
|
977 |
+
#grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, retain_graph=True)[0]
|
978 |
+
|
979 |
+
|
980 |
+
#self.loss_grad_fake_WL = 10**6*torch.mean(grad_fake_WL**2)
|
981 |
+
#grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, retain_graph=True)[0]
|
982 |
+
#self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
|
983 |
+
|
984 |
+
|
985 |
+
|
986 |
+
self.loss_T.backward(retain_graph=True)
|
987 |
+
|
988 |
+
|
989 |
+
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
990 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=True, retain_graph=True)[0]
|
991 |
+
|
992 |
+
|
993 |
+
a = self.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_WL))
|
994 |
+
|
995 |
+
|
996 |
+
|
997 |
+
if a is None:
|
998 |
+
print(self.loss_w_fake, self.loss_G, torch.std(grad_fake_adv), torch.std(grad_fake_WL))
|
999 |
+
if a>1000 or a<0.0001:
|
1000 |
+
print(a)
|
1001 |
+
|
1002 |
+
self.loss_w_fake = a.detach() * self.loss_w_fake
|
1003 |
+
|
1004 |
+
self.loss_T = self.loss_G + self.loss_w_fake
|
1005 |
+
|
1006 |
+
self.loss_T.backward(retain_graph=True)
|
1007 |
+
grad_fake_WL = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
1008 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G, self.fake, create_graph=False, retain_graph=True)[0]
|
1009 |
+
self.loss_grad_fake_WL = 10 ** 6 * torch.mean(grad_fake_WL ** 2)
|
1010 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
1011 |
+
|
1012 |
+
with torch.no_grad():
|
1013 |
+
self.loss_T.backward()
|
1014 |
+
|
1015 |
+
def backward_G(self):
|
1016 |
+
self.opt.gb_alpha = 0.7
|
1017 |
+
self.loss_G = loss_hinge_gen(self.netD(**{'x': self.fake, 'z': self.z}), self.len_text_fake.detach(), self.opt.mask_loss)
|
1018 |
+
# OCR loss on real data
|
1019 |
+
|
1020 |
+
pred_fake_OCR = self.netOCR(self.fake)
|
1021 |
+
preds_size = torch.IntTensor([pred_fake_OCR.size(0)] * self.opt.batch_size).detach()
|
1022 |
+
loss_OCR_fake = self.OCR_criterion(pred_fake_OCR, self.text_encode_fake.detach(), preds_size, self.len_text_fake.detach())
|
1023 |
+
self.loss_OCR_fake = torch.mean(loss_OCR_fake[~torch.isnan(loss_OCR_fake)])
|
1024 |
+
|
1025 |
+
|
1026 |
+
self.loss_w_fake = self.netW(self.fake, self.wcl)
|
1027 |
+
#self.loss_OCR_fake = self.loss_OCR_fake + self.loss_w_fake
|
1028 |
+
# total loss
|
1029 |
+
|
1030 |
+
# l1 = self.params[0]*self.loss_G
|
1031 |
+
# l2 = self.params[0]*self.loss_OCR_fake
|
1032 |
+
#l3 = self.params[0]*self.loss_w_fake
|
1033 |
+
self.loss_G_ = 10*self.loss_G + self.loss_w_fake
|
1034 |
+
self.loss_T = self.loss_G_ + self.loss_OCR_fake
|
1035 |
+
|
1036 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, retain_graph=True)[0]
|
1037 |
+
|
1038 |
+
|
1039 |
+
self.loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
|
1040 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, retain_graph=True)[0]
|
1041 |
+
self.loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
|
1042 |
+
|
1043 |
+
if not False:
|
1044 |
+
|
1045 |
+
self.loss_T.backward(retain_graph=True)
|
1046 |
+
|
1047 |
+
|
1048 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
1049 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=True, retain_graph=True)[0]
|
1050 |
+
#grad_fake_wl = torch.autograd.grad(self.loss_w_fake, self.fake, create_graph=True, retain_graph=True)[0]
|
1051 |
+
|
1052 |
+
|
1053 |
+
a = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))
|
1054 |
+
|
1055 |
+
|
1056 |
+
#a0 = self.opt.gb_alpha * torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_wl))
|
1057 |
+
|
1058 |
+
if a is None:
|
1059 |
+
print(self.loss_OCR_fake, self.loss_G_, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
|
1060 |
+
if a>1000 or a<0.0001:
|
1061 |
+
print(a)
|
1062 |
+
b = self.opt.gb_alpha * (torch.mean(grad_fake_adv) -
|
1063 |
+
torch.div(torch.std(grad_fake_adv), self.epsilon+torch.std(grad_fake_OCR))*
|
1064 |
+
torch.mean(grad_fake_OCR))
|
1065 |
+
# self.loss_OCR_fake = a.detach() * self.loss_OCR_fake + b.detach() * torch.sum(self.fake)
|
1066 |
+
self.loss_OCR_fake = a.detach() * self.loss_OCR_fake
|
1067 |
+
#self.loss_w_fake = a0.detach() * self.loss_w_fake
|
1068 |
+
|
1069 |
+
self.loss_T = (1-1*self.opt.onlyOCR)*self.loss_G_ + self.loss_OCR_fake# + self.loss_w_fake
|
1070 |
+
self.loss_T.backward(retain_graph=True)
|
1071 |
+
grad_fake_OCR = torch.autograd.grad(self.loss_OCR_fake, self.fake, create_graph=False, retain_graph=True)[0]
|
1072 |
+
grad_fake_adv = torch.autograd.grad(self.loss_G_, self.fake, create_graph=False, retain_graph=True)[0]
|
1073 |
+
self.loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
|
1074 |
+
self.loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
|
1075 |
+
with torch.no_grad():
|
1076 |
+
self.loss_T.backward()
|
1077 |
+
else:
|
1078 |
+
self.loss_T.backward()
|
1079 |
+
|
1080 |
+
if self.opt.clip_grad > 0:
|
1081 |
+
clip_grad_norm_(self.netG.parameters(), self.opt.clip_grad)
|
1082 |
+
if any(torch.isnan(loss_OCR_fake)) or torch.isnan(self.loss_G_):
|
1083 |
+
print('loss OCR fake: ', loss_OCR_fake, ' loss_G: ', self.loss_G, ' words: ', self.words)
|
1084 |
+
sys.exit()
|
1085 |
+
|
1086 |
+
|
1087 |
+
|
1088 |
+
def optimize_D_OCR(self):
|
1089 |
+
self.forward()
|
1090 |
+
self.set_requires_grad([self.netD], True)
|
1091 |
+
self.set_requires_grad([self.netOCR], True)
|
1092 |
+
self.optimizer_D.zero_grad()
|
1093 |
+
#if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1094 |
+
self.optimizer_OCR.zero_grad()
|
1095 |
+
self.backward_D_OCR()
|
1096 |
+
|
1097 |
+
def optimize_OCR(self):
|
1098 |
+
self.forward()
|
1099 |
+
self.set_requires_grad([self.netD], False)
|
1100 |
+
self.set_requires_grad([self.netOCR], True)
|
1101 |
+
if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1102 |
+
self.optimizer_OCR.zero_grad()
|
1103 |
+
self.backward_OCR()
|
1104 |
+
|
1105 |
+
def optimize_D(self):
|
1106 |
+
self.forward()
|
1107 |
+
self.set_requires_grad([self.netD], True)
|
1108 |
+
self.backward_D()
|
1109 |
+
|
1110 |
+
def optimize_D_OCR_step(self):
|
1111 |
+
self.optimizer_D.step()
|
1112 |
+
|
1113 |
+
self.optimizer_OCR.step()
|
1114 |
+
self.optimizer_D.zero_grad()
|
1115 |
+
self.optimizer_OCR.zero_grad()
|
1116 |
+
|
1117 |
+
|
1118 |
+
def optimize_D_OCR_WL(self):
|
1119 |
+
self.forward()
|
1120 |
+
self.set_requires_grad([self.netD], True)
|
1121 |
+
self.set_requires_grad([self.netOCR], True)
|
1122 |
+
self.set_requires_grad([self.netW], True)
|
1123 |
+
self.optimizer_D.zero_grad()
|
1124 |
+
self.optimizer_wl.zero_grad()
|
1125 |
+
if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1126 |
+
self.optimizer_OCR.zero_grad()
|
1127 |
+
self.backward_D_OCR_WL()
|
1128 |
+
|
1129 |
+
def optimize_D_OCR_WL_step(self):
|
1130 |
+
self.optimizer_D.step()
|
1131 |
+
if self.opt.OCR_init in ['glorot', 'xavier', 'ortho', 'N02']:
|
1132 |
+
self.optimizer_OCR.step()
|
1133 |
+
self.optimizer_wl.step()
|
1134 |
+
self.optimizer_D.zero_grad()
|
1135 |
+
self.optimizer_OCR.zero_grad()
|
1136 |
+
self.optimizer_wl.zero_grad()
|
1137 |
+
|
1138 |
+
def optimize_D_step(self):
|
1139 |
+
self.optimizer_D.step()
|
1140 |
+
if any(torch.isnan(self.netD.infer_img.blocks[0][0].conv1.bias)):
|
1141 |
+
print('D is nan')
|
1142 |
+
sys.exit()
|
1143 |
+
self.optimizer_D.zero_grad()
|
1144 |
+
|
1145 |
+
def optimize_G(self):
|
1146 |
+
self.forward()
|
1147 |
+
self.set_requires_grad([self.netD], False)
|
1148 |
+
self.set_requires_grad([self.netOCR], False)
|
1149 |
+
self.set_requires_grad([self.netW], False)
|
1150 |
+
self.backward_G()
|
1151 |
+
|
1152 |
+
def optimize_G_WL(self):
|
1153 |
+
self.forward()
|
1154 |
+
self.set_requires_grad([self.netD], False)
|
1155 |
+
self.set_requires_grad([self.netOCR], False)
|
1156 |
+
self.set_requires_grad([self.netW], False)
|
1157 |
+
self.backward_G_WL()
|
1158 |
+
|
1159 |
+
|
1160 |
+
def optimize_G_only(self):
|
1161 |
+
self.forward()
|
1162 |
+
self.set_requires_grad([self.netD], False)
|
1163 |
+
self.set_requires_grad([self.netOCR], False)
|
1164 |
+
self.set_requires_grad([self.netW], False)
|
1165 |
+
self.backward_G_only()
|
1166 |
+
|
1167 |
+
|
1168 |
+
def optimize_G_step(self):
|
1169 |
+
|
1170 |
+
self.optimizer_G.step()
|
1171 |
+
self.optimizer_G.zero_grad()
|
1172 |
+
|
1173 |
+
def optimize_ocr(self):
|
1174 |
+
self.set_requires_grad([self.netOCR], True)
|
1175 |
+
# OCR loss on real data
|
1176 |
+
pred_real_OCR = self.netOCR(self.real)
|
1177 |
+
preds_size =torch.IntTensor([pred_real_OCR.size(0)] * self.opt.batch_size).detach()
|
1178 |
+
self.loss_OCR_real = self.OCR_criterion(pred_real_OCR, self.text_encode.detach(), preds_size, self.len_text.detach())
|
1179 |
+
self.loss_OCR_real.backward()
|
1180 |
+
self.optimizer_OCR.step()
|
1181 |
+
|
1182 |
+
def optimize_z(self):
|
1183 |
+
self.set_requires_grad([self.z], True)
|
1184 |
+
|
1185 |
+
|
1186 |
+
def optimize_parameters(self):
|
1187 |
+
self.forward()
|
1188 |
+
self.set_requires_grad([self.netD], False)
|
1189 |
+
self.optimizer_G.zero_grad()
|
1190 |
+
self.backward_G()
|
1191 |
+
self.optimizer_G.step()
|
1192 |
+
|
1193 |
+
self.set_requires_grad([self.netD], True)
|
1194 |
+
self.optimizer_D.zero_grad()
|
1195 |
+
self.backward_D()
|
1196 |
+
self.optimizer_D.step()
|
1197 |
+
|
1198 |
+
def test(self):
|
1199 |
+
self.visual_names = ['fake']
|
1200 |
+
self.netG.eval()
|
1201 |
+
with torch.no_grad():
|
1202 |
+
self.forward()
|
1203 |
+
|
1204 |
+
def train_GD(self):
|
1205 |
+
self.netG.train()
|
1206 |
+
self.netD.train()
|
1207 |
+
self.optimizer_G.zero_grad()
|
1208 |
+
self.optimizer_D.zero_grad()
|
1209 |
+
# How many chunks to split x and y into?
|
1210 |
+
x = torch.split(self.real, self.opt.batch_size)
|
1211 |
+
y = torch.split(self.label, self.opt.batch_size)
|
1212 |
+
counter = 0
|
1213 |
+
|
1214 |
+
# Optionally toggle D and G's "require_grad"
|
1215 |
+
if self.opt.toggle_grads:
|
1216 |
+
toggle_grad(self.netD, True)
|
1217 |
+
toggle_grad(self.netG, False)
|
1218 |
+
|
1219 |
+
for step_index in range(self.opt.num_critic_train):
|
1220 |
+
self.optimizer_D.zero_grad()
|
1221 |
+
with torch.set_grad_enabled(False):
|
1222 |
+
self.forward()
|
1223 |
+
D_input = torch.cat([self.fake, x[counter]], 0) if x is not None else self.fake
|
1224 |
+
D_class = torch.cat([self.label_fake, y[counter]], 0) if y[counter] is not None else y[counter]
|
1225 |
+
# Get Discriminator output
|
1226 |
+
D_out = self.netD(D_input, D_class)
|
1227 |
+
if x is not None:
|
1228 |
+
pred_fake, pred_real = torch.split(D_out, [self.fake.shape[0], x[counter].shape[0]]) # D_fake, D_real
|
1229 |
+
else:
|
1230 |
+
pred_fake = D_out
|
1231 |
+
# Combined loss
|
1232 |
+
self.loss_Dreal, self.loss_Dfake = loss_hinge_dis(pred_fake, pred_real, self.len_text_fake.detach(), self.len_text.detach(), self.opt.mask_loss)
|
1233 |
+
self.loss_D = self.loss_Dreal + self.loss_Dfake
|
1234 |
+
self.loss_D.backward()
|
1235 |
+
counter += 1
|
1236 |
+
self.optimizer_D.step()
|
1237 |
+
|
1238 |
+
# Optionally toggle D and G's "require_grad"
|
1239 |
+
if self.opt.toggle_grads:
|
1240 |
+
toggle_grad(self.netD, False)
|
1241 |
+
toggle_grad(self.netG, True)
|
1242 |
+
# Zero G's gradients by default before training G, for safety
|
1243 |
+
self.optimizer_G.zero_grad()
|
1244 |
+
self.forward()
|
1245 |
+
self.loss_G = loss_hinge_gen(self.netD(self.fake, self.label_fake), self.len_text_fake.detach(), self.opt.mask_loss)
|
1246 |
+
self.loss_G.backward()
|
1247 |
+
self.optimizer_G.step()
|
1248 |
+
|
1249 |
+
|
1250 |
+
|
1251 |
+
|
1252 |
+
|
1253 |
+
|
1254 |
+
|
1255 |
+
|
1256 |
+
|
1257 |
+
|
1258 |
+
|
1259 |
+
|
1260 |
+
|
1261 |
+
|
1262 |
+
|
1263 |
+
|
1264 |
+
|
util/models/networks.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
from util.util import to_device, load_network
|
7 |
+
|
8 |
+
###############################################################################
|
9 |
+
# Helper Functions
|
10 |
+
###############################################################################
|
11 |
+
|
12 |
+
|
13 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
14 |
+
"""Initialize network weights.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
net (network) -- network to be initialized
|
18 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
19 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
20 |
+
|
21 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
22 |
+
work better for some applications. Feel free to try yourself.
|
23 |
+
"""
|
24 |
+
def init_func(m): # define the initialization function
|
25 |
+
classname = m.__class__.__name__
|
26 |
+
if (isinstance(m, nn.Conv2d)
|
27 |
+
or isinstance(m, nn.Linear)
|
28 |
+
or isinstance(m, nn.Embedding)):
|
29 |
+
# if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
30 |
+
if init_type == 'N02':
|
31 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
32 |
+
elif init_type in ['glorot', 'xavier']:
|
33 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
34 |
+
elif init_type == 'kaiming':
|
35 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
36 |
+
elif init_type == 'ortho':
|
37 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
38 |
+
else:
|
39 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
40 |
+
# if hasattr(m, 'bias') and m.bias is not None:
|
41 |
+
# init.constant_(m.bias.data, 0.0)
|
42 |
+
# elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
43 |
+
# init.normal_(m.weight.data, 1.0, init_gain)
|
44 |
+
# init.constant_(m.bias.data, 0.0)
|
45 |
+
if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']:
|
46 |
+
print('initialize network with %s' % init_type)
|
47 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
48 |
+
else:
|
49 |
+
print('loading the model from %s' % init_type)
|
50 |
+
net = load_network(net, init_type, 'latest')
|
51 |
+
return net
|
52 |
+
|
53 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
54 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
55 |
+
Parameters:
|
56 |
+
net (network) -- the network to be initialized
|
57 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
58 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
59 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
60 |
+
|
61 |
+
Return an initialized network.
|
62 |
+
"""
|
63 |
+
if len(gpu_ids) > 0:
|
64 |
+
assert(torch.cuda.is_available())
|
65 |
+
net.to(gpu_ids[0])
|
66 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
67 |
+
init_weights(net, init_type, init_gain=init_gain)
|
68 |
+
return net
|
69 |
+
|
70 |
+
|
71 |
+
def get_scheduler(optimizer, opt):
|
72 |
+
"""Return a learning rate scheduler
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
optimizer -- the optimizer of the network
|
76 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
77 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
78 |
+
|
79 |
+
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
|
80 |
+
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
|
81 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
82 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
83 |
+
"""
|
84 |
+
if opt.lr_policy == 'linear':
|
85 |
+
def lambda_rule(epoch):
|
86 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
87 |
+
return lr_l
|
88 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
89 |
+
elif opt.lr_policy == 'step':
|
90 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
91 |
+
elif opt.lr_policy == 'plateau':
|
92 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
93 |
+
elif opt.lr_policy == 'cosine':
|
94 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
95 |
+
else:
|
96 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
97 |
+
return scheduler
|
98 |
+
|
util/models/sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
util/models/sync_batchnorm/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (380 Bytes). View file
|
|
util/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (384 Bytes). View file
|
|
util/models/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc
ADDED
Binary file (13.1 kB). View file
|
|
util/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc
ADDED
Binary file (13 kB). View file
|
|
util/models/sync_batchnorm/__pycache__/comm.cpython-36.pyc
ADDED
Binary file (4.77 kB). View file
|
|
util/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc
ADDED
Binary file (4.77 kB). View file
|
|
util/models/sync_batchnorm/__pycache__/replicate.cpython-36.pyc
ADDED
Binary file (3.44 kB). View file
|
|
util/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc
ADDED
Binary file (3.45 kB). View file
|
|
util/models/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
# _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
def forward(self, input, gain=None, bias=None):
|
49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
50 |
+
if not (self._is_parallel and self.training):
|
51 |
+
out = F.batch_norm(
|
52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
53 |
+
self.training, self.momentum, self.eps)
|
54 |
+
if gain is not None:
|
55 |
+
out = out + gain
|
56 |
+
if bias is not None:
|
57 |
+
out = out + bias
|
58 |
+
return out
|
59 |
+
|
60 |
+
# Resize the input to (B, C, -1).
|
61 |
+
input_shape = input.size()
|
62 |
+
# print(input_shape)
|
63 |
+
input = input.view(input.size(0), input.size(1), -1)
|
64 |
+
|
65 |
+
# Compute the sum and square-sum.
|
66 |
+
sum_size = input.size(0) * input.size(2)
|
67 |
+
input_sum = _sum_ft(input)
|
68 |
+
input_ssum = _sum_ft(input ** 2)
|
69 |
+
# Reduce-and-broadcast the statistics.
|
70 |
+
# print('it begins')
|
71 |
+
if self._parallel_id == 0:
|
72 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
73 |
+
else:
|
74 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
75 |
+
# if self._parallel_id == 0:
|
76 |
+
# # print('here')
|
77 |
+
# sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
78 |
+
# else:
|
79 |
+
# # print('there')
|
80 |
+
# sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
81 |
+
|
82 |
+
# print('how2')
|
83 |
+
# num = sum_size
|
84 |
+
# print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
|
85 |
+
# Fix the graph
|
86 |
+
# sum = (sum.detach() - input_sum.detach()) + input_sum
|
87 |
+
# ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
|
88 |
+
|
89 |
+
# mean = sum / num
|
90 |
+
# var = ssum / num - mean ** 2
|
91 |
+
# # var = (ssum - mean * sum) / num
|
92 |
+
# inv_std = torch.rsqrt(var + self.eps)
|
93 |
+
|
94 |
+
# Compute the output.
|
95 |
+
if gain is not None:
|
96 |
+
# print('gaining')
|
97 |
+
# scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
|
98 |
+
# shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
|
99 |
+
# output = input * scale - shift
|
100 |
+
output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
|
101 |
+
elif self.affine:
|
102 |
+
# MJY:: Fuse the multiplication for speed.
|
103 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
104 |
+
else:
|
105 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
106 |
+
|
107 |
+
# Reshape it.
|
108 |
+
return output.view(input_shape)
|
109 |
+
|
110 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
111 |
+
self._is_parallel = True
|
112 |
+
self._parallel_id = copy_id
|
113 |
+
|
114 |
+
# parallel_id == 0 means master device.
|
115 |
+
if self._parallel_id == 0:
|
116 |
+
ctx.sync_master = self._sync_master
|
117 |
+
else:
|
118 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
119 |
+
|
120 |
+
def _data_parallel_master(self, intermediates):
|
121 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
122 |
+
|
123 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
124 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
125 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
126 |
+
|
127 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
128 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
129 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
130 |
+
|
131 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
132 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
133 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
134 |
+
|
135 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
136 |
+
# print('a')
|
137 |
+
# print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
|
138 |
+
# broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
|
139 |
+
# print('b')
|
140 |
+
outputs = []
|
141 |
+
for i, rec in enumerate(intermediates):
|
142 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
143 |
+
# outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
|
144 |
+
|
145 |
+
return outputs
|
146 |
+
|
147 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
148 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
149 |
+
also maintains the moving average on the master device."""
|
150 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
151 |
+
mean = sum_ / size
|
152 |
+
sumvar = ssum - sum_ * mean
|
153 |
+
unbias_var = sumvar / (size - 1)
|
154 |
+
bias_var = sumvar / size
|
155 |
+
|
156 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
157 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
158 |
+
return mean, torch.rsqrt(bias_var + self.eps)
|
159 |
+
# return mean, bias_var.clamp(self.eps) ** -0.5
|
160 |
+
|
161 |
+
|
162 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
163 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
164 |
+
mini-batch.
|
165 |
+
|
166 |
+
.. math::
|
167 |
+
|
168 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
169 |
+
|
170 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
171 |
+
standard-deviation are reduced across all devices during training.
|
172 |
+
|
173 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
174 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
175 |
+
the statistics only on that device, which accelerated the computation and
|
176 |
+
is also easy to implement, but the statistics might be inaccurate.
|
177 |
+
Instead, in this synchronized version, the statistics will be computed
|
178 |
+
over all training samples distributed on multiple devices.
|
179 |
+
|
180 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
181 |
+
as the built-in PyTorch implementation.
|
182 |
+
|
183 |
+
The mean and standard-deviation are calculated per-dimension over
|
184 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
185 |
+
of size C (where C is the input size).
|
186 |
+
|
187 |
+
During training, this layer keeps a running estimate of its computed mean
|
188 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
189 |
+
|
190 |
+
During evaluation, this running mean/variance is used for normalization.
|
191 |
+
|
192 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
193 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
194 |
+
|
195 |
+
Args:
|
196 |
+
num_features: num_features from an expected input of size
|
197 |
+
`batch_size x num_features [x width]`
|
198 |
+
eps: a value added to the denominator for numerical stability.
|
199 |
+
Default: 1e-5
|
200 |
+
momentum: the value used for the running_mean and running_var
|
201 |
+
computation. Default: 0.1
|
202 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
203 |
+
affine parameters. Default: ``True``
|
204 |
+
|
205 |
+
Shape:
|
206 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
207 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
208 |
+
|
209 |
+
Examples:
|
210 |
+
>>> # With Learnable Parameters
|
211 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
212 |
+
>>> # Without Learnable Parameters
|
213 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
214 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
215 |
+
>>> output = m(input)
|
216 |
+
"""
|
217 |
+
|
218 |
+
def _check_input_dim(self, input):
|
219 |
+
if input.dim() != 2 and input.dim() != 3:
|
220 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
221 |
+
.format(input.dim()))
|
222 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
223 |
+
|
224 |
+
|
225 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
226 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
227 |
+
of 3d inputs
|
228 |
+
|
229 |
+
.. math::
|
230 |
+
|
231 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
232 |
+
|
233 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
234 |
+
standard-deviation are reduced across all devices during training.
|
235 |
+
|
236 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
237 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
238 |
+
the statistics only on that device, which accelerated the computation and
|
239 |
+
is also easy to implement, but the statistics might be inaccurate.
|
240 |
+
Instead, in this synchronized version, the statistics will be computed
|
241 |
+
over all training samples distributed on multiple devices.
|
242 |
+
|
243 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
244 |
+
as the built-in PyTorch implementation.
|
245 |
+
|
246 |
+
The mean and standard-deviation are calculated per-dimension over
|
247 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
248 |
+
of size C (where C is the input size).
|
249 |
+
|
250 |
+
During training, this layer keeps a running estimate of its computed mean
|
251 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
252 |
+
|
253 |
+
During evaluation, this running mean/variance is used for normalization.
|
254 |
+
|
255 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
256 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
257 |
+
|
258 |
+
Args:
|
259 |
+
num_features: num_features from an expected input of
|
260 |
+
size batch_size x num_features x height x width
|
261 |
+
eps: a value added to the denominator for numerical stability.
|
262 |
+
Default: 1e-5
|
263 |
+
momentum: the value used for the running_mean and running_var
|
264 |
+
computation. Default: 0.1
|
265 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
266 |
+
affine parameters. Default: ``True``
|
267 |
+
|
268 |
+
Shape:
|
269 |
+
- Input: :math:`(N, C, H, W)`
|
270 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
271 |
+
|
272 |
+
Examples:
|
273 |
+
>>> # With Learnable Parameters
|
274 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
275 |
+
>>> # Without Learnable Parameters
|
276 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
277 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
278 |
+
>>> output = m(input)
|
279 |
+
"""
|
280 |
+
|
281 |
+
def _check_input_dim(self, input):
|
282 |
+
if input.dim() != 4:
|
283 |
+
raise ValueError('expected 4D input (got {}D input)'
|
284 |
+
.format(input.dim()))
|
285 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
286 |
+
|
287 |
+
|
288 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
289 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
290 |
+
of 4d inputs
|
291 |
+
|
292 |
+
.. math::
|
293 |
+
|
294 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
295 |
+
|
296 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
297 |
+
standard-deviation are reduced across all devices during training.
|
298 |
+
|
299 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
300 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
301 |
+
the statistics only on that device, which accelerated the computation and
|
302 |
+
is also easy to implement, but the statistics might be inaccurate.
|
303 |
+
Instead, in this synchronized version, the statistics will be computed
|
304 |
+
over all training samples distributed on multiple devices.
|
305 |
+
|
306 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
307 |
+
as the built-in PyTorch implementation.
|
308 |
+
|
309 |
+
The mean and standard-deviation are calculated per-dimension over
|
310 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
311 |
+
of size C (where C is the input size).
|
312 |
+
|
313 |
+
During training, this layer keeps a running estimate of its computed mean
|
314 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
315 |
+
|
316 |
+
During evaluation, this running mean/variance is used for normalization.
|
317 |
+
|
318 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
319 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
320 |
+
or Spatio-temporal BatchNorm
|
321 |
+
|
322 |
+
Args:
|
323 |
+
num_features: num_features from an expected input of
|
324 |
+
size batch_size x num_features x depth x height x width
|
325 |
+
eps: a value added to the denominator for numerical stability.
|
326 |
+
Default: 1e-5
|
327 |
+
momentum: the value used for the running_mean and running_var
|
328 |
+
computation. Default: 0.1
|
329 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
330 |
+
affine parameters. Default: ``True``
|
331 |
+
|
332 |
+
Shape:
|
333 |
+
- Input: :math:`(N, C, D, H, W)`
|
334 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
335 |
+
|
336 |
+
Examples:
|
337 |
+
>>> # With Learnable Parameters
|
338 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
339 |
+
>>> # Without Learnable Parameters
|
340 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
341 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
342 |
+
>>> output = m(input)
|
343 |
+
"""
|
344 |
+
|
345 |
+
def _check_input_dim(self, input):
|
346 |
+
if input.dim() != 5:
|
347 |
+
raise ValueError('expected 5D input (got {}D input)'
|
348 |
+
.format(input.dim()))
|
349 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
util/models/sync_batchnorm/batchnorm_reimpl.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : batchnorm_reimpl.py
|
4 |
+
# Author : acgtyrant
|
5 |
+
# Date : 11/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.init as init
|
14 |
+
|
15 |
+
__all__ = ['BatchNormReimpl']
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNorm2dReimpl(nn.Module):
|
19 |
+
"""
|
20 |
+
A re-implementation of batch normalization, used for testing the numerical
|
21 |
+
stability.
|
22 |
+
|
23 |
+
Author: acgtyrant
|
24 |
+
See also:
|
25 |
+
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
26 |
+
"""
|
27 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_features = num_features
|
31 |
+
self.eps = eps
|
32 |
+
self.momentum = momentum
|
33 |
+
self.weight = nn.Parameter(torch.empty(num_features))
|
34 |
+
self.bias = nn.Parameter(torch.empty(num_features))
|
35 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
36 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
37 |
+
self.reset_parameters()
|
38 |
+
|
39 |
+
def reset_running_stats(self):
|
40 |
+
self.running_mean.zero_()
|
41 |
+
self.running_var.fill_(1)
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
self.reset_running_stats()
|
45 |
+
init.uniform_(self.weight)
|
46 |
+
init.zeros_(self.bias)
|
47 |
+
|
48 |
+
def forward(self, input_):
|
49 |
+
batchsize, channels, height, width = input_.size()
|
50 |
+
numel = batchsize * height * width
|
51 |
+
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
52 |
+
sum_ = input_.sum(1)
|
53 |
+
sum_of_square = input_.pow(2).sum(1)
|
54 |
+
mean = sum_ / numel
|
55 |
+
sumvar = sum_of_square - sum_ * mean
|
56 |
+
|
57 |
+
self.running_mean = (
|
58 |
+
(1 - self.momentum) * self.running_mean
|
59 |
+
+ self.momentum * mean.detach()
|
60 |
+
)
|
61 |
+
unbias_var = sumvar / (numel - 1)
|
62 |
+
self.running_var = (
|
63 |
+
(1 - self.momentum) * self.running_var
|
64 |
+
+ self.momentum * unbias_var.detach()
|
65 |
+
)
|
66 |
+
|
67 |
+
bias_var = sumvar / numel
|
68 |
+
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
69 |
+
output = (
|
70 |
+
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
71 |
+
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
72 |
+
|
73 |
+
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
74 |
+
|
util/models/sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|