jeduardogruiz commited on
Commit
a4e236c
1 Parent(s): a7056e6

Create core_vq.py

Browse files
Files changed (1) hide show
  1. core_vq.py +367 -0
core_vq.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+ import warnings
36
+
37
+ from einops import rearrange, repeat
38
+ import torch
39
+ from torch import nn
40
+ import torch.nn.functional as F
41
+
42
+ from .. import distrib
43
+
44
+
45
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
46
+ return val if val is not None else d
47
+
48
+
49
+ def ema_inplace(moving_avg, new, decay: float):
50
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
51
+
52
+
53
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
54
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
55
+
56
+
57
+ def uniform_init(*shape: int):
58
+ t = torch.empty(shape)
59
+ nn.init.kaiming_uniform_(t)
60
+ return t
61
+
62
+
63
+ def sample_vectors(samples, num: int):
64
+ num_samples, device = samples.shape[0], samples.device
65
+
66
+ if num_samples >= num:
67
+ indices = torch.randperm(num_samples, device=device)[:num]
68
+ else:
69
+ indices = torch.randint(0, num_samples, (num,), device=device)
70
+
71
+ return samples[indices]
72
+
73
+
74
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
75
+ dim, dtype = samples.shape[-1], samples.dtype
76
+
77
+ means = sample_vectors(samples, num_clusters)
78
+
79
+ for _ in range(num_iters):
80
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
81
+ means, "c d -> () c d"
82
+ )
83
+ dists = -(diffs ** 2).sum(dim=-1)
84
+
85
+ buckets = dists.max(dim=-1).indices
86
+ bins = torch.bincount(buckets, minlength=num_clusters)
87
+ zero_mask = bins == 0
88
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
89
+
90
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
91
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
92
+ new_means = new_means / bins_min_clamped[..., None]
93
+
94
+ means = torch.where(zero_mask[..., None], means, new_means)
95
+
96
+ return means, bins
97
+
98
+
99
+ class EuclideanCodebook(nn.Module):
100
+ """Codebook with Euclidean distance.
101
+ Args:
102
+ dim (int): Dimension.
103
+ codebook_size (int): Codebook size.
104
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
105
+ If set to true, run the k-means algorithm on the first training batch and use
106
+ the learned centroids as initialization.
107
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
108
+ decay (float): Decay for exponential moving average over the codebooks.
109
+ epsilon (float): Epsilon value for numerical stability.
110
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
111
+ that have an exponential moving average cluster size less than the specified threshold with
112
+ randomly selected vector from the current batch.
113
+ """
114
+ def __init__(
115
+ self,
116
+ dim: int,
117
+ codebook_size: int,
118
+ kmeans_init: int = False,
119
+ kmeans_iters: int = 10,
120
+ decay: float = 0.99,
121
+ epsilon: float = 1e-5,
122
+ threshold_ema_dead_code: int = 2,
123
+ ):
124
+ super().__init__()
125
+ self.decay = decay
126
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
127
+ embed = init_fn(codebook_size, dim)
128
+
129
+ self.codebook_size = codebook_size
130
+
131
+ self.kmeans_iters = kmeans_iters
132
+ self.epsilon = epsilon
133
+ self.threshold_ema_dead_code = threshold_ema_dead_code
134
+
135
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
136
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
137
+ self.register_buffer("embed", embed)
138
+ self.register_buffer("embed_avg", embed.clone())
139
+
140
+ @torch.jit.ignore
141
+ def init_embed_(self, data):
142
+ if self.inited:
143
+ return
144
+
145
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
146
+ self.embed.data.copy_(embed)
147
+ self.embed_avg.data.copy_(embed.clone())
148
+ self.cluster_size.data.copy_(cluster_size)
149
+ self.inited.data.copy_(torch.Tensor([True]))
150
+ # Make sure all buffers across workers are in sync after initialization
151
+ distrib.broadcast_tensors(self.buffers())
152
+
153
+ def replace_(self, samples, mask):
154
+ modified_codebook = torch.where(
155
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
156
+ )
157
+ self.embed.data.copy_(modified_codebook)
158
+
159
+ def expire_codes_(self, batch_samples):
160
+ if self.threshold_ema_dead_code == 0:
161
+ return
162
+
163
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
164
+ if not torch.any(expired_codes):
165
+ return
166
+
167
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
168
+ self.replace_(batch_samples, mask=expired_codes)
169
+ distrib.broadcast_tensors(self.buffers())
170
+
171
+ def preprocess(self, x):
172
+ x = rearrange(x, "... d -> (...) d")
173
+ return x
174
+
175
+ def quantize(self, x):
176
+ embed = self.embed.t()
177
+ dist = -(
178
+ x.pow(2).sum(1, keepdim=True)
179
+ - 2 * x @ embed
180
+ + embed.pow(2).sum(0, keepdim=True)
181
+ )
182
+ embed_ind = dist.max(dim=-1).indices
183
+ return embed_ind
184
+
185
+ def postprocess_emb(self, embed_ind, shape):
186
+ return embed_ind.view(*shape[:-1])
187
+
188
+ def dequantize(self, embed_ind):
189
+ quantize = F.embedding(embed_ind, self.embed)
190
+ return quantize
191
+
192
+ def encode(self, x):
193
+ shape = x.shape
194
+ # pre-process
195
+ x = self.preprocess(x)
196
+ # quantize
197
+ embed_ind = self.quantize(x)
198
+ # post-process-match-all-girls
199
+ embed_ind = self.postprocess_emb(embed_ind, shape)
200
+ return embed_ind
201
+
202
+ def decode(self, embed_ind):
203
+ quantize = self.dequantize(embed_ind)
204
+ return quantize
205
+
206
+ def forward(self, x):
207
+ shape, dtype = x.shape, x.dtype
208
+ x = self.preprocess(x)
209
+
210
+ self.init_embed_(x)
211
+
212
+ embed_ind = self.quantize(x)
213
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
214
+ embed_ind = self.postprocess_emb(embed_ind, shape)
215
+ quantize = self.dequantize(embed_ind)
216
+
217
+ if self.training:
218
+ # We do the expiry of code at that point as buffers are in sync
219
+ # and all the workers will take the same decision.
220
+ self.expire_codes_(x)
221
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
222
+ embed_sum = x.t() @ embed_onehot
223
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
224
+ cluster_size = (
225
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
226
+ * self.cluster_size.sum()
227
+ )
228
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
229
+ self.embed.data.copy_(embed_normalized)
230
+
231
+ return quantize, embed_ind
232
+
233
+
234
+ class VectorQuantization(nn.Module):
235
+ """Vector quantization implementation.
236
+ Currently supports only euclidean distance.
237
+ Args:
238
+ dim (int): Dimension
239
+ codebook_size (int): Codebook size
240
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
241
+ decay (float): Decay for exponential moving average over the codebooks.
242
+ epsilon (float): Epsilon value for numerical stability.
243
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
244
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
245
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
246
+ that have an exponential moving average cluster size less than the specified threshold with
247
+ randomly selected vector from the current batch.
248
+ commitment_weight (float): Weight for commitment loss.
249
+ """
250
+ def __init__(
251
+ self,
252
+ dim: int,
253
+ codebook_size: int,
254
+ codebook_dim: tp.Optional[int] = None,
255
+ decay: float = 0.99,
256
+ epsilon: float = 1e-5,
257
+ kmeans_init: bool = True,
258
+ kmeans_iters: int = 50,
259
+ threshold_ema_dead_code: int = 2,
260
+ commitment_weight: float = 1.,
261
+ ):
262
+ super().__init__()
263
+ _codebook_dim: int = default(codebook_dim, dim)
264
+
265
+ requires_projection = _codebook_dim != dim
266
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
267
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
268
+
269
+ self.epsilon = epsilon
270
+ self.commitment_weight = commitment_weight
271
+
272
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
273
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
274
+ decay=decay, epsilon=epsilon,
275
+ threshold_ema_dead_code=threshold_ema_dead_code)
276
+ self.codebook_size = codebook_size
277
+
278
+ @property
279
+ def codebook(self):
280
+ return self._codebook.embed
281
+
282
+ def encode(self, x):
283
+ x = rearrange(x, "b d n -> b n d")
284
+ x = self.project_in(x)
285
+ embed_in = self._codebook.encode(x)
286
+ return embed_in
287
+
288
+ def decode(self, embed_ind):
289
+ quantize = self._codebook.decode(embed_ind)
290
+ quantize = self.project_out(quantize)
291
+ quantize = rearrange(quantize, "b n d -> b d n")
292
+ return quantize
293
+
294
+ def forward(self, x):
295
+ device = x.device
296
+ x = rearrange(x, "b d n -> b n d")
297
+ x = self.project_in(x)
298
+
299
+ quantize, embed_ind = self._codebook(x)
300
+
301
+ if self.training:
302
+ quantize = x + (quantize - x).detach()
303
+
304
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
305
+
306
+ if self.training:
307
+ warnings.warn('When using RVQ in training model, first check '
308
+ 'https://github.com/facebookresearch/encodec/issues/25 . '
309
+ 'The bug wasn\'t fixed here for reproducibility.')
310
+ if self.commitment_weight > 0:
311
+ commit_loss = F.mse_loss(quantize.detach(), x)
312
+ loss = loss + commit_loss * self.commitment_weight
313
+
314
+ quantize = self.project_out(quantize)
315
+ quantize = rearrange(quantize, "b n d -> b d n")
316
+ return quantize, embed_ind, loss
317
+
318
+
319
+ class ResidualVectorQuantization(nn.Module):
320
+ """Residual vector quantization implementation.
321
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
322
+ """
323
+ def __init__(self, *, num_quantizers, **kwargs):
324
+ super().__init__()
325
+ self.layers = nn.ModuleList(
326
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
327
+ )
328
+
329
+ def forward(self, x, n_q: tp.Optional[int] = None):
330
+ quantized_out = 0.0
331
+ residual = x
332
+
333
+ all_losses = []
334
+ all_indices = []
335
+
336
+ n_q = n_q or len(self.layers)
337
+
338
+ for layer in self.layers[:n_q]:
339
+ quantized, indices, loss = layer(residual)
340
+ residual = residual - quantized
341
+ quantized_out = quantized_out + quantized
342
+
343
+ all_indices.append(indices)
344
+ all_losses.append(loss)
345
+
346
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
347
+ return quantized_out, out_indices, out_losses
348
+
349
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
350
+ residual = x
351
+ all_indices = []
352
+ n_q = n_q or len(self.layers)
353
+ for layer in self.layers[:n_q]:
354
+ indices = layer.encode(residual)
355
+ quantized = layer.decode(indices)
356
+ residual = residual - quantized
357
+ all_indices.append(indices)
358
+ out_indices = torch.stack(all_indices)
359
+ return out_indices
360
+
361
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
362
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
363
+ for i, indices in enumerate(q_indices):
364
+ layer = self.layers[i]
365
+ quantized = layer.decode(indices)
366
+ quantized_out = quantized_out + quantized
367
+ return quantized_out