Emaad commited on
Commit
548170b
β€’
1 Parent(s): 65b1781

file upload

Browse files
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: CELL-E 2-Sequence Prediction
3
- emoji: πŸ’»
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.29.0
 
8
  app_file: app.py
9
- pinned: false
 
 
10
  license: mit
11
  ---
12
 
 
1
  ---
2
+ title: CELL-E 2 - Sequence Prediction
3
+ emoji: πŸ”¬
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ python_version: 3.11
8
+ sdk_version: 3.30.0
9
  app_file: app.py
10
+ tags: [proteins, image-to-text]
11
+ fullWidth: true
12
+ pinned: true
13
  license: mit
14
  ---
15
 
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from prediction import run_sequence_prediction
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from celle.utils import process_image
7
+ from PIL import Image
8
+ from matplotlib import pyplot as plt
9
+
10
+
11
+ def gradio_demo(model_name, sequence_input, image):
12
+ model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
13
+ config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
14
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
15
+ hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ if "Finetuned" in model_name:
19
+ dataset = "OpenCell"
20
+
21
+ else:
22
+ dataset = "HPA"
23
+
24
+
25
+ nucleus_image = image['image']
26
+ protein_image = image['mask']
27
+
28
+ nucleus_image = process_image(nucleus_image, dataset, "nucleus")
29
+ protein_image = process_image(protein_image, dataset, "nucleus")
30
+ protein_image = 1.0*(protein_image > .5)
31
+ print(f'{nucleus_image=}')
32
+ print(f'{protein_image.shape=}')
33
+
34
+ threshold, heatmap = run_sequence_prediction(
35
+ sequence_input=sequence_input,
36
+ nucleus_image=nucleus_image,
37
+ protein_image=protein_image,
38
+ model_ckpt_path=model,
39
+ model_config_path=config,
40
+ device=device,
41
+ )
42
+
43
+ protein_image = protein_image[0, 0]
44
+ protein_image = protein_image * 1.0
45
+
46
+
47
+ # Plot the heatmap
48
+ plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
49
+ plt.axis("off")
50
+
51
+ # Save the plot to a temporary file
52
+ plt.savefig("temp.png", bbox_inches="tight", dpi=256)
53
+
54
+ # Open the temporary file as a PIL image
55
+ heatmap = Image.open("temp.png")
56
+
57
+ return (
58
+ T.ToPILImage()(nucleus_image[0, 0]),
59
+ T.ToPILImage()(protein_image),
60
+ T.ToPILImage()(threshold),
61
+ heatmap,
62
+ )
63
+
64
+
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("Select the prediction model.")
67
+ gr.Markdown(
68
+ "CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF."
69
+ )
70
+ gr.Markdown(
71
+ "CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells."
72
+ )
73
+ with gr.Row():
74
+ model_name = gr.Dropdown(
75
+ ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
76
+ value="CELL-E_2_HPA_2560",
77
+ label="Model Name",
78
+ )
79
+ with gr.Row():
80
+ gr.Markdown(
81
+ "Input the desired amino acid sequence. GFP is shown below by default."
82
+ )
83
+
84
+ with gr.Row():
85
+ sequence_input = gr.Textbox(
86
+ value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
87
+ label="Sequence",
88
+ )
89
+ with gr.Row():
90
+ gr.Markdown(
91
+ "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)"
92
+ )
93
+ gr.Markdown("The protein image is optional and is just used for display.")
94
+
95
+ with gr.Row().style(equal_height=True):
96
+ nucleus_image = gr.Image(
97
+ source="upload",
98
+ tool="sketch",
99
+ label="Nucleus Image",
100
+ line_color="white",
101
+ interactive=True,
102
+ image_mode="L",
103
+ type="pil"
104
+ )
105
+
106
+ with gr.Row():
107
+ gr.Markdown("Image predictions are show below.")
108
+
109
+ with gr.Row().style(equal_height=True):
110
+ predicted_sequence = gr.Textbox(
111
+ label="Predicted Sequence",
112
+ )
113
+
114
+ with gr.Row():
115
+ button = gr.Button("Run Model")
116
+
117
+ inputs = [model_name, sequence_input, nucleus_image]
118
+
119
+ outputs = [predicted_sequence]
120
+
121
+ button.click(gradio_demo, inputs, outputs)
122
+
123
+ demo.launch(share=True)
celle/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from celle.celle import CELLE
2
+ from celle.vae import VQGanVAE
3
+
4
+ __version__ = "2.0.0"
celle/attention.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+
6
+ from rotary_embedding_torch import apply_rotary_emb
7
+ from celle.utils import exists, default, max_neg_value
8
+
9
+
10
+ # helpers
11
+ def stable_softmax(t, dim=-1, alpha=32**2):
12
+ t = t / alpha
13
+ t = t - torch.amax(t, dim=dim, keepdim=True).detach()
14
+ return (t * alpha).softmax(dim=dim)
15
+
16
+
17
+ def apply_pos_emb(pos_emb, qkv):
18
+ n = qkv[0].shape[-2]
19
+ pos_emb = pos_emb[..., :n, :]
20
+ return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
21
+
22
+
23
+ # classes
24
+ class Attention(nn.Module):
25
+ def __init__(
26
+ self,
27
+ dim,
28
+ seq_len,
29
+ causal=False,
30
+ heads=8,
31
+ dim_head=64,
32
+ dropout=0.0,
33
+ stable=False,
34
+ static_mask=None,
35
+ ):
36
+ super().__init__()
37
+ inner_dim = dim_head * heads
38
+ self.heads = heads
39
+ self.seq_len = seq_len
40
+ self.scale = dim_head**-0.5
41
+ self.stable = stable
42
+ self.causal = causal
43
+ self.register_buffer("static_mask", static_mask, persistent=False)
44
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
45
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
46
+ self.save_attn = nn.Identity()
47
+
48
+ def forward(self, x, context_mask=None, rotary_pos_emb=None):
49
+ # x: [batch_size, seq_len, dim]
50
+ b, n, _, h = *x.shape, self.heads
51
+ device = x.device
52
+
53
+ softmax = torch.softmax if not self.stable else stable_softmax
54
+
55
+ # qkv: 3 tensors of shape [batch_size, seq_len, inner_dim]
56
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
57
+
58
+ # q,k,v: [batch_size, heads, seq_len, dim_head]
59
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
60
+
61
+ if exists(rotary_pos_emb):
62
+ q, k, v = apply_pos_emb(rotary_pos_emb[..., :, :], (q, k, v))
63
+
64
+ q *= self.scale
65
+
66
+ # dots: [batch_size, heads, seq_len_i ,seq_len_j]
67
+ dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
68
+ mask_value = max_neg_value(dots)
69
+
70
+ if exists(context_mask):
71
+ # context_mask: [batch_size ,1 ,1 ,seq_len_j]
72
+ context_mask = rearrange(context_mask, "b j -> b 1 1 j")
73
+ context_mask = F.pad(context_mask, (1, 0), value=True)
74
+
75
+ mask_value = -torch.finfo(dots.dtype).max
76
+ dots = dots.masked_fill(~context_mask, mask_value)
77
+
78
+ if self.causal:
79
+ i, j = dots.shape[-2:]
80
+ context_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
81
+ dots.masked_fill_(context_mask, mask_value)
82
+
83
+ if exists(self.static_mask):
84
+ dots.masked_fill_(~self.static_mask[:n, :n], mask_value)
85
+
86
+ # attn: [batch_size ,heads ,seq_len_i ,seq_len_j]
87
+ attn = softmax(dots, dim=-1)
88
+ attn = self.save_attn(attn)
89
+
90
+ # out: [batch_size ,heads ,seq_len_i ,dim_head]
91
+ out = torch.einsum("b h n j, b h j d -> b h n d", attn, v)
92
+
93
+ # out: [batch_size ,seq_len_i ,(heads*dim_head)]
94
+ out = rearrange(out, "b h n d -> b n (h d)")
95
+
96
+ # out: [batch_size ,seq_len_i ,dim]
97
+ out = self.to_out(out)
98
+
99
+ return out
100
+
101
+
102
+ # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
103
+
104
+
105
+ class SparseConvCausalAttention(nn.Module):
106
+ def __init__(
107
+ self,
108
+ dim,
109
+ seq_len,
110
+ image_size=32,
111
+ kernel_size=5,
112
+ dilation=1,
113
+ heads=8,
114
+ dim_head=64,
115
+ dropout=0.0,
116
+ stable=False,
117
+ **kwargs,
118
+ ):
119
+ super().__init__()
120
+ assert kernel_size % 2 == 1, "kernel size must be odd"
121
+
122
+ inner_dim = dim_head * heads
123
+ self.seq_len = seq_len
124
+ self.heads = heads
125
+ self.scale = dim_head**-0.5
126
+ self.image_size = image_size
127
+ self.kernel_size = kernel_size
128
+ self.dilation = dilation
129
+
130
+ self.stable = stable
131
+
132
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
133
+
134
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
135
+
136
+ def forward(self, x, mask=None, rotary_pos_emb=None):
137
+ b, n, _, h, img_size, kernel_size, dilation, seq_len, device = (
138
+ *x.shape,
139
+ self.heads,
140
+ self.image_size,
141
+ self.kernel_size,
142
+ self.dilation,
143
+ self.seq_len,
144
+ x.device,
145
+ )
146
+ softmax = torch.softmax if not self.stable else stable_softmax
147
+
148
+ img_seq_len = img_size**2
149
+ text_len = seq_len + 1 - img_seq_len
150
+
151
+ # padding
152
+
153
+ padding = seq_len - n + 1
154
+ mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool())
155
+
156
+ x = F.pad(x, (0, 0, 0, padding), value=0)
157
+ mask = mask[:, :text_len]
158
+
159
+ # derive query / keys / values
160
+
161
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
162
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv)
163
+
164
+ if exists(rotary_pos_emb):
165
+ q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
166
+
167
+ q *= self.scale
168
+
169
+ ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
170
+ lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)
171
+ )
172
+
173
+ # text attention
174
+
175
+ dots_text = einsum("b i d, b j d -> b i j", q_text, k_text)
176
+ mask_value = max_neg_value(dots_text)
177
+
178
+ i, j = dots_text.shape[-2:]
179
+ text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
180
+ dots_text.masked_fill_(text_causal_mask, mask_value)
181
+
182
+ attn_text = softmax(dots_text, dim=-1)
183
+ out_text = einsum("b i j, b j d -> b i d", attn_text, v_text)
184
+
185
+ # image attention
186
+
187
+ effective_kernel_size = (kernel_size - 1) * dilation + 1
188
+ padding = effective_kernel_size // 2
189
+
190
+ k_img, v_img = map(
191
+ lambda t: rearrange(t, "b (h w) c -> b c h w", h=img_size), (k_img, v_img)
192
+ )
193
+ k_img, v_img = map(
194
+ lambda t: F.unfold(t, kernel_size, padding=padding, dilation=dilation),
195
+ (k_img, v_img),
196
+ )
197
+ k_img, v_img = map(
198
+ lambda t: rearrange(t, "b (d j) i -> b i j d", j=kernel_size**2),
199
+ (k_img, v_img),
200
+ )
201
+
202
+ # let image attend to all of text
203
+
204
+ dots_image = einsum("b i d, b i j d -> b i j", q_img, k_img)
205
+ dots_image_to_text = einsum("b i d, b j d -> b i j", q_img, k_text)
206
+
207
+ # calculate causal attention for local convolution
208
+
209
+ i, j = dots_image.shape[-2:]
210
+ img_seq = torch.arange(img_seq_len, device=device)
211
+ k_img_indices = rearrange(img_seq.float(), "(h w) -> () () h w", h=img_size)
212
+ k_img_indices = F.pad(
213
+ k_img_indices, (padding,) * 4, value=img_seq_len
214
+ ) # padding set to be max, so it is never attended to
215
+ k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation)
216
+ k_img_indices = rearrange(k_img_indices, "b j i -> b i j")
217
+
218
+ # mask image attention
219
+
220
+ q_img_indices = rearrange(img_seq, "i -> () i ()")
221
+ causal_mask = q_img_indices < k_img_indices
222
+
223
+ # concat text mask with image causal mask
224
+
225
+ causal_mask = repeat(causal_mask, "() i j -> b i j", b=b * h)
226
+ mask = repeat(mask, "b j -> (b h) i j", i=i, h=h)
227
+ mask = torch.cat((~mask, causal_mask), dim=-1)
228
+
229
+ # image can attend to all of text
230
+
231
+ dots = torch.cat((dots_image_to_text, dots_image), dim=-1)
232
+ dots.masked_fill_(mask, mask_value)
233
+
234
+ attn = softmax(dots, dim=-1)
235
+
236
+ # aggregate
237
+
238
+ attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
239
+
240
+ out_image_to_image = einsum("b i j, b i j d -> b i d", attn_image, v_img)
241
+ out_image_to_text = einsum("b i j, b j d -> b i d", attn_image_to_text, v_text)
242
+
243
+ out_image = out_image_to_image + out_image_to_text
244
+
245
+ # combine attended values for both text and image
246
+
247
+ out = torch.cat((out_text, out_image), dim=1)
248
+
249
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
250
+
251
+ out = self.to_out(out)
252
+
253
+ return out[:, :n]
celle/celle.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary packages and modules
2
+ from math import floor, ceil
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from axial_positional_embedding import AxialPositionalEmbedding
7
+ from einops import rearrange
8
+ from celle.utils import (
9
+ exists,
10
+ always,
11
+ eval_decorator,
12
+ gumbel_sample,
13
+ top_k,
14
+ gamma_func,
15
+ DivideMax,
16
+ )
17
+ from tqdm import tqdm
18
+
19
+ # Import additional modules from within the codebase
20
+ from celle.transformer import Transformer
21
+
22
+
23
+ def generate_mask(gamma_func, batch_size, length, device):
24
+ # Get the number of `True` values in the mask for each batch element
25
+ num_true_values = floor(gamma_func(torch.rand(1)) * length)
26
+
27
+ # Generate a random sample of indices to set to `True` in the mask
28
+ # The number of indices in the sample is determined by `num_true_values`
29
+ indices = (
30
+ torch.rand((batch_size, length), device=device)
31
+ .topk(num_true_values, dim=1)
32
+ .indices
33
+ )
34
+
35
+ # Create a binary mask tensor with `True` values at the sampled indices
36
+ mask = torch.zeros((batch_size, length), dtype=torch.bool, device=device)
37
+ mask.scatter_(dim=1, index=indices, value=True)
38
+
39
+ return mask
40
+
41
+
42
+ def match_batch_size(text, condition, image, batch_size):
43
+ """
44
+ This function ensures all inputs to the sample function have the same batch size.
45
+ """
46
+ if text.shape[0] != batch_size:
47
+ text = text.repeat(batch_size, 1)
48
+
49
+ if condition.shape[0] != batch_size:
50
+ condition = condition.repeat(batch_size, 1)
51
+
52
+ if image.shape[0] != batch_size:
53
+ image = image.repeat(batch_size, 1)
54
+
55
+ return text, condition, image
56
+
57
+
58
+ def calc_unmask_probs(timestep, timesteps, gamma_func):
59
+ if timestep == 1 or timesteps == 1:
60
+ unmask_prob = 1
61
+ else:
62
+ unmask_prob = 1 - gamma_func(timestep)
63
+ return unmask_prob
64
+
65
+
66
+ def calculate_logits(
67
+ input_tokens, input_mask, logits_function, filter_thres, temperature
68
+ ):
69
+ logits, _, _ = logits_function(input_tokens, input_mask, return_encoding=False)
70
+ filtered_logits = top_k(logits, thres=filter_thres)
71
+ sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
72
+
73
+ return logits, sample
74
+
75
+
76
+ def unmask_tokens(
77
+ input_tokens,
78
+ input_mask,
79
+ num_masked_tokens,
80
+ logits,
81
+ sample,
82
+ timestep,
83
+ timesteps,
84
+ gamma,
85
+ filter_func=None,
86
+ pad_token=None,
87
+ mask_token=None,
88
+ force_aas=True,
89
+ ):
90
+ sample = sample.masked_fill(~input_mask.unsqueeze(-1), -torch.inf)
91
+ if filter_func:
92
+ sample = filter_func(
93
+ input_tokens, sample, force_aas, pad_token=pad_token, mask_token=mask_token
94
+ )
95
+ selected_token_probs, selected_tokens = torch.max(sample, dim=-1)
96
+
97
+ unmask_prob = calc_unmask_probs(timestep, timesteps, gamma)
98
+ num_tokens_to_unmask = max(1, ceil(unmask_prob * num_masked_tokens))
99
+
100
+ _, top_k_indices = torch.topk(selected_token_probs, num_tokens_to_unmask, dim=-1)
101
+
102
+ sample_mask = torch.zeros(
103
+ input_tokens.shape, dtype=torch.bool, device=input_tokens.device
104
+ )
105
+ sample_mask.scatter_(dim=1, index=top_k_indices, value=True)
106
+
107
+ unmasked_tokens = torch.where(sample_mask, selected_tokens, input_tokens)
108
+ full_logits = torch.where(
109
+ sample_mask.unsqueeze(-1), logits, torch.zeros_like(logits)
110
+ )
111
+ return unmasked_tokens, full_logits
112
+
113
+
114
+ def suppress_invalid_text_tokens(
115
+ text,
116
+ logits,
117
+ start_token=None,
118
+ end_token=None,
119
+ pad_token=None,
120
+ mask_token=None,
121
+ force_aas=False,
122
+ ):
123
+ # Find the indices of start_token and end_token in tensor text along axis=1
124
+ idx_start = (text == start_token).nonzero(as_tuple=True)[1]
125
+ idx_end = (text == end_token).nonzero(as_tuple=True)[1]
126
+
127
+ # For every position other than the index corresponding to the start index, set the values on the start index of dimension=2 to -torch.inf
128
+ if idx_start.nelement() != start_token:
129
+ try:
130
+ mask = idx_start.unsqueeze(1) != torch.arange(
131
+ logits.size(1), device=text.device
132
+ )
133
+ indices = torch.where(mask)
134
+ logits[indices[0], indices[1], start_token] = -torch.inf
135
+ except:
136
+ pass
137
+
138
+ # else:
139
+ # idx_start = torch.zeros(text.size(0), dtype=torch.long)
140
+
141
+ # Similarly, for every position other than the index corresponding to the end index, set the values on the end index of dimension=2 to -torch.inf
142
+ if idx_end.nelement() != 0:
143
+ try:
144
+ mask = idx_end.unsqueeze(1) != torch.arange(
145
+ logits.size(1), device=text.device
146
+ )
147
+ indices = torch.where(mask)
148
+ logits[indices[0], indices[1], end_token] = -torch.inf
149
+ except:
150
+ pass
151
+
152
+ # else:
153
+ # idx_end = torch.full((text.size(0),), text.size(1) - 1, dtype=torch.long)
154
+
155
+ if pad_token:
156
+ if idx_start.nelement() != 0 and idx_end.nelement() != 0:
157
+ try:
158
+ # For every position between the indices of start_token and end_token, set the values for 1st index of dimension=2 equal to -torch.inf. Any value outside of that range should be set to torch.inf.
159
+ mask = (
160
+ torch.arange(logits.size(1), device=text.device)
161
+ >= idx_start.unsqueeze(1)
162
+ ) & (
163
+ torch.arange(logits.size(1), device=text.device)
164
+ <= idx_end.unsqueeze(1)
165
+ )
166
+
167
+ indices = torch.where(mask)
168
+ logits[indices[0], indices[1], pad_token] = -torch.inf
169
+
170
+ indices = torch.where(~mask)
171
+ logits[indices[0], indices[1], pad_token] = torch.inf
172
+
173
+ except:
174
+ pass
175
+
176
+ elif idx_start.nelement() != 0:
177
+ try:
178
+ mask = torch.arange(
179
+ logits.size(1), device=text.device
180
+ ) < idx_start.unsqueeze(1)
181
+ logits[indices[0], indices[1], pad_token] = torch.inf
182
+ except:
183
+ pass
184
+
185
+ elif idx_end.nelement() != 0:
186
+ try:
187
+ mask = torch.arange(
188
+ logits.size(1), device=text.device
189
+ ) > idx_end.unsqueeze(1)
190
+ logits[indices[0], indices[1], pad_token] = torch.inf
191
+ except:
192
+ pass
193
+
194
+ if force_aas:
195
+ if pad_token:
196
+ logits[:, :, pad_token] = -torch.inf
197
+ logits[:, :, 3] = -torch.inf
198
+ logits[:, :, 29:] = -torch.inf
199
+
200
+ if mask_token:
201
+ logits[:, :, mask_token] = -torch.inf
202
+
203
+ return logits
204
+
205
+
206
+ def detokenize_text(text_embedding, sequence):
207
+ if text_embedding == "esm1b" or text_embedding == "esm2":
208
+ from esm import Alphabet
209
+
210
+ alphabet = (
211
+ Alphabet.from_architecture("ESM-1b").get_batch_converter().alphabet.all_toks
212
+ )
213
+ else:
214
+ assert NameError("Detokenization only available for ESM mdodels")
215
+
216
+ output_seqs = []
217
+
218
+ for batch in sequence:
219
+ converted_seq = [alphabet[idx] for idx in batch]
220
+ converted_seq = "".join(converted_seq)
221
+ output_seqs.append(converted_seq)
222
+
223
+ return output_seqs
224
+
225
+ class ImageEmbedding(nn.Module):
226
+ def __init__(self, num_tokens, dim):
227
+ super(ImageEmbedding, self).__init__()
228
+ self.image_embedding = nn.Embedding(num_tokens, dim)
229
+
230
+ def forward(self, image):
231
+ return self.image_embedding(image)
232
+
233
+
234
+ class ModelExtender(nn.Module):
235
+ def __init__(self, vocab, out_features, fixed_embedding=False):
236
+ super(ModelExtender, self).__init__()
237
+
238
+ # Initialize the model according to the given vocabulary
239
+ self.vocab = vocab
240
+
241
+ if vocab == "esm1b":
242
+ from esm import pretrained
243
+
244
+ self.model, _ = pretrained.esm1b_t33_650M_UR50S()
245
+ self.in_features = 1280
246
+ elif vocab == "esm2":
247
+ from esm import pretrained
248
+
249
+ if out_features == 320:
250
+ self.model, _ = pretrained.esm2_t6_8M_UR50D()
251
+ elif out_features == 480:
252
+ self.model, _ = pretrained.esm2_t12_35M_UR50D()
253
+ elif out_features == 640:
254
+ self.model, _ = pretrained.esm2_t30_150M_UR50D()
255
+ elif out_features == 1280:
256
+ self.model, _ = pretrained.esm2_t33_650M_UR50D()
257
+ elif out_features == 2560:
258
+ self.model, _ = pretrained.esm2_t36_3B_UR50D()
259
+ else:
260
+ self.model, _ = pretrained.esm2_t33_650M_UR50D()
261
+ self.in_features = self.model.embed_dim
262
+
263
+ # Set the number of output features and initialize the scaling layer
264
+ self.out_features = out_features
265
+ self.scale_layer = nn.Linear(self.in_features, self.out_features)
266
+
267
+ # Determine whether to freeze the model's parameters
268
+ self.fixed_embedding = fixed_embedding
269
+ if self.fixed_embedding:
270
+ self.model = self.model.eval()
271
+
272
+ def forward(self, x, **kwargs):
273
+ # If the model's parameters are fixed, use torch.no_grad()
274
+ if self.fixed_embedding:
275
+ with torch.no_grad():
276
+ if self.vocab == "esm1b" or self.vocab == "esm2":
277
+ # Reduce sequence length dimension, get top layer representation tensor
278
+ x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
279
+ "representations"
280
+ ][self.model.num_layers]
281
+ # Tensor shape: (batch_size, hidden_size)
282
+ else:
283
+ # Get top layer representation tensor
284
+ x = self.model(x, **kwargs)[0]
285
+ # Tensor shape: (batch_size, sequence_length, hidden_size)
286
+ else:
287
+ if self.vocab == "esm1b" or self.vocab == "esm2":
288
+ # Reduce sequence length dimension, get top layer representation tensor
289
+ x = self.model(x.squeeze(1), repr_layers=[self.model.num_layers])[
290
+ "representations"
291
+ ][self.model.num_layers]
292
+ # Tensor shape: (batch_size, hidden_size)
293
+ else:
294
+ # Get top layer representation tensor
295
+ x = self.model(x, **kwargs)[0]
296
+ # Tensor shape: (batch_size, sequence_length, hidden_size)
297
+
298
+ # Scale the representation tensor if necessary
299
+ if self.out_features != self.in_features:
300
+ x = self.scale_layer(x)
301
+ # Tensor shape: (batch_size, out_features)
302
+
303
+ return x
304
+
305
+ class CELLE(nn.Module):
306
+ def __init__(
307
+ self,
308
+ *,
309
+ dim,
310
+ vae, # The VAE model used to encode/decode images
311
+ condition_vae=None, # An optional VAE model used to condition the image generation
312
+ num_images=2, # Number of images to generate
313
+ num_text_tokens=30, # Number of tokens in the text vocabulary
314
+ text_seq_len=1000, # Maximum length of input text sequence
315
+ depth=16, # Number of layers in the transformer model
316
+ heads=16, # Number of attention heads
317
+ dim_head=64, # Dimensionality of each attention head
318
+ attn_dropout=0.1, # Dropout rate for attention weights
319
+ ff_dropout=0.1, # Dropout rate for feedforward layers
320
+ attn_types=None, # Types of attention to use in the transformer
321
+ causal=False, # Whether to use causal attention
322
+ loss_cond_weight=1, # Weight of conditioning loss
323
+ loss_img_weight=1, # Weight of image generation loss
324
+ stable=False, # Whether to use divide-by-max normalization in the transformer
325
+ rotary_emb=True, # Whether to use rotary positional embeddings
326
+ text_embedding="esm2", # Text embedding to use (esm1b, esm2)
327
+ fixed_embedding=True, # Whether to fix the text embedding or learn it
328
+ sampling_mode="cosine", # Sampling mode for the VAE
329
+ linear_project=False, # Whether to project embeddings linearly
330
+ **kwargs,
331
+ ):
332
+ super().__init__()
333
+
334
+ # Set the stable flag
335
+ self.stable = stable
336
+
337
+ # If the stable flag is set, initialize the DivideMax layer for normalization
338
+ if stable:
339
+ self.norm_by_max = DivideMax(dim=-1)
340
+
341
+ ### Initializing text parameters ###
342
+
343
+ # Initialize the text and fixed embeddings
344
+ self.text_embedding = text_embedding
345
+ self.fixed_embedding = fixed_embedding
346
+
347
+ # Offset logits index and calculate cross entropy loss
348
+ self.num_text_tokens = num_text_tokens
349
+ self.linear_project = linear_project
350
+
351
+ # Add <BOS> and <EOS> tokens to the beginning and end of text sequences
352
+ if text_embedding.lower() in ("esm1b", "esm2"):
353
+ self.text_seq_len = text_seq_len + 2
354
+ else:
355
+ self.text_seq_len = text_seq_len
356
+
357
+ # Initialize embeddings for <SEP> token
358
+ self.sep_emb = nn.Embedding(1, dim)
359
+
360
+ # Initialize positional embeddings for text sequences and <SEP> token
361
+ self.text_pos_emb = (
362
+ nn.Embedding(self.text_seq_len + 1, dim) if not rotary_emb else always(0)
363
+ ) # +1 for <SEP>
364
+
365
+ ### ###
366
+
367
+ self.num_images = num_images
368
+
369
+ ### Initializing condition parameters ###
370
+
371
+ # Initialize the number of condition tokens, condition sequence length, and condition embedding
372
+ if exists(condition_vae):
373
+ condition_size = condition_vae.image_size
374
+ num_condition_tokens = condition_vae.num_tokens
375
+ self.num_condition_tokens = num_condition_tokens
376
+ condition_fmap_size = condition_vae.image_size // (
377
+ 2**condition_vae.num_layers
378
+ )
379
+ condition_seq_len = condition_fmap_size**2
380
+
381
+ # Initialize ImageEmbedding for condition embedding
382
+ self.condition_emb = ImageEmbedding(num_condition_tokens + 1, dim)
383
+
384
+ # Initialize positional embeddings for condition embedding
385
+ self.condition_pos_emb = (
386
+ AxialPositionalEmbedding(
387
+ dim, axial_shape=(condition_fmap_size, condition_fmap_size)
388
+ )
389
+ if not rotary_emb
390
+ else always(0)
391
+ )
392
+
393
+ else:
394
+ condition_fmap_size = 0
395
+ condition_seq_len = 0
396
+ num_condition_tokens = 0
397
+
398
+ ### ####
399
+
400
+ ### Initializing image parameters ###
401
+
402
+ # Initialize the image size, image token size, and sequence length
403
+ self.image_size = vae.image_size
404
+ num_image_tokens = vae.num_tokens
405
+ image_fmap_size = vae.image_size // (2**vae.num_layers)
406
+ image_seq_len = image_fmap_size**2
407
+ self.image_seq_len = image_seq_len
408
+ self.num_image_tokens = num_image_tokens
409
+
410
+ # Initialize ImageEmbedding and positional embeddings for image embedding
411
+ self.image_emb = ImageEmbedding(num_image_tokens + 1, dim) # +1 for <IM_MASK>
412
+
413
+ self.image_pos_emb = (
414
+ AxialPositionalEmbedding(
415
+ dim, axial_shape=(image_fmap_size, image_fmap_size)
416
+ )
417
+ if not rotary_emb
418
+ else always(0)
419
+ )
420
+
421
+ # Set total sequence length and total tokens
422
+ self.num_condition_tokens = num_condition_tokens
423
+ self.condition_seq_len = condition_seq_len
424
+ # Text Length + <SEP> + Condition Tokens + Image Tokens
425
+ seq_len = self.text_seq_len + 1 + self.condition_seq_len + self.image_seq_len
426
+ total_tokens = (
427
+ num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens + 1
428
+ )
429
+ self.total_tokens = total_tokens
430
+ self.total_seq_len = seq_len
431
+
432
+ # Set the VAE and condition VAE for the model
433
+ self.vae = vae.eval()
434
+ self.condition_vae = condition_vae.eval()
435
+
436
+ ### ###
437
+
438
+ ### Setting discrete ids ###
439
+ # Initialize text embedding based on the given text_embedding parameter
440
+ if text_embedding == "esm1b" or text_embedding == "esm2":
441
+ self.text_mask_token = 32
442
+ self.pad_token = 1
443
+ self.text_emb = ModelExtender(text_embedding, dim, fixed_embedding)
444
+ else:
445
+ raise ValueError("Only ESM models are supported.")
446
+
447
+ # Set token indices for text, condition, and image sequences
448
+ self.sep_token = num_text_tokens
449
+ self.cond_mask_token = num_condition_tokens
450
+ self.image_mask_token = num_image_tokens
451
+
452
+ # Create indices for sequence and logits dimensions
453
+ self.seq_range = torch.arange(seq_len)
454
+ self.logits_range = torch.arange(total_tokens)
455
+
456
+ # Reshape sequence and logits indices
457
+ self.seq_range = rearrange(self.seq_range, "n -> () n ()")
458
+ self.logits_range = rearrange(self.logits_range, "d -> () () d")
459
+
460
+ # Create a mask to exclude invalid token positions from the model output
461
+ # e.g. no image tokens where sequence tokens should be
462
+ logits_mask = (
463
+ # Mask text tokens beyond text_seq_len and invalid logits_range
464
+ (
465
+ (self.seq_range < self.text_seq_len)
466
+ & (self.logits_range < num_text_tokens)
467
+ & (self.logits_range != self.text_mask_token)
468
+ )
469
+ |
470
+ # Mask [SEP] token after text
471
+ (
472
+ (self.seq_range == self.text_seq_len)
473
+ & (self.logits_range == num_text_tokens)
474
+ )
475
+ |
476
+ # Mask condition tokens beyond text_seq_len+1 ([SEP]) and invalid logits_range
477
+ (
478
+ (self.seq_range >= self.text_seq_len + 1)
479
+ & (self.seq_range < self.text_seq_len + 1 + condition_seq_len)
480
+ & (self.logits_range >= num_text_tokens + 1)
481
+ & (self.logits_range < num_text_tokens + 1 + num_condition_tokens)
482
+ )
483
+ |
484
+ # Mask image tokens beyond num_text_tokens+num_condition_tokens+1
485
+ (
486
+ (self.seq_range >= self.text_seq_len + 1 + condition_seq_len)
487
+ & (self.logits_range >= num_text_tokens + 1 + num_condition_tokens + 1)
488
+ & (
489
+ self.logits_range
490
+ < num_text_tokens + 1 + num_condition_tokens + 1 + num_image_tokens
491
+ )
492
+ )
493
+ )
494
+
495
+ # Invert the mask
496
+ logits_mask = ~logits_mask
497
+
498
+ # Register the buffer with the logits_mask
499
+ self.register_buffer("logits_mask", logits_mask, persistent=False)
500
+
501
+ ### ###
502
+
503
+ # Initialize the Transformer model with given parameters
504
+ self.transformer = Transformer(
505
+ dim=dim,
506
+ causal=causal,
507
+ seq_len=seq_len,
508
+ depth=depth,
509
+ heads=heads,
510
+ dim_head=dim_head,
511
+ attn_dropout=attn_dropout,
512
+ ff_dropout=ff_dropout,
513
+ image_fmap_size=image_fmap_size + condition_fmap_size,
514
+ num_images=num_images,
515
+ stable=stable,
516
+ rotary_emb=rotary_emb,
517
+ )
518
+
519
+ # Initialize the linear layers for converting transformer output to logits
520
+ self.to_logits = nn.Sequential(
521
+ nn.LayerNorm(dim),
522
+ nn.Linear(dim, self.total_tokens),
523
+ )
524
+
525
+ # Set instance variables for weights and critic
526
+ self.loss_img_weight = loss_img_weight
527
+ self.loss_cond_weight = loss_cond_weight
528
+ self.gamma = gamma_func(sampling_mode)
529
+
530
+ def embed_and_transform(self, inputs, masks, return_encoding=False):
531
+ text, condition, image = inputs
532
+ device = text.device
533
+ text_mask, _, image_mask = masks
534
+
535
+ text_labels = text.clone()
536
+ text = torch.where(
537
+ text_mask, self.text_mask_token * torch.ones_like(text, device=device), text
538
+ )
539
+
540
+ tokens = self.text_emb(text)
541
+
542
+ # Add SEP token
543
+
544
+ sep_token_emb = self.sep_emb(
545
+ torch.zeros((tokens.shape[0], 1), dtype=torch.long, device=device)
546
+ )
547
+ tokens = torch.cat((tokens, sep_token_emb), dim=1)
548
+ tokens += self.text_pos_emb(torch.arange(text.shape[1] + 1, device=device))
549
+
550
+ with torch.no_grad():
551
+ if self.linear_project:
552
+ b = condition.shape[0]
553
+ condition, _, [_, _, condition_labels] = self.condition_vae.encode(
554
+ condition
555
+ )
556
+ condition_labels = rearrange(condition_labels, "(b n) -> b n", b=b)
557
+
558
+ else:
559
+ condition_labels = condition
560
+ if condition.dtype == torch.float:
561
+ condition_labels = self.condition_vae.get_codebook_indices(
562
+ condition
563
+ )
564
+ condition = condition_labels.clone()
565
+
566
+ condition_emb = self.condition_emb(condition)
567
+ condition_emb += self.condition_pos_emb(condition_emb)
568
+ tokens = torch.cat((tokens, condition_emb), dim=1)
569
+
570
+ with torch.no_grad():
571
+ if self.linear_project:
572
+ b = image.shape[0]
573
+ image, _, [_, _, image_labels] = self.vae.encode(image)
574
+ image_labels = rearrange(image_labels, "(b n) -> b n", b=b)
575
+
576
+ else:
577
+ image_labels = image
578
+ if image.dtype == torch.float:
579
+ image_labels = self.vae.get_codebook_indices(image)
580
+ image = torch.where(
581
+ image_mask,
582
+ self.image_mask_token
583
+ * torch.ones_like(image_labels, device=device),
584
+ image_labels,
585
+ )
586
+
587
+ image_emb = self.image_emb(image)
588
+
589
+ image_emb += self.image_pos_emb(image_emb)
590
+ tokens = torch.cat((tokens, image_emb), dim=1)
591
+
592
+ if self.stable:
593
+ alpha = 0.1
594
+ tokens = tokens * alpha + tokens.detach() * (1 - alpha)
595
+
596
+ out = self.transformer(tokens)
597
+
598
+ if self.stable:
599
+ out = self.norm_by_max(out)
600
+
601
+ logits = self.to_logits(out)
602
+
603
+ max_neg_value = -torch.finfo(logits.dtype).max
604
+ logits.masked_fill_(self.logits_mask, max_neg_value)
605
+
606
+ if return_encoding:
607
+ return logits, out, [text_labels, condition_labels, image_labels]
608
+ else:
609
+ return logits, None, [text_labels, condition_labels, image_labels]
610
+
611
+ def forward(
612
+ self,
613
+ text,
614
+ condition=None,
615
+ image=None,
616
+ return_loss=False,
617
+ return_encoding=False,
618
+ ):
619
+ batch_size, device = text.shape[0], text.device
620
+
621
+ # Check that image is supplied when training
622
+ assert exists(image), "when training, image must be supplied"
623
+
624
+ # Check that image dimensions match the expected dimensions
625
+ assert tuple(image.shape[1:]) == (
626
+ self.vae.channels,
627
+ self.image_size,
628
+ self.image_size,
629
+ ), f"invalid image of dimensions {image.shape} passed in during training"
630
+
631
+ # Generate masks for text, condition, and image
632
+
633
+ # text_mask = generate_mask(self.gamma, batch_size, self.text_seq_len, device)
634
+
635
+ text_mask = generate_mask(
636
+ gamma_func("scaled-cosine"), batch_size, self.text_seq_len, device
637
+ )
638
+
639
+ image_mask = generate_mask(self.gamma, batch_size, self.image_seq_len, device)
640
+
641
+ # Embed and transform inputs
642
+ logits, _, labels = self.embed_and_transform(
643
+ [text, condition, image],
644
+ [text_mask, None, image_mask],
645
+ return_encoding,
646
+ device,
647
+ )
648
+
649
+ # If not returning loss, return the logits
650
+ if not return_loss:
651
+ return logits
652
+
653
+ # Separate labels
654
+ text, condition, image = labels
655
+
656
+ # Add SEP token to end of text label
657
+ sep_token = torch.tensor(self.sep_token, device=device).repeat(
658
+ labels.shape[0], 1
659
+ )
660
+ labels = torch.cat([labels, sep_token], dim=1)
661
+
662
+ # If condition exists and condition vae is defined, add the condition to the labels
663
+ if exists(condition) and exists(self.condition_vae):
664
+ offsetted_condition = condition + self.num_text_tokens + 1
665
+ labels = torch.cat((labels, offsetted_condition), dim=1)
666
+
667
+ # Add image to the labels
668
+ offsetted_image = (
669
+ image + self.num_text_tokens + 1 + self.num_condition_tokens + 1
670
+ )
671
+ labels = torch.cat((labels, offsetted_image), dim=1)
672
+
673
+ # Rearrange logits for cross-entropy loss calculation
674
+ # Logits size: (batch_size, vocab_size, total_seq_len)
675
+ # Labels size: (batch_size, total_seq_len)
676
+ logits = rearrange(logits, "b n c -> b c n")
677
+
678
+ # Calculate cross-entropy loss for text and image
679
+ loss_text = F.cross_entropy(
680
+ logits[:, :, : self.text_seq_len],
681
+ labels[:, : self.text_seq_len],
682
+ reduction="none",
683
+ )[text_mask].mean()
684
+
685
+ loss_img = F.cross_entropy(
686
+ logits[:, :, self.text_seq_len + 1 + self.condition_seq_len :],
687
+ labels[:, self.text_seq_len + 1 + self.condition_seq_len :],
688
+ reduction="none",
689
+ )[image_mask].mean()
690
+
691
+ # Calculate total loss
692
+ loss = (loss_text + self.loss_img_weight * loss_img) / (
693
+ self.loss_img_weight + 1
694
+ )
695
+
696
+ loss_dict = {
697
+ "loss_text": loss_text,
698
+ # "loss_cond": loss_cond,
699
+ "loss_img": loss_img,
700
+ "loss": torch.nan_to_num(loss, 0.0, 0.0, 0.0),
701
+ }
702
+
703
+ return loss, loss_dict, None
704
+
705
+ def create_tensors(self, text, condition, image):
706
+ """
707
+ This function creates tensors for text, condition, and image when they are not provided as inputs to the sample function.
708
+ """
709
+ device = next(
710
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
711
+ None,
712
+ ).device
713
+
714
+ if not isinstance(text, torch.Tensor):
715
+ text = (
716
+ torch.ones(1, self.text_seq_len, device=device, dtype=torch.long)
717
+ * self.text_mask_token
718
+ )
719
+
720
+ if not isinstance(condition, torch.Tensor):
721
+ condition = (
722
+ torch.ones(1, self.condition_seq_len, device=device, dtype=torch.long)
723
+ * self.cond_mask_token
724
+ )
725
+ else:
726
+ with torch.no_grad():
727
+ condition = self.condition_vae.get_codebook_indices(condition)
728
+
729
+ if not isinstance(image, torch.Tensor):
730
+ image = (
731
+ torch.ones(1, self.image_seq_len, device=device, dtype=torch.long)
732
+ * self.image_mask_token
733
+ )
734
+ else:
735
+ with torch.no_grad():
736
+ image = self.vae.get_codebook_indices(image)
737
+
738
+ return text, condition, image
739
+
740
+ @torch.no_grad()
741
+ @eval_decorator
742
+ def sample(
743
+ self,
744
+ text=None,
745
+ condition=None,
746
+ image=None,
747
+ temperature=1.0,
748
+ filter_thres=0.9,
749
+ progress=False,
750
+ timesteps=1,
751
+ force_aas=True,
752
+ ):
753
+ # ensure timesteps is a positive integer
754
+ assert int(timesteps) > 0
755
+ # set model and VAEs to evaluation mode
756
+ self.eval()
757
+ vae = self.vae.eval()
758
+ if progress == True:
759
+ progress = tqdm
760
+ else:
761
+ progress = lambda x: x
762
+
763
+
764
+ # ensure that at least one of text, condition, or image is supplied
765
+ assert (
766
+ isinstance(text, torch.Tensor)
767
+ or isinstance(condition, torch.Tensor)
768
+ or isinstance(image, torch.Tensor)
769
+ ), "some data must be supplied"
770
+
771
+ # convert text, condition, and image to tensors if they aren't already
772
+ text, condition, image = self.create_tensors(text, condition, image)
773
+
774
+ # determine the maximum batch size of the input tensors
775
+ batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
776
+
777
+ # match the batch sizes of text, condition, and image
778
+ text, condition, image = match_batch_size(text, condition, image, batch_size)
779
+
780
+ # determine the device of the tensors
781
+ device = next(
782
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
783
+ None,
784
+ ).device
785
+
786
+ assert text.shape[0] == condition.shape[0] == image.shape[0]
787
+
788
+ # Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
789
+
790
+ # full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
791
+ full_text_logits = torch.zeros(
792
+ batch_size, self.text_seq_len, self.num_text_tokens
793
+ ).to(device)
794
+
795
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
796
+ full_text_logits = full_text_logits.scatter_(
797
+ dim=-1, index=text.unsqueeze(-1), value=1
798
+ )
799
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
800
+ full_image_logits = torch.zeros(
801
+ batch_size, self.image_seq_len, self.num_image_tokens + 1
802
+ ).to(device)
803
+
804
+ # Remove the last token from each image sequence by setting full_image_logits to its first num_image_tokens elements
805
+ full_image_logits = full_image_logits.scatter_(
806
+ dim=-1, index=image.unsqueeze(-1), value=1
807
+ )
808
+
809
+ # cut off mask token
810
+ full_image_logits = full_image_logits[:, :, : self.num_image_tokens]
811
+
812
+ count = 0
813
+
814
+ for timestep in progress(torch.linspace(0, 1, timesteps)):
815
+ # Create masks for the text, condition, and image tensors
816
+ text_mask = text == self.text_mask_token
817
+ cond_mask = condition == self.cond_mask_token
818
+ image_mask = image == self.image_mask_token
819
+
820
+ # Calculate logits and samples using the calculate_logits function
821
+ logits, sample = calculate_logits(
822
+ [text, condition, image],
823
+ [text_mask, cond_mask, image_mask],
824
+ self.embed_and_transform,
825
+ filter_thres,
826
+ temperature,
827
+ )
828
+
829
+ # Calculate the number of masked tokens in the text and image tensors
830
+ num_masked_text_tokens = torch.sum(text_mask, dim=1)[0]
831
+ num_masked_image_tokens = torch.sum(image_mask, dim=1)[0]
832
+
833
+ # If there are masked text tokens, unmask them using unmask_tokens and fill the full text logits tensor with -inf for unmasked tokens
834
+ if num_masked_text_tokens.any() > 0:
835
+ text, full_text_logits = unmask_tokens(
836
+ text,
837
+ text_mask,
838
+ num_masked_text_tokens,
839
+ logits[:, : self.text_seq_len, : self.num_text_tokens],
840
+ sample[:, : self.text_seq_len, : self.num_text_tokens],
841
+ timestep,
842
+ timesteps,
843
+ self.gamma,
844
+ suppress_invalid_text_tokens,
845
+ self.pad_token,
846
+ self.text_mask_token,
847
+ force_aas=force_aas,
848
+ )
849
+ full_text_logits = full_text_logits.masked_fill(
850
+ ~text_mask.unsqueeze(-1), -torch.inf
851
+ )
852
+
853
+ # If there are masked image tokens, unmask them using unmask_tokens and fill the full image logits tensor with -inf for unmasked tokens
854
+ if num_masked_image_tokens > 0:
855
+ image, full_image_logits = unmask_tokens(
856
+ image,
857
+ image_mask,
858
+ num_masked_image_tokens,
859
+ logits[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
860
+ sample[:, -self.image_seq_len :, -(self.num_image_tokens + 1) : -1],
861
+ timestep,
862
+ timesteps,
863
+ self.gamma,
864
+ )
865
+ full_text_logits = full_text_logits.masked_fill(
866
+ ~text_mask.unsqueeze(-1), -torch.inf
867
+ )
868
+
869
+ # Generate heatmap
870
+ with torch.no_grad():
871
+ # Normalize full image logits tensor
872
+ full_image_logits /= torch.max(
873
+ torch.abs(full_image_logits), dim=-1, keepdim=True
874
+ ).values
875
+
876
+ # Apply quantize embedding to full image logits tensor
877
+ full_image_logits = torch.matmul(
878
+ full_image_logits, self.vae.model.quantize.embedding.weight
879
+ )
880
+
881
+ # Rearrange full image logits tensor
882
+ h = int(self.image_seq_len**0.5)
883
+ full_image_logits = rearrange(
884
+ full_image_logits, "b (h w) c -> b c h w", h=h
885
+ )
886
+
887
+ # Decode full image logits tensor
888
+ full_image_logits = self.vae.model.decode(full_image_logits)
889
+
890
+ # Add clipping to full image logits tensor
891
+ max_val = torch.max(full_image_logits.view(batch_size, -1), dim=-1)[0]
892
+ min_val = torch.min(full_image_logits.view(batch_size, -1), dim=-1)[0]
893
+ full_image_logits += torch.clip(1 - max_val, 0, float("inf")).view(
894
+ batch_size, 1, 1, 1
895
+ )
896
+ full_image_logits += torch.clip(0 - min_val, float("-inf"), 0).view(
897
+ batch_size, 1, 1, 1
898
+ )
899
+
900
+ # Clip full image logits tensor values to the range [0, 1]
901
+ full_image_logits = torch.clip(full_image_logits, 0, 1)
902
+
903
+ # Return text tensor, detokenized text tensor, full text logits tensor,
904
+ # binary image tensor, and full image logits tensor
905
+ return (
906
+ text,
907
+ detokenize_text(self.text_embedding, text),
908
+ full_text_logits,
909
+ 1.0 * (vae.decode(image) > 0.5),
910
+ full_image_logits,
911
+ )
912
+
913
+ @torch.no_grad()
914
+ @eval_decorator
915
+ def sample_text(
916
+ self,
917
+ text=False,
918
+ condition=False,
919
+ image=False,
920
+ temperature=1.0,
921
+ filter_thres=0.9,
922
+ progress=False,
923
+ n_unmask=1,
924
+ place_amino=True,
925
+ force_aas=False,
926
+ ):
927
+ # set model and VAEs to evaluation mode
928
+ self.eval()
929
+
930
+ # ensure that at least one of text, condition, or image is supplied
931
+ assert (
932
+ isinstance(text, torch.Tensor)
933
+ or isinstance(condition, torch.Tensor)
934
+ or isinstance(image, torch.Tensor)
935
+ ), "some data must be supplied"
936
+
937
+ # convert text, condition, and image to tensors if they aren't already
938
+ text, condition, image = self.create_tensors(text, condition, image)
939
+
940
+ # determine the maximum batch size of the input tensors
941
+ batch_size = max(text.shape[0], condition.shape[0], image.shape[0])
942
+
943
+ # match the batch sizes of text, condition, and image
944
+ text, condition, image = match_batch_size(text, condition, image, batch_size)
945
+
946
+ # determine the device of the tensors
947
+ device = next(
948
+ filter(lambda x: isinstance(x, torch.Tensor), [text, condition, image]),
949
+ None,
950
+ ).device
951
+
952
+ assert text.shape[0] == condition.shape[0] == image.shape[0]
953
+
954
+ # Create a tensor of zeros of size (batch_size, image_seq_len, num_image_tokens + 1) and set it to device
955
+
956
+ # full_text_logits = torch.zeros(batch_size, self.text_seq_len, self.num_text_tokens+3).to(device)
957
+ full_text_logits = torch.zeros(
958
+ batch_size, self.text_seq_len, self.num_text_tokens
959
+ ).to(device)
960
+
961
+ # Use scatter_ to fill the tensor with 1 values at the indices given by the image tensor
962
+ full_text_logits = full_text_logits.scatter_(
963
+ dim=-1, index=text.unsqueeze(-1), value=1
964
+ )
965
+
966
+ text_mask = text == self.text_mask_token
967
+ cond_mask = condition == self.cond_mask_token
968
+ image_mask = image == self.image_mask_token
969
+
970
+ mask_indices = text_mask.nonzero()
971
+ non_mask_indices = (~text_mask).nonzero()
972
+
973
+ # figure out the center of the amino acids to determine generation direction
974
+ central_protein_index = torch.tensor(
975
+ [
976
+ torch.median(
977
+ non_mask_indices[torch.where(non_mask_indices[:, 0] == idx)][:, -1]
978
+ )
979
+ for idx in range(batch_size)
980
+ ]
981
+ )
982
+
983
+ count = 1
984
+
985
+ run_mask = text_mask
986
+ if progress:
987
+ pbar = progress(total=torch.sum(run_mask).item())
988
+ while torch.sum(run_mask) > 0:
989
+ logits, sample = calculate_logits(
990
+ [text, condition, image],
991
+ [text_mask, cond_mask, image_mask],
992
+ self.embed_and_transform,
993
+ filter_thres,
994
+ temperature,
995
+ )
996
+
997
+ # sub_sample: [batch_size ,text_seq_len ,num_text_tokens]
998
+ sub_sample = sample[:, : self.text_seq_len, : self.num_text_tokens]
999
+ sub_sample = sub_sample.masked_fill(~text_mask.unsqueeze(-1), -torch.inf)
1000
+ sub_sample = suppress_invalid_text_tokens(
1001
+ text, sub_sample, 0, 2, self.pad_token, self.text_mask_token, force_aas
1002
+ )
1003
+ # calculate % to unmasked
1004
+ # get most likely token and probability for each position
1005
+
1006
+ for idx in range(batch_size):
1007
+ selected_mask_indices = mask_indices[
1008
+ torch.where(mask_indices[:, 0] == idx)
1009
+ ][:, -1]
1010
+
1011
+ # Generate to the left
1012
+ if selected_mask_indices[-count] < central_protein_index[idx]:
1013
+ unmask_index = selected_mask_indices[-count]
1014
+ left_sample = max(0, (unmask_index + 1) - n_unmask)
1015
+ right_sample = min(unmask_index + 1, self.text_seq_len - 1)
1016
+ central_protein_index[idx] = max(
1017
+ 0, central_protein_index[idx] - 0.5 * n_unmask
1018
+ )
1019
+
1020
+ # Generate to the right
1021
+ elif selected_mask_indices[count - 1] > central_protein_index[idx]:
1022
+ unmask_index = selected_mask_indices[count - 1]
1023
+ left_sample = max(0, unmask_index)
1024
+ right_sample = min(unmask_index + n_unmask, self.text_seq_len - 1)
1025
+ central_protein_index[idx] = min(
1026
+ central_protein_index[idx] + 0.5 * n_unmask,
1027
+ self.text_seq_len - 1,
1028
+ )
1029
+
1030
+ # save logits for relevant position
1031
+ full_text_logits[
1032
+ idx, left_sample:right_sample, : self.text_seq_len - 1
1033
+ ] = logits[idx, left_sample:right_sample, : self.num_text_tokens]
1034
+
1035
+ run_mask[idx, left_sample:right_sample] = False
1036
+
1037
+ # you may want to resample the amion acids or calculate marginal probs
1038
+ # if so, set place_amino to false
1039
+ if place_amino:
1040
+ text[idx, left_sample:right_sample] = torch.where(
1041
+ text[idx, left_sample:right_sample] == self.text_mask_token,
1042
+ sub_sample[
1043
+ idx, left_sample:right_sample, : self.num_text_tokens
1044
+ ].argmax(dim=-1),
1045
+ text[idx, left_sample:right_sample],
1046
+ )
1047
+
1048
+ text_mask = run_mask
1049
+
1050
+ count += n_unmask
1051
+
1052
+ if progress:
1053
+ pbar.update(n_unmask)
1054
+ if progress:
1055
+ pbar.close()
1056
+
1057
+ return (
1058
+ text,
1059
+ detokenize_text(self.text_embedding, text),
1060
+ full_text_logits,
1061
+ )
celle/reversible.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ # for routing arguments into the functions of the reversible layer
4
+ def route_args(router, args, depth):
5
+ routed_args = [(dict(), dict()) for _ in range(depth)]
6
+ matched_keys = [key for key in args.keys() if key in router]
7
+
8
+ for key in matched_keys:
9
+ val = args[key]
10
+ for depth, ((f_args, g_args), routes) in enumerate(
11
+ zip(routed_args, router[key])
12
+ ):
13
+ new_f_args, new_g_args = map(
14
+ lambda route: ({key: val} if route else {}), routes
15
+ )
16
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
17
+ return routed_args
18
+
19
+ class SequentialSequence(nn.Module):
20
+ def __init__(self, layers, args_route={}, layer_dropout=0.0):
21
+ super().__init__()
22
+ assert all(
23
+ len(route) == len(layers) for route in args_route.values()
24
+ ), "each argument route map must have the same depth as the number of sequential layers"
25
+ self.layers = layers
26
+ self.args_route = args_route
27
+ self.layer_dropout = layer_dropout
28
+
29
+ def forward(self, x, **kwargs):
30
+ args = route_args(self.args_route, kwargs, len(self.layers))
31
+ layers_and_args = list(zip(self.layers, args))
32
+
33
+ for (f, g), (f_args, g_args) in layers_and_args:
34
+ x = x + f(x, **f_args)
35
+ x = x + g(x, **g_args)
36
+ return x
celle/transformer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+ from celle.reversible import SequentialSequence
9
+ from celle.attention import Attention
10
+
11
+ from rotary_embedding_torch import RotaryEmbedding, broadcat
12
+ from celle.utils import exists, default, cast_tuple
13
+
14
+ # https://arxiv.org/abs/2103.17239
15
+ class LayerScale(nn.Module):
16
+ def __init__(self, dim, depth, fn):
17
+ super().__init__()
18
+ if depth <= 18:
19
+ init_eps = 0.1
20
+ elif depth > 18 and depth <= 24:
21
+ init_eps = 1e-5
22
+ else:
23
+ init_eps = 1e-6
24
+
25
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
26
+ self.scale = nn.Parameter(scale)
27
+ self.fn = fn
28
+
29
+ def forward(self, x, **kwargs):
30
+ return self.fn(x, **kwargs) * self.scale
31
+
32
+
33
+ # layer norm
34
+ class PreNorm(nn.Module):
35
+ def __init__(self, dim, fn):
36
+ super().__init__()
37
+ self.norm = nn.LayerNorm(dim)
38
+ self.norm_out = nn.Identity()
39
+ self.fn = fn
40
+
41
+ def forward(self, x, **kwargs):
42
+ x = self.norm(x)
43
+ x = self.fn(x, **kwargs)
44
+ return self.norm_out(x)
45
+
46
+
47
+ # feed forward
48
+
49
+
50
+ class GEGLU(nn.Module):
51
+ def forward(self, x):
52
+ x, gates = x.chunk(2, dim=-1)
53
+ return x * F.gelu(gates)
54
+
55
+
56
+ class FeedForward(nn.Module):
57
+ def __init__(self, dim, dropout=0.0, mult=4.0):
58
+ super().__init__()
59
+ self.net = nn.Sequential(
60
+ nn.Linear(dim, dim * mult * 2),
61
+ GEGLU(),
62
+ nn.Dropout(dropout),
63
+ nn.Linear(dim * mult, dim),
64
+ )
65
+
66
+ def forward(self, x):
67
+ return self.net(x)
68
+
69
+
70
+ # main transformer class
71
+ class Transformer(nn.Module):
72
+ def __init__(
73
+ self,
74
+ *,
75
+ dim,
76
+ depth,
77
+ seq_len,
78
+ causal=True,
79
+ heads=8,
80
+ dim_head=64,
81
+ ff_mult=4,
82
+ attn_dropout=0.0,
83
+ ff_dropout=0.0,
84
+ image_fmap_size=None,
85
+ num_images=None,
86
+ stable=False,
87
+ rotary_emb=True,
88
+ ):
89
+ super().__init__()
90
+ layers = nn.ModuleList([])
91
+
92
+ self.seq_len = seq_len
93
+ self.image_fmap_size = image_fmap_size
94
+
95
+ for ind in range(depth):
96
+
97
+ attn_class = partial(Attention, stable=stable)
98
+
99
+ attn = attn_class(
100
+ dim,
101
+ causal=causal,
102
+ seq_len=seq_len,
103
+ heads=heads,
104
+ dim_head=dim_head,
105
+ dropout=attn_dropout,
106
+ )
107
+
108
+ ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
109
+
110
+ layers.append(
111
+ nn.ModuleList(
112
+ [
113
+ LayerScale(
114
+ dim, ind + 1, PreNorm(dim, attn)
115
+ ),
116
+ LayerScale(
117
+ dim, ind + 1, PreNorm(dim, ff)
118
+ ),
119
+ ]
120
+ )
121
+ )
122
+
123
+ # pairs arguments with attention layer
124
+ route_attn = ((True, False),) * depth
125
+ attn_route_map = {
126
+ "mask": route_attn,
127
+ "rotary_pos_emb": route_attn,
128
+ }
129
+
130
+ self.layers = SequentialSequence(layers, args_route=attn_route_map)
131
+
132
+ # generate positional embeddings for rotary
133
+
134
+ pos_emb = None
135
+ if rotary_emb:
136
+ rot_dim = dim_head // 3
137
+ img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images
138
+
139
+ text_len = seq_len - img_seq_len + 1
140
+
141
+ text_pos_emb = RotaryEmbedding(dim=rot_dim)
142
+
143
+ img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel")
144
+
145
+ text_freqs = text_pos_emb(torch.arange(text_len))
146
+
147
+ img_to_text_freqs = text_pos_emb(
148
+ torch.full((img_seq_len,), 8192)
149
+ ) # image is given a position far away from text
150
+
151
+ text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0)
152
+
153
+ img_freqs_axial = img_axial_pos_emb(
154
+ torch.linspace(-1, 1, steps=image_fmap_size)
155
+ )
156
+
157
+ if num_images > 1:
158
+ split_img_freqs_axial = torch.split(
159
+ img_freqs_axial, image_fmap_size // num_images, dim=0
160
+ )
161
+
162
+ split_img_freqs = [
163
+ broadcat(
164
+ (
165
+ rearrange(img_freqs_axial_per_image, "i d -> i () d"),
166
+ rearrange(img_freqs_axial_per_image, "j d -> () j d"),
167
+ ),
168
+ dim=-1,
169
+ )
170
+ for img_freqs_axial_per_image in split_img_freqs_axial
171
+ ]
172
+
173
+ split_img_freqs = [
174
+ rearrange(img_freqs_per_image, "h w d -> (h w) d")
175
+ for img_freqs_per_image in split_img_freqs
176
+ ]
177
+
178
+ # concat per image-image_freqs
179
+
180
+ img_freqs = torch.cat(split_img_freqs, dim=0)
181
+
182
+ elif num_images == 1:
183
+ img_freqs = broadcat(
184
+ (
185
+ rearrange(img_freqs_axial, "i d -> i () d"),
186
+ rearrange(img_freqs_axial, "j d -> () j d"),
187
+ ),
188
+ dim=-1,
189
+ )
190
+
191
+ img_freqs = rearrange(img_freqs, "h w d -> (h w) d")
192
+
193
+ else:
194
+ assert False, "num_images must be int greater than 0"
195
+ self.img_axial_pos_emb = img_axial_pos_emb
196
+ self.text_pos_emb = text_pos_emb
197
+
198
+ text_axial_freqs = img_axial_pos_emb(
199
+ torch.full((text_len,), -10.0)
200
+ ) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
201
+
202
+ text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1)
203
+
204
+ img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0)
205
+
206
+ pos_emb = torch.cat((text_freqs, img_freqs), dim=-1)
207
+
208
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
209
+
210
+ self.register_buffer("pos_emb", pos_emb)
211
+
212
+ def forward(self, x, **kwargs):
213
+ return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs)
celle/utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from math import pi
4
+ import torchvision.transforms.functional as TF
5
+
6
+
7
+ # Define helper functions
8
+ def exists(val):
9
+ """Check if a variable exists"""
10
+ return val is not None
11
+
12
+
13
+ def uniq(arr):
14
+ return {el: True for el in arr}.keys()
15
+
16
+
17
+ def default(val, d):
18
+ """If a value exists, return it; otherwise, return a default value"""
19
+ return val if exists(val) else d
20
+
21
+
22
+ def max_neg_value(t):
23
+ return -torch.finfo(t.dtype).max
24
+
25
+
26
+ def cast_tuple(val, depth=1):
27
+ if isinstance(val, list):
28
+ val = tuple(val)
29
+ return val if isinstance(val, tuple) else (val,) * depth
30
+
31
+
32
+ def is_empty(t):
33
+ """Check if a tensor is empty"""
34
+ # Return True if the number of elements in the tensor is zero, else False
35
+ return t.nelement() == 0
36
+
37
+
38
+ def masked_mean(t, mask, dim=1):
39
+ """
40
+ Compute the mean of a tensor, masked by a given mask
41
+
42
+ Args:
43
+ t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim)
44
+ mask (torch.Tensor): mask tensor of shape (batch_size, seq_len)
45
+ dim (int): dimension along which to compute the mean (default=1)
46
+
47
+ Returns:
48
+ torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim)
49
+ """
50
+ t = t.masked_fill(~mask[:, :, None], 0.0)
51
+ return t.sum(dim=1) / mask.sum(dim=1)[..., None]
52
+
53
+
54
+ def set_requires_grad(model, value):
55
+ """
56
+ Set whether or not the model's parameters require gradients
57
+
58
+ Args:
59
+ model (torch.nn.Module): the PyTorch model to modify
60
+ value (bool): whether or not to require gradients
61
+ """
62
+ for param in model.parameters():
63
+ param.requires_grad = value
64
+
65
+
66
+ def eval_decorator(fn):
67
+ """
68
+ Decorator function to evaluate a given function
69
+
70
+ Args:
71
+ fn (callable): function to evaluate
72
+
73
+ Returns:
74
+ callable: the decorated function
75
+ """
76
+
77
+ def inner(model, *args, **kwargs):
78
+ was_training = model.training
79
+ model.eval()
80
+ out = fn(model, *args, **kwargs)
81
+ model.train(was_training)
82
+ return out
83
+
84
+ return inner
85
+
86
+
87
+ def log(t, eps=1e-20):
88
+ """
89
+ Compute the natural logarithm of a tensor
90
+
91
+ Args:
92
+ t (torch.Tensor): input tensor
93
+ eps (float): small value to add to prevent taking the log of 0 (default=1e-20)
94
+
95
+ Returns:
96
+ torch.Tensor: the natural logarithm of the input tensor
97
+ """
98
+ return torch.log(t + eps)
99
+
100
+
101
+ def gumbel_noise(t):
102
+ """
103
+ Generate Gumbel noise
104
+
105
+ Args:
106
+ t (torch.Tensor): input tensor
107
+
108
+ Returns:
109
+ torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor
110
+ """
111
+ noise = torch.zeros_like(t).uniform_(0, 1)
112
+ return -log(-log(noise))
113
+
114
+
115
+ def gumbel_sample(t, temperature=0.9, dim=-1):
116
+ """
117
+ Sample from a Gumbel-softmax distribution
118
+
119
+ Args:
120
+ t (torch.Tensor): input tensor of shape (batch_size, num_classes)
121
+ temperature (float): temperature for the Gumbel-softmax distribution (default=0.9)
122
+ dim (int): dimension along which to sample (default=-1)
123
+
124
+ Returns:
125
+ torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor
126
+ """
127
+ return (t / max(temperature, 1e-10)) + gumbel_noise(t)
128
+
129
+
130
+ def top_k(logits, thres=0.5):
131
+ """
132
+ Return a tensor where all but the top k values are set to negative infinity
133
+
134
+ Args:
135
+ logits (torch.Tensor): input tensor of shape (batch_size, num_classes)
136
+ thres (float): threshold for the top k values (default=0.5)
137
+
138
+ Returns:
139
+ torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity
140
+ """
141
+ num_logits = logits.shape[-1]
142
+ k = max(int((1 - thres) * num_logits), 1)
143
+ val, ind = torch.topk(logits, k)
144
+ probs = torch.full_like(logits, float("-inf"))
145
+ probs.scatter_(-1, ind, val)
146
+ return probs
147
+
148
+
149
+ def gamma_func(mode="cosine", scale=0.15):
150
+ """Return a function that takes a single input r and returns a value based on the selected mode"""
151
+
152
+ # Define a different function based on the selected mode
153
+ if mode == "linear":
154
+ return lambda r: 1 - r
155
+ elif mode == "cosine":
156
+ return lambda r: torch.cos(r * pi / 2)
157
+ elif mode == "square":
158
+ return lambda r: 1 - r**2
159
+ elif mode == "cubic":
160
+ return lambda r: 1 - r**3
161
+ elif mode == "scaled-cosine":
162
+ return lambda r: scale * (torch.cos(r * pi / 2))
163
+ else:
164
+ # Raise an error if the selected mode is not implemented
165
+ raise NotImplementedError
166
+
167
+
168
+ class always:
169
+ """Helper class to always return a given value"""
170
+
171
+ def __init__(self, val):
172
+ self.val = val
173
+
174
+ def __call__(self, x, *args, **kwargs):
175
+ return self.val
176
+
177
+
178
+ class DivideMax(torch.nn.Module):
179
+ def __init__(self, dim):
180
+ super().__init__()
181
+ self.dim = dim
182
+
183
+ def forward(self, x):
184
+ maxes = x.amax(dim=self.dim, keepdim=True).detach()
185
+ return x / maxes
186
+
187
+ def replace_outliers(image, percentile=0.0001):
188
+
189
+ lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile(
190
+ image, 1 - percentile
191
+ )
192
+ mask = (image <= upper_bound) & (image >= lower_bound)
193
+
194
+ valid_pixels = image[mask]
195
+
196
+ image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels))
197
+
198
+ return image
199
+
200
+
201
+ def process_image(image, dataset, image_type=None):
202
+ image = TF.to_tensor(image).unsqueeze(0)
203
+
204
+ if dataset == "HPA":
205
+ if image_type == 'nucleus':
206
+ normalize = (0.0655, 0.0650)
207
+
208
+ elif image_type == 'protein':
209
+ normalize = (0.1732, 0.1208)
210
+
211
+ elif dataset == "OpenCell":
212
+
213
+ if image_type == 'nucleus':
214
+ normalize = (0.0272, 0.0244)
215
+
216
+ elif image_type == 'protein':
217
+ normalize = (0.0486, 0.0671)
218
+
219
+ t_forms = []
220
+
221
+ t_forms.append(transforms.RandomCrop(256))
222
+
223
+ # t_forms.append(transforms.Normalize(normalize[0],normalize[1]))
224
+
225
+
226
+ image = transforms.Compose(t_forms)(image)
227
+
228
+ return image
celle/vae.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt, log
2
+ from omegaconf import OmegaConf
3
+ import importlib
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange
10
+
11
+ # helpers methods
12
+
13
+
14
+ def load_model(path):
15
+ with open(path, "rb") as f:
16
+ return torch.load(f, map_location=torch.device("cpu"))
17
+
18
+
19
+ def map_pixels(x, eps=0.1):
20
+ return (1 - 2 * eps) * x + eps
21
+
22
+
23
+ def unmap_pixels(x, eps=0.1):
24
+ return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
25
+
26
+
27
+ def make_contiguous(module):
28
+ with torch.no_grad():
29
+ for param in module.parameters():
30
+ param.set_(param.contiguous())
31
+
32
+
33
+ # VQGAN from Taming Transformers paper
34
+ # https://arxiv.org/abs/2012.09841
35
+
36
+
37
+ def get_obj_from_str(string, reload=False):
38
+ module, cls = string.rsplit(".", 1)
39
+ if reload:
40
+ module_imp = importlib.import_module(module)
41
+ importlib.reload(module_imp)
42
+ return getattr(importlib.import_module(module, package=None), cls)
43
+
44
+
45
+ def instantiate_from_config(config):
46
+ if not "target" in config:
47
+ raise KeyError("Expected key `target` to instantiate.")
48
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
49
+
50
+
51
+ class VQGanVAE(nn.Module):
52
+ def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels=1):
53
+ super().__init__()
54
+
55
+ assert vqgan_config_path is not None
56
+
57
+ model_path = vqgan_model_path
58
+ config_path = vqgan_config_path
59
+
60
+ config = OmegaConf.load(config_path)
61
+
62
+ model = instantiate_from_config(config["model"])
63
+
64
+ if vqgan_model_path:
65
+
66
+ state = torch.load(model_path, map_location="cpu")["state_dict"]
67
+ model.load_state_dict(state, strict=True)
68
+
69
+ print(f"Loaded VQGAN from {model_path} and {config_path}")
70
+
71
+ self.model = model
72
+
73
+ # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
74
+ f = (
75
+ config.model.params.ddconfig.resolution
76
+ / config.model.params.ddconfig.attn_resolutions[0]
77
+ )
78
+ self.num_layers = int(log(f) / log(2))
79
+ self.image_size = config.model.params.ddconfig.resolution
80
+ self.num_tokens = config.model.params.n_embed
81
+ # self.is_gumbel = isinstance(self.model, GumbelVQ)
82
+ self.is_gumbel = False
83
+ self.channels = config.model.params.ddconfig.in_channels
84
+
85
+ def encode(self, img):
86
+ return self.model.encode(img)
87
+
88
+ def get_codebook_indices(self, img):
89
+ b = img.shape[0]
90
+ # img = (2 * img) - 1
91
+ _, _, [_, _, indices] = self.encode(img)
92
+ if self.is_gumbel:
93
+ return rearrange(indices, "b h w -> b (h w)", b=b)
94
+ return rearrange(indices, "(b n) -> b n", b=b)
95
+
96
+ def decode(self, img_seq):
97
+ b, n = img_seq.shape
98
+ one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
99
+ z = (
100
+ one_hot_indices @ self.model.quantize.embed.weight
101
+ if self.is_gumbel
102
+ else (one_hot_indices @ self.model.quantize.embedding.weight)
103
+ )
104
+
105
+ z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
106
+ img = self.model.decode(z)
107
+
108
+ # img = (img.clamp(-1.0, 1.0) + 1) * 0.5
109
+ return img
110
+
111
+ def forward(self, img, optimizer_idx=1):
112
+ return self.model.training_step(img, optimizer_idx=optimizer_idx)
celle_main.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.random
6
+ from torch.optim import AdamW
7
+ from torch.utils.data import DataLoader
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import seed_everything
10
+ from pytorch_lightning.trainer import Trainer
11
+
12
+ from dataloader import CellLoader
13
+ from celle import VQGanVAE, CELLE
14
+ from omegaconf import OmegaConf
15
+ import argparse, os, sys, datetime, glob
16
+
17
+ from celle.celle import gumbel_sample, top_k
18
+
19
+ torch.random.manual_seed(42)
20
+ np.random.seed(42)
21
+
22
+ from celle_taming_main import (
23
+ instantiate_from_config,
24
+ nondefault_trainer_args,
25
+ get_parser,
26
+ )
27
+
28
+
29
+ class CellDataModule(pl.LightningDataModule):
30
+ def __init__(
31
+ self,
32
+ data_csv,
33
+ dataset,
34
+ sequence_mode="standard",
35
+ vocab="bert",
36
+ crop_size=256,
37
+ resize=600,
38
+ batch_size=1,
39
+ threshold="median",
40
+ text_seq_len=1000,
41
+ num_workers=1,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.data_csv = data_csv
47
+ self.dataset = dataset
48
+ self.protein_sequence_length = 0
49
+ self.image_folders = []
50
+ self.crop_size = crop_size
51
+ self.resize = resize
52
+ self.batch_size = batch_size
53
+ self.sequence_mode = sequence_mode
54
+ self.threshold = threshold
55
+ self.text_seq_len = int(text_seq_len)
56
+ self.vocab = vocab
57
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
58
+
59
+ def setup(self, stage=None):
60
+ # called on every GPU
61
+ self.cell_dataset_train = CellLoader(
62
+ data_csv=self.data_csv,
63
+ dataset=self.dataset,
64
+ crop_size=self.crop_size,
65
+ resize=self.resize,
66
+ split_key="train",
67
+ crop_method="random",
68
+ sequence_mode=self.sequence_mode,
69
+ vocab=self.vocab,
70
+ text_seq_len=self.text_seq_len,
71
+ threshold=self.threshold,
72
+ )
73
+
74
+ self.cell_dataset_val = CellLoader(
75
+ data_csv=self.data_csv,
76
+ dataset=self.dataset,
77
+ crop_size=self.crop_size,
78
+ resize=self.resize,
79
+ crop_method="center",
80
+ split_key="val",
81
+ sequence_mode=self.sequence_mode,
82
+ vocab=self.vocab,
83
+ text_seq_len=self.text_seq_len,
84
+ threshold=self.threshold,
85
+ )
86
+
87
+ def prepare_data(self):
88
+
89
+ pass
90
+
91
+ def train_dataloader(self):
92
+ return DataLoader(
93
+ self.cell_dataset_train,
94
+ num_workers=self.num_workers,
95
+ shuffle=True,
96
+ batch_size=self.batch_size,
97
+ )
98
+
99
+ def val_dataloader(self):
100
+ return DataLoader(
101
+ self.cell_dataset_val,
102
+ num_workers=self.num_workers,
103
+ batch_size=self.batch_size,
104
+ )
105
+
106
+ # def test_dataloader(self):
107
+ # transforms = ...
108
+ # return DataLoader(self.test, batch_size=64)
109
+
110
+
111
+ class CELLE_trainer(pl.LightningModule):
112
+ def __init__(
113
+ self,
114
+ vqgan_model_path,
115
+ vqgan_config_path,
116
+ ckpt_path=None,
117
+ image_key="threshold",
118
+ condition_model_path=None,
119
+ condition_config_path=None,
120
+ num_images=2,
121
+ dim=2,
122
+ num_text_tokens=30,
123
+ text_seq_len=1000,
124
+ depth=16,
125
+ heads=16,
126
+ dim_head=64,
127
+ attn_dropout=0.1,
128
+ ff_dropout=0.1,
129
+ attn_types="full",
130
+ loss_img_weight=7,
131
+ stable=False,
132
+ rotary_emb=True,
133
+ text_embedding="bert",
134
+ fixed_embedding=True,
135
+ loss_cond_weight=1,
136
+ learning_rate=3e-4,
137
+ monitor="val_loss",
138
+ ):
139
+ super().__init__()
140
+
141
+ vae = VQGanVAE(
142
+ vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path
143
+ )
144
+
145
+ self.image_key = image_key
146
+
147
+ if condition_config_path:
148
+ condition_vae = VQGanVAE(
149
+ vqgan_model_path=condition_model_path,
150
+ vqgan_config_path=condition_config_path,
151
+ )
152
+ else:
153
+ condition_vae = None
154
+
155
+ self.celle = CELLE(
156
+ dim=dim,
157
+ vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens
158
+ condition_vae=condition_vae,
159
+ num_images=num_images,
160
+ num_text_tokens=num_text_tokens, # vocab size for text
161
+ text_seq_len=text_seq_len, # text sequence length
162
+ depth=depth, # should aim to be 64
163
+ heads=heads, # attention heads
164
+ dim_head=dim_head, # attention head dimension
165
+ attn_dropout=attn_dropout, # attention dropout
166
+ ff_dropout=ff_dropout, # feedforward dropout
167
+ loss_img_weight=loss_img_weight,
168
+ stable=stable,
169
+ rotary_emb=rotary_emb,
170
+ text_embedding=text_embedding,
171
+ fixed_embedding=fixed_embedding,
172
+ loss_cond_weight=loss_cond_weight,
173
+ )
174
+
175
+ self.learning_rate = learning_rate
176
+ self.num_text_tokens = num_text_tokens
177
+ self.num_images = num_images
178
+
179
+ if monitor is not None:
180
+ self.monitor = monitor
181
+
182
+ ignore_keys = []
183
+
184
+ if condition_model_path:
185
+ ignore_keys.append("celle.condition_vae")
186
+
187
+ if vqgan_model_path:
188
+ ignore_keys.append("celle.vae")
189
+
190
+ if ckpt_path is not None:
191
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
192
+
193
+ def init_from_ckpt(self, path, ignore_keys=list()):
194
+ sd = torch.load(path, map_location="cpu")["state_dict"]
195
+ ckpt = sd.copy()
196
+ for k in sd.keys():
197
+ for ik in ignore_keys:
198
+ if k.startswith(ik):
199
+ # print("Deleting key {} from state_dict.".format(k))
200
+ del ckpt[k]
201
+ self.load_state_dict(ckpt, strict=True)
202
+ print(f"Restored from {path}")
203
+
204
+ def forward(self, text, condition, target, return_loss=True):
205
+
206
+ return self.celle(
207
+ text=text, condition=condition, image=target, return_loss=return_loss
208
+ )
209
+
210
+ def get_input(self, batch):
211
+ text = batch["sequence"].squeeze(1)
212
+ condition = batch["nucleus"]
213
+ target = batch[self.image_key]
214
+
215
+ return text, condition, target
216
+
217
+ def get_image_from_logits(self, logits, temperature=0.9):
218
+
219
+ filtered_logits = top_k(logits, thres=0.5)
220
+ sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
221
+
222
+ self.celle.vae.eval()
223
+ out = self.celle.vae.decode(
224
+ sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :]
225
+ - (self.celle.num_text_tokens + self.celle.num_condition_tokens)
226
+ )
227
+
228
+ return out
229
+
230
+ def get_loss(self, text, condition, target):
231
+
232
+ loss_dict = {}
233
+
234
+ loss, loss_dict, logits = self(text, condition, target, return_loss=True)
235
+
236
+ return loss, loss_dict
237
+
238
+ def total_loss(
239
+ self,
240
+ loss,
241
+ loss_dict,
242
+ mode="train",
243
+ ):
244
+
245
+ loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()}
246
+
247
+ for key, value in loss_dict.items():
248
+ self.log(
249
+ key,
250
+ value,
251
+ prog_bar=True,
252
+ logger=True,
253
+ on_step=True,
254
+ on_epoch=True,
255
+ sync_dist=True,
256
+ )
257
+
258
+ return loss
259
+
260
+ def training_step(self, batch, batch_idx):
261
+
262
+ text, condition, target = self.get_input(batch)
263
+ loss, log_dict = self.get_loss(text, condition, target)
264
+
265
+ loss = self.total_loss(loss, log_dict, mode="train")
266
+
267
+ return loss
268
+
269
+ def validation_step(self, batch, batch_idx):
270
+
271
+ with torch.no_grad():
272
+
273
+ text, condition, target = self.get_input(batch)
274
+ loss, log_dict = self.get_loss(text, condition, target)
275
+
276
+ loss = self.total_loss(loss, log_dict, mode="val")
277
+
278
+ return loss
279
+
280
+ def configure_optimizers(self):
281
+
282
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
283
+
284
+ return optimizer
285
+
286
+ def scale_image(self, image):
287
+
288
+ for tensor in image:
289
+ if torch.min(tensor) < 0:
290
+ tensor += -torch.min(tensor)
291
+ else:
292
+ tensor -= torch.min(tensor)
293
+
294
+ tensor /= torch.max(tensor)
295
+
296
+ return image
297
+
298
+ @torch.no_grad()
299
+ def log_images(self, batch, **kwargs):
300
+
301
+ log = []
302
+
303
+ text, condition, target = self.get_input(batch)
304
+ text = text.squeeze(1).to(self.device)
305
+ condition = condition.to(self.device)
306
+
307
+ out = self.celle.generate_images(text=text, condition=condition)
308
+
309
+ log["condition"] = self.scale_image(condition)
310
+ log["output"] = self.scale_image(out)
311
+ if self.image_key == "threshold":
312
+ log["threshold"] = self.scale_image(target)
313
+ log["target"] = self.scale_image(batch["target"])
314
+ else:
315
+ log["target"] = self.scale_image(target)
316
+
317
+ return log
318
+
319
+
320
+ # from https://github.com/CompVis/taming-transformers/blob/master/celle_main.py
321
+
322
+ if __name__ == "__main__":
323
+ # custom parser to specify config files, train, test and debug mode,
324
+ # postfix, resume.
325
+ # `--key value` arguments are interpreted as arguments to the trainer.
326
+ # `nested.key=value` arguments are interpreted as config parameters.
327
+ # configs are merged from left-to-right followed by command line parameters.
328
+
329
+ # model:
330
+ # learning_rate: float
331
+ # target: path to lightning module
332
+ # params:
333
+ # key: value
334
+ # data:
335
+ # target: celle_main.DataModuleFromConfig
336
+ # params:
337
+ # batch_size: int
338
+ # wrap: bool
339
+ # train:
340
+ # target: path to train dataset
341
+ # params:
342
+ # key: value
343
+ # validation:
344
+ # target: path to validation dataset
345
+ # params:
346
+ # key: value
347
+ # test:
348
+ # target: path to test dataset
349
+ # params:
350
+ # key: value
351
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
352
+ # trainer:
353
+ # additional arguments to trainer
354
+ # logger:
355
+ # logger to instantiate
356
+ # modelcheckpoint:
357
+ # modelcheckpoint to instantiate
358
+ # callbacks:
359
+ # callback1:
360
+ # target: importpath
361
+ # params:
362
+ # key: value
363
+
364
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
365
+
366
+ # add cwd for convenience and to make classes in this file available when
367
+ # running as `python celle_main.py`
368
+ # (in particular `celle_main.DataModuleFromConfig`)
369
+ sys.path.append(os.getcwd())
370
+
371
+ parser = get_parser()
372
+ parser = Trainer.add_argparse_args(parser)
373
+
374
+ opt, unknown = parser.parse_known_args()
375
+ if opt.name and opt.resume:
376
+ raise ValueError(
377
+ "-n/--name and -r/--resume cannot be specified both."
378
+ "If you want to resume training in a new log folder, "
379
+ "use -n/--name in combination with --resume_from_checkpoint"
380
+ )
381
+ if opt.resume:
382
+ if not os.path.exists(opt.resume):
383
+ raise ValueError("Cannot find {}".format(opt.resume))
384
+ if os.path.isfile(opt.resume):
385
+ paths = opt.resume.split("/")
386
+ idx = len(paths) - paths[::-1].index("logs") + 1
387
+ logdir = "/".join(paths[:idx])
388
+ ckpt = opt.resume
389
+ else:
390
+ assert os.path.isdir(opt.resume), opt.resume
391
+ logdir = opt.resume.rstrip("/")
392
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
393
+
394
+ opt.resume_from_checkpoint = ckpt
395
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
396
+ opt.base = base_configs + opt.base
397
+ _tmp = logdir.split("/")
398
+ nowname = _tmp[_tmp.index("logs") + 1]
399
+ else:
400
+ if opt.name:
401
+ name = "_" + opt.name
402
+ elif opt.base:
403
+ cfg_fname = os.path.split(opt.base[0])[-1]
404
+ cfg_name = os.path.splitext(cfg_fname)[0]
405
+ name = "_" + cfg_name
406
+ else:
407
+ name = ""
408
+ nowname = now + name + opt.postfix
409
+ logdir = os.path.join("logs", nowname)
410
+
411
+ ckptdir = os.path.join(logdir, "checkpoints")
412
+ cfgdir = os.path.join(logdir, "configs")
413
+ seed_everything(opt.seed)
414
+
415
+ try:
416
+ # init and save configs
417
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
418
+ cli = OmegaConf.from_dotlist(unknown)
419
+ config = OmegaConf.merge(*configs, cli)
420
+ lightning_config = config.pop("lightning", OmegaConf.create())
421
+ # merge trainer cli with config
422
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
423
+ # default to ddp
424
+ # trainer_config["distributed_backend"] = "ddp"
425
+ for k in nondefault_trainer_args(opt):
426
+ trainer_config[k] = getattr(opt, k)
427
+ if not "gpus" in trainer_config:
428
+ del trainer_config["distributed_backend"]
429
+ cpu = True
430
+ else:
431
+ gpuinfo = trainer_config["gpus"]
432
+ print(f"Running on GPUs {gpuinfo}")
433
+ cpu = False
434
+ trainer_opt = argparse.Namespace(**trainer_config)
435
+ lightning_config.trainer = trainer_config
436
+
437
+ # model
438
+ # model = instantiate_from_config(config.model)
439
+ model = instantiate_from_config(config.model)
440
+ # trainer and callbacks
441
+ trainer_kwargs = dict()
442
+
443
+ # default logger configs
444
+ # NOTE wandb < 0.10.0 interferes with shutdown
445
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
446
+ # debugging (wrongly sized pudb ui)
447
+ # thus prefer testtube for now
448
+ default_logger_cfgs = {
449
+ "wandb": {
450
+ "target": "pytorch_lightning.loggers.WandbLogger",
451
+ "params": {
452
+ "name": nowname,
453
+ "save_dir": logdir,
454
+ "offline": opt.debug,
455
+ "id": nowname,
456
+ },
457
+ },
458
+ "testtube": {
459
+ # "target": "pytorch_lightning.loggers.TestTubeLogger",
460
+ "target": "pytorch_lightning.loggers.TensorBoardLogger",
461
+ "params": {
462
+ "name": "testtube",
463
+ "save_dir": logdir,
464
+ },
465
+ },
466
+ }
467
+ default_logger_cfg = default_logger_cfgs["testtube"]
468
+ # logger_cfg = lightning_config.logger or OmegaConf.create()
469
+ try:
470
+ logger_cfg = lightning_config.logger
471
+ except:
472
+ logger_cfg = OmegaConf.create()
473
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
474
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
475
+
476
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
477
+ # specify which metric is used to determine best models
478
+ default_modelckpt_cfg = {
479
+ "checkpoint_callback": {
480
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
481
+ "params": {
482
+ "dirpath": ckptdir,
483
+ "filename": "{epoch:06}",
484
+ "verbose": True,
485
+ "save_last": True,
486
+ },
487
+ }
488
+ }
489
+ if hasattr(model, "monitor"):
490
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
491
+ default_modelckpt_cfg["checkpoint_callback"]["params"][
492
+ "monitor"
493
+ ] = model.monitor
494
+ default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
495
+ try:
496
+ modelckpt_cfg = lightning_config.modelcheckpoint
497
+ except:
498
+ modelckpt_cfg = OmegaConf.create()
499
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
500
+ # trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
501
+
502
+ # add callback which sets up log directory
503
+ default_callbacks_cfg = {
504
+ "setup_callback": {
505
+ "target": "celle_taming_main.SetupCallback",
506
+ "params": {
507
+ "resume": opt.resume,
508
+ "now": now,
509
+ "logdir": logdir,
510
+ "ckptdir": ckptdir,
511
+ "cfgdir": cfgdir,
512
+ "config": config,
513
+ "lightning_config": lightning_config,
514
+ },
515
+ },
516
+ # "image_logger": {
517
+ # "target": "celle_taming_main.ImageLogger",
518
+ # "params": {
519
+ # "batch_frequency": 0,
520
+ # "max_images": 0,
521
+ # "clamp": False,
522
+ # "increase_log_steps": False,
523
+ # },
524
+ # },
525
+ # "learning_rate_logger": {
526
+ # "target": "celle_taming_main.LearningRateMonitor",
527
+ # "params": {
528
+ # "logging_interval": "step",
529
+ # # "log_momentum": True
530
+ # },
531
+ # },
532
+ }
533
+ try:
534
+ callbacks_cfg = lightning_config.callbacks
535
+ except:
536
+ callbacks_cfg = OmegaConf.create()
537
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
538
+ callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
539
+ trainer_kwargs["callbacks"] = [
540
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
541
+ ]
542
+
543
+ trainer = Trainer.from_argparse_args(
544
+ trainer_opt, **trainer_kwargs, profiler="simple"
545
+ )
546
+
547
+ # data
548
+ data = instantiate_from_config(config.data)
549
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
550
+ # calling these ourselves should not be necessary but it is.
551
+ # lightning still takes care of proper multiprocessing though
552
+ data.setup()
553
+ data.prepare_data()
554
+
555
+ # configure learning rate
556
+ bs, lr = config.data.params.batch_size, config.model.learning_rate
557
+
558
+ if not cpu:
559
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
560
+ else:
561
+ ngpu = 1
562
+ try:
563
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
564
+ except:
565
+ accumulate_grad_batches = 1
566
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
567
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
568
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * lr
569
+
570
+ print(
571
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format(
572
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, lr
573
+ )
574
+ )
575
+
576
+ # allow checkpointing via USR1
577
+ def melk(*args, **kwargs):
578
+ # run all checkpoint hooks
579
+ if trainer.global_rank == 0:
580
+ print("Summoning checkpoint.")
581
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
582
+ trainer.save_checkpoint(ckpt_path)
583
+
584
+ def divein(*args, **kwargs):
585
+ if trainer.global_rank == 0:
586
+ import pudb
587
+
588
+ pudb.set_trace()
589
+
590
+ import signal
591
+
592
+ signal.signal(signal.SIGUSR1, melk)
593
+ signal.signal(signal.SIGUSR2, divein)
594
+
595
+ # run
596
+ if opt.train:
597
+ try:
598
+ # model = torch.compile(model, mode="reduce_overhead")
599
+ torch.compile(trainer.fit(model, data), mode="max-autotune")
600
+ except Exception:
601
+ melk()
602
+ raise
603
+ if not opt.no_test and not trainer.interrupted:
604
+ trainer.test(model, data)
605
+ except Exception:
606
+ if opt.debug and trainer.global_rank == 0:
607
+ try:
608
+ import pudb as debugger
609
+ except ImportError:
610
+ import pdb as debugger
611
+ debugger.post_mortem()
612
+ raise
613
+ finally:
614
+ # move newly created debug project to debug_runs
615
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
616
+ dst, name = os.path.split(logdir)
617
+ dst = os.path.join(dst, "debug_runs", name)
618
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
619
+ os.rename(logdir, dst)
celle_taming_main.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib
2
+ from omegaconf import OmegaConf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from dataloader import CellLoader
9
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning import seed_everything
12
+ from pytorch_lightning.trainer import Trainer
13
+ from pytorch_lightning.callbacks import Callback
14
+ from pytorch_lightning.utilities import rank_zero_only
15
+
16
+
17
+ def get_obj_from_str(string, reload=False):
18
+ module, cls = string.rsplit(".", 1)
19
+ if reload:
20
+ module_imp = importlib.import_module(module)
21
+ importlib.reload(module_imp)
22
+ return getattr(importlib.import_module(module, package=None), cls)
23
+
24
+
25
+ def get_parser(**parser_kwargs):
26
+ def str2bool(v):
27
+ if isinstance(v, bool):
28
+ return v
29
+ if v.lower() in ("yes", "true", "t", "y", "1"):
30
+ return True
31
+ elif v.lower() in ("no", "false", "f", "n", "0"):
32
+ return False
33
+ else:
34
+ raise argparse.ArgumentTypeError("Boolean value expected.")
35
+
36
+ parser = argparse.ArgumentParser(**parser_kwargs)
37
+ parser.add_argument(
38
+ "-n",
39
+ "--name",
40
+ type=str,
41
+ const=True,
42
+ default="",
43
+ nargs="?",
44
+ help="postfix for logdir",
45
+ )
46
+ parser.add_argument(
47
+ "-r",
48
+ "--resume",
49
+ type=str,
50
+ const=True,
51
+ default="",
52
+ nargs="?",
53
+ help="resume from logdir or checkpoint in logdir",
54
+ )
55
+ parser.add_argument(
56
+ "-b",
57
+ "--base",
58
+ nargs="*",
59
+ metavar="base_config.yaml",
60
+ help="paths to base configs. Loaded from left-to-right. "
61
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
62
+ default=list(),
63
+ )
64
+ parser.add_argument(
65
+ "-t",
66
+ "--train",
67
+ type=str2bool,
68
+ const=True,
69
+ default=False,
70
+ nargs="?",
71
+ help="train",
72
+ )
73
+ parser.add_argument(
74
+ "--no-test",
75
+ type=str2bool,
76
+ const=True,
77
+ default=False,
78
+ nargs="?",
79
+ help="disable test",
80
+ )
81
+ parser.add_argument(
82
+ "-p", "--project", help="name of new or path to existing project"
83
+ )
84
+ parser.add_argument(
85
+ "-d",
86
+ "--debug",
87
+ type=str2bool,
88
+ nargs="?",
89
+ const=True,
90
+ default=False,
91
+ help="enable post-mortem debugging",
92
+ )
93
+ parser.add_argument(
94
+ "-s",
95
+ "--seed",
96
+ type=int,
97
+ default=42,
98
+ help="seed for seed_everything",
99
+ )
100
+ parser.add_argument(
101
+ "-f",
102
+ "--postfix",
103
+ type=str,
104
+ default="",
105
+ help="post-postfix for default name",
106
+ )
107
+
108
+ return parser
109
+
110
+
111
+ def nondefault_trainer_args(opt):
112
+ parser = argparse.ArgumentParser()
113
+ parser = Trainer.add_argparse_args(parser)
114
+ args = parser.parse_args([])
115
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
116
+
117
+
118
+ def instantiate_from_config(config):
119
+ if not "target" in config:
120
+ raise KeyError("Expected key `target` to instantiate.")
121
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
122
+
123
+
124
+ class WrappedDataset(Dataset):
125
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
126
+
127
+ def __init__(self, dataset):
128
+ self.data = dataset
129
+
130
+ def __len__(self):
131
+ return len(self.data)
132
+
133
+ def __getitem__(self, idx):
134
+ return self.data[idx]
135
+
136
+
137
+ class DataModuleFromConfig(pl.LightningDataModule):
138
+ def __init__(
139
+ self,
140
+ data_csv,
141
+ dataset,
142
+ crop_size=256,
143
+ resize=600,
144
+ batch_size=1,
145
+ sequence_mode="latent",
146
+ vocab="bert",
147
+ text_seq_len=0,
148
+ num_workers=1,
149
+ threshold=False,
150
+ train=True,
151
+ validation=True,
152
+ test=None,
153
+ wrap=False,
154
+ **kwargs,
155
+ ):
156
+ super().__init__()
157
+ self.data_csv = data_csv
158
+ self.dataset = dataset
159
+ self.image_folders = []
160
+ self.crop_size = crop_size
161
+ self.resize = resize
162
+ self.batch_size = batch_size
163
+ self.sequence_mode = sequence_mode
164
+ self.threshold = threshold
165
+ self.text_seq_len = int(text_seq_len)
166
+ self.vocab = vocab
167
+ self.dataset_configs = dict()
168
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
169
+ if train is not None:
170
+ self.dataset_configs["train"] = train
171
+ self.train_dataloader = self._train_dataloader
172
+ if validation is not None:
173
+ self.dataset_configs["validation"] = validation
174
+ self.val_dataloader = self._val_dataloader
175
+ if test is not None:
176
+ self.dataset_configs["test"] = test
177
+ self.test_dataloader = self._test_dataloader
178
+ self.wrap = wrap
179
+
180
+ def prepare_data(self):
181
+ pass
182
+
183
+ def setup(self, stage=None):
184
+ # called on every GPU
185
+ self.cell_dataset_train = CellLoader(
186
+ data_csv=self.data_csv,
187
+ dataset=self.dataset,
188
+ crop_size=self.crop_size,
189
+ split_key="train",
190
+ crop_method="random",
191
+ sequence_mode=None,
192
+ vocab=self.vocab,
193
+ text_seq_len=self.text_seq_len,
194
+ threshold=self.threshold,
195
+ )
196
+
197
+ self.cell_dataset_val = CellLoader(
198
+ data_csv=self.data_csv,
199
+ dataset=self.dataset,
200
+ crop_size=self.crop_size,
201
+ split_key="val",
202
+ crop_method="center",
203
+ sequence_mode=None,
204
+ vocab=self.vocab,
205
+ text_seq_len=self.text_seq_len,
206
+ threshold=self.threshold,
207
+ )
208
+
209
+ def _train_dataloader(self):
210
+ return DataLoader(
211
+ self.cell_dataset_train,
212
+ num_workers=self.num_workers,
213
+ pin_memory=True,
214
+ shuffle=True,
215
+ batch_size=self.batch_size,
216
+ )
217
+
218
+ def _val_dataloader(self):
219
+ return DataLoader(
220
+ self.cell_dataset_val,
221
+ num_workers=self.num_workers,
222
+ pin_memory=True,
223
+ batch_size=self.batch_size,
224
+ )
225
+
226
+ # def _test_dataloader(self):
227
+ # return DataLoader(self.datasets["test"], batch_size=self.batch_size,
228
+ # num_workers=self.num_workers)
229
+
230
+
231
+ class SetupCallback(Callback):
232
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
233
+ super().__init__()
234
+ self.resume = resume
235
+ self.now = now
236
+ self.logdir = logdir
237
+ self.ckptdir = ckptdir
238
+ self.cfgdir = cfgdir
239
+ self.config = config
240
+ self.lightning_config = lightning_config
241
+
242
+ def on_fit_start(self, trainer, pl_module):
243
+ if trainer.global_rank == 0:
244
+ # Create logdirs and save configs
245
+ os.makedirs(self.logdir, exist_ok=True)
246
+ os.makedirs(self.ckptdir, exist_ok=True)
247
+ os.makedirs(self.cfgdir, exist_ok=True)
248
+
249
+ print("Project config")
250
+ print(OmegaConf.to_yaml(self.config))
251
+ OmegaConf.save(
252
+ self.config,
253
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
254
+ )
255
+
256
+ print("Lightning config")
257
+ print(OmegaConf.to_yaml(self.lightning_config))
258
+ OmegaConf.save(
259
+ OmegaConf.create({"lightning": self.lightning_config}),
260
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
261
+ )
262
+
263
+ else:
264
+ # ModelCheckpoint callback created log directory --- remove it
265
+ if not self.resume and os.path.exists(self.logdir):
266
+ dst, name = os.path.split(self.logdir)
267
+ dst = os.path.join(dst, "child_runs", name)
268
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
269
+ try:
270
+ os.rename(self.logdir, dst)
271
+ except FileNotFoundError:
272
+ pass
273
+
274
+
275
+ class ImageLogger(Callback):
276
+ def __init__(
277
+ self, batch_frequency, max_images, clamp=True, increase_log_steps=True
278
+ ):
279
+ super().__init__()
280
+ self.batch_freq = batch_frequency
281
+ self.max_images = max_images
282
+ self.logger_log_images = {
283
+ pl.loggers.WandbLogger: self._wandb,
284
+ # pl.loggers.TestTubeLogger: self._testtube,
285
+ pl.loggers.TensorBoardLogger: self._testtube,
286
+ }
287
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
288
+ if not increase_log_steps:
289
+ self.log_steps = [self.batch_freq]
290
+ self.clamp = clamp
291
+
292
+ @rank_zero_only
293
+ def _wandb(self, pl_module, images, batch_idx, split):
294
+ raise ValueError("No way wandb")
295
+ grids = dict()
296
+ for k in images:
297
+ grid = torchvision.utils.make_grid(images[k])
298
+ grids[f"{split}/{k}"] = wandb.Image(grid)
299
+ pl_module.logger.experiment.log(grids)
300
+
301
+ @rank_zero_only
302
+ def _testtube(self, pl_module, images, batch_idx, split):
303
+ for k in images:
304
+ images[k] -= torch.min(images[k])
305
+ images[k] /= torch.max(images[k])
306
+ grid = torchvision.utils.make_grid(images[k])
307
+ # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
308
+
309
+ tag = f"{split}/{k}"
310
+ pl_module.logger.experiment.add_image(
311
+ tag, grid, global_step=pl_module.global_step
312
+ )
313
+
314
+ @rank_zero_only
315
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
316
+ root = os.path.join(save_dir, "images", split)
317
+ for k in images:
318
+ images[k] -= torch.min(images[k])
319
+ images[k] /= torch.max(images[k])
320
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
321
+
322
+ # grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
323
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
324
+ grid = grid.numpy()
325
+ grid = (grid * 255).astype(np.uint8)
326
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
327
+ k, global_step, current_epoch, batch_idx
328
+ )
329
+ path = os.path.join(root, filename)
330
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
331
+ Image.fromarray(grid).save(path)
332
+
333
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
334
+ if (
335
+ self.check_frequency(batch_idx)
336
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
337
+ and callable(pl_module.log_images)
338
+ and self.max_images > 0
339
+ ):
340
+ logger = type(pl_module.logger)
341
+
342
+ is_train = pl_module.training
343
+ if is_train:
344
+ pl_module.eval()
345
+
346
+ with torch.no_grad():
347
+ images = pl_module.log_images(batch, split=split)
348
+
349
+ for k in images:
350
+ N = min(images[k].shape[0], self.max_images)
351
+ images[k] = images[k][:N]
352
+ if isinstance(images[k], torch.Tensor):
353
+ images[k] = images[k].detach().cpu()
354
+ if self.clamp:
355
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
356
+
357
+ self.log_local(
358
+ pl_module.logger.save_dir,
359
+ split,
360
+ images,
361
+ pl_module.global_step,
362
+ pl_module.current_epoch,
363
+ batch_idx,
364
+ )
365
+
366
+ logger_log_images = self.logger_log_images.get(
367
+ logger, lambda *args, **kwargs: None
368
+ )
369
+ logger_log_images(pl_module, images, pl_module.global_step, split)
370
+
371
+ if is_train:
372
+ pl_module.train()
373
+
374
+ def check_frequency(self, batch_idx):
375
+ if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
376
+ try:
377
+ self.log_steps.pop(0)
378
+ except IndexError:
379
+ pass
380
+ return True
381
+ return False
382
+
383
+ # def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
384
+ # def on_train_batch_end(self, *args, **kwargs):
385
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
386
+ self.log_img(pl_module, batch, batch_idx, split="train")
387
+
388
+ def on_validation_batch_end(
389
+ self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
390
+ ):
391
+ self.log_img(pl_module, batch, batch_idx, split="val")
392
+
393
+
394
+ if __name__ == "__main__":
395
+ # custom parser to specify config files, train, test and debug mode,
396
+ # postfix, resume.
397
+ # `--key value` arguments are interpreted as arguments to the trainer.
398
+ # `nested.key=value` arguments are interpreted as config parameters.
399
+ # configs are merged from left-to-right followed by command line parameters.
400
+
401
+ # model:
402
+ # base_learning_rate: float
403
+ # target: path to lightning module
404
+ # params:
405
+ # key: value
406
+ # data:
407
+ # target: main.DataModuleFromConfig
408
+ # params:
409
+ # batch_size: int
410
+ # wrap: bool
411
+ # train:
412
+ # target: path to train dataset
413
+ # params:
414
+ # key: value
415
+ # validation:
416
+ # target: path to validation dataset
417
+ # params:
418
+ # key: value
419
+ # test:
420
+ # target: path to test dataset
421
+ # params:
422
+ # key: value
423
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
424
+ # trainer:
425
+ # additional arguments to trainer
426
+ # logger:
427
+ # logger to instantiate
428
+ # modelcheckpoint:
429
+ # modelcheckpoint to instantiate
430
+ # callbacks:
431
+ # callback1:
432
+ # target: importpath
433
+ # params:
434
+ # key: value
435
+
436
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
437
+
438
+ # add cwd for convenience and to make classes in this file available when
439
+ # running as `python main.py`
440
+ # (in particular `main.DataModuleFromConfig`)
441
+ sys.path.append(os.getcwd())
442
+
443
+ parser = get_parser()
444
+ parser = Trainer.add_argparse_args(parser)
445
+
446
+ opt, unknown = parser.parse_known_args()
447
+ if opt.name and opt.resume:
448
+ raise ValueError(
449
+ "-n/--name and -r/--resume cannot be specified both."
450
+ "If you want to resume training in a new log folder, "
451
+ "use -n/--name in combination with --resume_from_checkpoint"
452
+ )
453
+ if opt.resume:
454
+ if not os.path.exists(opt.resume):
455
+ raise ValueError("Cannot find {}".format(opt.resume))
456
+ if os.path.isfile(opt.resume):
457
+ paths = opt.resume.split("/")
458
+ idx = len(paths) - paths[::-1].index("logs") + 1
459
+ logdir = "/".join(paths[:idx])
460
+ ckpt = opt.resume
461
+ else:
462
+ assert os.path.isdir(opt.resume), opt.resume
463
+ logdir = opt.resume.rstrip("/")
464
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
465
+
466
+ opt.resume_from_checkpoint = ckpt
467
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
468
+ opt.base = base_configs + opt.base
469
+ _tmp = logdir.split("/")
470
+ nowname = _tmp[_tmp.index("logs") + 1]
471
+ else:
472
+ if opt.name:
473
+ name = "_" + opt.name
474
+ elif opt.base:
475
+ cfg_fname = os.path.split(opt.base[0])[-1]
476
+ cfg_name = os.path.splitext(cfg_fname)[0]
477
+ name = "_" + cfg_name
478
+ else:
479
+ name = ""
480
+ nowname = now + name + opt.postfix
481
+ logdir = os.path.join("logs", nowname)
482
+
483
+ ckptdir = os.path.join(logdir, "checkpoints")
484
+ cfgdir = os.path.join(logdir, "configs")
485
+ seed_everything(opt.seed)
486
+
487
+ try:
488
+ # init and save configs
489
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
490
+ cli = OmegaConf.from_dotlist(unknown)
491
+ config = OmegaConf.merge(*configs, cli)
492
+ lightning_config = config.pop("lightning", OmegaConf.create())
493
+ # merge trainer cli with config
494
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
495
+ # default to ddp
496
+ trainer_config["distributed_backend"] = "ddp"
497
+ trainer_config["replace_sampler_ddp"] = False
498
+ trainer_config["strategy"] = "ddp"
499
+ trainer_config["persistent_workers"] = True
500
+ for k in nondefault_trainer_args(opt):
501
+ trainer_config[k] = getattr(opt, k)
502
+ if not "gpus" in trainer_config:
503
+ del trainer_config["distributed_backend"]
504
+ cpu = True
505
+ else:
506
+ gpuinfo = trainer_config["gpus"]
507
+ print(f"Running on GPUs {gpuinfo}")
508
+ cpu = False
509
+ trainer_opt = argparse.Namespace(**trainer_config)
510
+ lightning_config.trainer = trainer_config
511
+
512
+ # model
513
+ model = instantiate_from_config(config.model)
514
+ # trainer and callbacks
515
+ trainer_kwargs = dict()
516
+
517
+ # default logger configs
518
+ # NOTE wandb < 0.10.0 interferes with shutdown
519
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
520
+ # debugging (wrongly sized pudb ui)
521
+ # thus prefer testtube for now
522
+ default_logger_cfgs = {
523
+ "wandb": {
524
+ "target": "pytorch_lightning.loggers.WandbLogger",
525
+ "params": {
526
+ "name": nowname,
527
+ "save_dir": logdir,
528
+ "offline": opt.debug,
529
+ "id": nowname,
530
+ },
531
+ },
532
+ "testtube": {
533
+ # "target": "pytorch_lightning.loggers.TestTubeLogger",
534
+ "target": "pytorch_lightning.loggers.TensorBoardLogger",
535
+ "params": {
536
+ "name": "testtube",
537
+ "save_dir": logdir,
538
+ },
539
+ },
540
+ }
541
+ default_logger_cfg = default_logger_cfgs["testtube"]
542
+ try:
543
+ logger_cfg = lightning_config.logger
544
+ except:
545
+ logger_cfg = OmegaConf.create()
546
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
547
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
548
+
549
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
550
+ # specify which metric is used to determine best models
551
+ default_modelckpt_cfg = {
552
+ "checkpoint_callback": {
553
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
554
+ "params": {
555
+ "dirpath": ckptdir,
556
+ "filename": "{epoch:06}",
557
+ "verbose": True,
558
+ "save_last": True,
559
+ },
560
+ }
561
+ }
562
+ if hasattr(model, "monitor"):
563
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
564
+ default_modelckpt_cfg["checkpoint_callback"]["params"][
565
+ "monitor"
566
+ ] = model.monitor
567
+ default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
568
+ try:
569
+ modelckpt_cfg = lightning_config.modelcheckpoint
570
+ except:
571
+ modelckpt_cfg = OmegaConf.create()
572
+
573
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
574
+ # trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
575
+
576
+ # loaded_model_callbacks = instantiate_from_config(modelckpt_cfg)
577
+
578
+ # add callback which sets up log directory
579
+ default_callbacks_cfg = {
580
+ "setup_callback": {
581
+ "target": "celle_taming_main.SetupCallback",
582
+ "params": {
583
+ "resume": opt.resume,
584
+ "now": now,
585
+ "logdir": logdir,
586
+ "ckptdir": ckptdir,
587
+ "cfgdir": cfgdir,
588
+ "config": config,
589
+ "lightning_config": lightning_config,
590
+ },
591
+ },
592
+ "image_logger": {
593
+ "target": "celle_taming_main.ImageLogger",
594
+ "params": {
595
+ "batch_frequency": 2000,
596
+ "max_images": 10,
597
+ "clamp": True,
598
+ "increase_log_steps": False,
599
+ },
600
+ },
601
+ "learning_rate_logger": {
602
+ "target": "celle_taming_main.LearningRateMonitor",
603
+ "params": {
604
+ "logging_interval": "step",
605
+ # "log_momentum": True
606
+ },
607
+ },
608
+ }
609
+ try:
610
+ callbacks_cfg = lightning_config.callbacks
611
+ except:
612
+ callbacks_cfg = OmegaConf.create()
613
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
614
+ callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
615
+ trainer_kwargs["callbacks"] = [
616
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
617
+ ]
618
+ # loaded_callbacks = [
619
+ # instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
620
+ # ]
621
+
622
+ # trainer_kwargs["callbacks"] = loaded_callbacks.append(loaded_model_callbacks)
623
+
624
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
625
+
626
+ # data
627
+ data = instantiate_from_config(config.data)
628
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
629
+ # calling these ourselves should not be necessary but it is.
630
+ # lightning still takes care of proper multiprocessing though
631
+ data.prepare_data()
632
+ data.setup()
633
+
634
+ # configure learning rate
635
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
636
+ if not cpu:
637
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
638
+ else:
639
+ ngpu = 1
640
+ try:
641
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
642
+ except:
643
+ accumulate_grad_batches = 1
644
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
645
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
646
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
647
+ print(
648
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
649
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
650
+ )
651
+ )
652
+
653
+ # allow checkpointing via USR1
654
+ def melk(*args, **kwargs):
655
+ # run all checkpoint hooks
656
+ if trainer.global_rank == 0:
657
+ print("Summoning checkpoint.")
658
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
659
+ trainer.save_checkpoint(ckpt_path)
660
+
661
+ def divein(*args, **kwargs):
662
+ if trainer.global_rank == 0:
663
+ import pudb
664
+
665
+ pudb.set_trace()
666
+
667
+ import signal
668
+
669
+ signal.signal(signal.SIGUSR1, melk)
670
+ signal.signal(signal.SIGUSR2, divein)
671
+ # model = torch.compile(model)
672
+ # run
673
+ if opt.train:
674
+ try:
675
+ torch.compile(trainer.fit(model, data))
676
+ except Exception:
677
+ melk()
678
+ raise
679
+ if not opt.no_test and not trainer.interrupted:
680
+ trainer.test(model, data)
681
+ except Exception:
682
+ if opt.debug and trainer.global_rank == 0:
683
+ try:
684
+ import pudb as debugger
685
+ except ImportError:
686
+ import pdb as debugger
687
+ debugger.post_mortem()
688
+ raise
689
+ finally:
690
+ # move newly created debug project to debug_runs
691
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
692
+ dst, name = os.path.split(logdir)
693
+ dst = os.path.join(dst, "debug_runs", name)
694
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
695
+ os.rename(logdir, dst)
dataloader.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image, ImageSequence
4
+ import json
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision import transforms
10
+ import torchvision.transforms.functional as TF
11
+
12
+ from celle.utils import replace_outliers
13
+
14
+ def simple_conversion(seq):
15
+ """Create 26-dim embedding"""
16
+ chars = [
17
+ "-",
18
+ "M",
19
+ "R",
20
+ "H",
21
+ "K",
22
+ "D",
23
+ "E",
24
+ "S",
25
+ "T",
26
+ "N",
27
+ "Q",
28
+ "C",
29
+ "U",
30
+ "G",
31
+ "P",
32
+ "A",
33
+ "V",
34
+ "I",
35
+ "F",
36
+ "Y",
37
+ "W",
38
+ "L",
39
+ "O",
40
+ "X",
41
+ "Z",
42
+ "B",
43
+ "J",
44
+ ]
45
+
46
+ nums = range(len(chars))
47
+
48
+ seqs_x = np.zeros(len(seq))
49
+
50
+ for idx, char in enumerate(seq):
51
+
52
+ lui = chars.index(char)
53
+
54
+ seqs_x[idx] = nums[lui]
55
+
56
+ return torch.tensor([seqs_x]).long()
57
+
58
+
59
+ class CellLoader(Dataset):
60
+ """imports mined opencell images with protein sequence"""
61
+
62
+ def __init__(
63
+ self,
64
+ data_csv=None,
65
+ dataset=None,
66
+ split_key=None,
67
+ resize=600,
68
+ crop_size=600,
69
+ crop_method="random",
70
+ sequence_mode="simple",
71
+ vocab="bert",
72
+ threshold="median",
73
+ text_seq_len=0,
74
+ pad_mode="random",
75
+ ):
76
+ self.data_csv = data_csv
77
+ self.dataset = dataset
78
+ self.image_folders = []
79
+ self.crop_method = crop_method
80
+ self.resize = resize
81
+ self.crop_size = crop_size
82
+ self.sequence_mode = sequence_mode
83
+ self.threshold = threshold
84
+ self.text_seq_len = int(text_seq_len)
85
+ self.vocab = vocab
86
+ self.pad_mode = pad_mode
87
+
88
+ if self.sequence_mode == "embedding" or self.sequence_mode == "onehot":
89
+
90
+
91
+ if self.vocab == "esm1b" or self.vocab == "esm2":
92
+ from esm import Alphabet
93
+
94
+ self.tokenizer = Alphabet.from_architecture(
95
+ "ESM-1b"
96
+ ).get_batch_converter()
97
+ self.text_seq_len += 2
98
+
99
+ if data_csv:
100
+
101
+ data = pd.read_csv(data_csv)
102
+
103
+ self.parent_path = os.path.dirname(data_csv).split(data_csv)[0]
104
+
105
+ if split_key == "train":
106
+ self.data = data[data["split"] == "train"]
107
+ elif split_key == "val":
108
+ self.data = data[data["split"] == "val"]
109
+ else:
110
+ self.data = data
111
+
112
+ self.data = self.data.reset_index(drop=True)
113
+
114
+
115
+
116
+ def __len__(self):
117
+ return len(self.data)
118
+
119
+ def __getitem__(
120
+ self,
121
+ idx,
122
+ get_sequence=True,
123
+ get_images=True,
124
+ ):
125
+ if get_sequence and self.text_seq_len > 0:
126
+
127
+ protein_vector = self.get_protein_vector(idx)
128
+
129
+ else:
130
+ protein_vector = torch.zeros((1, 1))
131
+
132
+ if get_images:
133
+
134
+ nucleus, target, threshold = self.get_images(idx, self.dataset)
135
+ else:
136
+ nucleus, target, threshold = torch.zeros((3, 1))
137
+
138
+ data_dict = {
139
+ "nucleus": nucleus.float(),
140
+ "target": target.float(),
141
+ "threshold": threshold.float(),
142
+ "sequence": protein_vector.long(),
143
+ }
144
+
145
+ return data_dict
146
+
147
+ def get_protein_vector(self, idx):
148
+
149
+ if "protein_sequence" not in self.data.columns:
150
+
151
+ metadata = self.retrieve_metadata(idx)
152
+ protein_sequence = metadata["sequence"]
153
+ else:
154
+ protein_sequence = self.data.iloc[idx]["protein_sequence"]
155
+
156
+ protein_vector = self.tokenize_sequence(protein_sequence)
157
+
158
+ return protein_vector
159
+
160
+ def get_images(self, idx, dataset):
161
+
162
+ if dataset == "HPA":
163
+
164
+ nucleus = Image.open(
165
+ os.path.join(
166
+ self.parent_path, self.data.iloc[idx]["nucleus_image_path"]
167
+ )
168
+ )
169
+
170
+ target = Image.open(
171
+ os.path.join(self.parent_path, self.data.iloc[idx]["target_image_path"])
172
+ )
173
+
174
+ nucleus = TF.to_tensor(nucleus)[0]
175
+ target = TF.to_tensor(target)[0]
176
+
177
+ image = torch.stack([nucleus, target], axis=0)
178
+
179
+ normalize = (0.0655, 0.0650), (0.1732, 0.1208)
180
+
181
+ elif dataset == "OpenCell":
182
+ image = Image.open(
183
+ os.path.join(self.parent_path, self.data.iloc[idx]["image_path"])
184
+ )
185
+ nucleus, target = [page.copy() for page in ImageSequence.Iterator(image)]
186
+
187
+ nucleus = replace_outliers(torch.divide(TF.to_tensor(nucleus), 65536))[0]
188
+ target = replace_outliers(torch.divide(TF.to_tensor(target), 65536))[0]
189
+
190
+ image = torch.stack([nucleus, target], axis=0)
191
+
192
+ normalize = (
193
+ (0.0272, 0.0244),
194
+ (0.0486, 0.0671),
195
+ )
196
+
197
+ # # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914
198
+
199
+ t_forms = [transforms.Resize(self.resize, antialias=None)]
200
+
201
+ if self.crop_method == "random":
202
+
203
+ t_forms.append(transforms.RandomCrop(self.crop_size))
204
+ t_forms.append(transforms.RandomHorizontalFlip(p=0.5))
205
+ t_forms.append(transforms.RandomVerticalFlip(p=0.5))
206
+
207
+ elif self.crop_method == "center":
208
+
209
+ t_forms.append(transforms.CenterCrop(self.crop_size))
210
+
211
+ t_forms.append(transforms.Normalize(normalize[0], normalize[1]))
212
+
213
+ image = transforms.Compose(t_forms)(image)
214
+
215
+ nucleus, target = image
216
+
217
+ nucleus /= torch.abs(nucleus).max()
218
+ target -= target.min()
219
+ target /= target.max()
220
+
221
+ nucleus = nucleus.unsqueeze(0)
222
+ target = target.unsqueeze(0)
223
+
224
+ threshold = target
225
+
226
+ if self.threshold == "mean":
227
+
228
+ threshold = 1.0 * (threshold > (torch.mean(threshold)))
229
+
230
+ elif self.threshold == "median":
231
+
232
+ threshold = 1.0 * (threshold > (torch.median(threshold)))
233
+
234
+ elif self.threshold == "1090_IQR":
235
+
236
+ p10 = torch.quantile(threshold, 0.1, None)
237
+ p90 = torch.quantile(threshold, 0.9, None)
238
+ threshold = torch.clip(threshold, p10, p90)
239
+
240
+ nucleus = torch.nan_to_num(nucleus, 0.0, 1.0, 0.0)
241
+ target = torch.nan_to_num(target, 0.0, 1.0, 0.0)
242
+ threshold = torch.nan_to_num(threshold, 0.0, 1.0, 0.0)
243
+
244
+ return nucleus, target, threshold
245
+
246
+ def retrieve_metadata(self, idx):
247
+ with open(
248
+ os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"])
249
+ ) as f:
250
+ metadata = json.load(f)
251
+
252
+ return metadata
253
+
254
+ def tokenize_sequence(self, protein_sequence):
255
+
256
+ pad_token = 0
257
+
258
+ if self.sequence_mode == "simple":
259
+ protein_vector = simple_conversion(protein_sequence)
260
+
261
+ elif self.sequence_mode == "center":
262
+ protein_sequence = protein_sequence.center(self.text_seq_length, "-")
263
+ protein_vector = simple_conversion(protein_sequence)
264
+
265
+ elif self.sequence_mode == "alternating":
266
+ protein_sequence = protein_sequence.center(self.text_seq_length, "-")
267
+ protein_sequence = protein_sequence[::18]
268
+ protein_sequence = protein_sequence.center(
269
+ int(self.text_seq_length / 18) + 1, "-"
270
+ )
271
+ protein_vector = simple_conversion(protein_sequence)
272
+
273
+
274
+ elif self.sequence_mode == "embedding":
275
+
276
+ if self.vocab == "esm1b" or self.vocab == "esm2":
277
+ pad_token = 1
278
+ protein_vector = self.tokenizer([("", protein_sequence)])[-1]
279
+
280
+ if protein_vector.shape[-1] < self.text_seq_len:
281
+
282
+ diff = self.text_seq_len - protein_vector.shape[-1]
283
+
284
+ if self.pad_mode == "end":
285
+ protein_vector = torch.nn.functional.pad(
286
+ protein_vector, (0, diff), "constant", pad_token
287
+ )
288
+ elif self.pad_mode == "random":
289
+ split = diff - np.random.randint(0, diff + 1)
290
+
291
+ protein_vector = torch.cat(
292
+ [torch.ones(1, split) * 0, protein_vector], dim=1
293
+ )
294
+
295
+ protein_vector = torch.nn.functional.pad(
296
+ protein_vector, (0, diff - split), "constant", pad_token
297
+ )
298
+
299
+ elif protein_vector.shape[-1] > self.text_seq_len:
300
+ start_int = np.random.randint(
301
+ 0, protein_vector.shape[-1] - self.text_seq_len
302
+ )
303
+
304
+ protein_vector = protein_vector[
305
+ :, start_int : start_int + self.text_seq_len
306
+ ]
307
+
308
+ return protein_vector.long()
images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg ADDED
images/Armadillo repeat-containing X-linked protein 5 protein.jpg ADDED
prediction.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.chdir('..')
3
+ from dataloader import CellLoader
4
+ from celle_main import instantiate_from_config
5
+ from omegaconf import OmegaConf
6
+
7
+ def run_sequence_prediction(
8
+ sequence_input,
9
+ nucleus_image,
10
+ protein_image,
11
+ model_ckpt_path,
12
+ model_config_path,
13
+ device
14
+ ):
15
+ """
16
+ Run Celle model with provided inputs and display results.
17
+
18
+ :param sequence: Path to sequence file
19
+ :param nucleus_image_path: Path to nucleus image
20
+ :param protein_image_path: Path to protein image (optional)
21
+ :param model_ckpt_path: Path to model checkpoint
22
+ :param model_config_path: Path to model config
23
+ """
24
+
25
+ # Instantiate dataset object
26
+ dataset = CellLoader(
27
+ sequence_mode="embedding",
28
+ vocab="esm2",
29
+ split_key="val",
30
+ crop_method="center",
31
+ resize=600,
32
+ crop_size=256,
33
+ text_seq_len=1000,
34
+ pad_mode="end",
35
+ threshold="median",
36
+ )
37
+
38
+ # Check if sequence is provided and valid
39
+ if len(sequence_input) == 0:
40
+ raise ValueError("Sequence must be provided.")
41
+
42
+ if "<mask>" not in sequence_input:
43
+ print("Warning: Sequence does not contain any masked positions to predict.")
44
+
45
+ # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
46
+ sequence = dataset.tokenize_sequence(sequence_input)
47
+
48
+ # Load model config and set ckpt_path if not provided in config
49
+ config = OmegaConf.load(model_config_path)
50
+ if config["model"]["params"]["ckpt_path"] is None:
51
+ config["model"]["params"]["ckpt_path"] = model_ckpt_path
52
+
53
+ # Set condition_model_path and vqgan_model_path to None
54
+ config["model"]["params"]["condition_model_path"] = None
55
+ config["model"]["params"]["vqgan_model_path"] = None
56
+
57
+ # Instantiate model from config and move to device
58
+ model = instantiate_from_config(config).to(device)
59
+
60
+ # Sample from model using provided sequence and nucleus image
61
+ _, predicted_sequence, _ = model.celle.sample_text(
62
+ text=sequence,
63
+ condition=nucleus_image,
64
+ image=protein_image,
65
+ force_aas=True,
66
+ timesteps=1,
67
+ temperature=1,
68
+ progress=True,
69
+ )
70
+
71
+ formatted_predicted_sequence = ""
72
+
73
+ for i in range(min(len(predicted_sequence), len(sequence))):
74
+ if predicted_sequence[i] != sequence[i]:
75
+ formatted_predicted_sequence += f"**{predicted_sequence[i]}**"
76
+ else:
77
+ formatted_predicted_sequence += predicted_sequence[i]
78
+
79
+ if len(predicted_sequence) > len(sequence):
80
+ formatted_predicted_sequence += f"**{predicted_sequence[len(sequence):]}**"
81
+
82
+ return formatted_predicted_sequence
taming/lr_scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n):
33
+ return self.schedule(n)
34
+
taming/models/cond_transformer.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ from main import instantiate_from_config
7
+ from taming.modules.util import SOSProvider
8
+
9
+
10
+ def disabled_train(self, mode=True):
11
+ """Overwrite model.train with this function to make sure train/eval mode
12
+ does not change anymore."""
13
+ return self
14
+
15
+
16
+ class Net2NetTransformer(pl.LightningModule):
17
+ def __init__(self,
18
+ transformer_config,
19
+ first_stage_config,
20
+ cond_stage_config,
21
+ permuter_config=None,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ first_stage_key="image",
25
+ cond_stage_key="depth",
26
+ downsample_cond_size=-1,
27
+ pkeep=1.0,
28
+ sos_token=0,
29
+ unconditional=False,
30
+ ):
31
+ super().__init__()
32
+ self.be_unconditional = unconditional
33
+ self.sos_token = sos_token
34
+ self.first_stage_key = first_stage_key
35
+ self.cond_stage_key = cond_stage_key
36
+ self.init_first_stage_from_ckpt(first_stage_config)
37
+ self.init_cond_stage_from_ckpt(cond_stage_config)
38
+ if permuter_config is None:
39
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40
+ self.permuter = instantiate_from_config(config=permuter_config)
41
+ self.transformer = instantiate_from_config(config=transformer_config)
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+ self.downsample_cond_size = downsample_cond_size
46
+ self.pkeep = pkeep
47
+
48
+ def init_from_ckpt(self, path, ignore_keys=list()):
49
+ sd = torch.load(path, map_location="cpu")["state_dict"]
50
+ for k in sd.keys():
51
+ for ik in ignore_keys:
52
+ if k.startswith(ik):
53
+ self.print("Deleting key {} from state_dict.".format(k))
54
+ del sd[k]
55
+ self.load_state_dict(sd, strict=False)
56
+ print(f"Restored from {path}")
57
+
58
+ def init_first_stage_from_ckpt(self, config):
59
+ model = instantiate_from_config(config)
60
+ model = model.eval()
61
+ model.train = disabled_train
62
+ self.first_stage_model = model
63
+
64
+ def init_cond_stage_from_ckpt(self, config):
65
+ if config == "__is_first_stage__":
66
+ print("Using first stage also as cond stage.")
67
+ self.cond_stage_model = self.first_stage_model
68
+ elif config == "__is_unconditional__" or self.be_unconditional:
69
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70
+ f"Prepending {self.sos_token} as a sos token.")
71
+ self.be_unconditional = True
72
+ self.cond_stage_key = self.first_stage_key
73
+ self.cond_stage_model = SOSProvider(self.sos_token)
74
+ else:
75
+ model = instantiate_from_config(config)
76
+ model = model.eval()
77
+ model.train = disabled_train
78
+ self.cond_stage_model = model
79
+
80
+ def forward(self, x, c):
81
+ # one step to produce the logits
82
+ # x = target
83
+ # c = nucleus
84
+ _, z_indices = self.encode_to_z(x)
85
+ _, c_indices = self.encode_to_c(c)
86
+
87
+ if self.training and self.pkeep < 1.0:
88
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
89
+ device=z_indices.device))
90
+ mask = mask.round().to(dtype=torch.int64)
91
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
92
+ a_indices = mask*z_indices+(1-mask)*r_indices
93
+ else:
94
+ a_indices = z_indices
95
+
96
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
97
+
98
+ # target includes all sequence elements (no need to handle first one
99
+ # differently because we are conditioning)
100
+ target = z_indices
101
+ # make the prediction
102
+ logits, _ = self.transformer(cz_indices[:, :-1])
103
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
104
+ logits = logits[:, c_indices.shape[1]-1:]
105
+
106
+ return logits, target
107
+
108
+ def top_k_logits(self, logits, k):
109
+ v, ix = torch.topk(logits, k)
110
+ out = logits.clone()
111
+ out[out < v[..., [-1]]] = -float('Inf')
112
+ return out
113
+
114
+ @torch.no_grad()
115
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
116
+ callback=lambda k: None):
117
+ x = torch.cat((c,x),dim=1)
118
+ block_size = self.transformer.get_block_size()
119
+ assert not self.transformer.training
120
+ if self.pkeep <= 0.0:
121
+ # one pass suffices since input is pure noise anyway
122
+ assert len(x.shape)==2
123
+ noise_shape = (x.shape[0], steps-1)
124
+ #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
125
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
126
+ x = torch.cat((x,noise),dim=1)
127
+ logits, _ = self.transformer(x)
128
+ # take all logits for now and scale by temp
129
+ logits = logits / temperature
130
+ # optionally crop probabilities to only the top k options
131
+ if top_k is not None:
132
+ logits = self.top_k_logits(logits, top_k)
133
+ # apply softmax to convert to probabilities
134
+ probs = F.softmax(logits, dim=-1)
135
+ # sample from the distribution or take the most likely
136
+ if sample:
137
+ shape = probs.shape
138
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
139
+ ix = torch.multinomial(probs, num_samples=1)
140
+ probs = probs.reshape(shape[0],shape[1],shape[2])
141
+ ix = ix.reshape(shape[0],shape[1])
142
+ else:
143
+ _, ix = torch.topk(probs, k=1, dim=-1)
144
+ # cut off conditioning
145
+ x = ix[:, c.shape[1]-1:]
146
+ else:
147
+ for k in range(steps):
148
+ callback(k)
149
+ assert x.size(1) <= block_size # make sure model can see conditioning
150
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
151
+ logits, _ = self.transformer(x_cond)
152
+ # pluck the logits at the final step and scale by temperature
153
+ logits = logits[:, -1, :] / temperature
154
+ # optionally crop probabilities to only the top k options
155
+ if top_k is not None:
156
+ logits = self.top_k_logits(logits, top_k)
157
+ # apply softmax to convert to probabilities
158
+ probs = F.softmax(logits, dim=-1)
159
+ # sample from the distribution or take the most likely
160
+ if sample:
161
+ ix = torch.multinomial(probs, num_samples=1)
162
+ else:
163
+ _, ix = torch.topk(probs, k=1, dim=-1)
164
+ # append to the sequence and continue
165
+ x = torch.cat((x, ix), dim=1)
166
+ # cut off conditioning
167
+ x = x[:, c.shape[1]:]
168
+ return x
169
+
170
+ @torch.no_grad()
171
+ def encode_to_z(self, x):
172
+ quant_z, _, info = self.first_stage_model.encode(x)
173
+ indices = info[2].view(quant_z.shape[0], -1)
174
+ indices = self.permuter(indices)
175
+ return quant_z, indices
176
+
177
+ @torch.no_grad()
178
+ def encode_to_c(self, c):
179
+ if self.downsample_cond_size > -1:
180
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
181
+
182
+ #quant_c, _, info = self.cond_stage_model.encode(x)
183
+ #indices = info[2].view(quant_c.shape[0], -1)
184
+ #indices = self.permuter(indices)
185
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
186
+ if len(indices.shape) != 2:
187
+ indices = indices.view(c.shape[0], -1)
188
+ return quant_c, indices
189
+
190
+ @torch.no_grad()
191
+ def decode_to_img(self, index, zshape):
192
+ index = self.permuter(index, reverse=True)
193
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
194
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
195
+ index.reshape(-1), shape=bhwc)
196
+ x = self.first_stage_model.decode(quant_z)
197
+ return x
198
+
199
+ @torch.no_grad()
200
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
201
+ log = dict()
202
+
203
+ N = 4
204
+ if lr_interface:
205
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
206
+ else:
207
+ x, c = self.get_xc(batch, N)
208
+ x = x.to(device=self.device)
209
+ c = c.to(device=self.device)
210
+
211
+ quant_z, z_indices = self.encode_to_z(x)
212
+ quant_c, c_indices = self.encode_to_c(c)
213
+
214
+ # create a "half"" sample
215
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
216
+ index_sample = self.sample(z_start_indices, c_indices,
217
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
218
+ temperature=temperature if temperature is not None else 1.0,
219
+ sample=True,
220
+ top_k=top_k if top_k is not None else 100,
221
+ callback=callback if callback is not None else lambda k: None)
222
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
223
+
224
+ # sample
225
+ z_start_indices = z_indices[:, :0]
226
+ index_sample = self.sample(z_start_indices, c_indices,
227
+ steps=z_indices.shape[1],
228
+ temperature=temperature if temperature is not None else 1.0,
229
+ sample=True,
230
+ top_k=top_k if top_k is not None else 100,
231
+ callback=callback if callback is not None else lambda k: None)
232
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
233
+
234
+ # det sample
235
+ z_start_indices = z_indices[:, :0]
236
+ index_sample = self.sample(z_start_indices, c_indices,
237
+ steps=z_indices.shape[1],
238
+ sample=False,
239
+ callback=callback if callback is not None else lambda k: None)
240
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
241
+
242
+ # reconstruction
243
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
244
+
245
+ log["inputs"] = x
246
+ log["reconstructions"] = x_rec
247
+
248
+ if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target":
249
+ cond_rec = self.cond_stage_model.decode(quant_c)
250
+ if self.cond_stage_key == "segmentation":
251
+ # get image from segmentation mask
252
+ num_classes = cond_rec.shape[1]
253
+
254
+ c = torch.argmax(c, dim=1, keepdim=True)
255
+ c = F.one_hot(c, num_classes=num_classes)
256
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
257
+ c = self.cond_stage_model.to_rgb(c)
258
+
259
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
260
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
261
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
262
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
263
+ log["conditioning_rec"] = cond_rec
264
+ log["conditioning"] = c
265
+
266
+ log["samples_half"] = x_sample
267
+ log["samples_nopix"] = x_sample_nopix
268
+ log["samples_det"] = x_sample_det
269
+ return log
270
+
271
+ def get_input(self, key, batch):
272
+ x = batch[key]
273
+ if len(x.shape) == 3:
274
+ x = x[..., None]
275
+ #if len(x.shape) == 4:
276
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
277
+ if x.dtype == torch.double:
278
+ x = x.float()
279
+ return x
280
+
281
+ def get_xc(self, batch, N=None):
282
+ x = self.get_input(self.first_stage_key, batch)
283
+ c = self.get_input(self.cond_stage_key, batch)
284
+ if N is not None:
285
+ x = x[:N]
286
+ c = c[:N]
287
+ return x, c
288
+
289
+ def shared_step(self, batch):
290
+ x, c = self.get_xc(batch)
291
+ logits, target = self(x, c)
292
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
293
+ return loss
294
+
295
+ def training_step(self, batch, batch_idx):
296
+ loss = self.shared_step(batch)
297
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
298
+ return loss
299
+
300
+ def validation_step(self, batch, batch_idx):
301
+ loss = self.shared_step(batch)
302
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
303
+ return loss
304
+
305
+ def configure_optimizers(self):
306
+ """
307
+ Following minGPT:
308
+ This long function is unfortunately doing something very simple and is being very defensive:
309
+ We are separating out all parameters of the model into two buckets: those that will experience
310
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
311
+ We are then returning the PyTorch optimizer object.
312
+ """
313
+ # separate out all parameters to those that will and won't experience regularizing weight decay
314
+ decay = set()
315
+ no_decay = set()
316
+ whitelist_weight_modules = (torch.nn.Linear, )
317
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
318
+ for mn, m in self.transformer.named_modules():
319
+ for pn, p in m.named_parameters():
320
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
321
+
322
+ if pn.endswith('bias'):
323
+ # all biases will not be decayed
324
+ no_decay.add(fpn)
325
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
326
+ # weights of whitelist modules will be weight decayed
327
+ decay.add(fpn)
328
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
329
+ # weights of blacklist modules will NOT be weight decayed
330
+ no_decay.add(fpn)
331
+
332
+ # special case the position embedding parameter in the root GPT module as not decayed
333
+ no_decay.add('pos_emb')
334
+
335
+ # validate that we considered every parameter
336
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
337
+ inter_params = decay & no_decay
338
+ union_params = decay | no_decay
339
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
340
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
341
+ % (str(param_dict.keys() - union_params), )
342
+
343
+ # create the pytorch optimizer object
344
+ optim_groups = [
345
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
346
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
347
+ ]
348
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
349
+ return optimizer
taming/models/dummy_cond_stage.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+
4
+ class DummyCondStage:
5
+ def __init__(self, conditional_key):
6
+ self.conditional_key = conditional_key
7
+ self.train = None
8
+
9
+ def eval(self):
10
+ return self
11
+
12
+ @staticmethod
13
+ def encode(c: Tensor):
14
+ return c, None, (None, None, c)
15
+
16
+ @staticmethod
17
+ def decode(c: Tensor):
18
+ return c
19
+
20
+ @staticmethod
21
+ def to_rgb(c: Tensor):
22
+ return c
taming/models/vqgan.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import pytorch_lightning as pl
4
+
5
+ from celle_taming_main import instantiate_from_config
6
+
7
+ from taming.modules.diffusionmodules.model import Encoder, Decoder
8
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+ from taming.modules.vqvae.quantize import GumbelQuantize
10
+ from taming.modules.vqvae.quantize import EMAVectorQuantizer
11
+
12
+
13
+ class VQModel(pl.LightningModule):
14
+ def __init__(
15
+ self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ remap=None,
26
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
27
+ ):
28
+ super().__init__()
29
+ self.image_key = image_key
30
+ self.encoder = Encoder(**ddconfig)
31
+ self.decoder = Decoder(**ddconfig)
32
+ self.loss = instantiate_from_config(lossconfig)
33
+ self.quantize = VectorQuantizer(
34
+ n_embed,
35
+ embed_dim,
36
+ beta=0.25,
37
+ remap=remap,
38
+ sane_index_shape=sane_index_shape,
39
+ )
40
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
41
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
42
+ if ckpt_path is not None:
43
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
44
+ self.image_key = image_key
45
+ if colorize_nlabels is not None:
46
+ assert type(colorize_nlabels) == int
47
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
48
+ if monitor is not None:
49
+ self.monitor = monitor
50
+
51
+ def init_from_ckpt(self, path, ignore_keys=list()):
52
+ sd = torch.load(path, map_location="cpu")["state_dict"]
53
+ keys = list(sd.keys())
54
+ for k in keys:
55
+ for ik in ignore_keys:
56
+ if k.startswith(ik):
57
+ print("Deleting key {} from state_dict.".format(k))
58
+ del sd[k]
59
+ self.load_state_dict(sd, strict=False)
60
+ print(f"Restored from {path}")
61
+
62
+ def encode(self, x):
63
+ h = self.encoder(x)
64
+ h = self.quant_conv(h)
65
+ quant, emb_loss, info = self.quantize(h)
66
+ return quant, emb_loss, info
67
+
68
+ def decode(self, quant):
69
+ quant = self.post_quant_conv(quant)
70
+ dec = self.decoder(quant)
71
+ return dec
72
+
73
+ def decode_code(self, code_b):
74
+ quant_b = self.quantize.embed_code(code_b)
75
+ dec = self.decode(quant_b)
76
+ return dec
77
+
78
+ def forward(self, input):
79
+ quant, diff, _ = self.encode(input)
80
+ dec = self.decode(quant)
81
+ return dec, diff
82
+
83
+ def get_input(self, batch, k):
84
+
85
+ if k == "mixed":
86
+ keys = ["nucleus", "target"]
87
+ index = torch.randint(low=0, high=2, size=(1,), dtype=int).item()
88
+ k = keys[index]
89
+
90
+ x = batch[k]
91
+ if len(x.shape) == 3:
92
+ x = x[..., None]
93
+
94
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
95
+ return x
96
+
97
+ def training_step(self, batch, batch_idx=None, optimizer_idx=0):
98
+
99
+ if type(batch) == dict:
100
+
101
+ x = self.get_input(batch, self.image_key)
102
+
103
+ else:
104
+ x = batch
105
+
106
+ xrec, qloss = self(
107
+ x,
108
+ )
109
+
110
+ if optimizer_idx == 0:
111
+ # autoencode
112
+ aeloss, log_dict_ae = self.loss(
113
+ qloss,
114
+ x,
115
+ xrec,
116
+ optimizer_idx,
117
+ self.global_step,
118
+ last_layer=self.get_last_layer(),
119
+ split="train",
120
+ )
121
+
122
+ self.log(
123
+ "train/aeloss",
124
+ aeloss,
125
+ prog_bar=True,
126
+ logger=True,
127
+ on_step=True,
128
+ on_epoch=True,
129
+ sync_dist=True,
130
+ )
131
+ self.log_dict(
132
+ log_dict_ae,
133
+ prog_bar=False,
134
+ logger=True,
135
+ on_step=True,
136
+ on_epoch=True,
137
+ sync_dist=True,
138
+ )
139
+ return aeloss
140
+
141
+ if optimizer_idx == 1:
142
+ # discriminator
143
+ discloss, log_dict_disc = self.loss(
144
+ qloss,
145
+ x,
146
+ xrec,
147
+ optimizer_idx,
148
+ self.global_step,
149
+ last_layer=self.get_last_layer(),
150
+ split="train",
151
+ )
152
+ self.log(
153
+ "train/discloss",
154
+ discloss,
155
+ prog_bar=True,
156
+ logger=True,
157
+ on_step=True,
158
+ on_epoch=True,
159
+ sync_dist=True,
160
+ )
161
+ self.log_dict(
162
+ log_dict_disc,
163
+ prog_bar=False,
164
+ logger=True,
165
+ on_step=True,
166
+ on_epoch=True,
167
+ sync_dist=True,
168
+ )
169
+ return discloss
170
+
171
+ def validation_step(self, batch, batch_idx):
172
+
173
+ if type(batch) == dict:
174
+
175
+ x = self.get_input(batch, self.image_key)
176
+
177
+ else:
178
+ x = batch
179
+
180
+ xrec, qloss = self(x)
181
+ aeloss, log_dict_ae = self.loss(
182
+ qloss,
183
+ x,
184
+ xrec,
185
+ 0,
186
+ self.global_step,
187
+ last_layer=self.get_last_layer(),
188
+ split="val",
189
+ )
190
+
191
+ discloss, log_dict_disc = self.loss(
192
+ qloss,
193
+ x,
194
+ xrec,
195
+ 1,
196
+ self.global_step,
197
+ last_layer=self.get_last_layer(),
198
+ split="val",
199
+ )
200
+ # rec_loss = log_dict_ae["val/rec_loss"]
201
+ # self.log(
202
+ # "val/rec_loss",
203
+ # rec_loss,
204
+ # prog_bar=True,
205
+ # logger=True,
206
+ # on_step=True,
207
+ # on_epoch=True,
208
+ # sync_dist=True,
209
+ # )
210
+ # self.log(
211
+ # "val/aeloss",
212
+ # aeloss,
213
+ # prog_bar=True,
214
+ # logger=True,
215
+ # on_step=True,
216
+ # on_epoch=True,
217
+ # sync_dist=True,
218
+ # )
219
+
220
+ for key, value in log_dict_disc.items():
221
+ if key in log_dict_ae:
222
+ log_dict_ae[key].extend(value)
223
+ else:
224
+ log_dict_ae[key] = value
225
+
226
+ self.log_dict(log_dict_ae, sync_dist=True)
227
+ return self.log_dict
228
+
229
+ def configure_optimizers(self):
230
+ lr = self.learning_rate
231
+ opt_ae = torch.optim.Adam(
232
+ list(self.encoder.parameters())
233
+ + list(self.decoder.parameters())
234
+ + list(self.quantize.parameters())
235
+ + list(self.quant_conv.parameters())
236
+ + list(self.post_quant_conv.parameters()),
237
+ lr=lr,
238
+ betas=(0.5, 0.9),
239
+ )
240
+ opt_disc = torch.optim.Adam(
241
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
242
+ )
243
+ return [opt_ae, opt_disc], []
244
+
245
+ def get_last_layer(self):
246
+ return self.decoder.conv_out.weight
247
+
248
+ def log_images(self, batch, **kwargs):
249
+ log = dict()
250
+ x = self.get_input(batch, self.image_key)
251
+ x = x.to(self.device)
252
+ xrec, _ = self(x)
253
+ if x.shape[1] > 3:
254
+ # colorize with random projection
255
+ assert xrec.shape[1] > 3
256
+ x = self.to_rgb(x)
257
+ xrec = self.to_rgb(xrec)
258
+ log["inputs"] = x
259
+ log["reconstructions"] = xrec
260
+ return log
261
+
262
+ def to_rgb(self, x):
263
+ assert self.image_key == "segmentation"
264
+ if not hasattr(self, "colorize"):
265
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
266
+ x = F.conv2d(x, weight=self.colorize)
267
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
268
+ return x
269
+
270
+
271
+ class VQSegmentationModel(VQModel):
272
+ def __init__(self, n_labels, *args, **kwargs):
273
+ super().__init__(*args, **kwargs)
274
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
275
+
276
+ def configure_optimizers(self):
277
+ lr = self.learning_rate
278
+ opt_ae = torch.optim.Adam(
279
+ list(self.encoder.parameters())
280
+ + list(self.decoder.parameters())
281
+ + list(self.quantize.parameters())
282
+ + list(self.quant_conv.parameters())
283
+ + list(self.post_quant_conv.parameters()),
284
+ lr=lr,
285
+ betas=(0.5, 0.9),
286
+ )
287
+ return opt_ae
288
+
289
+ def training_step(self, batch, batch_idx):
290
+ x = self.get_input(batch, self.image_key)
291
+ xrec, qloss = self(x)
292
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
293
+ self.log_dict(
294
+ log_dict_ae,
295
+ prog_bar=False,
296
+ logger=True,
297
+ on_step=True,
298
+ on_epoch=True,
299
+ sync_dist=True,
300
+ )
301
+ return aeloss
302
+
303
+ def validation_step(self, batch, batch_idx):
304
+ x = self.get_input(batch, self.image_key)
305
+ xrec, qloss = self(x)
306
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
307
+ self.log_dict(
308
+ log_dict_ae,
309
+ prog_bar=False,
310
+ logger=True,
311
+ on_step=True,
312
+ on_epoch=True,
313
+ sync_dist=True,
314
+ )
315
+ total_loss = log_dict_ae["val/total_loss"]
316
+ self.log(
317
+ "val/total_loss",
318
+ total_loss,
319
+ prog_bar=True,
320
+ logger=True,
321
+ on_step=True,
322
+ on_epoch=True,
323
+ sync_dist=True,
324
+ )
325
+ return aeloss
326
+
327
+ @torch.no_grad()
328
+ def log_images(self, batch, **kwargs):
329
+ log = dict()
330
+ x = self.get_input(batch, self.image_key)
331
+ x = x.to(self.device)
332
+ xrec, _ = self(x)
333
+ if x.shape[1] > 3:
334
+ # colorize with random projection
335
+ assert xrec.shape[1] > 3
336
+ # convert logits to indices
337
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
338
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
339
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
340
+ x = self.to_rgb(x)
341
+ xrec = self.to_rgb(xrec)
342
+ log["inputs"] = x
343
+ log["reconstructions"] = xrec
344
+ return log
345
+
346
+
347
+ class VQNoDiscModel(VQModel):
348
+ def __init__(
349
+ self,
350
+ ddconfig,
351
+ lossconfig,
352
+ n_embed,
353
+ embed_dim,
354
+ ckpt_path=None,
355
+ ignore_keys=[],
356
+ image_key="image",
357
+ colorize_nlabels=None,
358
+ ):
359
+ super().__init__(
360
+ ddconfig=ddconfig,
361
+ lossconfig=lossconfig,
362
+ n_embed=n_embed,
363
+ embed_dim=embed_dim,
364
+ ckpt_path=ckpt_path,
365
+ ignore_keys=ignore_keys,
366
+ image_key=image_key,
367
+ colorize_nlabels=colorize_nlabels,
368
+ )
369
+
370
+ def training_step(self, batch, batch_idx):
371
+ x = self.get_input(batch, self.image_key)
372
+ xrec, qloss = self(x)
373
+ # autoencode
374
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
375
+ output = pl.TrainResult(minimize=aeloss)
376
+ output.log(
377
+ "train/aeloss",
378
+ aeloss,
379
+ prog_bar=True,
380
+ logger=True,
381
+ on_step=True,
382
+ on_epoch=True,
383
+ )
384
+ output.log_dict(
385
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
386
+ )
387
+ return output
388
+
389
+ def validation_step(self, batch, batch_idx):
390
+ x = self.get_input(batch, self.image_key)
391
+ xrec, qloss = self(x)
392
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
393
+ rec_loss = log_dict_ae["val/rec_loss"]
394
+ output = pl.EvalResult(checkpoint_on=rec_loss)
395
+ output.log(
396
+ "val/rec_loss",
397
+ rec_loss,
398
+ prog_bar=True,
399
+ logger=True,
400
+ on_step=True,
401
+ on_epoch=True,
402
+ )
403
+ output.log(
404
+ "val/aeloss",
405
+ aeloss,
406
+ prog_bar=True,
407
+ logger=True,
408
+ on_step=True,
409
+ on_epoch=True,
410
+ )
411
+ output.log_dict(log_dict_ae)
412
+
413
+ return output
414
+
415
+ def configure_optimizers(self):
416
+ optimizer = torch.optim.Adam(
417
+ list(self.encoder.parameters())
418
+ + list(self.decoder.parameters())
419
+ + list(self.quantize.parameters())
420
+ + list(self.quant_conv.parameters())
421
+ + list(self.post_quant_conv.parameters()),
422
+ lr=self.learning_rate,
423
+ betas=(0.5, 0.9),
424
+ )
425
+ return optimizer
426
+
427
+
428
+ class GumbelVQ(VQModel):
429
+ def __init__(
430
+ self,
431
+ ddconfig,
432
+ lossconfig,
433
+ n_embed,
434
+ embed_dim,
435
+ temperature_scheduler_config,
436
+ ckpt_path=None,
437
+ ignore_keys=[],
438
+ image_key="image",
439
+ colorize_nlabels=None,
440
+ monitor=None,
441
+ kl_weight=1e-8,
442
+ remap=None,
443
+ ):
444
+
445
+ z_channels = ddconfig["z_channels"]
446
+ super().__init__(
447
+ ddconfig,
448
+ lossconfig,
449
+ n_embed,
450
+ embed_dim,
451
+ ckpt_path=None,
452
+ ignore_keys=ignore_keys,
453
+ image_key=image_key,
454
+ colorize_nlabels=colorize_nlabels,
455
+ monitor=monitor,
456
+ )
457
+
458
+ self.loss.n_classes = n_embed
459
+ self.vocab_size = n_embed
460
+
461
+ self.quantize = GumbelQuantize(
462
+ z_channels,
463
+ embed_dim,
464
+ n_embed=n_embed,
465
+ kl_weight=kl_weight,
466
+ temp_init=1.0,
467
+ remap=remap,
468
+ )
469
+
470
+ self.temperature_scheduler = instantiate_from_config(
471
+ temperature_scheduler_config
472
+ ) # annealing of temp
473
+
474
+ if ckpt_path is not None:
475
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
476
+
477
+ def temperature_scheduling(self):
478
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
479
+
480
+ def encode_to_prequant(self, x):
481
+ h = self.encoder(x)
482
+ h = self.quant_conv(h)
483
+ return h
484
+
485
+ def decode_code(self, code_b):
486
+ raise NotImplementedError
487
+
488
+ def training_step(self, batch, batch_idx=None, optimizer_idx=0):
489
+ self.temperature_scheduling()
490
+ x = self.get_input(batch, self.image_key)
491
+ xrec, qloss = self(x)
492
+
493
+ if optimizer_idx == 0:
494
+ # autoencode
495
+ aeloss, log_dict_ae = self.loss(
496
+ qloss,
497
+ x,
498
+ xrec,
499
+ optimizer_idx,
500
+ self.global_step,
501
+ last_layer=self.get_last_layer(),
502
+ split="train",
503
+ )
504
+
505
+ self.log_dict(
506
+ log_dict_ae,
507
+ prog_bar=False,
508
+ logger=True,
509
+ on_step=True,
510
+ on_epoch=True,
511
+ sync_dist=True,
512
+ )
513
+ self.log(
514
+ "temperature",
515
+ self.quantize.temperature,
516
+ prog_bar=False,
517
+ logger=True,
518
+ on_step=True,
519
+ on_epoch=True,
520
+ sync_dist=True,
521
+ )
522
+ return aeloss
523
+
524
+ if optimizer_idx == 1:
525
+ # discriminator
526
+ discloss, log_dict_disc = self.loss(
527
+ qloss,
528
+ x,
529
+ xrec,
530
+ optimizer_idx,
531
+ self.global_step,
532
+ last_layer=self.get_last_layer(),
533
+ split="train",
534
+ )
535
+ self.log_dict(
536
+ log_dict_disc,
537
+ prog_bar=False,
538
+ logger=True,
539
+ on_step=True,
540
+ on_epoch=True,
541
+ sync_dist=True,
542
+ )
543
+ return discloss
544
+
545
+ def validation_step(self, batch, batch_idx):
546
+ x = self.get_input(batch, self.image_key)
547
+ xrec, qloss = self(x)
548
+ aeloss, log_dict_ae = self.loss(
549
+ qloss,
550
+ x,
551
+ xrec,
552
+ 0,
553
+ self.global_step,
554
+ last_layer=self.get_last_layer(),
555
+ split="val",
556
+ )
557
+
558
+ discloss, log_dict_disc = self.loss(
559
+ qloss,
560
+ x,
561
+ xrec,
562
+ 1,
563
+ self.global_step,
564
+ last_layer=self.get_last_layer(),
565
+ split="val",
566
+ )
567
+ rec_loss = log_dict_ae["val/rec_loss"]
568
+ self.log(
569
+ "val/rec_loss",
570
+ rec_loss,
571
+ prog_bar=True,
572
+ logger=True,
573
+ on_step=False,
574
+ on_epoch=True,
575
+ sync_dist=True,
576
+ )
577
+ self.log(
578
+ "val/aeloss",
579
+ aeloss,
580
+ prog_bar=True,
581
+ logger=True,
582
+ on_step=False,
583
+ on_epoch=True,
584
+ sync_dist=True,
585
+ )
586
+ self.log_dict(log_dict_ae, sync_dist=True)
587
+ self.log_dict(log_dict_disc, sync_dist=True)
588
+ return self.log_dict
589
+
590
+ def log_images(self, batch, **kwargs):
591
+ log = dict()
592
+ x = self.get_input(batch, self.image_key)
593
+ x = x.to(self.device)
594
+ # encode
595
+ h = self.encoder(x)
596
+ h = self.quant_conv(h)
597
+ quant, _, _ = self.quantize(h)
598
+ # decode
599
+ x_rec = self.decode(quant)
600
+ log["inputs"] = x
601
+ log["reconstructions"] = x_rec
602
+ return log
603
+
604
+
605
+ class EMAVQ(VQModel):
606
+ def __init__(
607
+ self,
608
+ ddconfig,
609
+ lossconfig,
610
+ n_embed,
611
+ embed_dim,
612
+ ckpt_path=None,
613
+ ignore_keys=[],
614
+ image_key="image",
615
+ colorize_nlabels=None,
616
+ monitor=None,
617
+ remap=None,
618
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
619
+ ):
620
+ super().__init__(
621
+ ddconfig,
622
+ lossconfig,
623
+ n_embed,
624
+ embed_dim,
625
+ ckpt_path=None,
626
+ ignore_keys=ignore_keys,
627
+ image_key=image_key,
628
+ colorize_nlabels=colorize_nlabels,
629
+ monitor=monitor,
630
+ )
631
+ self.quantize = EMAVectorQuantizer(
632
+ n_embed=n_embed, embedding_dim=embed_dim, beta=0.25, remap=remap
633
+ )
634
+
635
+ def configure_optimizers(self):
636
+ lr = self.learning_rate
637
+ # Remove self.quantize from parameter list since it is updated via EMA
638
+ opt_ae = torch.optim.Adam(
639
+ list(self.encoder.parameters())
640
+ + list(self.decoder.parameters())
641
+ + list(self.quant_conv.parameters())
642
+ + list(self.post_quant_conv.parameters()),
643
+ lr=lr,
644
+ betas=(0.5, 0.9),
645
+ )
646
+ opt_disc = torch.optim.Adam(
647
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
648
+ )
649
+ return [opt_ae, opt_disc], []
taming/modules/autoencoder/lpips/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
taming/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def get_timestep_embedding(timesteps, embedding_dim):
9
+ """
10
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
11
+ From Fairseq.
12
+ Build sinusoidal embeddings.
13
+ This matches the implementation in tensor2tensor, but differs slightly
14
+ from the description in Section 3.5 of "Attention Is All You Need".
15
+ """
16
+ assert len(timesteps.shape) == 1
17
+
18
+ half_dim = embedding_dim // 2
19
+ emb = math.log(10000) / (half_dim - 1)
20
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
+ emb = emb.to(device=timesteps.device)
22
+ emb = timesteps.float()[:, None] * emb[None, :]
23
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
+ if embedding_dim % 2 == 1: # zero pad
25
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
+ return emb
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels):
35
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ def __init__(self, in_channels, with_conv):
40
+ super().__init__()
41
+ self.with_conv = with_conv
42
+ if self.with_conv:
43
+ self.conv = torch.nn.Conv2d(in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels,
95
+ out_channels)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x, temb):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ if temb is not None:
124
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
+
126
+ h = self.norm2(h)
127
+ h = nonlinearity(h)
128
+ h = self.dropout(h)
129
+ h = self.conv2(h)
130
+
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ x = self.conv_shortcut(x)
134
+ else:
135
+ x = self.nin_shortcut(x)
136
+
137
+ return x+h
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(self, in_channels):
142
+ super().__init__()
143
+ self.in_channels = in_channels
144
+
145
+ self.norm = Normalize(in_channels)
146
+ self.q = torch.nn.Conv2d(in_channels,
147
+ in_channels,
148
+ kernel_size=1,
149
+ stride=1,
150
+ padding=0)
151
+ self.k = torch.nn.Conv2d(in_channels,
152
+ in_channels,
153
+ kernel_size=1,
154
+ stride=1,
155
+ padding=0)
156
+ self.v = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.proj_out = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b,c,h,w = q.shape
177
+ q = q.reshape(b,c,h*w)
178
+ q = q.permute(0,2,1) # b,hw,c
179
+ k = k.reshape(b,c,h*w) # b,c,hw
180
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c)**(-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b,c,h*w)
186
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b,c,h,w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x+h_
193
+
194
+
195
+ class Model(nn.Module):
196
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
+ resolution, use_timestep=True):
199
+ super().__init__()
200
+ self.ch = ch
201
+ self.temb_ch = self.ch*4
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.resolution = resolution
205
+ self.in_channels = in_channels
206
+
207
+ self.use_timestep = use_timestep
208
+ if self.use_timestep:
209
+ # timestep embedding
210
+ self.temb = nn.Module()
211
+ self.temb.dense = nn.ModuleList([
212
+ torch.nn.Linear(self.ch,
213
+ self.temb_ch),
214
+ torch.nn.Linear(self.temb_ch,
215
+ self.temb_ch),
216
+ ])
217
+
218
+ # downsampling
219
+ self.conv_in = torch.nn.Conv2d(in_channels,
220
+ self.ch,
221
+ kernel_size=3,
222
+ stride=1,
223
+ padding=1)
224
+
225
+ curr_res = resolution
226
+ in_ch_mult = (1,)+tuple(ch_mult)
227
+ self.down = nn.ModuleList()
228
+ for i_level in range(self.num_resolutions):
229
+ block = nn.ModuleList()
230
+ attn = nn.ModuleList()
231
+ block_in = ch*in_ch_mult[i_level]
232
+ block_out = ch*ch_mult[i_level]
233
+ for i_block in range(self.num_res_blocks):
234
+ block.append(ResnetBlock(in_channels=block_in,
235
+ out_channels=block_out,
236
+ temb_channels=self.temb_ch,
237
+ dropout=dropout))
238
+ block_in = block_out
239
+ if curr_res in attn_resolutions:
240
+ attn.append(AttnBlock(block_in))
241
+ down = nn.Module()
242
+ down.block = block
243
+ down.attn = attn
244
+ if i_level != self.num_resolutions-1:
245
+ down.downsample = Downsample(block_in, resamp_with_conv)
246
+ curr_res = curr_res // 2
247
+ self.down.append(down)
248
+
249
+ # middle
250
+ self.mid = nn.Module()
251
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
+ out_channels=block_in,
253
+ temb_channels=self.temb_ch,
254
+ dropout=dropout)
255
+ self.mid.attn_1 = AttnBlock(block_in)
256
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
+ out_channels=block_in,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout)
260
+
261
+ # upsampling
262
+ self.up = nn.ModuleList()
263
+ for i_level in reversed(range(self.num_resolutions)):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_out = ch*ch_mult[i_level]
267
+ skip_in = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks+1):
269
+ if i_block == self.num_res_blocks:
270
+ skip_in = ch*in_ch_mult[i_level]
271
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
272
+ out_channels=block_out,
273
+ temb_channels=self.temb_ch,
274
+ dropout=dropout))
275
+ block_in = block_out
276
+ if curr_res in attn_resolutions:
277
+ attn.append(AttnBlock(block_in))
278
+ up = nn.Module()
279
+ up.block = block
280
+ up.attn = attn
281
+ if i_level != 0:
282
+ up.upsample = Upsample(block_in, resamp_with_conv)
283
+ curr_res = curr_res * 2
284
+ self.up.insert(0, up) # prepend to get consistent order
285
+
286
+ # end
287
+ self.norm_out = Normalize(block_in)
288
+ self.conv_out = torch.nn.Conv2d(block_in,
289
+ out_ch,
290
+ kernel_size=3,
291
+ stride=1,
292
+ padding=1)
293
+
294
+
295
+ def forward(self, x, t=None):
296
+ #assert x.shape[2] == x.shape[3] == self.resolution
297
+
298
+ if self.use_timestep:
299
+ # timestep embedding
300
+ assert t is not None
301
+ temb = get_timestep_embedding(t, self.ch)
302
+ temb = self.temb.dense[0](temb)
303
+ temb = nonlinearity(temb)
304
+ temb = self.temb.dense[1](temb)
305
+ else:
306
+ temb = None
307
+
308
+ # downsampling
309
+ hs = [self.conv_in(x)]
310
+ for i_level in range(self.num_resolutions):
311
+ for i_block in range(self.num_res_blocks):
312
+ h = self.down[i_level].block[i_block](hs[-1], temb)
313
+ if len(self.down[i_level].attn) > 0:
314
+ h = self.down[i_level].attn[i_block](h)
315
+ hs.append(h)
316
+ if i_level != self.num_resolutions-1:
317
+ hs.append(self.down[i_level].downsample(hs[-1]))
318
+
319
+ # middle
320
+ h = hs[-1]
321
+ h = self.mid.block_1(h, temb)
322
+ h = self.mid.attn_1(h)
323
+ h = self.mid.block_2(h, temb)
324
+
325
+ # upsampling
326
+ for i_level in reversed(range(self.num_resolutions)):
327
+ for i_block in range(self.num_res_blocks+1):
328
+ h = self.up[i_level].block[i_block](
329
+ torch.cat([h, hs.pop()], dim=1), temb)
330
+ if len(self.up[i_level].attn) > 0:
331
+ h = self.up[i_level].attn[i_block](h)
332
+ if i_level != 0:
333
+ h = self.up[i_level].upsample(h)
334
+
335
+ # end
336
+ h = self.norm_out(h)
337
+ h = nonlinearity(h)
338
+ h = self.conv_out(h)
339
+ return h
340
+
341
+
342
+ class Encoder(nn.Module):
343
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
+ resolution, z_channels, double_z=True, **ignore_kwargs):
346
+ super().__init__()
347
+ self.ch = ch
348
+ self.temb_ch = 0
349
+ self.num_resolutions = len(ch_mult)
350
+ self.num_res_blocks = num_res_blocks
351
+ self.resolution = resolution
352
+ self.in_channels = in_channels
353
+
354
+ # downsampling
355
+ self.conv_in = torch.nn.Conv2d(in_channels,
356
+ self.ch,
357
+ kernel_size=3,
358
+ stride=1,
359
+ padding=1)
360
+
361
+ curr_res = resolution
362
+ in_ch_mult = (1,)+tuple(ch_mult)
363
+ self.down = nn.ModuleList()
364
+ for i_level in range(self.num_resolutions):
365
+ block = nn.ModuleList()
366
+ attn = nn.ModuleList()
367
+ block_in = ch*in_ch_mult[i_level]
368
+ block_out = ch*ch_mult[i_level]
369
+ for i_block in range(self.num_res_blocks):
370
+ block.append(ResnetBlock(in_channels=block_in,
371
+ out_channels=block_out,
372
+ temb_channels=self.temb_ch,
373
+ dropout=dropout))
374
+ block_in = block_out
375
+ if curr_res in attn_resolutions:
376
+ attn.append(AttnBlock(block_in))
377
+ down = nn.Module()
378
+ down.block = block
379
+ down.attn = attn
380
+ if i_level != self.num_resolutions-1:
381
+ down.downsample = Downsample(block_in, resamp_with_conv)
382
+ curr_res = curr_res // 2
383
+ self.down.append(down)
384
+
385
+ # middle
386
+ self.mid = nn.Module()
387
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
+ out_channels=block_in,
389
+ temb_channels=self.temb_ch,
390
+ dropout=dropout)
391
+ self.mid.attn_1 = AttnBlock(block_in)
392
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
+ out_channels=block_in,
394
+ temb_channels=self.temb_ch,
395
+ dropout=dropout)
396
+
397
+ # end
398
+ self.norm_out = Normalize(block_in)
399
+ self.conv_out = torch.nn.Conv2d(block_in,
400
+ 2*z_channels if double_z else z_channels,
401
+ kernel_size=3,
402
+ stride=1,
403
+ padding=1)
404
+
405
+
406
+ def forward(self, x):
407
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
+
409
+ # timestep embedding
410
+ temb = None
411
+
412
+ # downsampling
413
+ hs = [self.conv_in(x)]
414
+ for i_level in range(self.num_resolutions):
415
+ for i_block in range(self.num_res_blocks):
416
+ h = self.down[i_level].block[i_block](hs[-1], temb)
417
+ if len(self.down[i_level].attn) > 0:
418
+ h = self.down[i_level].attn[i_block](h)
419
+ hs.append(h)
420
+ if i_level != self.num_resolutions-1:
421
+ hs.append(self.down[i_level].downsample(hs[-1]))
422
+
423
+ # middle
424
+ h = hs[-1]
425
+ h = self.mid.block_1(h, temb)
426
+ h = self.mid.attn_1(h)
427
+ h = self.mid.block_2(h, temb)
428
+
429
+ # end
430
+ h = self.norm_out(h)
431
+ h = nonlinearity(h)
432
+ h = self.conv_out(h)
433
+ return h
434
+
435
+
436
+ class Decoder(nn.Module):
437
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
+ super().__init__()
441
+ self.ch = ch
442
+ self.temb_ch = 0
443
+ self.num_resolutions = len(ch_mult)
444
+ self.num_res_blocks = num_res_blocks
445
+ self.resolution = resolution
446
+ self.in_channels = in_channels
447
+ self.give_pre_end = give_pre_end
448
+
449
+ # compute in_ch_mult, block_in and curr_res at lowest res
450
+ in_ch_mult = (1,)+tuple(ch_mult)
451
+ block_in = ch*ch_mult[self.num_resolutions-1]
452
+ curr_res = resolution // 2**(self.num_resolutions-1)
453
+ self.z_shape = (1,z_channels,curr_res,curr_res)
454
+ print("Working with z of shape {} = {} dimensions.".format(
455
+ self.z_shape, np.prod(self.z_shape)))
456
+
457
+ # z to block_in
458
+ self.conv_in = torch.nn.Conv2d(z_channels,
459
+ block_in,
460
+ kernel_size=3,
461
+ stride=1,
462
+ padding=1)
463
+
464
+ # middle
465
+ self.mid = nn.Module()
466
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout)
470
+ self.mid.attn_1 = AttnBlock(block_in)
471
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout)
475
+
476
+ # upsampling
477
+ self.up = nn.ModuleList()
478
+ for i_level in reversed(range(self.num_resolutions)):
479
+ block = nn.ModuleList()
480
+ attn = nn.ModuleList()
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks+1):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(AttnBlock(block_in))
490
+ up = nn.Module()
491
+ up.block = block
492
+ up.attn = attn
493
+ if i_level != 0:
494
+ up.upsample = Upsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res * 2
496
+ self.up.insert(0, up) # prepend to get consistent order
497
+
498
+ # end
499
+ self.norm_out = Normalize(block_in)
500
+ self.conv_out = torch.nn.Conv2d(block_in,
501
+ out_ch,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ def forward(self, z):
507
+ #assert z.shape[1:] == self.z_shape[1:]
508
+ self.last_z_shape = z.shape
509
+
510
+ # timestep embedding
511
+ temb = None
512
+
513
+ # z to block_in
514
+ h = self.conv_in(z)
515
+
516
+ # middle
517
+ h = self.mid.block_1(h, temb)
518
+ h = self.mid.attn_1(h)
519
+ h = self.mid.block_2(h, temb)
520
+
521
+ # upsampling
522
+ for i_level in reversed(range(self.num_resolutions)):
523
+ for i_block in range(self.num_res_blocks+1):
524
+ h = self.up[i_level].block[i_block](h, temb)
525
+ if len(self.up[i_level].attn) > 0:
526
+ h = self.up[i_level].attn[i_block](h)
527
+ if i_level != 0:
528
+ h = self.up[i_level].upsample(h)
529
+
530
+ # end
531
+ if self.give_pre_end:
532
+ return h
533
+
534
+ h = self.norm_out(h)
535
+ h = nonlinearity(h)
536
+ h = self.conv_out(h)
537
+ return h
538
+
539
+
540
+ class VUNet(nn.Module):
541
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
+ in_channels, c_channels,
544
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
+ super().__init__()
546
+ self.ch = ch
547
+ self.temb_ch = self.ch*4
548
+ self.num_resolutions = len(ch_mult)
549
+ self.num_res_blocks = num_res_blocks
550
+ self.resolution = resolution
551
+
552
+ self.use_timestep = use_timestep
553
+ if self.use_timestep:
554
+ # timestep embedding
555
+ self.temb = nn.Module()
556
+ self.temb.dense = nn.ModuleList([
557
+ torch.nn.Linear(self.ch,
558
+ self.temb_ch),
559
+ torch.nn.Linear(self.temb_ch,
560
+ self.temb_ch),
561
+ ])
562
+
563
+ # downsampling
564
+ self.conv_in = torch.nn.Conv2d(c_channels,
565
+ self.ch,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1)
569
+
570
+ curr_res = resolution
571
+ in_ch_mult = (1,)+tuple(ch_mult)
572
+ self.down = nn.ModuleList()
573
+ for i_level in range(self.num_resolutions):
574
+ block = nn.ModuleList()
575
+ attn = nn.ModuleList()
576
+ block_in = ch*in_ch_mult[i_level]
577
+ block_out = ch*ch_mult[i_level]
578
+ for i_block in range(self.num_res_blocks):
579
+ block.append(ResnetBlock(in_channels=block_in,
580
+ out_channels=block_out,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout))
583
+ block_in = block_out
584
+ if curr_res in attn_resolutions:
585
+ attn.append(AttnBlock(block_in))
586
+ down = nn.Module()
587
+ down.block = block
588
+ down.attn = attn
589
+ if i_level != self.num_resolutions-1:
590
+ down.downsample = Downsample(block_in, resamp_with_conv)
591
+ curr_res = curr_res // 2
592
+ self.down.append(down)
593
+
594
+ self.z_in = torch.nn.Conv2d(z_channels,
595
+ block_in,
596
+ kernel_size=1,
597
+ stride=1,
598
+ padding=0)
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
+ out_channels=block_in,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout)
605
+ self.mid.attn_1 = AttnBlock(block_in)
606
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
+ out_channels=block_in,
608
+ temb_channels=self.temb_ch,
609
+ dropout=dropout)
610
+
611
+ # upsampling
612
+ self.up = nn.ModuleList()
613
+ for i_level in reversed(range(self.num_resolutions)):
614
+ block = nn.ModuleList()
615
+ attn = nn.ModuleList()
616
+ block_out = ch*ch_mult[i_level]
617
+ skip_in = ch*ch_mult[i_level]
618
+ for i_block in range(self.num_res_blocks+1):
619
+ if i_block == self.num_res_blocks:
620
+ skip_in = ch*in_ch_mult[i_level]
621
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
622
+ out_channels=block_out,
623
+ temb_channels=self.temb_ch,
624
+ dropout=dropout))
625
+ block_in = block_out
626
+ if curr_res in attn_resolutions:
627
+ attn.append(AttnBlock(block_in))
628
+ up = nn.Module()
629
+ up.block = block
630
+ up.attn = attn
631
+ if i_level != 0:
632
+ up.upsample = Upsample(block_in, resamp_with_conv)
633
+ curr_res = curr_res * 2
634
+ self.up.insert(0, up) # prepend to get consistent order
635
+
636
+ # end
637
+ self.norm_out = Normalize(block_in)
638
+ self.conv_out = torch.nn.Conv2d(block_in,
639
+ out_ch,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1)
643
+
644
+
645
+ def forward(self, x, z):
646
+ #assert x.shape[2] == x.shape[3] == self.resolution
647
+
648
+ if self.use_timestep:
649
+ # timestep embedding
650
+ assert t is not None
651
+ temb = get_timestep_embedding(t, self.ch)
652
+ temb = self.temb.dense[0](temb)
653
+ temb = nonlinearity(temb)
654
+ temb = self.temb.dense[1](temb)
655
+ else:
656
+ temb = None
657
+
658
+ # downsampling
659
+ hs = [self.conv_in(x)]
660
+ for i_level in range(self.num_resolutions):
661
+ for i_block in range(self.num_res_blocks):
662
+ h = self.down[i_level].block[i_block](hs[-1], temb)
663
+ if len(self.down[i_level].attn) > 0:
664
+ h = self.down[i_level].attn[i_block](h)
665
+ hs.append(h)
666
+ if i_level != self.num_resolutions-1:
667
+ hs.append(self.down[i_level].downsample(hs[-1]))
668
+
669
+ # middle
670
+ h = hs[-1]
671
+ z = self.z_in(z)
672
+ h = torch.cat((h,z),dim=1)
673
+ h = self.mid.block_1(h, temb)
674
+ h = self.mid.attn_1(h)
675
+ h = self.mid.block_2(h, temb)
676
+
677
+ # upsampling
678
+ for i_level in reversed(range(self.num_resolutions)):
679
+ for i_block in range(self.num_res_blocks+1):
680
+ h = self.up[i_level].block[i_block](
681
+ torch.cat([h, hs.pop()], dim=1), temb)
682
+ if len(self.up[i_level].attn) > 0:
683
+ h = self.up[i_level].attn[i_block](h)
684
+ if i_level != 0:
685
+ h = self.up[i_level].upsample(h)
686
+
687
+ # end
688
+ h = self.norm_out(h)
689
+ h = nonlinearity(h)
690
+ h = self.conv_out(h)
691
+ return h
692
+
693
+
694
+ class SimpleDecoder(nn.Module):
695
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
696
+ super().__init__()
697
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
+ ResnetBlock(in_channels=in_channels,
699
+ out_channels=2 * in_channels,
700
+ temb_channels=0, dropout=0.0),
701
+ ResnetBlock(in_channels=2 * in_channels,
702
+ out_channels=4 * in_channels,
703
+ temb_channels=0, dropout=0.0),
704
+ ResnetBlock(in_channels=4 * in_channels,
705
+ out_channels=2 * in_channels,
706
+ temb_channels=0, dropout=0.0),
707
+ nn.Conv2d(2*in_channels, in_channels, 1),
708
+ Upsample(in_channels, with_conv=True)])
709
+ # end
710
+ self.norm_out = Normalize(in_channels)
711
+ self.conv_out = torch.nn.Conv2d(in_channels,
712
+ out_channels,
713
+ kernel_size=3,
714
+ stride=1,
715
+ padding=1)
716
+
717
+ def forward(self, x):
718
+ for i, layer in enumerate(self.model):
719
+ if i in [1,2,3]:
720
+ x = layer(x, None)
721
+ else:
722
+ x = layer(x)
723
+
724
+ h = self.norm_out(x)
725
+ h = nonlinearity(h)
726
+ x = self.conv_out(h)
727
+ return x
728
+
729
+
730
+ class UpsampleDecoder(nn.Module):
731
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
+ ch_mult=(2,2), dropout=0.0):
733
+ super().__init__()
734
+ # upsampling
735
+ self.temb_ch = 0
736
+ self.num_resolutions = len(ch_mult)
737
+ self.num_res_blocks = num_res_blocks
738
+ block_in = in_channels
739
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
+ self.res_blocks = nn.ModuleList()
741
+ self.upsample_blocks = nn.ModuleList()
742
+ for i_level in range(self.num_resolutions):
743
+ res_block = []
744
+ block_out = ch * ch_mult[i_level]
745
+ for i_block in range(self.num_res_blocks + 1):
746
+ res_block.append(ResnetBlock(in_channels=block_in,
747
+ out_channels=block_out,
748
+ temb_channels=self.temb_ch,
749
+ dropout=dropout))
750
+ block_in = block_out
751
+ self.res_blocks.append(nn.ModuleList(res_block))
752
+ if i_level != self.num_resolutions - 1:
753
+ self.upsample_blocks.append(Upsample(block_in, True))
754
+ curr_res = curr_res * 2
755
+
756
+ # end
757
+ self.norm_out = Normalize(block_in)
758
+ self.conv_out = torch.nn.Conv2d(block_in,
759
+ out_channels,
760
+ kernel_size=3,
761
+ stride=1,
762
+ padding=1)
763
+
764
+ def forward(self, x):
765
+ # upsampling
766
+ h = x
767
+ for k, i_level in enumerate(range(self.num_resolutions)):
768
+ for i_block in range(self.num_res_blocks + 1):
769
+ h = self.res_blocks[i_level][i_block](h, None)
770
+ if i_level != self.num_resolutions - 1:
771
+ h = self.upsample_blocks[k](h)
772
+ h = self.norm_out(h)
773
+ h = nonlinearity(h)
774
+ h = self.conv_out(h)
775
+ return h
776
+
taming/modules/discriminator/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ from taming.modules.util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find('Conv') != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
taming/modules/losses/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from taming.modules.losses.vqperceptual import DummyLoss
2
+
taming/modules/losses/lpips.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+ from taming.util import get_ckpt_path
9
+
10
+
11
+ class LPIPS(nn.Module):
12
+ # Learned perceptual metric
13
+ def __init__(self, use_dropout=True):
14
+ super().__init__()
15
+ self.scaling_layer = ScalingLayer()
16
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
17
+ self.net = vgg16(pretrained=True, requires_grad=False)
18
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
+ self.load_from_pretrained()
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def load_from_pretrained(self, name="vgg_lpips"):
28
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
+
32
+ @classmethod
33
+ def from_pretrained(cls, name="vgg_lpips"):
34
+ if name != "vgg_lpips":
35
+ raise NotImplementedError
36
+ model = cls()
37
+ ckpt = get_ckpt_path(name)
38
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
+ return model
40
+
41
+ def forward(self, input, target):
42
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
+ feats0, feats1, diffs = {}, {}, {}
45
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
+ for kk in range(len(self.chns)):
47
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
+
50
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
+ val = res[0]
52
+ for l in range(1, len(self.chns)):
53
+ val += res[l]
54
+ return val
55
+
56
+
57
+ class ScalingLayer(nn.Module):
58
+ def __init__(self):
59
+ super(ScalingLayer, self).__init__()
60
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
+
63
+ def forward(self, inp):
64
+ return (inp - self.shift) / self.scale
65
+
66
+
67
+ class NetLinLayer(nn.Module):
68
+ """ A single linear layer which does a 1x1 conv """
69
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
+ super(NetLinLayer, self).__init__()
71
+ layers = [nn.Dropout(), ] if (use_dropout) else []
72
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
+ self.model = nn.Sequential(*layers)
74
+
75
+
76
+ class vgg16(torch.nn.Module):
77
+ def __init__(self, requires_grad=False, pretrained=True):
78
+ super(vgg16, self).__init__()
79
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
+ self.slice1 = torch.nn.Sequential()
81
+ self.slice2 = torch.nn.Sequential()
82
+ self.slice3 = torch.nn.Sequential()
83
+ self.slice4 = torch.nn.Sequential()
84
+ self.slice5 = torch.nn.Sequential()
85
+ self.N_slices = 5
86
+ for x in range(4):
87
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
+ for x in range(4, 9):
89
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
+ for x in range(9, 16):
91
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
+ for x in range(16, 23):
93
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
+ for x in range(23, 30):
95
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
+ if not requires_grad:
97
+ for param in self.parameters():
98
+ param.requires_grad = False
99
+
100
+ def forward(self, X):
101
+ h = self.slice1(X)
102
+ h_relu1_2 = h
103
+ h = self.slice2(h)
104
+ h_relu2_2 = h
105
+ h = self.slice3(h)
106
+ h_relu3_3 = h
107
+ h = self.slice4(h)
108
+ h_relu4_3 = h
109
+ h = self.slice5(h)
110
+ h_relu5_3 = h
111
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
+ return out
114
+
115
+
116
+ def normalize_tensor(x,eps=1e-10):
117
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
+ return x/(norm_factor+eps)
119
+
120
+
121
+ def spatial_average(x, keepdim=True):
122
+ return x.mean([2,3],keepdim=keepdim)
123
+
taming/modules/losses/segmentation.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BCELoss(nn.Module):
6
+ def forward(self, prediction, target):
7
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
8
+ return loss, {}
9
+
10
+
11
+ class BCELossWithQuant(nn.Module):
12
+ def __init__(self, codebook_weight=1.):
13
+ super().__init__()
14
+ self.codebook_weight = codebook_weight
15
+
16
+ def forward(self, qloss, target, prediction, split):
17
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18
+ loss = bce_loss + self.codebook_weight*qloss
19
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
21
+ "{}/quant_loss".format(split): qloss.detach().mean()
22
+ }
taming/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from taming.modules.losses.lpips import LPIPS
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+
8
+
9
+ class DummyLoss(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+
14
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
15
+ if global_step < threshold:
16
+ weight = value
17
+ return weight
18
+
19
+
20
+ def hinge_d_loss(logits_real, logits_fake):
21
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
22
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
23
+ d_loss = 0.5 * (loss_real + loss_fake)
24
+ return d_loss
25
+
26
+
27
+ def vanilla_d_loss(logits_real, logits_fake):
28
+ d_loss = 0.5 * (
29
+ torch.mean(torch.nn.functional.softplus(-logits_real))
30
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
31
+ )
32
+ return d_loss
33
+
34
+
35
+ class VQLPIPSWithDiscriminator(nn.Module):
36
+ def __init__(
37
+ self,
38
+ disc_start,
39
+ codebook_weight=1.0,
40
+ pixelloss_weight=1.0,
41
+ disc_num_layers=3,
42
+ disc_in_channels=3,
43
+ disc_factor=1.0,
44
+ disc_weight=1.0,
45
+ perceptual_weight=1.0,
46
+ use_actnorm=False,
47
+ disc_conditional=False,
48
+ disc_ndf=64,
49
+ disc_loss="hinge",
50
+ ):
51
+ super().__init__()
52
+ assert disc_loss in ["hinge", "vanilla"]
53
+ self.codebook_weight = codebook_weight
54
+ self.pixel_weight = pixelloss_weight
55
+ self.perceptual_loss = LPIPS().eval()
56
+ self.perceptual_weight = perceptual_weight
57
+
58
+ self.discriminator = NLayerDiscriminator(
59
+ input_nc=disc_in_channels,
60
+ n_layers=disc_num_layers,
61
+ use_actnorm=use_actnorm,
62
+ ndf=disc_ndf,
63
+ ).apply(weights_init)
64
+ self.discriminator_iter_start = disc_start
65
+ if disc_loss == "hinge":
66
+ self.disc_loss = hinge_d_loss
67
+ elif disc_loss == "vanilla":
68
+ self.disc_loss = vanilla_d_loss
69
+ else:
70
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
71
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
72
+ self.disc_factor = disc_factor
73
+ self.discriminator_weight = disc_weight
74
+ self.disc_conditional = disc_conditional
75
+
76
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
77
+ if last_layer is not None:
78
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
79
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
80
+ else:
81
+ nll_grads = torch.autograd.grad(
82
+ nll_loss, self.last_layer[0], retain_graph=True
83
+ )[0]
84
+ g_grads = torch.autograd.grad(
85
+ g_loss, self.last_layer[0], retain_graph=True
86
+ )[0]
87
+
88
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
89
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
90
+ d_weight = d_weight * self.discriminator_weight
91
+ return d_weight
92
+
93
+ def forward(
94
+ self,
95
+ codebook_loss,
96
+ inputs,
97
+ reconstructions,
98
+ optimizer_idx,
99
+ global_step,
100
+ last_layer=None,
101
+ cond=None,
102
+ split="train",
103
+ ):
104
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
105
+ if self.perceptual_weight > 0:
106
+ p_loss = self.perceptual_loss(
107
+ inputs.contiguous(), reconstructions.contiguous()
108
+ )
109
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
110
+ else:
111
+ p_loss = torch.tensor([0.0])
112
+
113
+ nll_loss = rec_loss
114
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
115
+ nll_loss = torch.mean(nll_loss)
116
+
117
+ # now the GAN part
118
+ if optimizer_idx == 0:
119
+ # generator update
120
+ if cond is None:
121
+ assert not self.disc_conditional
122
+ logits_fake = self.discriminator(reconstructions.contiguous())
123
+ else:
124
+ assert self.disc_conditional
125
+ logits_fake = self.discriminator(
126
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
127
+ )
128
+ g_loss = -torch.mean(logits_fake)
129
+
130
+ try:
131
+ d_weight = self.calculate_adaptive_weight(
132
+ nll_loss, g_loss, last_layer=last_layer
133
+ )
134
+ except RuntimeError:
135
+ assert not self.training
136
+ d_weight = torch.tensor(0.0)
137
+
138
+ disc_factor = adopt_weight(
139
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
140
+ )
141
+ loss = (
142
+ nll_loss
143
+ + d_weight * disc_factor * g_loss
144
+ + self.codebook_weight * codebook_loss.mean()
145
+ )
146
+
147
+ log = {
148
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
149
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
150
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
151
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
152
+ "{}/p_loss".format(split): p_loss.detach().mean(),
153
+ "{}/d_weight".format(split): d_weight.detach(),
154
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
155
+ "{}/g_loss".format(split): g_loss.detach().mean(),
156
+ }
157
+ return loss, log
158
+
159
+ if optimizer_idx == 1:
160
+ # second pass for discriminator update
161
+ if cond is None:
162
+ logits_real = self.discriminator(inputs.contiguous().detach())
163
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
164
+ else:
165
+ logits_real = self.discriminator(
166
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
167
+ )
168
+ logits_fake = self.discriminator(
169
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
170
+ )
171
+
172
+ disc_factor = adopt_weight(
173
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
174
+ )
175
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
176
+
177
+ log = {
178
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
179
+ "{}/logits_real".format(split): logits_real.detach().mean(),
180
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
181
+ }
182
+ return d_loss, log
taming/modules/misc/coord.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class CoordStage(object):
4
+ def __init__(self, n_embed, down_factor):
5
+ self.n_embed = n_embed
6
+ self.down_factor = down_factor
7
+
8
+ def eval(self):
9
+ return self
10
+
11
+ def encode(self, c):
12
+ """fake vqmodel interface"""
13
+ assert 0.0 <= c.min() and c.max() <= 1.0
14
+ b,ch,h,w = c.shape
15
+ assert ch == 1
16
+
17
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18
+ mode="area")
19
+ c = c.clamp(0.0, 1.0)
20
+ c = self.n_embed*c
21
+ c_quant = c.round()
22
+ c_ind = c_quant.to(dtype=torch.long)
23
+
24
+ info = None, None, c_ind
25
+ return c_quant, None, info
26
+
27
+ def decode(self, c):
28
+ c = c/self.n_embed
29
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30
+ mode="nearest")
31
+ return c
taming/modules/transformer/mingpt.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from: https://github.com/karpathy/minGPT/
3
+ GPT model:
4
+ - the initial stem consists of a combination of token encoding and a positional encoding
5
+ - the meat of it is a uniform sequence of Transformer blocks
6
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7
+ - all blocks feed into a central residual pathway similar to resnets
8
+ - the final decoder is a linear projection into a vanilla Softmax classifier
9
+ """
10
+
11
+ import math
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from transformers import top_k_top_p_filtering
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GPTConfig:
23
+ """ base GPT config, params common to all GPT versions """
24
+ embd_pdrop = 0.1
25
+ resid_pdrop = 0.1
26
+ attn_pdrop = 0.1
27
+
28
+ def __init__(self, vocab_size, block_size, **kwargs):
29
+ self.vocab_size = vocab_size
30
+ self.block_size = block_size
31
+ for k,v in kwargs.items():
32
+ setattr(self, k, v)
33
+
34
+
35
+ class GPT1Config(GPTConfig):
36
+ """ GPT-1 like network roughly 125M params """
37
+ n_layer = 12
38
+ n_head = 12
39
+ n_embd = 768
40
+
41
+
42
+ class CausalSelfAttention(nn.Module):
43
+ """
44
+ A vanilla multi-head masked self-attention layer with a projection at the end.
45
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
46
+ explicit implementation here to show that there is nothing too scary here.
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ assert config.n_embd % config.n_head == 0
52
+ # key, query, value projections for all heads
53
+ self.key = nn.Linear(config.n_embd, config.n_embd)
54
+ self.query = nn.Linear(config.n_embd, config.n_embd)
55
+ self.value = nn.Linear(config.n_embd, config.n_embd)
56
+ # regularization
57
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
58
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
59
+ # output projection
60
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
61
+ # causal mask to ensure that attention is only applied to the left in the input sequence
62
+ mask = torch.tril(torch.ones(config.block_size,
63
+ config.block_size))
64
+ if hasattr(config, "n_unmasked"):
65
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
66
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
67
+ self.n_head = config.n_head
68
+
69
+ def forward(self, x, layer_past=None):
70
+ B, T, C = x.size()
71
+
72
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
73
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76
+
77
+ present = torch.stack((k, v))
78
+ if layer_past is not None:
79
+ past_key, past_value = layer_past
80
+ k = torch.cat((past_key, k), dim=-2)
81
+ v = torch.cat((past_value, v), dim=-2)
82
+
83
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
85
+ if layer_past is None:
86
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
87
+
88
+ att = F.softmax(att, dim=-1)
89
+ att = self.attn_drop(att)
90
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
91
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
92
+
93
+ # output projection
94
+ y = self.resid_drop(self.proj(y))
95
+ return y, present # TODO: check that this does not break anything
96
+
97
+
98
+ class Block(nn.Module):
99
+ """ an unassuming Transformer block """
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.ln1 = nn.LayerNorm(config.n_embd)
103
+ self.ln2 = nn.LayerNorm(config.n_embd)
104
+ self.attn = CausalSelfAttention(config)
105
+ self.mlp = nn.Sequential(
106
+ nn.Linear(config.n_embd, 4 * config.n_embd),
107
+ nn.GELU(), # nice
108
+ nn.Linear(4 * config.n_embd, config.n_embd),
109
+ nn.Dropout(config.resid_pdrop),
110
+ )
111
+
112
+ def forward(self, x, layer_past=None, return_present=False):
113
+ # TODO: check that training still works
114
+ if return_present: assert not self.training
115
+ # layer past: tuple of length two with B, nh, T, hs
116
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
117
+
118
+ x = x + attn
119
+ x = x + self.mlp(self.ln2(x))
120
+ if layer_past is not None or return_present:
121
+ return x, present
122
+ return x
123
+
124
+
125
+ class GPT(nn.Module):
126
+ """ the full GPT language model, with a context size of block_size """
127
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
128
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
129
+ super().__init__()
130
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
131
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
132
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
133
+ n_unmasked=n_unmasked)
134
+ # input embedding stem
135
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
136
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
137
+ self.drop = nn.Dropout(config.embd_pdrop)
138
+ # transformer
139
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
140
+ # decoder head
141
+ self.ln_f = nn.LayerNorm(config.n_embd)
142
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143
+ self.block_size = config.block_size
144
+ self.apply(self._init_weights)
145
+ self.config = config
146
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147
+
148
+ def get_block_size(self):
149
+ return self.block_size
150
+
151
+ def _init_weights(self, module):
152
+ if isinstance(module, (nn.Linear, nn.Embedding)):
153
+ module.weight.data.normal_(mean=0.0, std=0.02)
154
+ if isinstance(module, nn.Linear) and module.bias is not None:
155
+ module.bias.data.zero_()
156
+ elif isinstance(module, nn.LayerNorm):
157
+ module.bias.data.zero_()
158
+ module.weight.data.fill_(1.0)
159
+
160
+ def forward(self, idx, embeddings=None, targets=None):
161
+ # forward the GPT model
162
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
163
+
164
+ if embeddings is not None: # prepend explicit embeddings
165
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
166
+
167
+ t = token_embeddings.shape[1]
168
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
169
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
170
+ x = self.drop(token_embeddings + position_embeddings)
171
+ x = self.blocks(x)
172
+ x = self.ln_f(x)
173
+ logits = self.head(x)
174
+
175
+ # if we are given some desired targets also calculate the loss
176
+ loss = None
177
+ if targets is not None:
178
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
179
+
180
+ return logits, loss
181
+
182
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
183
+ # inference only
184
+ assert not self.training
185
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186
+ if embeddings is not None: # prepend explicit embeddings
187
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
188
+
189
+ if past is not None:
190
+ assert past_length is not None
191
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
192
+ past_shape = list(past.shape)
193
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
194
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
195
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
196
+ else:
197
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
198
+
199
+ x = self.drop(token_embeddings + position_embeddings)
200
+ presents = [] # accumulate over layers
201
+ for i, block in enumerate(self.blocks):
202
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
203
+ presents.append(present)
204
+
205
+ x = self.ln_f(x)
206
+ logits = self.head(x)
207
+ # if we are given some desired targets also calculate the loss
208
+ loss = None
209
+ if targets is not None:
210
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
211
+
212
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
213
+
214
+
215
+ class DummyGPT(nn.Module):
216
+ # for debugging
217
+ def __init__(self, add_value=1):
218
+ super().__init__()
219
+ self.add_value = add_value
220
+
221
+ def forward(self, idx):
222
+ return idx + self.add_value, None
223
+
224
+
225
+ class CodeGPT(nn.Module):
226
+ """Takes in semi-embeddings"""
227
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
228
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
229
+ super().__init__()
230
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
231
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
232
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
233
+ n_unmasked=n_unmasked)
234
+ # input embedding stem
235
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
236
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
237
+ self.drop = nn.Dropout(config.embd_pdrop)
238
+ # transformer
239
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
240
+ # decoder head
241
+ self.ln_f = nn.LayerNorm(config.n_embd)
242
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
243
+ self.block_size = config.block_size
244
+ self.apply(self._init_weights)
245
+ self.config = config
246
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
247
+
248
+ def get_block_size(self):
249
+ return self.block_size
250
+
251
+ def _init_weights(self, module):
252
+ if isinstance(module, (nn.Linear, nn.Embedding)):
253
+ module.weight.data.normal_(mean=0.0, std=0.02)
254
+ if isinstance(module, nn.Linear) and module.bias is not None:
255
+ module.bias.data.zero_()
256
+ elif isinstance(module, nn.LayerNorm):
257
+ module.bias.data.zero_()
258
+ module.weight.data.fill_(1.0)
259
+
260
+ def forward(self, idx, embeddings=None, targets=None):
261
+ # forward the GPT model
262
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
263
+
264
+ if embeddings is not None: # prepend explicit embeddings
265
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
266
+
267
+ t = token_embeddings.shape[1]
268
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
269
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
270
+ x = self.drop(token_embeddings + position_embeddings)
271
+ x = self.blocks(x)
272
+ x = self.taming_cinln_f(x)
273
+ logits = self.head(x)
274
+
275
+ # if we are given some desired targets also calculate the loss
276
+ loss = None
277
+ if targets is not None:
278
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
279
+
280
+ return logits, loss
281
+
282
+
283
+
284
+ #### sampling utils
285
+
286
+ def top_k_logits(logits, k):
287
+ v, ix = torch.topk(logits, k)
288
+ out = logits.clone()
289
+ out[out < v[:, [-1]]] = -float('Inf')
290
+ return out
291
+
292
+ @torch.no_grad()
293
+ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
294
+ """
295
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
296
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
297
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
298
+ of block_size, unlike an RNN that has an infinite context window.
299
+ """
300
+ block_size = model.get_block_size()
301
+ model.eval()
302
+ for k in range(steps):
303
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
304
+ logits, _ = model(x_cond)
305
+ # pluck the logits at the final step and scale by temperature
306
+ logits = logits[:, -1, :] / temperature
307
+ # optionally crop probabilities to only the top k options
308
+ if top_k is not None:
309
+ logits = top_k_logits(logits, top_k)
310
+ # apply softmax to convert to probabilities
311
+ probs = F.softmax(logits, dim=-1)
312
+ # sample from the distribution or take the most likely
313
+ if sample:
314
+ ix = torch.multinomial(probs, num_samples=1)
315
+ else:
316
+ _, ix = torch.topk(probs, k=1, dim=-1)
317
+ # append to the sequence and continue
318
+ x = torch.cat((x, ix), dim=1)
319
+
320
+ return x
321
+
322
+
323
+ @torch.no_grad()
324
+ def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
325
+ top_k=None, top_p=None, callback=None):
326
+ # x is conditioning
327
+ sample = x
328
+ cond_len = x.shape[1]
329
+ past = None
330
+ for n in range(steps):
331
+ if callback is not None:
332
+ callback(n)
333
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
334
+ if past is None:
335
+ past = [present]
336
+ else:
337
+ past.append(present)
338
+ logits = logits[:, -1, :] / temperature
339
+ if top_k is not None:
340
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
341
+
342
+ probs = F.softmax(logits, dim=-1)
343
+ if not sample_logits:
344
+ _, x = torch.topk(probs, k=1, dim=-1)
345
+ else:
346
+ x = torch.multinomial(probs, num_samples=1)
347
+ # append to the sequence and continue
348
+ sample = torch.cat((sample, x), dim=1)
349
+ del past
350
+ sample = sample[:, cond_len:] # cut conditioning off
351
+ return sample
352
+
353
+
354
+ #### clustering utils
355
+
356
+ class KMeans(nn.Module):
357
+ def __init__(self, ncluster=512, nc=3, niter=10):
358
+ super().__init__()
359
+ self.ncluster = ncluster
360
+ self.nc = nc
361
+ self.niter = niter
362
+ self.shape = (3,32,32)
363
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
364
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
365
+
366
+ def is_initialized(self):
367
+ return self.initialized.item() == 1
368
+
369
+ @torch.no_grad()
370
+ def initialize(self, x):
371
+ N, D = x.shape
372
+ assert D == self.nc, D
373
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
374
+ for i in range(self.niter):
375
+ # assign all pixels to the closest codebook element
376
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
377
+ # move each codebook element to be the mean of the pixels that assigned to it
378
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
379
+ # re-assign any poorly positioned codebook elements
380
+ nanix = torch.any(torch.isnan(c), dim=1)
381
+ ndead = nanix.sum().item()
382
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
383
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
384
+
385
+ self.C.copy_(c)
386
+ self.initialized.fill_(1)
387
+
388
+
389
+ def forward(self, x, reverse=False, shape=None):
390
+ if not reverse:
391
+ # flatten
392
+ bs,c,h,w = x.shape
393
+ assert c == self.nc
394
+ x = x.reshape(bs,c,h*w,1)
395
+ C = self.C.permute(1,0)
396
+ C = C.reshape(1,c,1,self.ncluster)
397
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
398
+ return a
399
+ else:
400
+ # flatten
401
+ bs, HW = x.shape
402
+ """
403
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
404
+ c = c[bs*[0],:,:,:]
405
+ c = c[:,:,HW*[0],:]
406
+ x = x.reshape(bs, 1, HW, 1)
407
+ x = x[:,3*[0],:,:]
408
+ x = torch.gather(c, dim=3, index=x)
409
+ """
410
+ x = self.C[x]
411
+ x = x.permute(0,2,1)
412
+ shape = shape if shape is not None else self.shape
413
+ x = x.reshape(bs, *shape)
414
+
415
+ return x
taming/modules/transformer/permuter.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class AbstractPermuter(nn.Module):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__()
9
+ def forward(self, x, reverse=False):
10
+ raise NotImplementedError
11
+
12
+
13
+ class Identity(AbstractPermuter):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x, reverse=False):
18
+ return x
19
+
20
+
21
+ class Subsample(AbstractPermuter):
22
+ def __init__(self, H, W):
23
+ super().__init__()
24
+ C = 1
25
+ indices = np.arange(H*W).reshape(C,H,W)
26
+ while min(H, W) > 1:
27
+ indices = indices.reshape(C,H//2,2,W//2,2)
28
+ indices = indices.transpose(0,2,4,1,3)
29
+ indices = indices.reshape(C*4,H//2, W//2)
30
+ H = H//2
31
+ W = W//2
32
+ C = C*4
33
+ assert H == W == 1
34
+ idx = torch.tensor(indices.ravel())
35
+ self.register_buffer('forward_shuffle_idx',
36
+ nn.Parameter(idx, requires_grad=False))
37
+ self.register_buffer('backward_shuffle_idx',
38
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
39
+
40
+ def forward(self, x, reverse=False):
41
+ if not reverse:
42
+ return x[:, self.forward_shuffle_idx]
43
+ else:
44
+ return x[:, self.backward_shuffle_idx]
45
+
46
+
47
+ def mortonify(i, j):
48
+ """(i,j) index to linear morton code"""
49
+ i = np.uint64(i)
50
+ j = np.uint64(j)
51
+
52
+ z = np.uint(0)
53
+
54
+ for pos in range(32):
55
+ z = (z |
56
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58
+ )
59
+ return z
60
+
61
+
62
+ class ZCurve(AbstractPermuter):
63
+ def __init__(self, H, W):
64
+ super().__init__()
65
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66
+ idx = np.argsort(reverseidx)
67
+ idx = torch.tensor(idx)
68
+ reverseidx = torch.tensor(reverseidx)
69
+ self.register_buffer('forward_shuffle_idx',
70
+ idx)
71
+ self.register_buffer('backward_shuffle_idx',
72
+ reverseidx)
73
+
74
+ def forward(self, x, reverse=False):
75
+ if not reverse:
76
+ return x[:, self.forward_shuffle_idx]
77
+ else:
78
+ return x[:, self.backward_shuffle_idx]
79
+
80
+
81
+ class SpiralOut(AbstractPermuter):
82
+ def __init__(self, H, W):
83
+ super().__init__()
84
+ assert H == W
85
+ size = W
86
+ indices = np.arange(size*size).reshape(size,size)
87
+
88
+ i0 = size//2
89
+ j0 = size//2-1
90
+
91
+ i = i0
92
+ j = j0
93
+
94
+ idx = [indices[i0, j0]]
95
+ step_mult = 0
96
+ for c in range(1, size//2+1):
97
+ step_mult += 1
98
+ # steps left
99
+ for k in range(step_mult):
100
+ i = i - 1
101
+ j = j
102
+ idx.append(indices[i, j])
103
+
104
+ # step down
105
+ for k in range(step_mult):
106
+ i = i
107
+ j = j + 1
108
+ idx.append(indices[i, j])
109
+
110
+ step_mult += 1
111
+ if c < size//2:
112
+ # step right
113
+ for k in range(step_mult):
114
+ i = i + 1
115
+ j = j
116
+ idx.append(indices[i, j])
117
+
118
+ # step up
119
+ for k in range(step_mult):
120
+ i = i
121
+ j = j - 1
122
+ idx.append(indices[i, j])
123
+ else:
124
+ # end reached
125
+ for k in range(step_mult-1):
126
+ i = i + 1
127
+ idx.append(indices[i, j])
128
+
129
+ assert len(idx) == size*size
130
+ idx = torch.tensor(idx)
131
+ self.register_buffer('forward_shuffle_idx', idx)
132
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133
+
134
+ def forward(self, x, reverse=False):
135
+ if not reverse:
136
+ return x[:, self.forward_shuffle_idx]
137
+ else:
138
+ return x[:, self.backward_shuffle_idx]
139
+
140
+
141
+ class SpiralIn(AbstractPermuter):
142
+ def __init__(self, H, W):
143
+ super().__init__()
144
+ assert H == W
145
+ size = W
146
+ indices = np.arange(size*size).reshape(size,size)
147
+
148
+ i0 = size//2
149
+ j0 = size//2-1
150
+
151
+ i = i0
152
+ j = j0
153
+
154
+ idx = [indices[i0, j0]]
155
+ step_mult = 0
156
+ for c in range(1, size//2+1):
157
+ step_mult += 1
158
+ # steps left
159
+ for k in range(step_mult):
160
+ i = i - 1
161
+ j = j
162
+ idx.append(indices[i, j])
163
+
164
+ # step down
165
+ for k in range(step_mult):
166
+ i = i
167
+ j = j + 1
168
+ idx.append(indices[i, j])
169
+
170
+ step_mult += 1
171
+ if c < size//2:
172
+ # step right
173
+ for k in range(step_mult):
174
+ i = i + 1
175
+ j = j
176
+ idx.append(indices[i, j])
177
+
178
+ # step up
179
+ for k in range(step_mult):
180
+ i = i
181
+ j = j - 1
182
+ idx.append(indices[i, j])
183
+ else:
184
+ # end reached
185
+ for k in range(step_mult-1):
186
+ i = i + 1
187
+ idx.append(indices[i, j])
188
+
189
+ assert len(idx) == size*size
190
+ idx = idx[::-1]
191
+ idx = torch.tensor(idx)
192
+ self.register_buffer('forward_shuffle_idx', idx)
193
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194
+
195
+ def forward(self, x, reverse=False):
196
+ if not reverse:
197
+ return x[:, self.forward_shuffle_idx]
198
+ else:
199
+ return x[:, self.backward_shuffle_idx]
200
+
201
+
202
+ class Random(nn.Module):
203
+ def __init__(self, H, W):
204
+ super().__init__()
205
+ indices = np.random.RandomState(1).permutation(H*W)
206
+ idx = torch.tensor(indices.ravel())
207
+ self.register_buffer('forward_shuffle_idx', idx)
208
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209
+
210
+ def forward(self, x, reverse=False):
211
+ if not reverse:
212
+ return x[:, self.forward_shuffle_idx]
213
+ else:
214
+ return x[:, self.backward_shuffle_idx]
215
+
216
+
217
+ class AlternateParsing(AbstractPermuter):
218
+ def __init__(self, H, W):
219
+ super().__init__()
220
+ indices = np.arange(W*H).reshape(H,W)
221
+ for i in range(1, H, 2):
222
+ indices[i, :] = indices[i, ::-1]
223
+ idx = indices.flatten()
224
+ assert len(idx) == H*W
225
+ idx = torch.tensor(idx)
226
+ self.register_buffer('forward_shuffle_idx', idx)
227
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228
+
229
+ def forward(self, x, reverse=False):
230
+ if not reverse:
231
+ return x[:, self.forward_shuffle_idx]
232
+ else:
233
+ return x[:, self.backward_shuffle_idx]
234
+
235
+
236
+ if __name__ == "__main__":
237
+ p0 = AlternateParsing(16, 16)
238
+ print(p0.forward_shuffle_idx)
239
+ print(p0.backward_shuffle_idx)
240
+
241
+ x = torch.randint(0, 768, size=(11, 256))
242
+ y = p0(x)
243
+ xre = p0(y, reverse=True)
244
+ assert torch.equal(x, xre)
245
+
246
+ p1 = SpiralOut(2, 2)
247
+ print(p1.forward_shuffle_idx)
248
+ print(p1.backward_shuffle_idx)
taming/modules/util.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def count_params(model):
6
+ total_params = sum(p.numel() for p in model.parameters())
7
+ return total_params
8
+
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+ class AbstractEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def encode(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+
103
+ class Labelator(AbstractEncoder):
104
+ """Net2Net Interface for Class-Conditional Model"""
105
+ def __init__(self, n_classes, quantize_interface=True):
106
+ super().__init__()
107
+ self.n_classes = n_classes
108
+ self.quantize_interface = quantize_interface
109
+
110
+ def encode(self, c):
111
+ c = c[:,None]
112
+ if self.quantize_interface:
113
+ return c, None, [None, None, c.long()]
114
+ return c
115
+
116
+
117
+ class SOSProvider(AbstractEncoder):
118
+ # for unconditional training
119
+ def __init__(self, sos_token, quantize_interface=True):
120
+ super().__init__()
121
+ self.sos_token = sos_token
122
+ self.quantize_interface = quantize_interface
123
+
124
+ def encode(self, x):
125
+ # get batch size from data and replicate sos_token
126
+ c = torch.ones(x.shape[0], 1)*self.sos_token
127
+ c = c.long().to(x.device)
128
+ if self.quantize_interface:
129
+ return c, None, [None, None, c]
130
+ return c
taming/modules/vqvae/quantize.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch import einsum
6
+ from einops import rearrange
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ """
11
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
+ ____________________________________________
13
+ Discretization bottleneck part of the VQ-VAE.
14
+ Inputs:
15
+ - n_e : number of embeddings
16
+ - e_dim : dimension of embedding
17
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
+ _____________________________________________
19
+ """
20
+
21
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
+ # used wherever VectorQuantizer has been used before and is additionally
24
+ # more efficient.
25
+ def __init__(self, n_e, e_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.n_e = n_e
28
+ self.e_dim = e_dim
29
+ self.beta = beta
30
+
31
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
+
34
+ def forward(self, z):
35
+ """
36
+ Inputs the output of the encoder network z and maps it to a discrete
37
+ one-hot vector that is the index of the closest embedding vector e_j
38
+ z (continuous) -> z_q (discrete)
39
+ z.shape = (batch, channel, height, width)
40
+ quantization pipeline:
41
+ 1. get encoder input (B,C,H,W)
42
+ 2. flatten input to (B*H*W,C)
43
+ """
44
+ # reshape z -> (batch, height, width, channel) and flatten
45
+ z = z.permute(0, 2, 3, 1).contiguous()
46
+ z_flattened = z.view(-1, self.e_dim)
47
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
+
49
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51
+ torch.matmul(z_flattened, self.embedding.weight.t())
52
+
53
+ ## could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
+
58
+ min_encodings = torch.zeros(
59
+ min_encoding_indices.shape[0], self.n_e).to(z)
60
+ min_encodings.scatter_(1, min_encoding_indices, 1)
61
+
62
+ # dtype min encodings: torch.float32
63
+ # min_encodings shape: torch.Size([2048, 512])
64
+ # min_encoding_indices.shape: torch.Size([2048, 1])
65
+
66
+ # get quantized latent vectors
67
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
+ #.........\end
69
+
70
+ # with:
71
+ # .........\start
72
+ #min_encoding_indices = torch.argmin(d, dim=1)
73
+ #z_q = self.embedding(min_encoding_indices)
74
+ # ......\end......... (TODO)
75
+
76
+ # compute loss for embedding
77
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78
+ torch.mean((z_q - z.detach()) ** 2)
79
+
80
+ # preserve gradients
81
+ z_q = z + (z_q - z).detach()
82
+
83
+ # perplexity
84
+ e_mean = torch.mean(min_encodings, dim=0)
85
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
+
87
+ # reshape back to match original input shape
88
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
+
90
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
+
92
+ def get_codebook_entry(self, indices, shape):
93
+ # shape specifying (batch, height, width, channel)
94
+ # TODO: check for more easy handling with nn.Embedding
95
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
+ min_encodings.scatter_(1, indices[:,None], 1)
97
+
98
+ # get quantized latent vectors
99
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
+
101
+ if shape is not None:
102
+ z_q = z_q.view(shape)
103
+
104
+ # reshape back to match original input shape
105
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
+
107
+ return z_q
108
+
109
+
110
+ class GumbelQuantize(nn.Module):
111
+ """
112
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
+ Gumbel Softmax trick quantizer
114
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
+ https://arxiv.org/abs/1611.01144
116
+ """
117
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119
+ remap=None, unknown_index="random"):
120
+ super().__init__()
121
+
122
+ self.embedding_dim = embedding_dim
123
+ self.n_embed = n_embed
124
+
125
+ self.straight_through = straight_through
126
+ self.temperature = temp_init
127
+ self.kl_weight = kl_weight
128
+
129
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130
+ self.embed = nn.Embedding(n_embed, embedding_dim)
131
+
132
+ self.use_vqinterface = use_vqinterface
133
+
134
+ self.remap = remap
135
+ if self.remap is not None:
136
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
137
+ self.re_embed = self.used.shape[0]
138
+ self.unknown_index = unknown_index # "random" or "extra" or integer
139
+ if self.unknown_index == "extra":
140
+ self.unknown_index = self.re_embed
141
+ self.re_embed = self.re_embed+1
142
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143
+ f"Using {self.unknown_index} for unknown indices.")
144
+ else:
145
+ self.re_embed = n_embed
146
+
147
+ def remap_to_used(self, inds):
148
+ ishape = inds.shape
149
+ assert len(ishape)>1
150
+ inds = inds.reshape(ishape[0],-1)
151
+ used = self.used.to(inds)
152
+ match = (inds[:,:,None]==used[None,None,...]).long()
153
+ new = match.argmax(-1)
154
+ unknown = match.sum(2)<1
155
+ if self.unknown_index == "random":
156
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157
+ else:
158
+ new[unknown] = self.unknown_index
159
+ return new.reshape(ishape)
160
+
161
+ def unmap_to_all(self, inds):
162
+ ishape = inds.shape
163
+ assert len(ishape)>1
164
+ inds = inds.reshape(ishape[0],-1)
165
+ used = self.used.to(inds)
166
+ if self.re_embed > self.used.shape[0]: # extra token
167
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169
+ return back.reshape(ishape)
170
+
171
+ def forward(self, z, temp=None, return_logits=False):
172
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173
+ hard = self.straight_through if self.training else True
174
+ temp = self.temperature if temp is None else temp
175
+
176
+ logits = self.proj(z)
177
+ if self.remap is not None:
178
+ # continue only with used logits
179
+ full_zeros = torch.zeros_like(logits)
180
+ logits = logits[:,self.used,...]
181
+
182
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183
+ if self.remap is not None:
184
+ # go back to all entries but unused set to zero
185
+ full_zeros[:,self.used,...] = soft_one_hot
186
+ soft_one_hot = full_zeros
187
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188
+
189
+ # + kl divergence to the prior loss
190
+ qy = F.softmax(logits, dim=1)
191
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192
+
193
+ ind = soft_one_hot.argmax(dim=1)
194
+ if self.remap is not None:
195
+ ind = self.remap_to_used(ind)
196
+ if self.use_vqinterface:
197
+ if return_logits:
198
+ return z_q, diff, (None, None, ind), logits
199
+ return z_q, diff, (None, None, ind)
200
+ return z_q, diff, ind
201
+
202
+ def get_codebook_entry(self, indices, shape):
203
+ b, h, w, c = shape
204
+ assert b*h*w == indices.shape[0]
205
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206
+ if self.remap is not None:
207
+ indices = self.unmap_to_all(indices)
208
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210
+ return z_q
211
+
212
+
213
+ class VectorQuantizer2(nn.Module):
214
+ """
215
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217
+ """
218
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
219
+ # backwards compatibility we use the buggy version by default, but you can
220
+ # specify legacy=False to fix it.
221
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222
+ sane_index_shape=False, legacy=True):
223
+ super().__init__()
224
+ self.n_e = n_e
225
+ self.e_dim = e_dim
226
+ self.beta = beta
227
+ self.legacy = legacy
228
+
229
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
230
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231
+
232
+ self.remap = remap
233
+ if self.remap is not None:
234
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
235
+ self.re_embed = self.used.shape[0]
236
+ self.unknown_index = unknown_index # "random" or "extra" or integer
237
+ if self.unknown_index == "extra":
238
+ self.unknown_index = self.re_embed
239
+ self.re_embed = self.re_embed+1
240
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
+ f"Using {self.unknown_index} for unknown indices.")
242
+ else:
243
+ self.re_embed = n_e
244
+
245
+ self.sane_index_shape = sane_index_shape
246
+
247
+ def remap_to_used(self, inds):
248
+ ishape = inds.shape
249
+ assert len(ishape)>1
250
+ inds = inds.reshape(ishape[0],-1)
251
+ used = self.used.to(inds)
252
+ match = (inds[:,:,None]==used[None,None,...]).long()
253
+ new = match.argmax(-1)
254
+ unknown = match.sum(2)<1
255
+ if self.unknown_index == "random":
256
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257
+ else:
258
+ new[unknown] = self.unknown_index
259
+ return new.reshape(ishape)
260
+
261
+ def unmap_to_all(self, inds):
262
+ ishape = inds.shape
263
+ assert len(ishape)>1
264
+ inds = inds.reshape(ishape[0],-1)
265
+ used = self.used.to(inds)
266
+ if self.re_embed > self.used.shape[0]: # extra token
267
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269
+ return back.reshape(ishape)
270
+
271
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
274
+ assert return_logits==False, "Only for interface compatible with Gumbel"
275
+ # reshape z -> (batch, height, width, channel) and flatten
276
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
277
+ z_flattened = z.view(-1, self.e_dim)
278
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279
+
280
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283
+
284
+ min_encoding_indices = torch.argmin(d, dim=1)
285
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
286
+ perplexity = None
287
+ min_encodings = None
288
+
289
+ # compute loss for embedding
290
+ if not self.legacy:
291
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292
+ torch.mean((z_q - z.detach()) ** 2)
293
+ else:
294
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295
+ torch.mean((z_q - z.detach()) ** 2)
296
+
297
+ # preserve gradients
298
+ z_q = z + (z_q - z).detach()
299
+
300
+ # reshape back to match original input shape
301
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302
+
303
+ if self.remap is not None:
304
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
306
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307
+
308
+ if self.sane_index_shape:
309
+ min_encoding_indices = min_encoding_indices.reshape(
310
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
311
+
312
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313
+
314
+ def get_codebook_entry(self, indices, shape):
315
+ # shape specifying (batch, height, width, channel)
316
+ if self.remap is not None:
317
+ indices = indices.reshape(shape[0],-1) # add batch axis
318
+ indices = self.unmap_to_all(indices)
319
+ indices = indices.reshape(-1) # flatten again
320
+
321
+ # get quantized latent vectors
322
+ z_q = self.embedding(indices)
323
+
324
+ if shape is not None:
325
+ z_q = z_q.view(shape)
326
+ # reshape back to match original input shape
327
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
328
+
329
+ return z_q
330
+
331
+ class EmbeddingEMA(nn.Module):
332
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333
+ super().__init__()
334
+ self.decay = decay
335
+ self.eps = eps
336
+ weight = torch.randn(num_tokens, codebook_dim)
337
+ self.weight = nn.Parameter(weight, requires_grad = False)
338
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340
+ self.update = True
341
+
342
+ def forward(self, embed_id):
343
+ return F.embedding(embed_id, self.weight)
344
+
345
+ def cluster_size_ema_update(self, new_cluster_size):
346
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347
+
348
+ def embed_avg_ema_update(self, new_embed_avg):
349
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350
+
351
+ def weight_update(self, num_tokens):
352
+ n = self.cluster_size.sum()
353
+ smoothed_cluster_size = (
354
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355
+ )
356
+ #normalize embedding average with smoothed cluster size
357
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358
+ self.weight.data.copy_(embed_normalized)
359
+
360
+
361
+ class EMAVectorQuantizer(nn.Module):
362
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363
+ remap=None, unknown_index="random"):
364
+ super().__init__()
365
+ self.codebook_dim = codebook_dim
366
+ self.num_tokens = num_tokens
367
+ self.beta = beta
368
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369
+
370
+ self.remap = remap
371
+ if self.remap is not None:
372
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
373
+ self.re_embed = self.used.shape[0]
374
+ self.unknown_index = unknown_index # "random" or "extra" or integer
375
+ if self.unknown_index == "extra":
376
+ self.unknown_index = self.re_embed
377
+ self.re_embed = self.re_embed+1
378
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379
+ f"Using {self.unknown_index} for unknown indices.")
380
+ else:
381
+ self.re_embed = n_embed
382
+
383
+ def remap_to_used(self, inds):
384
+ ishape = inds.shape
385
+ assert len(ishape)>1
386
+ inds = inds.reshape(ishape[0],-1)
387
+ used = self.used.to(inds)
388
+ match = (inds[:,:,None]==used[None,None,...]).long()
389
+ new = match.argmax(-1)
390
+ unknown = match.sum(2)<1
391
+ if self.unknown_index == "random":
392
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393
+ else:
394
+ new[unknown] = self.unknown_index
395
+ return new.reshape(ishape)
396
+
397
+ def unmap_to_all(self, inds):
398
+ ishape = inds.shape
399
+ assert len(ishape)>1
400
+ inds = inds.reshape(ishape[0],-1)
401
+ used = self.used.to(inds)
402
+ if self.re_embed > self.used.shape[0]: # extra token
403
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405
+ return back.reshape(ishape)
406
+
407
+ def forward(self, z):
408
+ # reshape z -> (batch, height, width, channel) and flatten
409
+ #z, 'b c h w -> b h w c'
410
+ z = rearrange(z, 'b c h w -> b h w c')
411
+ z_flattened = z.reshape(-1, self.codebook_dim)
412
+
413
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417
+
418
+
419
+ encoding_indices = torch.argmin(d, dim=1)
420
+
421
+ z_q = self.embedding(encoding_indices).view(z.shape)
422
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423
+ avg_probs = torch.mean(encodings, dim=0)
424
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425
+
426
+ if self.training and self.embedding.update:
427
+ #EMA cluster size
428
+ encodings_sum = encodings.sum(0)
429
+ self.embedding.cluster_size_ema_update(encodings_sum)
430
+ #EMA embedding average
431
+ embed_sum = encodings.transpose(0,1) @ z_flattened
432
+ self.embedding.embed_avg_ema_update(embed_sum)
433
+ #normalize embed_avg and update weight
434
+ self.embedding.weight_update(self.num_tokens)
435
+
436
+ # compute loss for embedding
437
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
438
+
439
+ # preserve gradients
440
+ z_q = z + (z_q - z).detach()
441
+
442
+ # reshape back to match original input shape
443
+ #z_q, 'b h w c -> b c h w'
444
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
445
+ return z_q, loss, (perplexity, encodings, encoding_indices)
taming/util.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, hashlib
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ URL_MAP = {
6
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7
+ }
8
+
9
+ CKPT_MAP = {
10
+ "vgg_lpips": "vgg.pth"
11
+ }
12
+
13
+ MD5_MAP = {
14
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15
+ }
16
+
17
+
18
+ def download(url, local_path, chunk_size=1024):
19
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
+ with requests.get(url, stream=True) as r:
21
+ total_size = int(r.headers.get("content-length", 0))
22
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
+ with open(local_path, "wb") as f:
24
+ for data in r.iter_content(chunk_size=chunk_size):
25
+ if data:
26
+ f.write(data)
27
+ pbar.update(chunk_size)
28
+
29
+
30
+ def md5_hash(path):
31
+ with open(path, "rb") as f:
32
+ content = f.read()
33
+ return hashlib.md5(content).hexdigest()
34
+
35
+
36
+ def get_ckpt_path(name, root, check=False):
37
+ assert name in URL_MAP
38
+ path = os.path.join(root, CKPT_MAP[name])
39
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
+ download(URL_MAP[name], path)
42
+ md5 = md5_hash(path)
43
+ assert md5 == MD5_MAP[name], md5
44
+ return path
45
+
46
+
47
+ class KeyNotFoundError(Exception):
48
+ def __init__(self, cause, keys=None, visited=None):
49
+ self.cause = cause
50
+ self.keys = keys
51
+ self.visited = visited
52
+ messages = list()
53
+ if keys is not None:
54
+ messages.append("Key not found: {}".format(keys))
55
+ if visited is not None:
56
+ messages.append("Visited: {}".format(visited))
57
+ messages.append("Cause:\n{}".format(cause))
58
+ message = "\n".join(messages)
59
+ super().__init__(message)
60
+
61
+
62
+ def retrieve(
63
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64
+ ):
65
+ """Given a nested list or dict return the desired value at key expanding
66
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67
+ is done in-place.
68
+
69
+ Parameters
70
+ ----------
71
+ list_or_dict : list or dict
72
+ Possibly nested list or dictionary.
73
+ key : str
74
+ key/to/value, path like string describing all keys necessary to
75
+ consider to get to the desired value. List indices can also be
76
+ passed here.
77
+ splitval : str
78
+ String that defines the delimiter between keys of the
79
+ different depth levels in `key`.
80
+ default : obj
81
+ Value returned if :attr:`key` is not found.
82
+ expand : bool
83
+ Whether to expand callable nodes on the path or not.
84
+
85
+ Returns
86
+ -------
87
+ The desired value or if :attr:`default` is not ``None`` and the
88
+ :attr:`key` is not found returns ``default``.
89
+
90
+ Raises
91
+ ------
92
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93
+ ``None``.
94
+ """
95
+
96
+ keys = key.split(splitval)
97
+
98
+ success = True
99
+ try:
100
+ visited = []
101
+ parent = None
102
+ last_key = None
103
+ for key in keys:
104
+ if callable(list_or_dict):
105
+ if not expand:
106
+ raise KeyNotFoundError(
107
+ ValueError(
108
+ "Trying to get past callable node with expand=False."
109
+ ),
110
+ keys=keys,
111
+ visited=visited,
112
+ )
113
+ list_or_dict = list_or_dict()
114
+ parent[last_key] = list_or_dict
115
+
116
+ last_key = key
117
+ parent = list_or_dict
118
+
119
+ try:
120
+ if isinstance(list_or_dict, dict):
121
+ list_or_dict = list_or_dict[key]
122
+ else:
123
+ list_or_dict = list_or_dict[int(key)]
124
+ except (KeyError, IndexError, ValueError) as e:
125
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
126
+
127
+ visited += [key]
128
+ # final expansion of retrieved value
129
+ if expand and callable(list_or_dict):
130
+ list_or_dict = list_or_dict()
131
+ parent[last_key] = list_or_dict
132
+ except KeyNotFoundError as e:
133
+ if default is None:
134
+ raise e
135
+ else:
136
+ list_or_dict = default
137
+ success = False
138
+
139
+ if not pass_success:
140
+ return list_or_dict
141
+ else:
142
+ return list_or_dict, success
143
+
144
+
145
+ if __name__ == "__main__":
146
+ config = {"keya": "a",
147
+ "keyb": "b",
148
+ "keyc":
149
+ {"cc1": 1,
150
+ "cc2": 2,
151
+ }
152
+ }
153
+ from omegaconf import OmegaConf
154
+ config = OmegaConf.create(config)
155
+ print(config)
156
+ retrieve(config, "keya")
157
+