Dan Jacobellis commited on
Commit
60adc6f
·
1 Parent(s): 7db5410

reduce size

Browse files
Files changed (1) hide show
  1. README.md +458 -3
README.md CHANGED
@@ -1,3 +1,458 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:79f1ed2b4bceef1bdb4c5496813cb38c5e4050a9332bf820fb6e073fe2db4ab3
3
- size 5604740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - danjacobellis/LSDIR_540
4
+ - danjacobellis/musdb_segments
5
+ ---
6
+ # Wavelet Learned Lossy Compression
7
+
8
+ - [Project page and documentation](https://danjacobellis.net/walloc)
9
+ - [Paper: "Learned Compression for Compressed Learning"](https://danjacobellis.net/_static/walloc.pdf)
10
+ - [Additional code accompanying the paper](https://github.com/danjacobellis/lccl)
11
+
12
+ ![](https://danjacobellis.net/walloc/_images/radar.svg)
13
+ Comparison of WaLLoC with other autoencoder designs for RGB Images and stereo audio.
14
+
15
+
16
+
17
+ ![](https://danjacobellis.net/walloc/_images/wpt.svg)
18
+ Example of forward and inverse WPT with $J=2$ levels. Each level applies filters $\text{L}_{\text{A}}$ and $\text{H}_{\text{A}}$ independently to each of the signal channels, followed by downsampling by two $(\downarrow 2)$. An inverse level consists of upsampling $(\uparrow 2)$ followed by $\text{L}_{\text{S}}$
19
+ and $\text{H}_{\text{S}}$, then summing the two channels. The full WPT $\tilde{\textbf{X}}$ of consists of $J$ levels.
20
+
21
+ ![](https://danjacobellis.net/walloc/_images/walloc.svg)
22
+ WaLLoC’s encode-decode pipeline. The entropy bottleneck and entropy coding steps are only required to achieve high compression ratios for storage and transmission. For compressed-domain learning where dimensionality reduction is the primary goal, these steps can be skipped to reduce overhead and completely eliminate CPU-GPU transfers.
23
+
24
+
25
+
26
+ # Wavelet Learned Lossy Compression (WaLLoC)
27
+
28
+ WaLLoC sandwiches a convolutional autoencoder between time-frequency analysis and synthesis transforms using CDF 9/7 wavelet filters. The time-frequency transform increases the number of signal channels, but reduces the temporal or spatial resolution, resulting in lower GPU memory consumption and higher throughput. WaLLoC's training procedure is highly simplified compared to other $\beta$-VAEs, VQ-VAEs, and neural codecs, but still offers significant dimensionality reduction and compression. This makes it suitable for dataset storage and compressed-domain learning. It currently supports 1D and 2D signals, including mono, stereo, or multi-channel audio, and grayscale, RGB, or hyperspectral images.
29
+
30
+ ## Installation
31
+
32
+ 1. Follow the installation instructions for [torch](https://pytorch.org/get-started/locally/)
33
+ 2. Install WaLLoC and other dependencies via pip
34
+
35
+ ```pip install walloc PyWavelets pytorch-wavelets```
36
+
37
+ ## Image compression
38
+
39
+
40
+ ```python
41
+ import os
42
+ import torch
43
+ import json
44
+ import matplotlib.pyplot as plt
45
+ import numpy as np
46
+ from types import SimpleNamespace
47
+ from PIL import Image, ImageEnhance
48
+ from IPython.display import display
49
+ from torchvision.transforms import ToPILImage, PILToTensor
50
+ from walloc import walloc
51
+ from walloc.walloc import latent_to_pil, pil_to_latent
52
+ ```
53
+
54
+ ### Load the model from a pre-trained checkpoint
55
+
56
+ ```wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_16x.pth```
57
+
58
+ ```wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_16x.json```
59
+
60
+
61
+ ```python
62
+ device = "cpu"
63
+ codec_config = SimpleNamespace(**json.load(open("RGB_16x.json")))
64
+ checkpoint = torch.load("RGB_16x.pth",map_location="cpu",weights_only=False)
65
+ codec = walloc.Codec2D(
66
+ channels = codec_config.channels,
67
+ J = codec_config.J,
68
+ Ne = codec_config.Ne,
69
+ Nd = codec_config.Nd,
70
+ latent_dim = codec_config.latent_dim,
71
+ latent_bits = codec_config.latent_bits,
72
+ lightweight_encode = codec_config.lightweight_encode
73
+ )
74
+ codec.load_state_dict(checkpoint['model_state_dict'])
75
+ codec = codec.to(device)
76
+ codec.eval();
77
+ ```
78
+
79
+ ### Load an example image
80
+
81
+ ```wget "https://r0k.us/graphics/kodak/kodak/kodim05.png"```
82
+
83
+
84
+ ```python
85
+ img = Image.open("kodim05.png")
86
+ img
87
+ ```
88
+
89
+
90
+
91
+
92
+
93
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_11_0.png)
94
+
95
+
96
+
97
+
98
+ ### Full encoding and decoding pipeline with .forward()
99
+
100
+ * If `codec.eval()` is called, the latent is rounded to nearest integer.
101
+
102
+ * If `codec.train()` is called, uniform noise is added instead of rounding.
103
+
104
+
105
+ ```python
106
+ with torch.no_grad():
107
+ codec.eval()
108
+ x = PILToTensor()(img).to(torch.float)
109
+ x = (x/255 - 0.5).unsqueeze(0).to(device)
110
+ x_hat, _, _ = codec(x)
111
+ ToPILImage()(x_hat[0]+0.5)
112
+ ```
113
+
114
+
115
+
116
+
117
+
118
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_13_0.png)
119
+
120
+
121
+
122
+
123
+ ### Accessing latents
124
+
125
+
126
+ ```python
127
+ with torch.no_grad():
128
+ X = codec.wavelet_analysis(x,J=codec.J)
129
+ z = codec.encoder[0:2](X)
130
+ z_hat = codec.encoder[2](z)
131
+ X_hat = codec.decoder(z_hat)
132
+ x_rec = codec.wavelet_synthesis(X_hat,J=codec.J)
133
+ print(f"dimensionality reduction: {x.numel()/z.numel()}×")
134
+ ```
135
+
136
+ dimensionality reduction: 16.0×
137
+
138
+
139
+
140
+ ```python
141
+ plt.figure(figsize=(5,3),dpi=150)
142
+ plt.hist(
143
+ z.flatten().numpy(),
144
+ range=(-25,25),
145
+ bins=151,
146
+ density=True,
147
+ );
148
+ plt.title("Histogram of latents")
149
+ plt.xlim([-25,25]);
150
+ ```
151
+
152
+
153
+
154
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_16_0.png)
155
+
156
+
157
+
158
+ # Lossless compression of latents
159
+
160
+
161
+ ```python
162
+ def scale_for_display(img, n_bits):
163
+ scale_factor = (2**8 - 1) / (2**n_bits - 1)
164
+ lut = [int(i * scale_factor) for i in range(2**n_bits)]
165
+ channels = img.split()
166
+ scaled_channels = [ch.point(lut * 2**(8-n_bits)) for ch in channels]
167
+ return Image.merge(img.mode, scaled_channels)
168
+ ```
169
+
170
+ ### Single channel PNG (L)
171
+
172
+
173
+ ```python
174
+ z_padded = torch.nn.functional.pad(z_hat, (0, 0, 0, 0, 0, 4))
175
+ z_pil = latent_to_pil(z_padded,codec.latent_bits,1)
176
+ display(scale_for_display(z_pil[0], codec.latent_bits))
177
+ z_pil[0].save('latent.png')
178
+ png = [Image.open("latent.png")]
179
+ z_rec = pil_to_latent(png,16,codec.latent_bits,1)
180
+ assert(z_rec.equal(z_padded))
181
+ print("compression_ratio: ", x.numel()/os.path.getsize("latent.png"))
182
+ ```
183
+
184
+
185
+
186
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_20_0.png)
187
+
188
+
189
+
190
+ compression_ratio: 26.729991842653856
191
+
192
+
193
+ ### Three channel WebP (RGB)
194
+
195
+
196
+ ```python
197
+ z_pil = latent_to_pil(z_hat,codec.latent_bits,3)
198
+ display(scale_for_display(z_pil[0], codec.latent_bits))
199
+ z_pil[0].save('latent.webp',lossless=True)
200
+ webp = [Image.open("latent.webp")]
201
+ z_rec = pil_to_latent(webp,12,codec.latent_bits,3)
202
+ assert(z_rec.equal(z_hat))
203
+ print("compression_ratio: ", x.numel()/os.path.getsize("latent.webp"))
204
+ ```
205
+
206
+
207
+
208
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_22_0.png)
209
+
210
+
211
+
212
+ compression_ratio: 28.811254396248536
213
+
214
+
215
+ ### Four channel TIF (CMYK)
216
+
217
+
218
+ ```python
219
+ z_padded = torch.nn.functional.pad(z_hat, (0, 0, 0, 0, 0, 4))
220
+ z_pil = latent_to_pil(z_padded,codec.latent_bits,4)
221
+ display(scale_for_display(z_pil[0], codec.latent_bits))
222
+ z_pil[0].save('latent.tif',compression="tiff_adobe_deflate")
223
+ tif = [Image.open("latent.tif")]
224
+ z_rec = pil_to_latent(tif,16,codec.latent_bits,4)
225
+ assert(z_rec.equal(z_padded))
226
+ print("compression_ratio: ", x.numel()/os.path.getsize("latent.tif"))
227
+ ```
228
+
229
+
230
+
231
+ ![jpeg](README_files/README_24_0.jpg)
232
+
233
+
234
+
235
+ compression_ratio: 21.04034530731638
236
+
237
+
238
+ # Audio Compression
239
+
240
+
241
+ ```python
242
+ import io
243
+ import os
244
+ import torch
245
+ import torchaudio
246
+ import json
247
+ import matplotlib.pyplot as plt
248
+ from types import SimpleNamespace
249
+ from PIL import Image
250
+ from datasets import load_dataset
251
+ from einops import rearrange
252
+ from IPython.display import Audio
253
+ from walloc import walloc
254
+ ```
255
+
256
+ ### Load the model from a pre-trained checkpoint
257
+
258
+ ```wget https://hf.co/danjacobellis/walloc/resolve/main/stereo_5x.pth```
259
+
260
+ ```wget https://hf.co/danjacobellis/walloc/resolve/main/stereo_5x.json```
261
+
262
+
263
+ ```python
264
+ codec_config = SimpleNamespace(**json.load(open("stereo_5x.json")))
265
+ checkpoint = torch.load("stereo_5x.pth",map_location="cpu",weights_only=False)
266
+ codec = walloc.Codec1D(
267
+ channels = codec_config.channels,
268
+ J = codec_config.J,
269
+ Ne = codec_config.Ne,
270
+ Nd = codec_config.Nd,
271
+ latent_dim = codec_config.latent_dim,
272
+ latent_bits = codec_config.latent_bits,
273
+ lightweight_encode = codec_config.lightweight_encode,
274
+ post_filter = codec_config.post_filter
275
+ )
276
+ codec.load_state_dict(checkpoint['model_state_dict'])
277
+ codec.eval();
278
+ ```
279
+
280
+ /home/dan/g/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
281
+ WeightNorm.apply(module, name, dim)
282
+
283
+
284
+ ### Load example audio track
285
+
286
+
287
+ ```python
288
+ MUSDB = load_dataset("danjacobellis/musdb_segments_val",split='validation')
289
+ audio_buff = io.BytesIO(MUSDB[40]['audio_mix']['bytes'])
290
+ x, fs = torchaudio.load(audio_buff,normalize=False)
291
+ x = x.to(torch.float)
292
+ x = x - x.mean()
293
+ max_abs = x.abs().max()
294
+ x = x / (max_abs + 1e-8)
295
+ x = x/2
296
+ ```
297
+
298
+
299
+ ```python
300
+ Audio(x[:,:2**20],rate=44100)
301
+ ```
302
+
303
+ <audio controls>
304
+ <source src="README_files/README_0.wav" type="audio/wav">
305
+ </audio>
306
+
307
+ ### Full encoding and decoding pipeline with .forward()
308
+
309
+ * If `codec.eval()` is called, the latent is rounded to nearest integer.
310
+
311
+ * If `codec.train()` is called, uniform noise is added instead of rounding.
312
+
313
+
314
+ ```python
315
+ with torch.no_grad():
316
+ codec.eval()
317
+ x_hat, _, _ = codec(x.unsqueeze(0))
318
+ ```
319
+
320
+
321
+ ```python
322
+ Audio(x_hat[0,:,:2**20],rate=44100)
323
+ ```
324
+
325
+ <audio controls>
326
+ <source src="README_files/README_1.wav" type="audio/wav">
327
+ </audio>
328
+
329
+ ### Accessing latents
330
+
331
+
332
+ ```python
333
+ with torch.no_grad():
334
+ X = codec.wavelet_analysis(x.unsqueeze(0),J=codec.J)
335
+ z = codec.encoder[0:2](X)
336
+ z_hat = codec.encoder[2](z)
337
+ X_hat = codec.decoder(z_hat)
338
+ x_rec = codec.wavelet_synthesis(X_hat,J=codec.J)
339
+ print(f"dimensionality reduction: {x.numel()/z.numel():.4g}×")
340
+ ```
341
+
342
+ dimensionality reduction: 4.74×
343
+
344
+
345
+
346
+ ```python
347
+ plt.figure(figsize=(5,3),dpi=150)
348
+ plt.hist(
349
+ z.flatten().numpy(),
350
+ range=(-25,25),
351
+ bins=151,
352
+ density=True,
353
+ );
354
+ plt.title("Histogram of latents")
355
+ plt.xlim([-25,25]);
356
+ ```
357
+
358
+
359
+
360
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_39_0.png)
361
+
362
+
363
+
364
+ # Lossless compression of latents
365
+
366
+
367
+ ```python
368
+ def pad(audio, p=2**16):
369
+ B,C,L = audio.shape
370
+ padding_size = (p - (L % p)) % p
371
+ if padding_size > 0:
372
+ audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
373
+ return audio
374
+ with torch.no_grad():
375
+ L = x.shape[-1]
376
+ x_padded = pad(x.unsqueeze(0), 2**16)
377
+ X = codec.wavelet_analysis(x_padded,codec.J)
378
+ z = codec.encoder(X)
379
+ ℓ = z.shape[-1]
380
+ z = pad(z,128)
381
+ z = rearrange(z, 'b c (w h) -> b c w h', h=128).to("cpu")
382
+ webp = walloc.latent_to_pil(z,codec.latent_bits,3)[0]
383
+ buff = io.BytesIO()
384
+ webp.save(buff, format='WEBP', lossless=True)
385
+ webp_bytes = buff.getbuffer()
386
+ ```
387
+
388
+
389
+ ```python
390
+ print("compression_ratio: ", x.numel()/len(webp_bytes))
391
+ webp
392
+ ```
393
+
394
+ compression_ratio: 9.83650170496386
395
+
396
+
397
+
398
+
399
+
400
+
401
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_42_1.png)
402
+
403
+
404
+
405
+
406
+ # Decoding
407
+
408
+
409
+ ```python
410
+ with torch.no_grad():
411
+ z_hat = walloc.pil_to_latent(
412
+ [Image.open(buff)],
413
+ codec.latent_dim,
414
+ codec.latent_bits,
415
+ 3)
416
+ X_hat = codec.decoder(rearrange(z_hat, 'b c h w -> b c (h w)')[:,:,:ℓ])
417
+ x_hat = codec.wavelet_synthesis(X_hat,codec.J)
418
+ x_hat = codec.post(x_hat)
419
+ x_hat = codec.clamp(x_hat[0,:,:L])
420
+ ```
421
+
422
+
423
+ ```python
424
+ start, end = 0, 1000
425
+ plt.figure(figsize=(8, 3), dpi=180)
426
+ plt.plot(x[0, start:end], alpha=0.5, c='b', label='Ch.1 (Uncompressed)')
427
+ plt.plot(x_hat[0, start:end], alpha=0.5, c='g', label='Ch.1 (WaLLoC)')
428
+ plt.plot(x[1, start:end], alpha=0.5, c='r', label='Ch.2 (Uncompressed)')
429
+ plt.plot(x_hat[1, start:end], alpha=0.5, c='purple', label='Ch.2 (WaLLoC)')
430
+
431
+ plt.xlim([400,1000])
432
+ plt.ylim([-0.6,0.3])
433
+ plt.legend(loc='lower center')
434
+ plt.box(False)
435
+ plt.xticks([])
436
+ plt.yticks([]);
437
+ ```
438
+
439
+
440
+
441
+ ![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/README_45_0.png)
442
+
443
+
444
+
445
+
446
+ ```python
447
+ !jupyter nbconvert --to markdown README.ipynb
448
+ ```
449
+
450
+ [NbConvertApp] Converting notebook README.ipynb to markdown
451
+ [NbConvertApp] Support files will be in README_files/
452
+ [NbConvertApp] Writing 1409744 bytes to README.md
453
+
454
+
455
+
456
+ ```python
457
+ !sed -i 's|!\[png](README_files/\(README_[0-9]*_[0-9]*\.png\))|![png](https://huggingface.co/danjacobellis/walloc/resolve/main/README_files/\1)|g' README.md
458
+ ```