Haoming02 commited on
Commit
f97b5f4
1 Parent(s): 1fae714
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -35
  2. .gitignore +1 -0
  3. LICENSE +72 -0
  4. README.md +5 -13
  5. SUPIR/__init__.py +0 -0
  6. SUPIR/models/SUPIR_model.py +176 -0
  7. SUPIR/models/__init__.py +0 -0
  8. SUPIR/modules/SUPIR_v0.py +718 -0
  9. SUPIR/modules/__init__.py +11 -0
  10. SUPIR/util.py +190 -0
  11. SUPIR/utils/__init__.py +0 -0
  12. SUPIR/utils/colorfix.py +120 -0
  13. SUPIR/utils/devices.py +138 -0
  14. SUPIR/utils/face_restoration_helper.py +514 -0
  15. SUPIR/utils/file.py +79 -0
  16. SUPIR/utils/tilevae.py +971 -0
  17. app.py +911 -0
  18. configs/clip_vit_config.json +25 -0
  19. configs/tokenizer/config.json +171 -0
  20. configs/tokenizer/merges.txt +0 -0
  21. configs/tokenizer/preprocessor_config.json +19 -0
  22. configs/tokenizer/special_tokens_map.json +24 -0
  23. configs/tokenizer/tokenizer.json +0 -0
  24. configs/tokenizer/tokenizer_config.json +34 -0
  25. configs/tokenizer/vocab.json +0 -0
  26. options/SUPIR_v0.yaml +140 -0
  27. sgm/__init__.py +4 -0
  28. sgm/lr_scheduler.py +135 -0
  29. sgm/models/__init__.py +2 -0
  30. sgm/models/autoencoder.py +335 -0
  31. sgm/models/diffusion.py +320 -0
  32. sgm/modules/__init__.py +8 -0
  33. sgm/modules/attention.py +637 -0
  34. sgm/modules/autoencoding/__init__.py +0 -0
  35. sgm/modules/autoencoding/losses/__init__.py +246 -0
  36. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  37. sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
  38. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  39. sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
  40. sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
  41. sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
  42. sgm/modules/autoencoding/lpips/model/model.py +88 -0
  43. sgm/modules/autoencoding/lpips/util.py +128 -0
  44. sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
  45. sgm/modules/autoencoding/regularizers/__init__.py +53 -0
  46. sgm/modules/diffusionmodules/__init__.py +7 -0
  47. sgm/modules/diffusionmodules/denoiser.py +73 -0
  48. sgm/modules/diffusionmodules/denoiser_scaling.py +31 -0
  49. sgm/modules/diffusionmodules/denoiser_weighting.py +24 -0
  50. sgm/modules/diffusionmodules/discretizer.py +69 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
LICENSE ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SUPIR Software License Agreement
2
+
3
+ Copyright (c) 2024 by SupPixel Pty Ltd. All rights reserved.
4
+
5
+ This work is jointly owned by SupPixel Pty Ltd, an Australian company (ACN 676 560 320). All rights reserved. The work is created by the author team of SUPIR at SupPixel Pty Ltd.
6
+
7
+ This License Agreement ("Agreement") is made and entered into by and between SupPixel Pty Ltd (collectively, the "Licensor") and any person or entity who accesses, uses, or distributes the SUPIR software (the "Licensee").
8
+
9
+ 1. Definitions
10
+
11
+ a. Licensed Software: Refers to the SUPIR software, including its source code, associated documentation, and any updates, upgrades, or modifications thereof. The Licensed Software includes both open-source components and proprietary components, as described herein.
12
+
13
+ b. Neural Network Components: Includes the weights, biases, and architecture of the proprietary neural network(s) developed and trained by the Licensor. This excludes any pre-existing open-source neural networks used in conjunction with the Licensor's proprietary networks, for which the Licensor does not hold copyright.
14
+
15
+ c. Commercial Use: Refers to any use of the Licensed Software or its components, in whole or in part, for the purpose of generating revenue, conducting business operations, or creating products or services for sale or profit. This includes, but is not limited to:
16
+
17
+ - Deploying the software on a server to provide a service to users for a fee.
18
+ - Using the software to process images that are intended for sale.
19
+ - Integrating the software or its components into a commercial product that is sold to end-users.
20
+ - Using the software to provide commercial consulting or contracting services, such as image analysis, enhancement, or modification for clients.
21
+ - Using the software to process images that are subsequently used in training datasets for machine learning models or other types of data analysis. This includes datasets that are used internally, sold, or provided to third parties under a commercial agreement.
22
+ - Licensing the software or its components to other businesses for commercial purposes.
23
+
24
+ 2. License Grant
25
+
26
+ a. Open-Source Components: The Licensor grants the Licensee a non-exclusive, worldwide, royalty-free license to use, copy, modify, and distribute the source code of the open-source components of the Licensed Software, subject to the terms and conditions of their respective open-source licenses. The specific open-source licenses for each component are indicated in the source code files and documentation of the Licensed Software.
27
+
28
+ b. Proprietary Neural Network Components: Notwithstanding the open-source license granted in Section 2(a) for open-source components, the Licensor retains all rights, title, and interest in and to the proprietary Neural Network Components developed and trained by the Licensor. The license granted herein does not include rights to use, copy, modify, distribute, or create derivative works of these proprietary Neural Network Components without obtaining express written permission from the Licensor.
29
+
30
+ c. The source code of the Licensed Software is available as open-source on this GitHub repository. However, the open-source license does not grant any rights to the weights, biases, or architecture of the proprietary neural network(s) used in the Licensed Software, which remain the exclusive property of the Licensor.
31
+
32
+ 3. Commercial Use and Licensing
33
+
34
+ a. The Licensor retains all ownership and commercial usage rights in the Licensed Software, including its neural network(s), regardless of the open-source availability of portions of the Licensed Software's source code. Under the open-source license, anyone may freely use, copy, modify, and distribute the source code of the Licensed Software for non-commercial purposes only. Similarly, the proprietary neural network weights, biases, and architecture may only be used for non-commercial purposes. Any commercial use, copying, modification, distribution, or creation of derivative works of the Licensed Software's source code or the proprietary neural network weights, biases, and architecture is strictly prohibited without express written permission from the Licensor.
35
+
36
+ b. Any Commercial Use of the Licensed Software, including its source code and neural network(s), requires a separate commercial license agreement with the Licensor. The terms and conditions of such commercial licensing, including the scope of use, fees, and delivery of the relevant components, shall be negotiated and agreed upon between the Licensor and the Licensee in writing. For commercial licensing inquiries, please contact the Licensor at [email protected].
37
+
38
+ 4. Intellectual Property Rights
39
+
40
+ a. The Licensor retains all intellectual property rights in the Licensed Software, including copyrights, patents, trademarks, trade secrets, and any other proprietary rights. The Licensee acknowledges that the Licensor is the sole owner of the Licensed Software and its components, including the proprietary Neural Network Components.
41
+
42
+ b. The Licensee agrees not to challenge or contest the Licensor's ownership or intellectual property rights in the Licensed Software, nor to assist or encourage any third party in doing so.
43
+
44
+ c. The Licensee shall promptly notify the Licensor of any actual or suspected infringement of the Licensor's intellectual property rights in the Licensed Software by third parties, and shall provide reasonable assistance to the Licensor in enforcing its rights against such infringement.
45
+
46
+ 5. Warranty Disclaimer
47
+
48
+ The Licensed Software is provided "as is", without warranty of any kind, either express or implied, including, but not limited to, the implied warranties of merchantability, fitness for a particular purpose, and non-infringement. The Licensor disclaims all warranties, express or implied, regarding the accuracy, reliability, completeness, or usefulness of the Licensed Software or its components. The Licensor does not warrant that the Licensed Software will be error-free, uninterrupted, or free from defects or vulnerabilities.
49
+
50
+ 6. Limitation of Liability
51
+
52
+ In no event shall the Licensor be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the Licensed Software or the use or other dealings in the Licensed Software. The Licensor shall not be liable for any direct, indirect, incidental, special, consequential, or exemplary damages, including but not limited to damages for loss of profits, goodwill, use, data, or other intangible losses, even if the Licensor has been advised of the possibility of such damages.
53
+
54
+ 7. Indemnification
55
+
56
+ The Licensee agrees to indemnify, defend, and hold harmless the Licensor and its affiliates, officers, directors, employees, and agents from and against any and all claims, liabilities, damages, losses, costs, expenses, or fees (including reasonable attorneys' fees) arising from or relating to the Licensee's use of the Licensed Software or any violation of this Agreement by the Licensee.
57
+
58
+ 8. Governing Law and Jurisdiction
59
+
60
+ This Agreement shall be governed by and construed in accordance with the laws of the State of New South Wales, Australia, without giving effect to any choice or conflict of law provision or rule. Any legal action or proceeding arising out of or relating to this Agreement shall be brought exclusively in the courts of New South Wales, Australia, and each party irrevocably submits to the jurisdiction of such courts.
61
+
62
+ 9. Entire Agreement
63
+
64
+ This Agreement constitutes the entire agreement between the parties concerning the subject matter hereof and supersedes all prior or contemporaneous communications, understandings, and agreements, whether written or oral, between the parties regarding such subject matter.
65
+
66
+ By accessing, using, or distributing the Licensed Software, the Licensee agrees to be bound by the terms and conditions of this Agreement. If the Licensee does not agree to the terms of this Agreement, the Licensee must not access, use, or distribute the Licensed Software.
67
+
68
+ For inquiries or to obtain permission for Commercial Use, please contact:
69
+
70
+ Dr. Jinjin Gu
71
+ SupPixel Pty Ltd
72
README.md CHANGED
@@ -1,13 +1,5 @@
1
- ---
2
- title: SUPIR Forge
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # SUPIR Backend
2
+
3
+ The backend code for [forge-space-SUPIR](https://github.com/Haoming02/forge-space-SUPIR), that gets synced to [Haoming02/SUPIR-Forge](https://huggingface.co/spaces/Haoming02/SUPIR-Forge)
4
+
5
+ > There is no point in downloading this, unless you are making a PR
 
 
 
 
 
 
 
 
SUPIR/__init__.py ADDED
File without changes
SUPIR/models/SUPIR_model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sgm.models.diffusion import DiffusionEngine
3
+ from sgm.util import instantiate_from_config
4
+ import copy
5
+ from sgm.modules.distributions.distributions import DiagonalGaussianDistribution
6
+ import random
7
+ from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
8
+ from pytorch_lightning import seed_everything
9
+ from torch.nn.functional import interpolate
10
+ from SUPIR.utils.tilevae import VAEHook
11
+
12
+ class SUPIRModel(DiffusionEngine):
13
+ def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ control_model = instantiate_from_config(control_stage_config)
16
+ self.model.load_control_model(control_model)
17
+ self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
18
+ self.sampler_config = kwargs['sampler_config']
19
+
20
+ assert (ae_dtype in ['fp32', 'fp16', 'bf16']) and (diffusion_dtype in ['fp32', 'fp16', 'bf16'])
21
+ if ae_dtype == 'fp32':
22
+ ae_dtype = torch.float32
23
+ elif ae_dtype == 'fp16':
24
+ raise RuntimeError('fp16 cause NaN in AE')
25
+ elif ae_dtype == 'bf16':
26
+ ae_dtype = torch.bfloat16
27
+
28
+ if diffusion_dtype == 'fp32':
29
+ diffusion_dtype = torch.float32
30
+ elif diffusion_dtype == 'fp16':
31
+ diffusion_dtype = torch.float16
32
+ elif diffusion_dtype == 'bf16':
33
+ diffusion_dtype = torch.bfloat16
34
+
35
+ self.ae_dtype = ae_dtype
36
+ self.model.dtype = diffusion_dtype
37
+
38
+ self.p_p = p_p
39
+ self.n_p = n_p
40
+
41
+ @torch.no_grad()
42
+ def encode_first_stage(self, x):
43
+ with torch.autocast("cuda", dtype=self.ae_dtype):
44
+ z = self.first_stage_model.encode(x)
45
+ z = self.scale_factor * z
46
+ return z
47
+
48
+ @torch.no_grad()
49
+ def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
50
+ with torch.autocast("cuda", dtype=self.ae_dtype):
51
+ h = self.first_stage_model.denoise_encoder(x)
52
+ moments = self.first_stage_model.quant_conv(h)
53
+ posterior = DiagonalGaussianDistribution(moments)
54
+ if use_sample:
55
+ z = posterior.sample()
56
+ else:
57
+ z = posterior.mode()
58
+ z = self.scale_factor * z
59
+ return z
60
+
61
+ @torch.no_grad()
62
+ def decode_first_stage(self, z):
63
+ z = 1.0 / self.scale_factor * z
64
+ with torch.autocast("cuda", dtype=self.ae_dtype):
65
+ out = self.first_stage_model.decode(z)
66
+ return out.float()
67
+
68
+ @torch.no_grad()
69
+ def batchify_denoise(self, x, is_stage1=False):
70
+ '''
71
+ [N, C, H, W], [-1, 1], RGB
72
+ '''
73
+ x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
74
+ return self.decode_first_stage(x)
75
+
76
+ @torch.no_grad()
77
+ def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, restoration_scale=4.0, s_churn=0, s_noise=1.003, cfg_scale=4.0, seed=-1,
78
+ num_samples=1, control_scale=1, color_fix_type='None', use_linear_CFG=False, use_linear_control_scale=False,
79
+ cfg_scale_start=1.0, control_scale_start=0.0, **kwargs):
80
+ '''
81
+ [N, C], [-1, 1], RGB
82
+ '''
83
+ assert len(x) == len(p)
84
+ assert color_fix_type in ['Wavelet', 'AdaIn', 'None']
85
+
86
+ N = len(x)
87
+ if num_samples > 1:
88
+ assert N == 1
89
+ N = num_samples
90
+ x = x.repeat(N, 1, 1, 1)
91
+ p = p * N
92
+
93
+ if p_p == 'default':
94
+ p_p = self.p_p
95
+ if n_p == 'default':
96
+ n_p = self.n_p
97
+
98
+ self.sampler_config.params.num_steps = num_steps
99
+ if use_linear_CFG:
100
+ self.sampler_config.params.guider_config.params.scale_min = cfg_scale
101
+ self.sampler_config.params.guider_config.params.scale = cfg_scale_start
102
+ else:
103
+ self.sampler_config.params.guider_config.params.scale_min = cfg_scale
104
+ self.sampler_config.params.guider_config.params.scale = cfg_scale
105
+ self.sampler_config.params.restore_cfg = restoration_scale
106
+ self.sampler_config.params.s_churn = s_churn
107
+ self.sampler_config.params.s_noise = s_noise
108
+ self.sampler = instantiate_from_config(self.sampler_config)
109
+
110
+ if seed == -1:
111
+ seed = random.randint(0, 65535)
112
+ seed_everything(seed)
113
+
114
+ _z = self.encode_first_stage_with_denoise(x, use_sample=False)
115
+ x_stage1 = self.decode_first_stage(_z)
116
+ z_stage1 = self.encode_first_stage(x_stage1)
117
+
118
+ c, uc = self.prepare_condition(_z, p, p_p, n_p, N)
119
+
120
+ denoiser = lambda input, sigma, c, control_scale: self.denoiser(
121
+ self.model, input, sigma, c, control_scale, **kwargs
122
+ )
123
+
124
+ noised_z = torch.randn_like(_z).to(_z.device)
125
+
126
+ _samples = self.sampler(denoiser, noised_z, cond=c, uc=uc, x_center=z_stage1, control_scale=control_scale,
127
+ use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start)
128
+ samples = self.decode_first_stage(_samples)
129
+ if color_fix_type == 'Wavelet':
130
+ samples = wavelet_reconstruction(samples, x_stage1)
131
+ elif color_fix_type == 'AdaIn':
132
+ samples = adaptive_instance_normalization(samples, x_stage1)
133
+ return samples
134
+
135
+ def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
136
+ self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
137
+ self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
138
+ self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
139
+ self.first_stage_model.denoise_encoder.forward = VAEHook(
140
+ self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
141
+ fast_encoder=False, color_fix=False, to_gpu=True)
142
+ self.first_stage_model.encoder.forward = VAEHook(
143
+ self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
144
+ fast_encoder=False, color_fix=False, to_gpu=True)
145
+ self.first_stage_model.decoder.forward = VAEHook(
146
+ self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
147
+ fast_encoder=False, color_fix=False, to_gpu=True)
148
+
149
+ def prepare_condition(self, _z, p, p_p, n_p, N):
150
+ batch = {}
151
+ batch['original_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
152
+ batch['crop_coords_top_left'] = torch.tensor([0, 0]).repeat(N, 1).to(_z.device)
153
+ batch['target_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
154
+ batch['aesthetic_score'] = torch.tensor([9.0]).repeat(N, 1).to(_z.device)
155
+ batch['control'] = _z
156
+
157
+ batch_uc = copy.deepcopy(batch)
158
+ batch_uc['txt'] = [n_p for _ in p]
159
+
160
+ if not isinstance(p[0], list):
161
+ batch['txt'] = [''.join([_p, p_p]) for _p in p]
162
+ with torch.cuda.amp.autocast(dtype=self.ae_dtype):
163
+ c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
164
+ else:
165
+ assert len(p) == 1, 'Support bs=1 only for local prompt conditioning.'
166
+ p_tiles = p[0]
167
+ c = []
168
+ for i, p_tile in enumerate(p_tiles):
169
+ batch['txt'] = [''.join([p_tile, p_p])]
170
+ with torch.cuda.amp.autocast(dtype=self.ae_dtype):
171
+ if i == 0:
172
+ _c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
173
+ else:
174
+ _c, _ = self.conditioner.get_unconditional_conditioning(batch, None)
175
+ c.append(_c)
176
+ return c, uc
SUPIR/models/__init__.py ADDED
File without changes
SUPIR/modules/SUPIR_v0.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from einops._torch_specific import allow_ops_in_compiled_graph
2
+ # allow_ops_in_compiled_graph()
3
+ import einops
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+ from einops import rearrange, repeat
8
+
9
+ from sgm.modules.diffusionmodules.util import (
10
+ avg_pool_nd,
11
+ checkpoint,
12
+ conv_nd,
13
+ linear,
14
+ normalization,
15
+ timestep_embedding,
16
+ zero_module,
17
+ )
18
+
19
+ from sgm.modules.diffusionmodules.openaimodel import Downsample, Upsample, UNetModel, Timestep, \
20
+ TimestepEmbedSequential, ResBlock, AttentionBlock, TimestepBlock
21
+ from sgm.modules.attention import SpatialTransformer, MemoryEfficientCrossAttention, CrossAttention
22
+ from sgm.util import default, log_txt_as_img, exists, instantiate_from_config
23
+ import re
24
+ import torch
25
+ from functools import partial
26
+
27
+
28
+ try:
29
+ import xformers
30
+ import xformers.ops
31
+ XFORMERS_IS_AVAILBLE = True
32
+ except:
33
+ XFORMERS_IS_AVAILBLE = False
34
+
35
+
36
+ # dummy replace
37
+ def convert_module_to_f16(x):
38
+ pass
39
+
40
+
41
+ def convert_module_to_f32(x):
42
+ pass
43
+
44
+
45
+ class ZeroConv(nn.Module):
46
+ def __init__(self, label_nc, norm_nc, mask=False):
47
+ super().__init__()
48
+ self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
49
+ self.mask = mask
50
+
51
+ def forward(self, c, h, h_ori=None):
52
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
53
+ if not self.mask:
54
+ h = h + self.zero_conv(c)
55
+ else:
56
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
57
+ if h_ori is not None:
58
+ h = th.cat([h_ori, h], dim=1)
59
+ return h
60
+
61
+
62
+ class ZeroSFT(nn.Module):
63
+ def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
64
+ super().__init__()
65
+
66
+ # param_free_norm_type = str(parsed.group(1))
67
+ ks = 3
68
+ pw = ks // 2
69
+
70
+ self.norm = norm
71
+ if self.norm:
72
+ self.param_free_norm = normalization(norm_nc + concat_channels)
73
+ else:
74
+ self.param_free_norm = nn.Identity()
75
+
76
+ nhidden = 128
77
+
78
+ self.mlp_shared = nn.Sequential(
79
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
80
+ nn.SiLU()
81
+ )
82
+ self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
83
+ self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
84
+ # self.zero_mul = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
85
+ # self.zero_add = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
86
+
87
+ self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
88
+ self.pre_concat = bool(concat_channels != 0)
89
+ self.mask = mask
90
+
91
+ def forward(self, c, h, h_ori=None, control_scale=1):
92
+ assert self.mask is False
93
+ if h_ori is not None and self.pre_concat:
94
+ h_raw = th.cat([h_ori, h], dim=1)
95
+ else:
96
+ h_raw = h
97
+
98
+ if self.mask:
99
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
100
+ else:
101
+ h = h + self.zero_conv(c)
102
+ if h_ori is not None and self.pre_concat:
103
+ h = th.cat([h_ori, h], dim=1)
104
+ actv = self.mlp_shared(c)
105
+ gamma = self.zero_mul(actv)
106
+ beta = self.zero_add(actv)
107
+ if self.mask:
108
+ gamma = gamma * torch.zeros_like(gamma)
109
+ beta = beta * torch.zeros_like(beta)
110
+ h = self.param_free_norm(h) * (gamma + 1) + beta
111
+ if h_ori is not None and not self.pre_concat:
112
+ h = th.cat([h_ori, h], dim=1)
113
+ return h * control_scale + h_raw * (1 - control_scale)
114
+
115
+
116
+ class ZeroCrossAttn(nn.Module):
117
+ ATTENTION_MODES = {
118
+ "softmax": CrossAttention, # vanilla attention
119
+ "softmax-xformers": MemoryEfficientCrossAttention
120
+ }
121
+
122
+ def __init__(self, context_dim, query_dim, zero_out=True, mask=False):
123
+ super().__init__()
124
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
125
+ assert attn_mode in self.ATTENTION_MODES
126
+ attn_cls = self.ATTENTION_MODES[attn_mode]
127
+ self.attn = attn_cls(query_dim=query_dim, context_dim=context_dim, heads=query_dim//64, dim_head=64)
128
+ self.norm1 = normalization(query_dim)
129
+ self.norm2 = normalization(context_dim)
130
+
131
+ self.mask = mask
132
+
133
+ # if zero_out:
134
+ # # for p in self.attn.to_out.parameters():
135
+ # # p.detach().zero_()
136
+ # self.attn.to_out = zero_module(self.attn.to_out)
137
+
138
+ def forward(self, context, x, control_scale=1):
139
+ assert self.mask is False
140
+ x_in = x
141
+ x = self.norm1(x)
142
+ context = self.norm2(context)
143
+ b, c, h, w = x.shape
144
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
145
+ context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
146
+ x = self.attn(x, context)
147
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
148
+ if self.mask:
149
+ x = x * torch.zeros_like(x)
150
+ x = x_in + x * control_scale
151
+
152
+ return x
153
+
154
+
155
+ class GLVControl(nn.Module):
156
+ def __init__(
157
+ self,
158
+ in_channels,
159
+ model_channels,
160
+ out_channels,
161
+ num_res_blocks,
162
+ attention_resolutions,
163
+ dropout=0,
164
+ channel_mult=(1, 2, 4, 8),
165
+ conv_resample=True,
166
+ dims=2,
167
+ num_classes=None,
168
+ use_checkpoint=False,
169
+ use_fp16=False,
170
+ num_heads=-1,
171
+ num_head_channels=-1,
172
+ num_heads_upsample=-1,
173
+ use_scale_shift_norm=False,
174
+ resblock_updown=False,
175
+ use_new_attention_order=False,
176
+ use_spatial_transformer=False, # custom transformer support
177
+ transformer_depth=1, # custom transformer support
178
+ context_dim=None, # custom transformer support
179
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
180
+ legacy=True,
181
+ disable_self_attentions=None,
182
+ num_attention_blocks=None,
183
+ disable_middle_self_attn=False,
184
+ use_linear_in_transformer=False,
185
+ spatial_transformer_attn_type="softmax",
186
+ adm_in_channels=None,
187
+ use_fairscale_checkpoint=False,
188
+ offload_to_cpu=False,
189
+ transformer_depth_middle=None,
190
+ input_upscale=1,
191
+ ):
192
+ super().__init__()
193
+ from omegaconf.listconfig import ListConfig
194
+
195
+ if use_spatial_transformer:
196
+ assert (
197
+ context_dim is not None
198
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
199
+
200
+ if context_dim is not None:
201
+ assert (
202
+ use_spatial_transformer
203
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
204
+ if type(context_dim) == ListConfig:
205
+ context_dim = list(context_dim)
206
+
207
+ if num_heads_upsample == -1:
208
+ num_heads_upsample = num_heads
209
+
210
+ if num_heads == -1:
211
+ assert (
212
+ num_head_channels != -1
213
+ ), "Either num_heads or num_head_channels has to be set"
214
+
215
+ if num_head_channels == -1:
216
+ assert (
217
+ num_heads != -1
218
+ ), "Either num_heads or num_head_channels has to be set"
219
+
220
+ self.in_channels = in_channels
221
+ self.model_channels = model_channels
222
+ self.out_channels = out_channels
223
+ if isinstance(transformer_depth, int):
224
+ transformer_depth = len(channel_mult) * [transformer_depth]
225
+ elif isinstance(transformer_depth, ListConfig):
226
+ transformer_depth = list(transformer_depth)
227
+ transformer_depth_middle = default(
228
+ transformer_depth_middle, transformer_depth[-1]
229
+ )
230
+
231
+ if isinstance(num_res_blocks, int):
232
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
233
+ else:
234
+ if len(num_res_blocks) != len(channel_mult):
235
+ raise ValueError(
236
+ "provide num_res_blocks either as an int (globally constant) or "
237
+ "as a list/tuple (per-level) with the same length as channel_mult"
238
+ )
239
+ self.num_res_blocks = num_res_blocks
240
+ # self.num_res_blocks = num_res_blocks
241
+ if disable_self_attentions is not None:
242
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
243
+ assert len(disable_self_attentions) == len(channel_mult)
244
+ if num_attention_blocks is not None:
245
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
246
+ assert all(
247
+ map(
248
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
249
+ range(len(num_attention_blocks)),
250
+ )
251
+ )
252
+ print(
253
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
254
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
255
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
256
+ f"attention will still not be set."
257
+ ) # todo: convert to warning
258
+
259
+ self.attention_resolutions = attention_resolutions
260
+ self.dropout = dropout
261
+ self.channel_mult = channel_mult
262
+ self.conv_resample = conv_resample
263
+ self.num_classes = num_classes
264
+ self.use_checkpoint = use_checkpoint
265
+ if use_fp16:
266
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
267
+ # self.dtype = th.float16 if use_fp16 else th.float32
268
+ self.num_heads = num_heads
269
+ self.num_head_channels = num_head_channels
270
+ self.num_heads_upsample = num_heads_upsample
271
+ self.predict_codebook_ids = n_embed is not None
272
+
273
+ assert use_fairscale_checkpoint != use_checkpoint or not (
274
+ use_checkpoint or use_fairscale_checkpoint
275
+ )
276
+
277
+ self.use_fairscale_checkpoint = False
278
+ checkpoint_wrapper_fn = (
279
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
280
+ if self.use_fairscale_checkpoint
281
+ else lambda x: x
282
+ )
283
+
284
+ time_embed_dim = model_channels * 4
285
+ self.time_embed = checkpoint_wrapper_fn(
286
+ nn.Sequential(
287
+ linear(model_channels, time_embed_dim),
288
+ nn.SiLU(),
289
+ linear(time_embed_dim, time_embed_dim),
290
+ )
291
+ )
292
+
293
+ if self.num_classes is not None:
294
+ if isinstance(self.num_classes, int):
295
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
296
+ elif self.num_classes == "continuous":
297
+ print("setting up linear c_adm embedding layer")
298
+ self.label_emb = nn.Linear(1, time_embed_dim)
299
+ elif self.num_classes == "timestep":
300
+ self.label_emb = checkpoint_wrapper_fn(
301
+ nn.Sequential(
302
+ Timestep(model_channels),
303
+ nn.Sequential(
304
+ linear(model_channels, time_embed_dim),
305
+ nn.SiLU(),
306
+ linear(time_embed_dim, time_embed_dim),
307
+ ),
308
+ )
309
+ )
310
+ elif self.num_classes == "sequential":
311
+ assert adm_in_channels is not None
312
+ self.label_emb = nn.Sequential(
313
+ nn.Sequential(
314
+ linear(adm_in_channels, time_embed_dim),
315
+ nn.SiLU(),
316
+ linear(time_embed_dim, time_embed_dim),
317
+ )
318
+ )
319
+ else:
320
+ raise ValueError()
321
+
322
+ self.input_blocks = nn.ModuleList(
323
+ [
324
+ TimestepEmbedSequential(
325
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
326
+ )
327
+ ]
328
+ )
329
+ self._feature_size = model_channels
330
+ input_block_chans = [model_channels]
331
+ ch = model_channels
332
+ ds = 1
333
+ for level, mult in enumerate(channel_mult):
334
+ for nr in range(self.num_res_blocks[level]):
335
+ layers = [
336
+ checkpoint_wrapper_fn(
337
+ ResBlock(
338
+ ch,
339
+ time_embed_dim,
340
+ dropout,
341
+ out_channels=mult * model_channels,
342
+ dims=dims,
343
+ use_checkpoint=use_checkpoint,
344
+ use_scale_shift_norm=use_scale_shift_norm,
345
+ )
346
+ )
347
+ ]
348
+ ch = mult * model_channels
349
+ if ds in attention_resolutions:
350
+ if num_head_channels == -1:
351
+ dim_head = ch // num_heads
352
+ else:
353
+ num_heads = ch // num_head_channels
354
+ dim_head = num_head_channels
355
+ if legacy:
356
+ # num_heads = 1
357
+ dim_head = (
358
+ ch // num_heads
359
+ if use_spatial_transformer
360
+ else num_head_channels
361
+ )
362
+ if exists(disable_self_attentions):
363
+ disabled_sa = disable_self_attentions[level]
364
+ else:
365
+ disabled_sa = False
366
+
367
+ if (
368
+ not exists(num_attention_blocks)
369
+ or nr < num_attention_blocks[level]
370
+ ):
371
+ layers.append(
372
+ checkpoint_wrapper_fn(
373
+ AttentionBlock(
374
+ ch,
375
+ use_checkpoint=use_checkpoint,
376
+ num_heads=num_heads,
377
+ num_head_channels=dim_head,
378
+ use_new_attention_order=use_new_attention_order,
379
+ )
380
+ )
381
+ if not use_spatial_transformer
382
+ else checkpoint_wrapper_fn(
383
+ SpatialTransformer(
384
+ ch,
385
+ num_heads,
386
+ dim_head,
387
+ depth=transformer_depth[level],
388
+ context_dim=context_dim,
389
+ disable_self_attn=disabled_sa,
390
+ use_linear=use_linear_in_transformer,
391
+ attn_type=spatial_transformer_attn_type,
392
+ use_checkpoint=use_checkpoint,
393
+ )
394
+ )
395
+ )
396
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
397
+ self._feature_size += ch
398
+ input_block_chans.append(ch)
399
+ if level != len(channel_mult) - 1:
400
+ out_ch = ch
401
+ self.input_blocks.append(
402
+ TimestepEmbedSequential(
403
+ checkpoint_wrapper_fn(
404
+ ResBlock(
405
+ ch,
406
+ time_embed_dim,
407
+ dropout,
408
+ out_channels=out_ch,
409
+ dims=dims,
410
+ use_checkpoint=use_checkpoint,
411
+ use_scale_shift_norm=use_scale_shift_norm,
412
+ down=True,
413
+ )
414
+ )
415
+ if resblock_updown
416
+ else Downsample(
417
+ ch, conv_resample, dims=dims, out_channels=out_ch
418
+ )
419
+ )
420
+ )
421
+ ch = out_ch
422
+ input_block_chans.append(ch)
423
+ ds *= 2
424
+ self._feature_size += ch
425
+
426
+ if num_head_channels == -1:
427
+ dim_head = ch // num_heads
428
+ else:
429
+ num_heads = ch // num_head_channels
430
+ dim_head = num_head_channels
431
+ if legacy:
432
+ # num_heads = 1
433
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
434
+ self.middle_block = TimestepEmbedSequential(
435
+ checkpoint_wrapper_fn(
436
+ ResBlock(
437
+ ch,
438
+ time_embed_dim,
439
+ dropout,
440
+ dims=dims,
441
+ use_checkpoint=use_checkpoint,
442
+ use_scale_shift_norm=use_scale_shift_norm,
443
+ )
444
+ ),
445
+ checkpoint_wrapper_fn(
446
+ AttentionBlock(
447
+ ch,
448
+ use_checkpoint=use_checkpoint,
449
+ num_heads=num_heads,
450
+ num_head_channels=dim_head,
451
+ use_new_attention_order=use_new_attention_order,
452
+ )
453
+ )
454
+ if not use_spatial_transformer
455
+ else checkpoint_wrapper_fn(
456
+ SpatialTransformer( # always uses a self-attn
457
+ ch,
458
+ num_heads,
459
+ dim_head,
460
+ depth=transformer_depth_middle,
461
+ context_dim=context_dim,
462
+ disable_self_attn=disable_middle_self_attn,
463
+ use_linear=use_linear_in_transformer,
464
+ attn_type=spatial_transformer_attn_type,
465
+ use_checkpoint=use_checkpoint,
466
+ )
467
+ ),
468
+ checkpoint_wrapper_fn(
469
+ ResBlock(
470
+ ch,
471
+ time_embed_dim,
472
+ dropout,
473
+ dims=dims,
474
+ use_checkpoint=use_checkpoint,
475
+ use_scale_shift_norm=use_scale_shift_norm,
476
+ )
477
+ ),
478
+ )
479
+
480
+ self.input_upscale = input_upscale
481
+ self.input_hint_block = TimestepEmbedSequential(
482
+ zero_module(conv_nd(dims, in_channels, model_channels, 3, padding=1))
483
+ )
484
+
485
+ def convert_to_fp16(self):
486
+ """
487
+ Convert the torso of the model to float16.
488
+ """
489
+ self.input_blocks.apply(convert_module_to_f16)
490
+ self.middle_block.apply(convert_module_to_f16)
491
+
492
+ def convert_to_fp32(self):
493
+ """
494
+ Convert the torso of the model to float32.
495
+ """
496
+ self.input_blocks.apply(convert_module_to_f32)
497
+ self.middle_block.apply(convert_module_to_f32)
498
+
499
+ def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
500
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
501
+ # x = x.to(torch.float32)
502
+ # timesteps = timesteps.to(torch.float32)
503
+ # xt = xt.to(torch.float32)
504
+ # context = context.to(torch.float32)
505
+ # y = y.to(torch.float32)
506
+ # print(x.dtype)
507
+ xt, context, y = xt.to(x.dtype), context.to(x.dtype), y.to(x.dtype)
508
+
509
+ if self.input_upscale != 1:
510
+ x = nn.functional.interpolate(x, scale_factor=self.input_upscale, mode='bilinear', antialias=True)
511
+ assert (y is not None) == (
512
+ self.num_classes is not None
513
+ ), "must specify y if and only if the model is class-conditional"
514
+ hs = []
515
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
516
+ # import pdb
517
+ # pdb.set_trace()
518
+ emb = self.time_embed(t_emb)
519
+
520
+ if self.num_classes is not None:
521
+ assert y.shape[0] == xt.shape[0]
522
+ emb = emb + self.label_emb(y)
523
+
524
+ guided_hint = self.input_hint_block(x, emb, context)
525
+
526
+ # h = x.type(self.dtype)
527
+ h = xt
528
+ for module in self.input_blocks:
529
+ if guided_hint is not None:
530
+ h = module(h, emb, context)
531
+ h += guided_hint
532
+ guided_hint = None
533
+ else:
534
+ h = module(h, emb, context)
535
+ hs.append(h)
536
+ # print(module)
537
+ # print(h.shape)
538
+ h = self.middle_block(h, emb, context)
539
+ hs.append(h)
540
+ return hs
541
+
542
+
543
+ class LightGLVUNet(UNetModel):
544
+ def __init__(self, mode='', project_type='ZeroSFT', project_channel_scale=1,
545
+ *args, **kwargs):
546
+ super().__init__(*args, **kwargs)
547
+ if mode == 'XL-base':
548
+ cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
549
+ project_channels = [160] * 4 + [320] * 3 + [640] * 3
550
+ concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
551
+ cross_attn_insert_idx = [6, 3]
552
+ self.progressive_mask_nums = [0, 3, 7, 11]
553
+ elif mode == 'XL-refine':
554
+ cond_output_channels = [384] * 4 + [768] * 3 + [1536] * 6
555
+ project_channels = [192] * 4 + [384] * 3 + [768] * 6
556
+ concat_channels = [384] * 2 + [768] * 3 + [1536] * 7 + [0]
557
+ cross_attn_insert_idx = [9, 6, 3]
558
+ self.progressive_mask_nums = [0, 3, 6, 10, 14]
559
+ else:
560
+ raise NotImplementedError
561
+
562
+ project_channels = [int(c * project_channel_scale) for c in project_channels]
563
+
564
+ self.project_modules = nn.ModuleList()
565
+ for i in range(len(cond_output_channels)):
566
+ # if i == len(cond_output_channels) - 1:
567
+ # _project_type = 'ZeroCrossAttn'
568
+ # else:
569
+ # _project_type = project_type
570
+ _project_type = project_type
571
+ if _project_type == 'ZeroSFT':
572
+ self.project_modules.append(ZeroSFT(project_channels[i], cond_output_channels[i],
573
+ concat_channels=concat_channels[i]))
574
+ elif _project_type == 'ZeroCrossAttn':
575
+ self.project_modules.append(ZeroCrossAttn(cond_output_channels[i], project_channels[i]))
576
+ else:
577
+ raise NotImplementedError
578
+
579
+ for i in cross_attn_insert_idx:
580
+ self.project_modules.insert(i, ZeroCrossAttn(cond_output_channels[i], concat_channels[i]))
581
+ # print(self.project_modules[i])
582
+
583
+ def step_progressive_mask(self):
584
+ if len(self.progressive_mask_nums) > 0:
585
+ mask_num = self.progressive_mask_nums.pop()
586
+ for i in range(len(self.project_modules)):
587
+ if i < mask_num:
588
+ self.project_modules[i].mask = True
589
+ else:
590
+ self.project_modules[i].mask = False
591
+ return
592
+ # print(f'step_progressive_mask, current masked layers: {mask_num}')
593
+ else:
594
+ return
595
+ # print('step_progressive_mask, no more masked layers')
596
+ # for i in range(len(self.project_modules)):
597
+ # print(self.project_modules[i].mask)
598
+
599
+
600
+ def forward(self, x, timesteps=None, context=None, y=None, control=None, control_scale=1, **kwargs):
601
+ """
602
+ Apply the model to an input batch.
603
+ :param x: an [N x C x ...] Tensor of inputs.
604
+ :param timesteps: a 1-D batch of timesteps.
605
+ :param context: conditioning plugged in via crossattn
606
+ :param y: an [N] Tensor of labels, if class-conditional.
607
+ :return: an [N x C x ...] Tensor of outputs.
608
+ """
609
+ assert (y is not None) == (
610
+ self.num_classes is not None
611
+ ), "must specify y if and only if the model is class-conditional"
612
+ hs = []
613
+
614
+ _dtype = control[0].dtype
615
+ x, context, y = x.to(_dtype), context.to(_dtype), y.to(_dtype)
616
+
617
+ with torch.no_grad():
618
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
619
+ emb = self.time_embed(t_emb)
620
+
621
+ if self.num_classes is not None:
622
+ assert y.shape[0] == x.shape[0]
623
+ emb = emb + self.label_emb(y)
624
+
625
+ # h = x.type(self.dtype)
626
+ h = x
627
+ for module in self.input_blocks:
628
+ h = module(h, emb, context)
629
+ hs.append(h)
630
+
631
+ adapter_idx = len(self.project_modules) - 1
632
+ control_idx = len(control) - 1
633
+ h = self.middle_block(h, emb, context)
634
+ h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
635
+ adapter_idx -= 1
636
+ control_idx -= 1
637
+
638
+ for i, module in enumerate(self.output_blocks):
639
+ _h = hs.pop()
640
+ h = self.project_modules[adapter_idx](control[control_idx], _h, h, control_scale=control_scale)
641
+ adapter_idx -= 1
642
+ # h = th.cat([h, _h], dim=1)
643
+ if len(module) == 3:
644
+ assert isinstance(module[2], Upsample)
645
+ for layer in module[:2]:
646
+ if isinstance(layer, TimestepBlock):
647
+ h = layer(h, emb)
648
+ elif isinstance(layer, SpatialTransformer):
649
+ h = layer(h, context)
650
+ else:
651
+ h = layer(h)
652
+ # print('cross_attn_here')
653
+ h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
654
+ adapter_idx -= 1
655
+ h = module[2](h)
656
+ else:
657
+ h = module(h, emb, context)
658
+ control_idx -= 1
659
+ # print(module)
660
+ # print(h.shape)
661
+
662
+ h = h.type(x.dtype)
663
+ if self.predict_codebook_ids:
664
+ assert False, "not supported anymore. what the f*** are you doing?"
665
+ else:
666
+ return self.out(h)
667
+
668
+ if __name__ == '__main__':
669
+ from omegaconf import OmegaConf
670
+
671
+ # refiner
672
+ # opt = OmegaConf.load('../../options/train/debug_p2_xl.yaml')
673
+ #
674
+ # model = instantiate_from_config(opt.model.params.control_stage_config)
675
+ # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
676
+ # hint = [h.cuda() for h in hint]
677
+ # print(sum(map(lambda hint: hint.numel(), model.parameters())))
678
+ #
679
+ # unet = instantiate_from_config(opt.model.params.network_config)
680
+ # unet = unet.cuda()
681
+ #
682
+ # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
683
+ # torch.randn([1, 2560]).cuda(), hint)
684
+ # print(sum(map(lambda _output: _output.numel(), unet.parameters())))
685
+
686
+ # base
687
+ with torch.no_grad():
688
+ opt = OmegaConf.load('../../options/dev/SUPIR_tmp.yaml')
689
+
690
+ model = instantiate_from_config(opt.model.params.control_stage_config)
691
+ model = model.cuda()
692
+
693
+ hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 2048]).cuda(),
694
+ torch.randn([1, 2816]).cuda())
695
+
696
+ # for h in hint:
697
+ # print(h.shape)
698
+
699
+ unet = instantiate_from_config(opt.model.params.network_config)
700
+ unet = unet.cuda()
701
+ _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 2048]).cuda(),
702
+ torch.randn([1, 2816]).cuda(), hint)
703
+
704
+
705
+ # model = instantiate_from_config(opt.model.params.control_stage_config)
706
+ # model = model.cuda()
707
+ # # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
708
+ # hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 1280]).cuda(),
709
+ # torch.randn([1, 2560]).cuda())
710
+ # # hint = [h.cuda() for h in hint]
711
+ #
712
+ # for h in hint:
713
+ # print(h.shape)
714
+ #
715
+ # unet = instantiate_from_config(opt.model.params.network_config)
716
+ # unet = unet.cuda()
717
+ # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
718
+ # torch.randn([1, 2560]).cuda(), hint)
SUPIR/modules/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SDXL_BASE_CHANNEL_DICT = {
2
+ 'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3,
3
+ 'project_channels': [160] * 4 + [320] * 3 + [640] * 3,
4
+ 'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0]
5
+ }
6
+
7
+ SDXL_REFINE_CHANNEL_DICT = {
8
+ 'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6,
9
+ 'project_channels': [192] * 4 + [384] * 3 + [768] * 6,
10
+ 'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0]
11
+ }
SUPIR/util.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from torch.nn.functional import interpolate
7
+ from omegaconf import OmegaConf
8
+ from sgm.util import instantiate_from_config
9
+
10
+
11
+ def get_state_dict(d):
12
+ return d.get("state_dict", d)
13
+
14
+
15
+ def load_state_dict(ckpt_path, location="cpu"):
16
+ _, extension = os.path.splitext(ckpt_path)
17
+ if extension.lower() == ".safetensors":
18
+ import safetensors.torch
19
+
20
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
21
+ else:
22
+ state_dict = get_state_dict(
23
+ torch.load(ckpt_path, map_location=torch.device(location))
24
+ )
25
+ state_dict = get_state_dict(state_dict)
26
+ print(f'Loaded state_dict from:\n"{ckpt_path}"')
27
+ return state_dict
28
+
29
+
30
+ def create_model(config_path):
31
+ config = OmegaConf.load(config_path)
32
+ model = instantiate_from_config(config.model).cpu()
33
+ print(f'Loaded model config from:\n"{config_path}"')
34
+ return model
35
+
36
+
37
+ def create_SUPIR_model(config_path, SDXL_CKPT, SUPIR_CKPT):
38
+ config = OmegaConf.load(config_path)
39
+ model = instantiate_from_config(config.model).cpu()
40
+ print(f'Loaded model config from:\n"{config_path}"')
41
+
42
+ sdxl_sd = load_state_dict(SDXL_CKPT)
43
+ model.load_state_dict(sdxl_sd, strict=False)
44
+ model.load_state_dict(load_state_dict(SUPIR_CKPT), strict=False)
45
+
46
+ return model, sdxl_sd
47
+
48
+
49
+ def load_QF_ckpt(config_path):
50
+ config = OmegaConf.load(config_path)
51
+ ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location="cpu")
52
+ ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location="cpu")
53
+ return ckpt_Q, ckpt_F
54
+
55
+
56
+ def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
57
+ """
58
+ PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
59
+ """
60
+ # size
61
+ w, h = img.size
62
+ w *= upsacle
63
+ h *= upsacle
64
+ w0, h0 = round(w), round(h)
65
+ if min(w, h) < min_size:
66
+ _upsacle = min_size / min(w, h)
67
+ w *= _upsacle
68
+ h *= _upsacle
69
+ if fix_resize is not None:
70
+ _upsacle = fix_resize / min(w, h)
71
+ w *= _upsacle
72
+ h *= _upsacle
73
+ w0, h0 = round(w), round(h)
74
+ w = int(np.round(w / 64.0)) * 64
75
+ h = int(np.round(h / 64.0)) * 64
76
+ x = img.resize((w, h), Image.BICUBIC)
77
+ x = np.array(x).round().clip(0, 255).astype(np.uint8)
78
+ x = x / 255 * 2 - 1
79
+ x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
80
+ return x, h0, w0
81
+
82
+
83
+ def Tensor2PIL(x, h0, w0):
84
+ """
85
+ Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
86
+ """
87
+ x = x.unsqueeze(0)
88
+ x = interpolate(x, size=(h0, w0), mode="bicubic")
89
+ x = (
90
+ (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5)
91
+ .cpu()
92
+ .numpy()
93
+ .clip(0, 255)
94
+ .astype(np.uint8)
95
+ )
96
+ return Image.fromarray(x)
97
+
98
+
99
+ def HWC3(x):
100
+ assert x.dtype == np.uint8
101
+ if x.ndim == 2:
102
+ x = x[:, :, None]
103
+ assert x.ndim == 3
104
+ H, W, C = x.shape
105
+ assert C == 1 or C == 3 or C == 4
106
+ if C == 3:
107
+ return x
108
+ if C == 1:
109
+ return np.concatenate([x, x, x], axis=2)
110
+ if C == 4:
111
+ color = x[:, :, 0:3].astype(np.float32)
112
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
113
+ y = color * alpha + 255.0 * (1.0 - alpha)
114
+ y = y.clip(0, 255).astype(np.uint8)
115
+ return y
116
+
117
+
118
+ def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
119
+ H, W, C = input_image.shape
120
+ H = float(H)
121
+ W = float(W)
122
+ H *= upscale
123
+ W *= upscale
124
+ if min_size is not None:
125
+ if min(H, W) < min_size:
126
+ _upsacle = min_size / min(W, H)
127
+ W *= _upsacle
128
+ H *= _upsacle
129
+ H = int(np.round(H / unit_resolution)) * unit_resolution
130
+ W = int(np.round(W / unit_resolution)) * unit_resolution
131
+ img = cv2.resize(
132
+ input_image,
133
+ (W, H),
134
+ interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA,
135
+ )
136
+ img = img.round().clip(0, 255).astype(np.uint8)
137
+ return img
138
+
139
+
140
+ def fix_resize(input_image, size=512, unit_resolution=64):
141
+ H, W, C = input_image.shape
142
+ H = float(H)
143
+ W = float(W)
144
+ upscale = size / min(H, W)
145
+ H *= upscale
146
+ W *= upscale
147
+ H = int(np.round(H / unit_resolution)) * unit_resolution
148
+ W = int(np.round(W / unit_resolution)) * unit_resolution
149
+ img = cv2.resize(
150
+ input_image,
151
+ (W, H),
152
+ interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA,
153
+ )
154
+ img = img.round().clip(0, 255).astype(np.uint8)
155
+ return img
156
+
157
+
158
+ def Numpy2Tensor(img):
159
+ """
160
+ np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
161
+ """
162
+ # size
163
+ img = np.array(img) / 255 * 2 - 1
164
+ img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
165
+ return img
166
+
167
+
168
+ def Tensor2Numpy(x, h0=None, w0=None):
169
+ """
170
+ Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
171
+ """
172
+ if h0 is not None and w0 is not None:
173
+ x = x.unsqueeze(0)
174
+ x = interpolate(x, size=(h0, w0), mode="bicubic")
175
+ x = x.squeeze(0)
176
+ x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
177
+ return x
178
+
179
+
180
+ def convert_dtype(dtype_str):
181
+ if dtype_str == "fp32":
182
+ return torch.float32
183
+ elif dtype_str == "fp16":
184
+ return torch.float16
185
+ elif dtype_str == "bf16":
186
+ return torch.bfloat16
187
+ elif dtype_str == "fp8":
188
+ return torch.float8_e4m3fn
189
+ else:
190
+ raise NotImplementedError
SUPIR/utils/__init__.py ADDED
File without changes
SUPIR/utils/colorfix.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # --------------------------------------------------------------------------------
3
+ # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
+ # --------------------------------------------------------------------------------
5
+ '''
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torch import Tensor
10
+ from torch.nn import functional as F
11
+
12
+ from torchvision.transforms import ToTensor, ToPILImage
13
+
14
+ def adain_color_fix(target: Image, source: Image):
15
+ # Convert images to tensors
16
+ to_tensor = ToTensor()
17
+ target_tensor = to_tensor(target).unsqueeze(0)
18
+ source_tensor = to_tensor(source).unsqueeze(0)
19
+
20
+ # Apply adaptive instance normalization
21
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22
+
23
+ # Convert tensor back to image
24
+ to_image = ToPILImage()
25
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26
+
27
+ return result_image
28
+
29
+ def wavelet_color_fix(target: Image, source: Image):
30
+ # Convert images to tensors
31
+ to_tensor = ToTensor()
32
+ target_tensor = to_tensor(target).unsqueeze(0)
33
+ source_tensor = to_tensor(source).unsqueeze(0)
34
+
35
+ # Apply wavelet reconstruction
36
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37
+
38
+ # Convert tensor back to image
39
+ to_image = ToPILImage()
40
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41
+
42
+ return result_image
43
+
44
+ def calc_mean_std(feat: Tensor, eps=1e-5):
45
+ """Calculate mean and std for adaptive_instance_normalization.
46
+ Args:
47
+ feat (Tensor): 4D tensor.
48
+ eps (float): A small value added to the variance to avoid
49
+ divide-by-zero. Default: 1e-5.
50
+ """
51
+ size = feat.size()
52
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
53
+ b, c = size[:2]
54
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57
+ return feat_mean, feat_std
58
+
59
+ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60
+ """Adaptive instance normalization.
61
+ Adjust the reference features to have the similar color and illuminations
62
+ as those in the degradate features.
63
+ Args:
64
+ content_feat (Tensor): The reference feature.
65
+ style_feat (Tensor): The degradate features.
66
+ """
67
+ size = content_feat.size()
68
+ style_mean, style_std = calc_mean_std(style_feat)
69
+ content_mean, content_std = calc_mean_std(content_feat)
70
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72
+
73
+ def wavelet_blur(image: Tensor, radius: int):
74
+ """
75
+ Apply wavelet blur to the input tensor.
76
+ """
77
+ # input shape: (1, 3, H, W)
78
+ # convolution kernel
79
+ kernel_vals = [
80
+ [0.0625, 0.125, 0.0625],
81
+ [0.125, 0.25, 0.125],
82
+ [0.0625, 0.125, 0.0625],
83
+ ]
84
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85
+ # add channel dimensions to the kernel to make it a 4D tensor
86
+ kernel = kernel[None, None]
87
+ # repeat the kernel across all input channels
88
+ kernel = kernel.repeat(3, 1, 1, 1)
89
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90
+ # apply convolution
91
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
92
+ return output
93
+
94
+ def wavelet_decomposition(image: Tensor, levels=5):
95
+ """
96
+ Apply wavelet decomposition to the input tensor.
97
+ This function only returns the low frequency & the high frequency.
98
+ """
99
+ high_freq = torch.zeros_like(image)
100
+ for i in range(levels):
101
+ radius = 2 ** i
102
+ low_freq = wavelet_blur(image, radius)
103
+ high_freq += (image - low_freq)
104
+ image = low_freq
105
+
106
+ return high_freq, low_freq
107
+
108
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109
+ """
110
+ Apply wavelet decomposition, so that the content will have the same color as the style.
111
+ """
112
+ # calculate the wavelet decomposition of the content feature
113
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114
+ del content_low_freq
115
+ # calculate the wavelet decomposition of the style feature
116
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117
+ del style_high_freq
118
+ # reconstruct the content feature with the style's high frequency
119
+ return content_high_freq + style_low_freq
120
+
SUPIR/utils/devices.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import contextlib
3
+ from functools import lru_cache
4
+
5
+ import torch
6
+ #from modules import errors
7
+
8
+ if sys.platform == "darwin":
9
+ from modules import mac_specific
10
+
11
+
12
+ def has_mps() -> bool:
13
+ if sys.platform != "darwin":
14
+ return False
15
+ else:
16
+ return mac_specific.has_mps
17
+
18
+
19
+ def get_cuda_device_string():
20
+ return "cuda"
21
+
22
+
23
+ def get_optimal_device_name():
24
+ if torch.cuda.is_available():
25
+ return get_cuda_device_string()
26
+
27
+ if has_mps():
28
+ return "mps"
29
+
30
+ return "cpu"
31
+
32
+
33
+ def get_optimal_device():
34
+ return torch.device(get_optimal_device_name())
35
+
36
+
37
+ def get_device_for(task):
38
+ return get_optimal_device()
39
+
40
+
41
+ def torch_gc():
42
+
43
+ if torch.cuda.is_available():
44
+ with torch.cuda.device(get_cuda_device_string()):
45
+ torch.cuda.empty_cache()
46
+ torch.cuda.ipc_collect()
47
+
48
+ if has_mps():
49
+ mac_specific.torch_mps_gc()
50
+
51
+
52
+ def enable_tf32():
53
+ if torch.cuda.is_available():
54
+
55
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
56
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
57
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
58
+ torch.backends.cudnn.benchmark = True
59
+
60
+ torch.backends.cuda.matmul.allow_tf32 = True
61
+ torch.backends.cudnn.allow_tf32 = True
62
+
63
+
64
+ enable_tf32()
65
+ #errors.run(enable_tf32, "Enabling TF32")
66
+
67
+ cpu = torch.device("cpu")
68
+ device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
69
+ dtype = torch.float16
70
+ dtype_vae = torch.float16
71
+ dtype_unet = torch.float16
72
+ unet_needs_upcast = False
73
+
74
+
75
+ def cond_cast_unet(input):
76
+ return input.to(dtype_unet) if unet_needs_upcast else input
77
+
78
+
79
+ def cond_cast_float(input):
80
+ return input.float() if unet_needs_upcast else input
81
+
82
+
83
+ def randn(seed, shape):
84
+ torch.manual_seed(seed)
85
+ return torch.randn(shape, device=device)
86
+
87
+
88
+ def randn_without_seed(shape):
89
+ return torch.randn(shape, device=device)
90
+
91
+
92
+ def autocast(disable=False):
93
+ if disable:
94
+ return contextlib.nullcontext()
95
+
96
+ return torch.autocast("cuda")
97
+
98
+
99
+ def without_autocast(disable=False):
100
+ return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
101
+
102
+
103
+ class NansException(Exception):
104
+ pass
105
+
106
+
107
+ def test_for_nans(x, where):
108
+ if not torch.all(torch.isnan(x)).item():
109
+ return
110
+
111
+ if where == "unet":
112
+ message = "A tensor with all NaNs was produced in Unet."
113
+
114
+ elif where == "vae":
115
+ message = "A tensor with all NaNs was produced in VAE."
116
+
117
+ else:
118
+ message = "A tensor with all NaNs was produced."
119
+
120
+ message += " Use --disable-nan-check commandline argument to disable this check."
121
+
122
+ raise NansException(message)
123
+
124
+
125
+ @lru_cache
126
+ def first_time_calculation():
127
+ """
128
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
129
+ spends about 2.7 seconds doing that, at least wih NVidia.
130
+ """
131
+
132
+ x = torch.zeros((1, 1)).to(device, dtype)
133
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
134
+ linear(x)
135
+
136
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
137
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
138
+ conv2d(x)
SUPIR/utils/face_restoration_helper.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from facexlib.detection import init_detection_model
8
+ from facexlib.parsing import init_parsing_model
9
+ from facexlib.utils.misc import img2tensor, imwrite
10
+
11
+ from .file import load_file_from_url
12
+
13
+
14
+ def get_largest_face(det_faces, h, w):
15
+ def get_location(val, length):
16
+ if val < 0:
17
+ return 0
18
+ elif val > length:
19
+ return length
20
+ else:
21
+ return val
22
+
23
+ face_areas = []
24
+ for det_face in det_faces:
25
+ left = get_location(det_face[0], w)
26
+ right = get_location(det_face[2], w)
27
+ top = get_location(det_face[1], h)
28
+ bottom = get_location(det_face[3], h)
29
+ face_area = (right - left) * (bottom - top)
30
+ face_areas.append(face_area)
31
+ largest_idx = face_areas.index(max(face_areas))
32
+ return det_faces[largest_idx], largest_idx
33
+
34
+
35
+ def get_center_face(det_faces, h=0, w=0, center=None):
36
+ if center is not None:
37
+ center = np.array(center)
38
+ else:
39
+ center = np.array([w / 2, h / 2])
40
+ center_dist = []
41
+ for det_face in det_faces:
42
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
43
+ dist = np.linalg.norm(face_center - center)
44
+ center_dist.append(dist)
45
+ center_idx = center_dist.index(min(center_dist))
46
+ return det_faces[center_idx], center_idx
47
+
48
+
49
+ class FaceRestoreHelper(object):
50
+ """Helper for the face restoration pipeline (base class)."""
51
+
52
+ def __init__(self,
53
+ upscale_factor,
54
+ face_size=512,
55
+ crop_ratio=(1, 1),
56
+ det_model='retinaface_resnet50',
57
+ save_ext='png',
58
+ template_3points=False,
59
+ pad_blur=False,
60
+ use_parse=False,
61
+ device=None):
62
+ self.template_3points = template_3points # improve robustness
63
+ self.upscale_factor = int(upscale_factor)
64
+ # the cropped face ratio based on the square face
65
+ self.crop_ratio = crop_ratio # (h, w)
66
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
67
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
68
+ self.det_model = det_model
69
+
70
+ if self.det_model == 'dlib':
71
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
72
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
73
+ [337.91089109, 488.38613861], [437.95049505, 493.51485149],
74
+ [513.58415842, 678.5049505]])
75
+ self.face_template = self.face_template / (1024 // face_size)
76
+ elif self.template_3points:
77
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
78
+ else:
79
+ # standard 5 landmarks for FFHQ faces with 512 x 512
80
+ # facexlib
81
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
82
+ [201.26117, 371.41043], [313.08905, 371.15118]])
83
+
84
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
85
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
86
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
87
+
88
+ self.face_template = self.face_template * (face_size / 512.0)
89
+ if self.crop_ratio[0] > 1:
90
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
91
+ if self.crop_ratio[1] > 1:
92
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
93
+ self.save_ext = save_ext
94
+ self.pad_blur = pad_blur
95
+ if self.pad_blur is True:
96
+ self.template_3points = False
97
+
98
+ self.all_landmarks_5 = []
99
+ self.det_faces = []
100
+ self.affine_matrices = []
101
+ self.inverse_affine_matrices = []
102
+ self.cropped_faces = []
103
+ self.restored_faces = []
104
+ self.pad_input_imgs = []
105
+
106
+ if device is None:
107
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108
+ # self.device = get_device()
109
+ else:
110
+ self.device = device
111
+
112
+ # init face detection model
113
+ self.face_detector = init_detection_model(det_model, half=False, device=self.device)
114
+
115
+ # init face parsing model
116
+ self.use_parse = use_parse
117
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
118
+
119
+ def set_upscale_factor(self, upscale_factor):
120
+ self.upscale_factor = upscale_factor
121
+
122
+ def read_image(self, img):
123
+ """img can be image path or cv2 loaded image."""
124
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
125
+ if isinstance(img, str):
126
+ img = cv2.imread(img)
127
+
128
+ if np.max(img) > 256: # 16-bit image
129
+ img = img / 65535 * 255
130
+ if len(img.shape) == 2: # gray image
131
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
132
+ elif img.shape[2] == 4: # BGRA image with alpha channel
133
+ img = img[:, :, 0:3]
134
+
135
+ self.input_img = img
136
+ # self.is_gray = is_gray(img, threshold=10)
137
+ # if self.is_gray:
138
+ # print('Grayscale input: True')
139
+
140
+ if min(self.input_img.shape[:2]) < 512:
141
+ f = 512.0 / min(self.input_img.shape[:2])
142
+ self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
143
+
144
+ def init_dlib(self, detection_path, landmark5_path):
145
+ """Initialize the dlib detectors and predictors."""
146
+ try:
147
+ import dlib
148
+ except ImportError:
149
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
150
+ detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
151
+ landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
152
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
153
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
154
+ return face_detector, shape_predictor_5
155
+
156
+ def get_face_landmarks_5_dlib(self,
157
+ only_keep_largest=False,
158
+ scale=1):
159
+ det_faces = self.face_detector(self.input_img, scale)
160
+
161
+ if len(det_faces) == 0:
162
+ print('No face detected. Try to increase upsample_num_times.')
163
+ return 0
164
+ else:
165
+ if only_keep_largest:
166
+ print('Detect several faces and only keep the largest.')
167
+ face_areas = []
168
+ for i in range(len(det_faces)):
169
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
170
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
171
+ face_areas.append(face_area)
172
+ largest_idx = face_areas.index(max(face_areas))
173
+ self.det_faces = [det_faces[largest_idx]]
174
+ else:
175
+ self.det_faces = det_faces
176
+
177
+ if len(self.det_faces) == 0:
178
+ return 0
179
+
180
+ for face in self.det_faces:
181
+ shape = self.shape_predictor_5(self.input_img, face.rect)
182
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
183
+ self.all_landmarks_5.append(landmark)
184
+
185
+ return len(self.all_landmarks_5)
186
+
187
+ def get_face_landmarks_5(self,
188
+ only_keep_largest=False,
189
+ only_center_face=False,
190
+ resize=None,
191
+ blur_ratio=0.01,
192
+ eye_dist_threshold=None):
193
+ if self.det_model == 'dlib':
194
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
195
+
196
+ if resize is None:
197
+ scale = 1
198
+ input_img = self.input_img
199
+ else:
200
+ h, w = self.input_img.shape[0:2]
201
+ scale = resize / min(h, w)
202
+ scale = max(1, scale) # always scale up
203
+ h, w = int(h * scale), int(w * scale)
204
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
205
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
206
+
207
+ with torch.no_grad():
208
+ bboxes = self.face_detector.detect_faces(input_img)
209
+
210
+ if bboxes is None or bboxes.shape[0] == 0:
211
+ return 0
212
+ else:
213
+ bboxes = bboxes / scale
214
+
215
+ for bbox in bboxes:
216
+ # remove faces with too small eye distance: side faces or too small faces
217
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
218
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
219
+ continue
220
+
221
+ if self.template_3points:
222
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
223
+ else:
224
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
225
+ self.all_landmarks_5.append(landmark)
226
+ self.det_faces.append(bbox[0:5])
227
+
228
+ if len(self.det_faces) == 0:
229
+ return 0
230
+ if only_keep_largest:
231
+ h, w, _ = self.input_img.shape
232
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
233
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
234
+ elif only_center_face:
235
+ h, w, _ = self.input_img.shape
236
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
237
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
238
+
239
+ # pad blurry images
240
+ if self.pad_blur:
241
+ self.pad_input_imgs = []
242
+ for landmarks in self.all_landmarks_5:
243
+ # get landmarks
244
+ eye_left = landmarks[0, :]
245
+ eye_right = landmarks[1, :]
246
+ eye_avg = (eye_left + eye_right) * 0.5
247
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
248
+ eye_to_eye = eye_right - eye_left
249
+ eye_to_mouth = mouth_avg - eye_avg
250
+
251
+ # Get the oriented crop rectangle
252
+ # x: half width of the oriented crop rectangle
253
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
254
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
255
+ # norm with the hypotenuse: get the direction
256
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
257
+ rect_scale = 1.5
258
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
259
+ # y: half height of the oriented crop rectangle
260
+ y = np.flipud(x) * [-1, 1]
261
+
262
+ # c: center
263
+ c = eye_avg + eye_to_mouth * 0.1
264
+ # quad: (left_top, left_bottom, right_bottom, right_top)
265
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
266
+ # qsize: side length of the square
267
+ qsize = np.hypot(*x) * 2
268
+ border = max(int(np.rint(qsize * 0.1)), 3)
269
+
270
+ # get pad
271
+ # pad: (width_left, height_top, width_right, height_bottom)
272
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
273
+ int(np.ceil(max(quad[:, 1]))))
274
+ pad = [
275
+ max(-pad[0] + border, 1),
276
+ max(-pad[1] + border, 1),
277
+ max(pad[2] - self.input_img.shape[0] + border, 1),
278
+ max(pad[3] - self.input_img.shape[1] + border, 1)
279
+ ]
280
+
281
+ if max(pad) > 1:
282
+ # pad image
283
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
284
+ # modify landmark coords
285
+ landmarks[:, 0] += pad[0]
286
+ landmarks[:, 1] += pad[1]
287
+ # blur pad images
288
+ h, w, _ = pad_img.shape
289
+ y, x, _ = np.ogrid[:h, :w, :1]
290
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
291
+ np.float32(w - 1 - x) / pad[2]),
292
+ 1.0 - np.minimum(np.float32(y) / pad[1],
293
+ np.float32(h - 1 - y) / pad[3]))
294
+ blur = int(qsize * blur_ratio)
295
+ if blur % 2 == 0:
296
+ blur += 1
297
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
298
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
299
+
300
+ pad_img = pad_img.astype('float32')
301
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
302
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
303
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
304
+ self.pad_input_imgs.append(pad_img)
305
+ else:
306
+ self.pad_input_imgs.append(np.copy(self.input_img))
307
+
308
+ return len(self.all_landmarks_5)
309
+
310
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
311
+ """Align and warp faces with face template.
312
+ """
313
+ if self.pad_blur:
314
+ assert len(self.pad_input_imgs) == len(
315
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
316
+ for idx, landmark in enumerate(self.all_landmarks_5):
317
+ # use 5 landmarks to get affine matrix
318
+ # use cv2.LMEDS method for the equivalence to skimage transform
319
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
320
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
321
+ self.affine_matrices.append(affine_matrix)
322
+ # warp and crop faces
323
+ if border_mode == 'constant':
324
+ border_mode = cv2.BORDER_CONSTANT
325
+ elif border_mode == 'reflect101':
326
+ border_mode = cv2.BORDER_REFLECT101
327
+ elif border_mode == 'reflect':
328
+ border_mode = cv2.BORDER_REFLECT
329
+ if self.pad_blur:
330
+ input_img = self.pad_input_imgs[idx]
331
+ else:
332
+ input_img = self.input_img
333
+ cropped_face = cv2.warpAffine(
334
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
335
+ self.cropped_faces.append(cropped_face)
336
+ # save the cropped face
337
+ if save_cropped_path is not None:
338
+ path = os.path.splitext(save_cropped_path)[0]
339
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
340
+ imwrite(cropped_face, save_path)
341
+
342
+ def get_inverse_affine(self, save_inverse_affine_path=None):
343
+ """Get inverse affine matrix."""
344
+ for idx, affine_matrix in enumerate(self.affine_matrices):
345
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
346
+ inverse_affine *= self.upscale_factor
347
+ self.inverse_affine_matrices.append(inverse_affine)
348
+ # save inverse affine matrices
349
+ if save_inverse_affine_path is not None:
350
+ path, _ = os.path.splitext(save_inverse_affine_path)
351
+ save_path = f'{path}_{idx:02d}.pth'
352
+ torch.save(inverse_affine, save_path)
353
+
354
+ def add_restored_face(self, restored_face, input_face=None):
355
+ # if self.is_gray:
356
+ # restored_face = bgr2gray(restored_face) # convert img into grayscale
357
+ # if input_face is not None:
358
+ # restored_face = adain_npy(restored_face, input_face) # transfer the color
359
+ self.restored_faces.append(restored_face)
360
+
361
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
362
+ h, w, _ = self.input_img.shape
363
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
364
+
365
+ if upsample_img is None:
366
+ # simply resize the background
367
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
368
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
369
+ else:
370
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
371
+
372
+ assert len(self.restored_faces) == len(
373
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
374
+
375
+ inv_mask_borders = []
376
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
377
+ if face_upsampler is not None:
378
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
379
+ inverse_affine /= self.upscale_factor
380
+ inverse_affine[:, 2] *= self.upscale_factor
381
+ face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
382
+ else:
383
+ # Add an offset to inverse affine matrix, for more precise back alignment
384
+ if self.upscale_factor > 1:
385
+ extra_offset = 0.5 * self.upscale_factor
386
+ else:
387
+ extra_offset = 0
388
+ inverse_affine[:, 2] += extra_offset
389
+ face_size = self.face_size
390
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
391
+
392
+ # if draw_box or not self.use_parse: # use square parse maps
393
+ # mask = np.ones(face_size, dtype=np.float32)
394
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
395
+ # # remove the black borders
396
+ # inv_mask_erosion = cv2.erode(
397
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
398
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
399
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
400
+ # # add border
401
+ # if draw_box:
402
+ # h, w = face_size
403
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
404
+ # border = int(1400/np.sqrt(total_face_area))
405
+ # mask_border[border:h-border, border:w-border,:] = 0
406
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
407
+ # inv_mask_borders.append(inv_mask_border)
408
+ # if not self.use_parse:
409
+ # # compute the fusion edge based on the area of face
410
+ # w_edge = int(total_face_area**0.5) // 20
411
+ # erosion_radius = w_edge * 2
412
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
413
+ # blur_size = w_edge * 2
414
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
415
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
416
+ # upsample_img = upsample_img[:, :, None]
417
+ # inv_soft_mask = inv_soft_mask[:, :, None]
418
+
419
+ # always use square mask
420
+ mask = np.ones(face_size, dtype=np.float32)
421
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
422
+ # remove the black borders
423
+ inv_mask_erosion = cv2.erode(
424
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
425
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
426
+ total_face_area = np.sum(inv_mask_erosion) # // 3
427
+ # add border
428
+ if draw_box:
429
+ h, w = face_size
430
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
431
+ border = int(1400 / np.sqrt(total_face_area))
432
+ mask_border[border:h - border, border:w - border, :] = 0
433
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
434
+ inv_mask_borders.append(inv_mask_border)
435
+ # compute the fusion edge based on the area of face
436
+ w_edge = int(total_face_area ** 0.5) // 20
437
+ erosion_radius = w_edge * 2
438
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
439
+ blur_size = w_edge * 2
440
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
441
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
442
+ upsample_img = upsample_img[:, :, None]
443
+ inv_soft_mask = inv_soft_mask[:, :, None]
444
+
445
+ # parse mask
446
+ if self.use_parse:
447
+ # inference
448
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
449
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
450
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
451
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
452
+ with torch.no_grad():
453
+ out = self.face_parse(face_input)[0]
454
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
455
+
456
+ parse_mask = np.zeros(out.shape)
457
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
458
+ for idx, color in enumerate(MASK_COLORMAP):
459
+ parse_mask[out == idx] = color
460
+ # blur the mask
461
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
462
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
463
+ # remove the black borders
464
+ thres = 10
465
+ parse_mask[:thres, :] = 0
466
+ parse_mask[-thres:, :] = 0
467
+ parse_mask[:, :thres] = 0
468
+ parse_mask[:, -thres:] = 0
469
+ parse_mask = parse_mask / 255.
470
+
471
+ parse_mask = cv2.resize(parse_mask, face_size)
472
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
473
+ inv_soft_parse_mask = parse_mask[:, :, None]
474
+ # pasted_face = inv_restored
475
+ fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
476
+ inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
477
+
478
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
479
+ alpha = upsample_img[:, :, 3:]
480
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
481
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
482
+ else:
483
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
484
+
485
+ if np.max(upsample_img) > 256: # 16-bit image
486
+ upsample_img = upsample_img.astype(np.uint16)
487
+ else:
488
+ upsample_img = upsample_img.astype(np.uint8)
489
+
490
+ # draw bounding box
491
+ if draw_box:
492
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
493
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
494
+ img_color[:, :, 0] = 0
495
+ img_color[:, :, 1] = 255
496
+ img_color[:, :, 2] = 0
497
+ for inv_mask_border in inv_mask_borders:
498
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
499
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
500
+
501
+ if save_path is not None:
502
+ path = os.path.splitext(save_path)[0]
503
+ save_path = f'{path}.{self.save_ext}'
504
+ imwrite(upsample_img, save_path)
505
+ return upsample_img
506
+
507
+ def clean_all(self):
508
+ self.all_landmarks_5 = []
509
+ self.restored_faces = []
510
+ self.affine_matrices = []
511
+ self.cropped_faces = []
512
+ self.inverse_affine_matrices = []
513
+ self.det_faces = []
514
+ self.pad_input_imgs = []
SUPIR/utils/file.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ from urllib.parse import urlparse
5
+ from torch.hub import download_url_to_file, get_dir
6
+
7
+
8
+ def load_file_list(file_list_path: str) -> List[str]:
9
+ files = []
10
+ # each line in file list contains a path of an image
11
+ with open(file_list_path, "r") as fin:
12
+ for line in fin:
13
+ path = line.strip()
14
+ if path:
15
+ files.append(path)
16
+ return files
17
+
18
+
19
+ def list_image_files(
20
+ img_dir: str,
21
+ exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
22
+ follow_links: bool=False,
23
+ log_progress: bool=False,
24
+ log_every_n_files: int=10000,
25
+ max_size: int=-1
26
+ ) -> List[str]:
27
+ files = []
28
+ for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
29
+ early_stop = False
30
+ for file_name in file_names:
31
+ if os.path.splitext(file_name)[1].lower() in exts:
32
+ if max_size >= 0 and len(files) >= max_size:
33
+ early_stop = True
34
+ break
35
+ files.append(os.path.join(dir_path, file_name))
36
+ if log_progress and len(files) % log_every_n_files == 0:
37
+ print(f"find {len(files)} images in {img_dir}")
38
+ if early_stop:
39
+ break
40
+ return files
41
+
42
+
43
+ def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
44
+ parent_path, file_name = os.path.split(file_path)
45
+ stem, ext = os.path.splitext(file_name)
46
+ return parent_path, stem, ext
47
+
48
+
49
+ # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
50
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
51
+ """Load file form http url, will download models if necessary.
52
+
53
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
54
+
55
+ Args:
56
+ url (str): URL to be downloaded.
57
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
58
+ Default: None.
59
+ progress (bool): Whether to show the download progress. Default: True.
60
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
61
+
62
+ Returns:
63
+ str: The path to the downloaded file.
64
+ """
65
+ if model_dir is None: # use the pytorch hub_dir
66
+ hub_dir = get_dir()
67
+ model_dir = os.path.join(hub_dir, 'checkpoints')
68
+
69
+ os.makedirs(model_dir, exist_ok=True)
70
+
71
+ parts = urlparse(url)
72
+ filename = os.path.basename(parts.path)
73
+ if file_name is not None:
74
+ filename = file_name
75
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
76
+ if not os.path.exists(cached_file):
77
+ print(f'Downloading: "{url}" to {cached_file}\n')
78
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
79
+ return cached_file
SUPIR/utils/tilevae.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ #
3
+ # Ultimate VAE Tile Optimization
4
+ #
5
+ # Introducing a revolutionary new optimization designed to make
6
+ # the VAE work with giant images on limited VRAM!
7
+ # Say goodbye to the frustration of OOM and hello to seamless output!
8
+ #
9
+ # ------------------------------------------------------------------------
10
+ #
11
+ # This script is a wild hack that splits the image into tiles,
12
+ # encodes each tile separately, and merges the result back together.
13
+ #
14
+ # Advantages:
15
+ # - The VAE can now work with giant images on limited VRAM
16
+ # (~10 GB for 8K images!)
17
+ # - The merged output is completely seamless without any post-processing.
18
+ #
19
+ # Drawbacks:
20
+ # - Giant RAM needed. To store the intermediate results for a 4096x4096
21
+ # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
22
+ # you need 128 GB RAM machine (it consumes ~100 GB)
23
+ # - NaNs always appear in for 8k images when you use fp16 (half) VAE
24
+ # You must use --no-half-vae to disable half VAE for that giant image.
25
+ # - Slow speed. With default tile size, it takes around 50/200 seconds
26
+ # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
27
+ # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
28
+ # - The gradient calculation is not compatible with this hack. It
29
+ # will break any backward() or torch.autograd.grad() that passes VAE.
30
+ # (But you can still use the VAE to generate training data.)
31
+ #
32
+ # How it works:
33
+ # 1) The image is split into tiles.
34
+ # - To ensure perfect results, each tile is padded with 32 pixels
35
+ # on each side.
36
+ # - Then the conv2d/silu/upsample/downsample can produce identical
37
+ # results to the original image without splitting.
38
+ # 2) The original forward is decomposed into a task queue and a task worker.
39
+ # - The task queue is a list of functions that will be executed in order.
40
+ # - The task worker is a loop that executes the tasks in the queue.
41
+ # 3) The task queue is executed for each tile.
42
+ # - Current tile is sent to GPU.
43
+ # - local operations are directly executed.
44
+ # - Group norm calculation is temporarily suspended until the mean
45
+ # and var of all tiles are calculated.
46
+ # - The residual is pre-calculated and stored and addded back later.
47
+ # - When need to go to the next tile, the current tile is send to cpu.
48
+ # 4) After all tiles are processed, tiles are merged on cpu and return.
49
+ #
50
+ # Enjoy!
51
+ #
52
+ # @author: LI YI @ Nanyang Technological University - Singapore
53
+ # @date: 2023-03-02
54
+ # @license: MIT License
55
+ #
56
+ # Please give me a star if you like this project!
57
+ #
58
+ # -------------------------------------------------------------------------
59
+
60
+ import gc
61
+ from time import time
62
+ import math
63
+ from tqdm import tqdm
64
+
65
+ import torch
66
+ import torch.version
67
+ import torch.nn.functional as F
68
+ from einops import rearrange
69
+ from diffusers.utils.import_utils import is_xformers_available
70
+
71
+ import SUPIR.utils.devices as devices
72
+
73
+ try:
74
+ import xformers
75
+ import xformers.ops
76
+ except ImportError:
77
+ pass
78
+
79
+ sd_flag = True
80
+
81
+ def get_recommend_encoder_tile_size():
82
+ if torch.cuda.is_available():
83
+ total_memory = torch.cuda.get_device_properties(
84
+ devices.device).total_memory // 2**20
85
+ if total_memory > 16*1000:
86
+ ENCODER_TILE_SIZE = 3072
87
+ elif total_memory > 12*1000:
88
+ ENCODER_TILE_SIZE = 2048
89
+ elif total_memory > 8*1000:
90
+ ENCODER_TILE_SIZE = 1536
91
+ else:
92
+ ENCODER_TILE_SIZE = 960
93
+ else:
94
+ ENCODER_TILE_SIZE = 512
95
+ return ENCODER_TILE_SIZE
96
+
97
+
98
+ def get_recommend_decoder_tile_size():
99
+ if torch.cuda.is_available():
100
+ total_memory = torch.cuda.get_device_properties(
101
+ devices.device).total_memory // 2**20
102
+ if total_memory > 30*1000:
103
+ DECODER_TILE_SIZE = 256
104
+ elif total_memory > 16*1000:
105
+ DECODER_TILE_SIZE = 192
106
+ elif total_memory > 12*1000:
107
+ DECODER_TILE_SIZE = 128
108
+ elif total_memory > 8*1000:
109
+ DECODER_TILE_SIZE = 96
110
+ else:
111
+ DECODER_TILE_SIZE = 64
112
+ else:
113
+ DECODER_TILE_SIZE = 64
114
+ return DECODER_TILE_SIZE
115
+
116
+
117
+ if 'global const':
118
+ DEFAULT_ENABLED = False
119
+ DEFAULT_MOVE_TO_GPU = False
120
+ DEFAULT_FAST_ENCODER = True
121
+ DEFAULT_FAST_DECODER = True
122
+ DEFAULT_COLOR_FIX = 0
123
+ DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
124
+ DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
125
+
126
+
127
+ # inplace version of silu
128
+ def inplace_nonlinearity(x):
129
+ # Test: fix for Nans
130
+ return F.silu(x, inplace=True)
131
+
132
+ # extracted from ldm.modules.diffusionmodules.model
133
+
134
+ # from diffusers lib
135
+ def attn_forward_new(self, h_):
136
+ batch_size, channel, height, width = h_.shape
137
+ hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
138
+
139
+ attention_mask = None
140
+ encoder_hidden_states = None
141
+ batch_size, sequence_length, _ = hidden_states.shape
142
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
143
+
144
+ query = self.to_q(hidden_states)
145
+
146
+ if encoder_hidden_states is None:
147
+ encoder_hidden_states = hidden_states
148
+ elif self.norm_cross:
149
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
150
+
151
+ key = self.to_k(encoder_hidden_states)
152
+ value = self.to_v(encoder_hidden_states)
153
+
154
+ query = self.head_to_batch_dim(query)
155
+ key = self.head_to_batch_dim(key)
156
+ value = self.head_to_batch_dim(value)
157
+
158
+ attention_probs = self.get_attention_scores(query, key, attention_mask)
159
+ hidden_states = torch.bmm(attention_probs, value)
160
+ hidden_states = self.batch_to_head_dim(hidden_states)
161
+
162
+ # linear proj
163
+ hidden_states = self.to_out[0](hidden_states)
164
+ # dropout
165
+ hidden_states = self.to_out[1](hidden_states)
166
+
167
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
168
+
169
+ return hidden_states
170
+
171
+ def attn_forward_new_pt2_0(self, hidden_states,):
172
+ scale = 1
173
+ attention_mask = None
174
+ encoder_hidden_states = None
175
+
176
+ input_ndim = hidden_states.ndim
177
+
178
+ if input_ndim == 4:
179
+ batch_size, channel, height, width = hidden_states.shape
180
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
181
+
182
+ batch_size, sequence_length, _ = (
183
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
184
+ )
185
+
186
+ if attention_mask is not None:
187
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
188
+ # scaled_dot_product_attention expects attention_mask shape to be
189
+ # (batch, heads, source_length, target_length)
190
+ attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
191
+
192
+ if self.group_norm is not None:
193
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
194
+
195
+ query = self.to_q(hidden_states, scale=scale)
196
+
197
+ if encoder_hidden_states is None:
198
+ encoder_hidden_states = hidden_states
199
+ elif self.norm_cross:
200
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
201
+
202
+ key = self.to_k(encoder_hidden_states, scale=scale)
203
+ value = self.to_v(encoder_hidden_states, scale=scale)
204
+
205
+ inner_dim = key.shape[-1]
206
+ head_dim = inner_dim // self.heads
207
+
208
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
209
+
210
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
211
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
212
+
213
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
214
+ # TODO: add support for attn.scale when we move to Torch 2.1
215
+ hidden_states = F.scaled_dot_product_attention(
216
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
217
+ )
218
+
219
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
220
+ hidden_states = hidden_states.to(query.dtype)
221
+
222
+ # linear proj
223
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
224
+ # dropout
225
+ hidden_states = self.to_out[1](hidden_states)
226
+
227
+ if input_ndim == 4:
228
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
229
+
230
+ return hidden_states
231
+
232
+ def attn_forward_new_xformers(self, hidden_states):
233
+ scale = 1
234
+ attention_op = None
235
+ attention_mask = None
236
+ encoder_hidden_states = None
237
+
238
+ input_ndim = hidden_states.ndim
239
+
240
+ if input_ndim == 4:
241
+ batch_size, channel, height, width = hidden_states.shape
242
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
243
+
244
+ batch_size, key_tokens, _ = (
245
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
246
+ )
247
+
248
+ attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
249
+ if attention_mask is not None:
250
+ # expand our mask's singleton query_tokens dimension:
251
+ # [batch*heads, 1, key_tokens] ->
252
+ # [batch*heads, query_tokens, key_tokens]
253
+ # so that it can be added as a bias onto the attention scores that xformers computes:
254
+ # [batch*heads, query_tokens, key_tokens]
255
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
256
+ _, query_tokens, _ = hidden_states.shape
257
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
258
+
259
+ if self.group_norm is not None:
260
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
261
+
262
+ query = self.to_q(hidden_states, scale=scale)
263
+
264
+ if encoder_hidden_states is None:
265
+ encoder_hidden_states = hidden_states
266
+ elif self.norm_cross:
267
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
268
+
269
+ key = self.to_k(encoder_hidden_states, scale=scale)
270
+ value = self.to_v(encoder_hidden_states, scale=scale)
271
+
272
+ query = self.head_to_batch_dim(query).contiguous()
273
+ key = self.head_to_batch_dim(key).contiguous()
274
+ value = self.head_to_batch_dim(value).contiguous()
275
+
276
+ hidden_states = xformers.ops.memory_efficient_attention(
277
+ query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
278
+ )
279
+ hidden_states = hidden_states.to(query.dtype)
280
+ hidden_states = self.batch_to_head_dim(hidden_states)
281
+
282
+ # linear proj
283
+ hidden_states = self.to_out[0](hidden_states, scale=scale)
284
+ # dropout
285
+ hidden_states = self.to_out[1](hidden_states)
286
+
287
+ if input_ndim == 4:
288
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
289
+
290
+ return hidden_states
291
+
292
+ def attn_forward(self, h_):
293
+ q = self.q(h_)
294
+ k = self.k(h_)
295
+ v = self.v(h_)
296
+
297
+ # compute attention
298
+ b, c, h, w = q.shape
299
+ q = q.reshape(b, c, h*w)
300
+ q = q.permute(0, 2, 1) # b,hw,c
301
+ k = k.reshape(b, c, h*w) # b,c,hw
302
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
303
+ w_ = w_ * (int(c)**(-0.5))
304
+ w_ = torch.nn.functional.softmax(w_, dim=2)
305
+
306
+ # attend to values
307
+ v = v.reshape(b, c, h*w)
308
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
309
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
310
+ h_ = torch.bmm(v, w_)
311
+ h_ = h_.reshape(b, c, h, w)
312
+
313
+ h_ = self.proj_out(h_)
314
+
315
+ return h_
316
+
317
+
318
+ def xformer_attn_forward(self, h_):
319
+ q = self.q(h_)
320
+ k = self.k(h_)
321
+ v = self.v(h_)
322
+
323
+ # compute attention
324
+ B, C, H, W = q.shape
325
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
326
+
327
+ q, k, v = map(
328
+ lambda t: t.unsqueeze(3)
329
+ .reshape(B, t.shape[1], 1, C)
330
+ .permute(0, 2, 1, 3)
331
+ .reshape(B * 1, t.shape[1], C)
332
+ .contiguous(),
333
+ (q, k, v),
334
+ )
335
+ out = xformers.ops.memory_efficient_attention(
336
+ q, k, v, attn_bias=None, op=self.attention_op)
337
+
338
+ out = (
339
+ out.unsqueeze(0)
340
+ .reshape(B, 1, out.shape[1], C)
341
+ .permute(0, 2, 1, 3)
342
+ .reshape(B, out.shape[1], C)
343
+ )
344
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
345
+ out = self.proj_out(out)
346
+ return out
347
+
348
+
349
+ def attn2task(task_queue, net):
350
+ if False: #isinstance(net, AttnBlock):
351
+ task_queue.append(('store_res', lambda x: x))
352
+ task_queue.append(('pre_norm', net.norm))
353
+ task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
354
+ task_queue.append(['add_res', None])
355
+ elif False: #isinstance(net, MemoryEfficientAttnBlock):
356
+ task_queue.append(('store_res', lambda x: x))
357
+ task_queue.append(('pre_norm', net.norm))
358
+ task_queue.append(
359
+ ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
360
+ task_queue.append(['add_res', None])
361
+ else:
362
+ task_queue.append(('store_res', lambda x: x))
363
+ task_queue.append(('pre_norm', net.norm))
364
+ if is_xformers_available:
365
+ # task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
366
+ task_queue.append(
367
+ ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
368
+ elif hasattr(F, "scaled_dot_product_attention"):
369
+ task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
370
+ else:
371
+ task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
372
+ task_queue.append(['add_res', None])
373
+
374
+ def resblock2task(queue, block):
375
+ """
376
+ Turn a ResNetBlock into a sequence of tasks and append to the task queue
377
+
378
+ @param queue: the target task queue
379
+ @param block: ResNetBlock
380
+
381
+ """
382
+ if block.in_channels != block.out_channels:
383
+ if sd_flag:
384
+ if block.use_conv_shortcut:
385
+ queue.append(('store_res', block.conv_shortcut))
386
+ else:
387
+ queue.append(('store_res', block.nin_shortcut))
388
+ else:
389
+ if block.use_in_shortcut:
390
+ queue.append(('store_res', block.conv_shortcut))
391
+ else:
392
+ queue.append(('store_res', block.nin_shortcut))
393
+
394
+ else:
395
+ queue.append(('store_res', lambda x: x))
396
+ queue.append(('pre_norm', block.norm1))
397
+ queue.append(('silu', inplace_nonlinearity))
398
+ queue.append(('conv1', block.conv1))
399
+ queue.append(('pre_norm', block.norm2))
400
+ queue.append(('silu', inplace_nonlinearity))
401
+ queue.append(('conv2', block.conv2))
402
+ queue.append(['add_res', None])
403
+
404
+
405
+ def build_sampling(task_queue, net, is_decoder):
406
+ """
407
+ Build the sampling part of a task queue
408
+ @param task_queue: the target task queue
409
+ @param net: the network
410
+ @param is_decoder: currently building decoder or encoder
411
+ """
412
+ if is_decoder:
413
+ if sd_flag:
414
+ resblock2task(task_queue, net.mid.block_1)
415
+ attn2task(task_queue, net.mid.attn_1)
416
+ # print(task_queue)
417
+ resblock2task(task_queue, net.mid.block_2)
418
+ resolution_iter = reversed(range(net.num_resolutions))
419
+ block_ids = net.num_res_blocks + 1
420
+ condition = 0
421
+ module = net.up
422
+ func_name = 'upsample'
423
+ else:
424
+ resblock2task(task_queue, net.mid_block.resnets[0])
425
+ attn2task(task_queue, net.mid_block.attentions[0])
426
+ resblock2task(task_queue, net.mid_block.resnets[1])
427
+ resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
428
+ block_ids = 2 + 1
429
+ condition = len(net.up_blocks) - 1
430
+ module = net.up_blocks
431
+ func_name = 'upsamplers'
432
+ else:
433
+ if sd_flag:
434
+ resolution_iter = range(net.num_resolutions)
435
+ block_ids = net.num_res_blocks
436
+ condition = net.num_resolutions - 1
437
+ module = net.down
438
+ func_name = 'downsample'
439
+ else:
440
+ resolution_iter = range(len(net.down_blocks))
441
+ block_ids = 2
442
+ condition = len(net.down_blocks) - 1
443
+ module = net.down_blocks
444
+ func_name = 'downsamplers'
445
+
446
+ for i_level in resolution_iter:
447
+ for i_block in range(block_ids):
448
+ if sd_flag:
449
+ resblock2task(task_queue, module[i_level].block[i_block])
450
+ else:
451
+ resblock2task(task_queue, module[i_level].resnets[i_block])
452
+ if i_level != condition:
453
+ if sd_flag:
454
+ task_queue.append((func_name, getattr(module[i_level], func_name)))
455
+ else:
456
+ if is_decoder:
457
+ task_queue.append((func_name, module[i_level].upsamplers[0]))
458
+ else:
459
+ task_queue.append((func_name, module[i_level].downsamplers[0]))
460
+
461
+ if not is_decoder:
462
+ if sd_flag:
463
+ resblock2task(task_queue, net.mid.block_1)
464
+ attn2task(task_queue, net.mid.attn_1)
465
+ resblock2task(task_queue, net.mid.block_2)
466
+ else:
467
+ resblock2task(task_queue, net.mid_block.resnets[0])
468
+ attn2task(task_queue, net.mid_block.attentions[0])
469
+ resblock2task(task_queue, net.mid_block.resnets[1])
470
+
471
+
472
+ def build_task_queue(net, is_decoder):
473
+ """
474
+ Build a single task queue for the encoder or decoder
475
+ @param net: the VAE decoder or encoder network
476
+ @param is_decoder: currently building decoder or encoder
477
+ @return: the task queue
478
+ """
479
+ task_queue = []
480
+ task_queue.append(('conv_in', net.conv_in))
481
+
482
+ # construct the sampling part of the task queue
483
+ # because encoder and decoder share the same architecture, we extract the sampling part
484
+ build_sampling(task_queue, net, is_decoder)
485
+ if is_decoder and not sd_flag:
486
+ net.give_pre_end = False
487
+ net.tanh_out = False
488
+
489
+ if not is_decoder or not net.give_pre_end:
490
+ if sd_flag:
491
+ task_queue.append(('pre_norm', net.norm_out))
492
+ else:
493
+ task_queue.append(('pre_norm', net.conv_norm_out))
494
+ task_queue.append(('silu', inplace_nonlinearity))
495
+ task_queue.append(('conv_out', net.conv_out))
496
+ if is_decoder and net.tanh_out:
497
+ task_queue.append(('tanh', torch.tanh))
498
+
499
+ return task_queue
500
+
501
+
502
+ def clone_task_queue(task_queue):
503
+ """
504
+ Clone a task queue
505
+ @param task_queue: the task queue to be cloned
506
+ @return: the cloned task queue
507
+ """
508
+ return [[item for item in task] for task in task_queue]
509
+
510
+
511
+ def get_var_mean(input, num_groups, eps=1e-6):
512
+ """
513
+ Get mean and var for group norm
514
+ """
515
+ b, c = input.size(0), input.size(1)
516
+ channel_in_group = int(c/num_groups)
517
+ input_reshaped = input.contiguous().view(
518
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
519
+ var, mean = torch.var_mean(
520
+ input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
521
+ return var, mean
522
+
523
+
524
+ def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
525
+ """
526
+ Custom group norm with fixed mean and var
527
+
528
+ @param input: input tensor
529
+ @param num_groups: number of groups. by default, num_groups = 32
530
+ @param mean: mean, must be pre-calculated by get_var_mean
531
+ @param var: var, must be pre-calculated by get_var_mean
532
+ @param weight: weight, should be fetched from the original group norm
533
+ @param bias: bias, should be fetched from the original group norm
534
+ @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
535
+
536
+ @return: normalized tensor
537
+ """
538
+ b, c = input.size(0), input.size(1)
539
+ channel_in_group = int(c/num_groups)
540
+ input_reshaped = input.contiguous().view(
541
+ 1, int(b * num_groups), channel_in_group, *input.size()[2:])
542
+
543
+ out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
544
+ training=False, momentum=0, eps=eps)
545
+
546
+ out = out.view(b, c, *input.size()[2:])
547
+
548
+ # post affine transform
549
+ if weight is not None:
550
+ out *= weight.view(1, -1, 1, 1)
551
+ if bias is not None:
552
+ out += bias.view(1, -1, 1, 1)
553
+ return out
554
+
555
+
556
+ def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
557
+ """
558
+ Crop the valid region from the tile
559
+ @param x: input tile
560
+ @param input_bbox: original input bounding box
561
+ @param target_bbox: output bounding box
562
+ @param scale: scale factor
563
+ @return: cropped tile
564
+ """
565
+ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
566
+ margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
567
+ return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
568
+
569
+ # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
570
+
571
+
572
+ def perfcount(fn):
573
+ def wrapper(*args, **kwargs):
574
+ ts = time()
575
+
576
+ if torch.cuda.is_available():
577
+ torch.cuda.reset_peak_memory_stats(devices.device)
578
+ devices.torch_gc()
579
+ gc.collect()
580
+
581
+ ret = fn(*args, **kwargs)
582
+
583
+ devices.torch_gc()
584
+ gc.collect()
585
+ if torch.cuda.is_available():
586
+ vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
587
+ torch.cuda.reset_peak_memory_stats(devices.device)
588
+ print(
589
+ f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
590
+ else:
591
+ print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
592
+
593
+ return ret
594
+ return wrapper
595
+
596
+ # copy end :)
597
+
598
+
599
+ class GroupNormParam:
600
+ def __init__(self):
601
+ self.var_list = []
602
+ self.mean_list = []
603
+ self.pixel_list = []
604
+ self.weight = None
605
+ self.bias = None
606
+
607
+ def add_tile(self, tile, layer):
608
+ var, mean = get_var_mean(tile, 32)
609
+ # For giant images, the variance can be larger than max float16
610
+ # In this case we create a copy to float32
611
+ if var.dtype == torch.float16 and var.isinf().any():
612
+ fp32_tile = tile.float()
613
+ var, mean = get_var_mean(fp32_tile, 32)
614
+ # ============= DEBUG: test for infinite =============
615
+ # if torch.isinf(var).any():
616
+ # print('var: ', var)
617
+ # ====================================================
618
+ self.var_list.append(var)
619
+ self.mean_list.append(mean)
620
+ self.pixel_list.append(
621
+ tile.shape[2]*tile.shape[3])
622
+ if hasattr(layer, 'weight'):
623
+ self.weight = layer.weight
624
+ self.bias = layer.bias
625
+ else:
626
+ self.weight = None
627
+ self.bias = None
628
+
629
+ def summary(self):
630
+ """
631
+ summarize the mean and var and return a function
632
+ that apply group norm on each tile
633
+ """
634
+ if len(self.var_list) == 0:
635
+ return None
636
+ var = torch.vstack(self.var_list)
637
+ mean = torch.vstack(self.mean_list)
638
+ max_value = max(self.pixel_list)
639
+ pixels = torch.tensor(
640
+ self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
641
+ sum_pixels = torch.sum(pixels)
642
+ pixels = pixels.unsqueeze(
643
+ 1) / sum_pixels
644
+ var = torch.sum(
645
+ var * pixels, dim=0)
646
+ mean = torch.sum(
647
+ mean * pixels, dim=0)
648
+ return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
649
+
650
+ @staticmethod
651
+ def from_tile(tile, norm):
652
+ """
653
+ create a function from a single tile without summary
654
+ """
655
+ var, mean = get_var_mean(tile, 32)
656
+ if var.dtype == torch.float16 and var.isinf().any():
657
+ fp32_tile = tile.float()
658
+ var, mean = get_var_mean(fp32_tile, 32)
659
+ # if it is a macbook, we need to convert back to float16
660
+ if var.device.type == 'mps':
661
+ # clamp to avoid overflow
662
+ var = torch.clamp(var, 0, 60000)
663
+ var = var.half()
664
+ mean = mean.half()
665
+ if hasattr(norm, 'weight'):
666
+ weight = norm.weight
667
+ bias = norm.bias
668
+ else:
669
+ weight = None
670
+ bias = None
671
+
672
+ def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
673
+ return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
674
+ return group_norm_func
675
+
676
+
677
+ class VAEHook:
678
+ def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
679
+ self.net = net # encoder | decoder
680
+ self.tile_size = tile_size
681
+ self.is_decoder = is_decoder
682
+ self.fast_mode = (fast_encoder and not is_decoder) or (
683
+ fast_decoder and is_decoder)
684
+ self.color_fix = color_fix and not is_decoder
685
+ self.to_gpu = to_gpu
686
+ self.pad = 11 if is_decoder else 32
687
+
688
+ def __call__(self, x):
689
+ B, C, H, W = x.shape
690
+ original_device = next(self.net.parameters()).device
691
+ try:
692
+ if self.to_gpu:
693
+ self.net.to(devices.get_optimal_device())
694
+ if max(H, W) <= self.pad * 2 + self.tile_size:
695
+ print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
696
+ return self.net.original_forward(x)
697
+ else:
698
+ return self.vae_tile_forward(x)
699
+ finally:
700
+ self.net.to(original_device)
701
+
702
+ def get_best_tile_size(self, lowerbound, upperbound):
703
+ """
704
+ Get the best tile size for GPU memory
705
+ """
706
+ divider = 32
707
+ while divider >= 2:
708
+ remainer = lowerbound % divider
709
+ if remainer == 0:
710
+ return lowerbound
711
+ candidate = lowerbound - remainer + divider
712
+ if candidate <= upperbound:
713
+ return candidate
714
+ divider //= 2
715
+ return lowerbound
716
+
717
+ def split_tiles(self, h, w):
718
+ """
719
+ Tool function to split the image into tiles
720
+ @param h: height of the image
721
+ @param w: width of the image
722
+ @return: tile_input_bboxes, tile_output_bboxes
723
+ """
724
+ tile_input_bboxes, tile_output_bboxes = [], []
725
+ tile_size = self.tile_size
726
+ pad = self.pad
727
+ num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
728
+ num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
729
+ # If any of the numbers are 0, we let it be 1
730
+ # This is to deal with long and thin images
731
+ num_height_tiles = max(num_height_tiles, 1)
732
+ num_width_tiles = max(num_width_tiles, 1)
733
+
734
+ # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
735
+ real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
736
+ real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
737
+ real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
738
+ real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
739
+
740
+ print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
741
+ f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
742
+
743
+ for i in range(num_height_tiles):
744
+ for j in range(num_width_tiles):
745
+ # bbox: [x1, x2, y1, y2]
746
+ # the padding is is unnessary for image borders. So we directly start from (32, 32)
747
+ input_bbox = [
748
+ pad + j * real_tile_width,
749
+ min(pad + (j + 1) * real_tile_width, w),
750
+ pad + i * real_tile_height,
751
+ min(pad + (i + 1) * real_tile_height, h),
752
+ ]
753
+
754
+ # if the output bbox is close to the image boundary, we extend it to the image boundary
755
+ output_bbox = [
756
+ input_bbox[0] if input_bbox[0] > pad else 0,
757
+ input_bbox[1] if input_bbox[1] < w - pad else w,
758
+ input_bbox[2] if input_bbox[2] > pad else 0,
759
+ input_bbox[3] if input_bbox[3] < h - pad else h,
760
+ ]
761
+
762
+ # scale to get the final output bbox
763
+ output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
764
+ tile_output_bboxes.append(output_bbox)
765
+
766
+ # indistinguishable expand the input bbox by pad pixels
767
+ tile_input_bboxes.append([
768
+ max(0, input_bbox[0] - pad),
769
+ min(w, input_bbox[1] + pad),
770
+ max(0, input_bbox[2] - pad),
771
+ min(h, input_bbox[3] + pad),
772
+ ])
773
+
774
+ return tile_input_bboxes, tile_output_bboxes
775
+
776
+ @torch.no_grad()
777
+ def estimate_group_norm(self, z, task_queue, color_fix):
778
+ device = z.device
779
+ tile = z
780
+ last_id = len(task_queue) - 1
781
+ while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
782
+ last_id -= 1
783
+ if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
784
+ raise ValueError('No group norm found in the task queue')
785
+ # estimate until the last group norm
786
+ for i in range(last_id + 1):
787
+ task = task_queue[i]
788
+ if task[0] == 'pre_norm':
789
+ group_norm_func = GroupNormParam.from_tile(tile, task[1])
790
+ task_queue[i] = ('apply_norm', group_norm_func)
791
+ if i == last_id:
792
+ return True
793
+ tile = group_norm_func(tile)
794
+ elif task[0] == 'store_res':
795
+ task_id = i + 1
796
+ while task_id < last_id and task_queue[task_id][0] != 'add_res':
797
+ task_id += 1
798
+ if task_id >= last_id:
799
+ continue
800
+ task_queue[task_id][1] = task[1](tile)
801
+ elif task[0] == 'add_res':
802
+ tile += task[1].to(device)
803
+ task[1] = None
804
+ elif color_fix and task[0] == 'downsample':
805
+ for j in range(i, last_id + 1):
806
+ if task_queue[j][0] == 'store_res':
807
+ task_queue[j] = ('store_res_cpu', task_queue[j][1])
808
+ return True
809
+ else:
810
+ tile = task[1](tile)
811
+ try:
812
+ devices.test_for_nans(tile, "vae")
813
+ except:
814
+ print(f'Nan detected in fast mode estimation. Fast mode disabled.')
815
+ return False
816
+
817
+ raise IndexError('Should not reach here')
818
+
819
+ @perfcount
820
+ @torch.no_grad()
821
+ def vae_tile_forward(self, z):
822
+ """
823
+ Decode a latent vector z into an image in a tiled manner.
824
+ @param z: latent vector
825
+ @return: image
826
+ """
827
+ device = next(self.net.parameters()).device
828
+ dtype = z.dtype
829
+ net = self.net
830
+ tile_size = self.tile_size
831
+ is_decoder = self.is_decoder
832
+
833
+ z = z.detach() # detach the input to avoid backprop
834
+
835
+ N, height, width = z.shape[0], z.shape[2], z.shape[3]
836
+ net.last_z_shape = z.shape
837
+
838
+ # Split the input into tiles and build a task queue for each tile
839
+ print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
840
+
841
+ in_bboxes, out_bboxes = self.split_tiles(height, width)
842
+
843
+ # Prepare tiles by split the input latents
844
+ tiles = []
845
+ for input_bbox in in_bboxes:
846
+ tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
847
+ tiles.append(tile)
848
+
849
+ num_tiles = len(tiles)
850
+ num_completed = 0
851
+
852
+ # Build task queues
853
+ single_task_queue = build_task_queue(net, is_decoder)
854
+ #print(single_task_queue)
855
+ if self.fast_mode:
856
+ # Fast mode: downsample the input image to the tile size,
857
+ # then estimate the group norm parameters on the downsampled image
858
+ scale_factor = tile_size / max(height, width)
859
+ z = z.to(device)
860
+ downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
861
+ # use nearest-exact to keep statictics as close as possible
862
+ print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
863
+
864
+ # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
865
+ # The downsampling will heavily distort its mean and std, so we need to recover it.
866
+ std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
867
+ std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
868
+ downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
869
+ del std_old, mean_old, std_new, mean_new
870
+ # occasionally the std_new is too small or too large, which exceeds the range of float16
871
+ # so we need to clamp it to max z's range.
872
+ downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
873
+ estimate_task_queue = clone_task_queue(single_task_queue)
874
+ if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
875
+ single_task_queue = estimate_task_queue
876
+ del downsampled_z
877
+
878
+ task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
879
+
880
+ # Dummy result
881
+ result = None
882
+ result_approx = None
883
+ #try:
884
+ # with devices.autocast():
885
+ # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
886
+ #except: pass
887
+ # Free memory of input latent tensor
888
+ del z
889
+
890
+ # Task queue execution
891
+ pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
892
+
893
+ # execute the task back and forth when switch tiles so that we always
894
+ # keep one tile on the GPU to reduce unnecessary data transfer
895
+ forward = True
896
+ interrupted = False
897
+ #state.interrupted = interrupted
898
+ while True:
899
+ #if state.interrupted: interrupted = True ; break
900
+
901
+ group_norm_param = GroupNormParam()
902
+ for i in range(num_tiles) if forward else reversed(range(num_tiles)):
903
+ #if state.interrupted: interrupted = True ; break
904
+
905
+ tile = tiles[i].to(device)
906
+ input_bbox = in_bboxes[i]
907
+ task_queue = task_queues[i]
908
+
909
+ interrupted = False
910
+ while len(task_queue) > 0:
911
+ #if state.interrupted: interrupted = True ; break
912
+
913
+ # DEBUG: current task
914
+ # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
915
+ task = task_queue.pop(0)
916
+ if task[0] == 'pre_norm':
917
+ group_norm_param.add_tile(tile, task[1])
918
+ break
919
+ elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
920
+ task_id = 0
921
+ res = task[1](tile)
922
+ if not self.fast_mode or task[0] == 'store_res_cpu':
923
+ res = res.cpu()
924
+ while task_queue[task_id][0] != 'add_res':
925
+ task_id += 1
926
+ task_queue[task_id][1] = res
927
+ elif task[0] == 'add_res':
928
+ tile += task[1].to(device)
929
+ task[1] = None
930
+ else:
931
+ tile = task[1](tile)
932
+ #print(tiles[i].shape, tile.shape, task)
933
+ pbar.update(1)
934
+
935
+ if interrupted: break
936
+
937
+ # check for NaNs in the tile.
938
+ # If there are NaNs, we abort the process to save user's time
939
+ #devices.test_for_nans(tile, "vae")
940
+
941
+ #print(tiles[i].shape, tile.shape, i, num_tiles)
942
+ if len(task_queue) == 0:
943
+ tiles[i] = None
944
+ num_completed += 1
945
+ if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
946
+ result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
947
+ result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
948
+ del tile
949
+ elif i == num_tiles - 1 and forward:
950
+ forward = False
951
+ tiles[i] = tile
952
+ elif i == 0 and not forward:
953
+ forward = True
954
+ tiles[i] = tile
955
+ else:
956
+ tiles[i] = tile.cpu()
957
+ del tile
958
+
959
+ if interrupted: break
960
+ if num_completed == num_tiles: break
961
+
962
+ # insert the group norm task to the head of each task queue
963
+ group_norm_func = group_norm_param.summary()
964
+ if group_norm_func is not None:
965
+ for i in range(num_tiles):
966
+ task_queue = task_queues[i]
967
+ task_queue.insert(0, ('apply_norm', group_norm_func))
968
+
969
+ # Done!
970
+ pbar.close()
971
+ return result.to(dtype) if result is not None else result_approx.to(device)
app.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ raise NotImplementedError(
2
+ "This Space is meant for SD-Webui-Forge\nhttps://github.com/lllyasviel/stable-diffusion-webui-forge"
3
+ )
4
+
5
+ import os
6
+ import gradio as gr
7
+ import argparse
8
+ import numpy as np
9
+ import torch
10
+ import einops
11
+ import copy
12
+ import math
13
+ import time
14
+ import random
15
+ import spaces
16
+ import re
17
+ import uuid
18
+
19
+ from gradio_imageslider import ImageSlider
20
+ from PIL import Image
21
+ from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype, create_SUPIR_model, load_QF_ckpt
22
+ from huggingface_hub import hf_hub_download
23
+ from pillow_heif import register_heif_opener
24
+
25
+ register_heif_opener()
26
+
27
+ max_64_bit_int = np.iinfo(np.int32).max
28
+
29
+ hf_hub_download(repo_id="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", filename="open_clip_pytorch_model.bin", local_dir="laion_CLIP-ViT-bigG-14-laion2B-39B-b160k")
30
+ hf_hub_download(repo_id="camenduru/SUPIR", filename="sd_xl_base_1.0_0.9vae.safetensors", local_dir="yushan777_SUPIR")
31
+ hf_hub_download(repo_id="camenduru/SUPIR", filename="SUPIR-v0F.ckpt", local_dir="yushan777_SUPIR")
32
+ hf_hub_download(repo_id="camenduru/SUPIR", filename="SUPIR-v0Q.ckpt", local_dir="yushan777_SUPIR")
33
+ hf_hub_download(repo_id="RunDiffusion/Juggernaut-XL-Lightning", filename="Juggernaut_RunDiffusionPhoto2_Lightning_4Steps.safetensors", local_dir="RunDiffusion_Juggernaut-XL-Lightning")
34
+
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--opt", type=str, default='options/SUPIR_v0.yaml')
37
+ parser.add_argument("--ip", type=str, default='127.0.0.1')
38
+ parser.add_argument("--port", type=int, default='6688')
39
+ parser.add_argument("--no_llava", action='store_true', default=True)#False
40
+ parser.add_argument("--use_image_slider", action='store_true', default=False)#False
41
+ parser.add_argument("--log_history", action='store_true', default=False)
42
+ parser.add_argument("--loading_half_params", action='store_true', default=False)#False
43
+ parser.add_argument("--use_tile_vae", action='store_true', default=True)#False
44
+ parser.add_argument("--encoder_tile_size", type=int, default=512)
45
+ parser.add_argument("--decoder_tile_size", type=int, default=64)
46
+ parser.add_argument("--load_8bit_llava", action='store_true', default=False)
47
+ args = parser.parse_args()
48
+
49
+ if torch.cuda.device_count() > 0:
50
+ SUPIR_device = 'cuda:0'
51
+
52
+ # Load SUPIR
53
+ model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
54
+ if args.loading_half_params:
55
+ model = model.half()
56
+ if args.use_tile_vae:
57
+ model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
58
+ model = model.to(SUPIR_device)
59
+ model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
60
+ model.current_model = 'v0-Q'
61
+ ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
62
+
63
+ def check_upload(input_image):
64
+ if input_image is None:
65
+ raise gr.Error("Please provide an image to restore.")
66
+ return gr.update(visible = True)
67
+
68
+ def update_seed(is_randomize_seed, seed):
69
+ if is_randomize_seed:
70
+ return random.randint(0, max_64_bit_int)
71
+ return seed
72
+
73
+ def reset():
74
+ return [
75
+ None,
76
+ 0,
77
+ None,
78
+ None,
79
+ "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
80
+ "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
81
+ 1,
82
+ 1024,
83
+ 1,
84
+ 2,
85
+ 50,
86
+ -1.0,
87
+ 1.,
88
+ default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0,
89
+ True,
90
+ random.randint(0, max_64_bit_int),
91
+ 5,
92
+ 1.003,
93
+ "Wavelet",
94
+ "fp32",
95
+ "fp32",
96
+ 1.0,
97
+ True,
98
+ False,
99
+ default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0,
100
+ 0.,
101
+ "v0-Q",
102
+ "input",
103
+ 6
104
+ ]
105
+
106
+ def check(input_image):
107
+ if input_image is None:
108
+ raise gr.Error("Please provide an image to restore.")
109
+
110
+ @spaces.GPU(duration=420)
111
+ def stage1_process(
112
+ input_image,
113
+ gamma_correction,
114
+ diff_dtype,
115
+ ae_dtype
116
+ ):
117
+ print('stage1_process ==>>')
118
+ if torch.cuda.device_count() == 0:
119
+ gr.Warning('Set this space to GPU config to make it work.')
120
+ return None, None
121
+ torch.cuda.set_device(SUPIR_device)
122
+ LQ = HWC3(np.array(Image.open(input_image)))
123
+ LQ = fix_resize(LQ, 512)
124
+ # stage1
125
+ LQ = np.array(LQ) / 255 * 2 - 1
126
+ LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
127
+
128
+ model.ae_dtype = convert_dtype(ae_dtype)
129
+ model.model.dtype = convert_dtype(diff_dtype)
130
+
131
+ LQ = model.batchify_denoise(LQ, is_stage1=True)
132
+ LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
133
+ # gamma correction
134
+ LQ = LQ / 255.0
135
+ LQ = np.power(LQ, gamma_correction)
136
+ LQ *= 255.0
137
+ LQ = LQ.round().clip(0, 255).astype(np.uint8)
138
+ print('<<== stage1_process')
139
+ return LQ, gr.update(visible = True)
140
+
141
+ def stage2_process(*args, **kwargs):
142
+ try:
143
+ return restore_in_Xmin(*args, **kwargs)
144
+ except Exception as e:
145
+ # NO_GPU_MESSAGE_INQUEUE
146
+ print("gradio.exceptions.Error 'No GPU is currently available for you after 60s'")
147
+ print('str(type(e)): ' + str(type(e))) # <class 'gradio.exceptions.Error'>
148
+ print('str(e): ' + str(e)) # You have exceeded your GPU quota...
149
+ try:
150
+ print('e.message: ' + e.message) # No GPU is currently available for you after 60s
151
+ except Exception as e2:
152
+ print('Failure')
153
+ if str(e).startswith("No GPU is currently available for you after 60s"):
154
+ print('Exception identified!!!')
155
+ #if str(type(e)) == "<class 'gradio.exceptions.Error'>":
156
+ #print('Exception of name ' + type(e).__name__)
157
+ raise e
158
+
159
+ def restore_in_Xmin(
160
+ noisy_image,
161
+ rotation,
162
+ denoise_image,
163
+ prompt,
164
+ a_prompt,
165
+ n_prompt,
166
+ num_samples,
167
+ min_size,
168
+ downscale,
169
+ upscale,
170
+ edm_steps,
171
+ s_stage1,
172
+ s_stage2,
173
+ s_cfg,
174
+ randomize_seed,
175
+ seed,
176
+ s_churn,
177
+ s_noise,
178
+ color_fix_type,
179
+ diff_dtype,
180
+ ae_dtype,
181
+ gamma_correction,
182
+ linear_CFG,
183
+ linear_s_stage2,
184
+ spt_linear_CFG,
185
+ spt_linear_s_stage2,
186
+ model_select,
187
+ output_format,
188
+ allocation
189
+ ):
190
+ print("noisy_image:\n" + str(noisy_image))
191
+ print("denoise_image:\n" + str(denoise_image))
192
+ print("rotation: " + str(rotation))
193
+ print("prompt: " + str(prompt))
194
+ print("a_prompt: " + str(a_prompt))
195
+ print("n_prompt: " + str(n_prompt))
196
+ print("num_samples: " + str(num_samples))
197
+ print("min_size: " + str(min_size))
198
+ print("downscale: " + str(downscale))
199
+ print("upscale: " + str(upscale))
200
+ print("edm_steps: " + str(edm_steps))
201
+ print("s_stage1: " + str(s_stage1))
202
+ print("s_stage2: " + str(s_stage2))
203
+ print("s_cfg: " + str(s_cfg))
204
+ print("randomize_seed: " + str(randomize_seed))
205
+ print("seed: " + str(seed))
206
+ print("s_churn: " + str(s_churn))
207
+ print("s_noise: " + str(s_noise))
208
+ print("color_fix_type: " + str(color_fix_type))
209
+ print("diff_dtype: " + str(diff_dtype))
210
+ print("ae_dtype: " + str(ae_dtype))
211
+ print("gamma_correction: " + str(gamma_correction))
212
+ print("linear_CFG: " + str(linear_CFG))
213
+ print("linear_s_stage2: " + str(linear_s_stage2))
214
+ print("spt_linear_CFG: " + str(spt_linear_CFG))
215
+ print("spt_linear_s_stage2: " + str(spt_linear_s_stage2))
216
+ print("model_select: " + str(model_select))
217
+ print("GPU time allocation: " + str(allocation) + " min")
218
+ print("output_format: " + str(output_format))
219
+
220
+ input_format = re.sub(r"^.*\.([^\.]+)$", r"\1", noisy_image)
221
+
222
+ if input_format not in ['png', 'webp', 'jpg', 'jpeg', 'gif', 'bmp', 'heic']:
223
+ gr.Warning('Invalid image format. Please first convert into *.png, *.webp, *.jpg, *.jpeg, *.gif, *.bmp or *.heic.')
224
+ return None, None, None, None
225
+
226
+ if output_format == "input":
227
+ if noisy_image is None:
228
+ output_format = "png"
229
+ else:
230
+ output_format = input_format
231
+ print("final output_format: " + str(output_format))
232
+
233
+ if prompt is None:
234
+ prompt = ""
235
+
236
+ if a_prompt is None:
237
+ a_prompt = ""
238
+
239
+ if n_prompt is None:
240
+ n_prompt = ""
241
+
242
+ if prompt != "" and a_prompt != "":
243
+ a_prompt = prompt + ", " + a_prompt
244
+ else:
245
+ a_prompt = prompt + a_prompt
246
+ print("Final prompt: " + str(a_prompt))
247
+
248
+ denoise_image = np.array(Image.open(noisy_image if denoise_image is None else denoise_image))
249
+
250
+ if rotation == 90:
251
+ denoise_image = np.array(list(zip(*denoise_image[::-1])))
252
+ elif rotation == 180:
253
+ denoise_image = np.array(list(zip(*denoise_image[::-1])))
254
+ denoise_image = np.array(list(zip(*denoise_image[::-1])))
255
+ elif rotation == -90:
256
+ denoise_image = np.array(list(zip(*denoise_image))[::-1])
257
+
258
+ if 1 < downscale:
259
+ input_height, input_width, input_channel = denoise_image.shape
260
+ denoise_image = np.array(Image.fromarray(denoise_image).resize((input_width // downscale, input_height // downscale), Image.LANCZOS))
261
+
262
+ denoise_image = HWC3(denoise_image)
263
+
264
+ if torch.cuda.device_count() == 0:
265
+ gr.Warning('Set this space to GPU config to make it work.')
266
+ return [noisy_image, denoise_image], gr.update(label="Downloadable results in *." + output_format + " format", format = output_format, value = [denoise_image]), None, gr.update(visible=True)
267
+
268
+ if model_select != model.current_model:
269
+ print('load ' + model_select)
270
+ if model_select == 'v0-Q':
271
+ model.load_state_dict(ckpt_Q, strict=False)
272
+ elif model_select == 'v0-F':
273
+ model.load_state_dict(ckpt_F, strict=False)
274
+ model.current_model = model_select
275
+
276
+ model.ae_dtype = convert_dtype(ae_dtype)
277
+ model.model.dtype = convert_dtype(diff_dtype)
278
+
279
+ # Allocation
280
+ if allocation == 1:
281
+ return restore_in_1min(
282
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
283
+ )
284
+ if allocation == 2:
285
+ return restore_in_2min(
286
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
287
+ )
288
+ if allocation == 3:
289
+ return restore_in_3min(
290
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
291
+ )
292
+ if allocation == 4:
293
+ return restore_in_4min(
294
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
295
+ )
296
+ if allocation == 5:
297
+ return restore_in_5min(
298
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
299
+ )
300
+ if allocation == 7:
301
+ return restore_in_7min(
302
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
303
+ )
304
+ if allocation == 8:
305
+ return restore_in_8min(
306
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
307
+ )
308
+ if allocation == 9:
309
+ return restore_in_9min(
310
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
311
+ )
312
+ if allocation == 10:
313
+ return restore_in_10min(
314
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
315
+ )
316
+ else:
317
+ return restore_in_6min(
318
+ noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
319
+ )
320
+
321
+ @spaces.GPU(duration=59)
322
+ def restore_in_1min(*args, **kwargs):
323
+ return restore_on_gpu(*args, **kwargs)
324
+
325
+ @spaces.GPU(duration=119)
326
+ def restore_in_2min(*args, **kwargs):
327
+ return restore_on_gpu(*args, **kwargs)
328
+
329
+ @spaces.GPU(duration=179)
330
+ def restore_in_3min(*args, **kwargs):
331
+ return restore_on_gpu(*args, **kwargs)
332
+
333
+ @spaces.GPU(duration=239)
334
+ def restore_in_4min(*args, **kwargs):
335
+ return restore_on_gpu(*args, **kwargs)
336
+
337
+ @spaces.GPU(duration=299)
338
+ def restore_in_5min(*args, **kwargs):
339
+ return restore_on_gpu(*args, **kwargs)
340
+
341
+ @spaces.GPU(duration=359)
342
+ def restore_in_6min(*args, **kwargs):
343
+ return restore_on_gpu(*args, **kwargs)
344
+
345
+ @spaces.GPU(duration=419)
346
+ def restore_in_7min(*args, **kwargs):
347
+ return restore_on_gpu(*args, **kwargs)
348
+
349
+ @spaces.GPU(duration=479)
350
+ def restore_in_8min(*args, **kwargs):
351
+ return restore_on_gpu(*args, **kwargs)
352
+
353
+ @spaces.GPU(duration=539)
354
+ def restore_in_9min(*args, **kwargs):
355
+ return restore_on_gpu(*args, **kwargs)
356
+
357
+ @spaces.GPU(duration=599)
358
+ def restore_in_10min(*args, **kwargs):
359
+ return restore_on_gpu(*args, **kwargs)
360
+
361
+ def restore_on_gpu(
362
+ noisy_image,
363
+ input_image,
364
+ prompt,
365
+ a_prompt,
366
+ n_prompt,
367
+ num_samples,
368
+ min_size,
369
+ downscale,
370
+ upscale,
371
+ edm_steps,
372
+ s_stage1,
373
+ s_stage2,
374
+ s_cfg,
375
+ randomize_seed,
376
+ seed,
377
+ s_churn,
378
+ s_noise,
379
+ color_fix_type,
380
+ diff_dtype,
381
+ ae_dtype,
382
+ gamma_correction,
383
+ linear_CFG,
384
+ linear_s_stage2,
385
+ spt_linear_CFG,
386
+ spt_linear_s_stage2,
387
+ model_select,
388
+ output_format,
389
+ allocation
390
+ ):
391
+ start = time.time()
392
+ print('restore ==>>')
393
+
394
+ torch.cuda.set_device(SUPIR_device)
395
+
396
+ with torch.no_grad():
397
+ input_image = upscale_image(input_image, upscale, unit_resolution=32, min_size=min_size)
398
+ LQ = np.array(input_image) / 255.0
399
+ LQ = np.power(LQ, gamma_correction)
400
+ LQ *= 255.0
401
+ LQ = LQ.round().clip(0, 255).astype(np.uint8)
402
+ LQ = LQ / 255 * 2 - 1
403
+ LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
404
+ captions = ['']
405
+
406
+ samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
407
+ s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
408
+ num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
409
+ use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
410
+ cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
411
+
412
+ x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
413
+ 0, 255).astype(np.uint8)
414
+ results = [x_samples[i] for i in range(num_samples)]
415
+ torch.cuda.empty_cache()
416
+
417
+ # All the results have the same size
418
+ input_height, input_width, input_channel = np.array(input_image).shape
419
+ result_height, result_width, result_channel = np.array(results[0]).shape
420
+
421
+ print('<<== restore')
422
+ end = time.time()
423
+ secondes = int(end - start)
424
+ minutes = math.floor(secondes / 60)
425
+ secondes = secondes - (minutes * 60)
426
+ hours = math.floor(minutes / 60)
427
+ minutes = minutes - (hours * 60)
428
+ information = ("Start the process again if you want a different result. " if randomize_seed else "") + \
429
+ "If you don't get the image you wanted, add more details in the « Image description ». " + \
430
+ "Wait " + str(allocation) + " min before a new run to avoid quota penalty or use another computer. " + \
431
+ "The image" + (" has" if len(results) == 1 else "s have") + " been generated in " + \
432
+ ((str(hours) + " h, ") if hours != 0 else "") + \
433
+ ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
434
+ str(secondes) + " sec. " + \
435
+ "The new image resolution is " + str(result_width) + \
436
+ " pixels large and " + str(result_height) + \
437
+ " pixels high, so a resolution of " + f'{result_width * result_height:,}' + " pixels."
438
+ print(information)
439
+ try:
440
+ print("Initial resolution: " + f'{input_width * input_height:,}')
441
+ print("Final resolution: " + f'{result_width * result_height:,}')
442
+ print("edm_steps: " + str(edm_steps))
443
+ print("num_samples: " + str(num_samples))
444
+ print("downscale: " + str(downscale))
445
+ print("Estimated minutes: " + f'{(((result_width * result_height**(1/1.75)) * input_width * input_height * (edm_steps**(1/2)) * (num_samples**(1/2.5)))**(1/2.5)) / 25000:,}')
446
+ except Exception as e:
447
+ print('Exception of Estimation')
448
+
449
+ # Only one image can be shown in the slider
450
+ return [noisy_image] + [results[0]], gr.update(label="Downloadable results in *." + output_format + " format", format = output_format, value = results), gr.update(value = information, visible = True), gr.update(visible=True)
451
+
452
+ def load_and_reset(param_setting):
453
+ print('load_and_reset ==>>')
454
+ if torch.cuda.device_count() == 0:
455
+ gr.Warning('Set this space to GPU config to make it work.')
456
+ return None, None, None, None, None, None, None, None, None, None, None, None, None, None
457
+ edm_steps = default_setting.edm_steps
458
+ s_stage2 = 1.0
459
+ s_stage1 = -1.0
460
+ s_churn = 5
461
+ s_noise = 1.003
462
+ a_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - ' \
463
+ 'realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore ' \
464
+ 'detailing, hyper sharpness, perfect without deformations.'
465
+ n_prompt = 'painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, ' \
466
+ '3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, ' \
467
+ 'signature, jpeg artifacts, deformed, lowres, over-smooth'
468
+ color_fix_type = 'Wavelet'
469
+ spt_linear_s_stage2 = 0.0
470
+ linear_s_stage2 = False
471
+ linear_CFG = True
472
+ if param_setting == "Quality":
473
+ s_cfg = default_setting.s_cfg_Quality
474
+ spt_linear_CFG = default_setting.spt_linear_CFG_Quality
475
+ model_select = "v0-Q"
476
+ elif param_setting == "Fidelity":
477
+ s_cfg = default_setting.s_cfg_Fidelity
478
+ spt_linear_CFG = default_setting.spt_linear_CFG_Fidelity
479
+ model_select = "v0-F"
480
+ else:
481
+ raise NotImplementedError
482
+ gr.Info('The parameters are reset.')
483
+ print('<<== load_and_reset')
484
+ return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
485
+ linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select
486
+
487
+ def log_information(result_gallery):
488
+ print('log_information')
489
+ if result_gallery is not None:
490
+ for i, result in enumerate(result_gallery):
491
+ print(result[0])
492
+
493
+ def on_select_result(result_slider, result_gallery, evt: gr.SelectData):
494
+ print('on_select_result')
495
+ if result_gallery is not None:
496
+ for i, result in enumerate(result_gallery):
497
+ print(result[0])
498
+ return [result_slider[0], result_gallery[evt.index][0]]
499
+
500
+ title_html = """
501
+ <h1><center>SUPIR</center></h1>
502
+ <big><center>Upscale your images up to x10 freely, without account, without watermark and download it</center></big>
503
+ <center><big><big>🤸<big><big><big><big><big><big>🤸</big></big></big></big></big></big></big></big></center>
504
+
505
+ <p>This is an online demo of SUPIR, a practicing model scaling for photo-realistic image restoration.
506
+ The content added by SUPIR is <b><u>imagination, not real-world information</u></b>.
507
+ SUPIR is for beauty and illustration only.
508
+ Most of the processes last few minutes.
509
+ If you want to upscale AI-generated images, be noticed that <i>PixArt Sigma</i> space can directly generate 5984x5984 images.
510
+ Due to Gradio issues, the generated image is slightly less satured than the original.
511
+ Please leave a <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">message in discussion</a> if you encounter issues.
512
+ You can also use <a href="https://huggingface.co/spaces/gokaygokay/AuraSR">AuraSR</a> to upscale x4.
513
+
514
+ <p><center><a href="https://arxiv.org/abs/2401.13627">Paper</a> &emsp; <a href="http://supir.xpixel.group/">Project Page</a> &emsp; <a href="https://huggingface.co/blog/MonsterMMORPG/supir-sota-image-upscale-better-than-magnific-ai">Local Install Guide</a></center></p>
515
+ <p><center><a style="display:inline-block" href='https://github.com/Fanghua-Yu/SUPIR'><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/Fanghua-Yu/SUPIR?style=social"></a></center></p>
516
+ """
517
+
518
+
519
+ claim_md = """
520
+ ## **Piracy**
521
+ The images are not stored but the logs are saved during a month.
522
+ ## **How to get SUPIR**
523
+ You can get SUPIR on HuggingFace by [duplicating this space](https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true) and set GPU.
524
+ You can also install SUPIR on your computer following [this tutorial](https://huggingface.co/blog/MonsterMMORPG/supir-sota-image-upscale-better-than-magnific-ai).
525
+ You can install _Pinokio_ on your computer and then install _SUPIR_ into it. It should be quite easy if you have an Nvidia GPU.
526
+ ## **Terms of use**
527
+ By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
528
+ ## **License**
529
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
530
+ """
531
+
532
+ # Gradio interface
533
+ with gr.Blocks() as interface:
534
+ if torch.cuda.device_count() == 0:
535
+ with gr.Row():
536
+ gr.HTML("""
537
+ <p style="background-color: red;"><big><big><big><b>⚠️To use SUPIR, <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true">duplicate this space</a> and set a GPU with 30 GB VRAM.</b>
538
+
539
+ You can't use SUPIR directly here because this space runs on a CPU, which is not enough for SUPIR. Please provide <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">feedback</a> if you have issues.
540
+ </big></big></big></p>
541
+ """)
542
+ gr.HTML(title_html)
543
+
544
+ input_image = gr.Image(label="Input (*.png, *.webp, *.jpeg, *.jpg, *.gif, *.bmp, *.heic)", show_label=True, type="filepath", height=600, elem_id="image-input")
545
+ rotation = gr.Radio([["No rotation", 0], ["⤵ Rotate +90°", 90], ["↩ Return 180°", 180], ["⤴ Rotate -90°", -90]], label="Orientation correction", info="Will apply the following rotation before restoring the image; the AI needs a good orientation to understand the content", value=0, interactive=True, visible=False)
546
+ with gr.Group():
547
+ prompt = gr.Textbox(label="Image description", info="Help the AI understand what the image represents; describe as much as possible, especially the details we can't see on the original image; you can write in any language", value="", placeholder="A 33 years old man, walking, in the street, Santiago, morning, Summer, photorealistic", lines=3)
548
+ prompt_hint = gr.HTML("You can use a <a href='"'https://huggingface.co/spaces/MaziyarPanahi/llava-llama-3-8b'"'>LlaVa space</a> to auto-generate the description of your image.")
549
+ upscale = gr.Radio([["x1", 1], ["x2", 2], ["x3", 3], ["x4", 4], ["x5", 5], ["x6", 6], ["x7", 7], ["x8", 8], ["x9", 9], ["x10", 10]], label="Upscale factor", info="Resolution x1 to x10", value=2, interactive=True)
550
+ output_format = gr.Radio([["As input", "input"], ["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="input", interactive=True)
551
+ allocation = gr.Radio([["1 min", 1], ["2 min", 2], ["3 min", 3], ["4 min", 4], ["5 min", 5], ["6 min", 6], ["7 min (discouraged)", 7], ["8 min (discouraged)", 8], ["9 min (discouraged)", 9], ["10 min (discouraged)", 10]], label="GPU allocation time", info="lower=May abort run, higher=Quota penalty for next runs", value=5, interactive=True)
552
+
553
+ with gr.Accordion("Pre-denoising (optional)", open=False):
554
+ gamma_correction = gr.Slider(label="Gamma Correction", info = "lower=lighter, higher=darker", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
555
+ denoise_button = gr.Button(value="Pre-denoise")
556
+ denoise_image = gr.Image(label="Denoised image", show_label=True, type="filepath", sources=[], interactive = False, height=600, elem_id="image-s1")
557
+ denoise_information = gr.HTML(value="If present, the denoised image will be used for the restoration instead of the input image.", visible=False)
558
+
559
+ with gr.Accordion("Advanced options", open=False):
560
+ a_prompt = gr.Textbox(label="Additional image description",
561
+ info="Completes the main image description",
562
+ value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
563
+ 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
564
+ 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
565
+ 'hyper sharpness, perfect without deformations.',
566
+ lines=3)
567
+ n_prompt = gr.Textbox(label="Negative image description",
568
+ info="Disambiguate by listing what the image does NOT represent",
569
+ value='painting, oil painting, illustration, drawing, art, sketch, anime, '
570
+ 'cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, '
571
+ 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
572
+ 'deformed, lowres, over-smooth',
573
+ lines=3)
574
+ edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
575
+ num_samples = gr.Slider(label="Num Samples", info="Number of generated results", minimum=1, maximum=4 if not args.use_image_slider else 1
576
+ , value=1, step=1)
577
+ min_size = gr.Slider(label="Minimum size", info="Minimum height, minimum width of the result", minimum=32, maximum=4096, value=1024, step=32)
578
+ downscale = gr.Radio([["/1", 1], ["/2", 2], ["/3", 3], ["/4", 4], ["/5", 5], ["/6", 6], ["/7", 7], ["/8", 8], ["/9", 9], ["/10", 10]], label="Pre-downscale factor", info="Reducing blurred image reduce the process time", value=1, interactive=True)
579
+ with gr.Row():
580
+ with gr.Column():
581
+ model_select = gr.Radio([["💃 Quality (v0-Q)", "v0-Q"], ["🎯 Fidelity (v0-F)", "v0-F"]], label="Model Selection", info="Pretrained model", value="v0-Q",
582
+ interactive=True)
583
+ with gr.Column():
584
+ color_fix_type = gr.Radio([["None", "None"], ["AdaIn (improve as a photo)", "AdaIn"], ["Wavelet (for JPEG artifacts)", "Wavelet"]], label="Color-Fix Type", info="AdaIn=Improve following a style, Wavelet=For JPEG artifacts", value="AdaIn",
585
+ interactive=True)
586
+ s_cfg = gr.Slider(label="Text Guidance Scale", info="lower=follow the image, higher=follow the prompt", minimum=1.0, maximum=15.0,
587
+ value=default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.1)
588
+ s_stage2 = gr.Slider(label="Restoring Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
589
+ s_stage1 = gr.Slider(label="Pre-denoising Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
590
+ s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
591
+ s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
592
+ with gr.Row():
593
+ with gr.Column():
594
+ linear_CFG = gr.Checkbox(label="Linear CFG", value=True)
595
+ spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
596
+ maximum=9.0, value=default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.5)
597
+ with gr.Column():
598
+ linear_s_stage2 = gr.Checkbox(label="Linear Restoring Guidance", value=False)
599
+ spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
600
+ maximum=1., value=0., step=0.05)
601
+ with gr.Column():
602
+ diff_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["fp16 (medium)", "fp16"], ["bf16 (speed)", "bf16"]], label="Diffusion Data Type", value="fp32",
603
+ interactive=True)
604
+ with gr.Column():
605
+ ae_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["bf16 (speed)", "bf16"]], label="Auto-Encoder Data Type", value="fp32",
606
+ interactive=True)
607
+ randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different")
608
+ seed = gr.Slider(label="Seed", minimum=0, maximum=max_64_bit_int, step=1, randomize=True)
609
+ with gr.Group():
610
+ param_setting = gr.Radio(["Quality", "Fidelity"], interactive=True, label="Presetting", value = "Quality")
611
+ restart_button = gr.Button(value="Apply presetting")
612
+
613
+ with gr.Column():
614
+ diffusion_button = gr.Button(value="🚀 Upscale/Restore", variant = "primary", elem_id = "process_button")
615
+ reset_btn = gr.Button(value="🧹 Reinit page", variant="stop", elem_id="reset_button", visible = False)
616
+
617
+ restore_information = gr.HTML(value = "Restart the process to get another result.", visible = False)
618
+ result_slider = ImageSlider(label = 'Comparator', show_label = False, interactive = False, elem_id = "slider1", show_download_button = False)
619
+ result_gallery = gr.Gallery(label = 'Downloadable results', show_label = True, interactive = False, elem_id = "gallery1")
620
+
621
+ gr.Examples(
622
+ examples = [
623
+ [
624
+ "./Examples/Example1.png",
625
+ 0,
626
+ None,
627
+ "Group of people, walking, happy, in the street, photorealistic, 8k, extremely detailled",
628
+ "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
629
+ "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
630
+ 2,
631
+ 1024,
632
+ 1,
633
+ 8,
634
+ 200,
635
+ -1,
636
+ 1,
637
+ 7.5,
638
+ False,
639
+ 42,
640
+ 5,
641
+ 1.003,
642
+ "AdaIn",
643
+ "fp16",
644
+ "bf16",
645
+ 1.0,
646
+ True,
647
+ 4,
648
+ False,
649
+ 0.,
650
+ "v0-Q",
651
+ "input",
652
+ 5
653
+ ],
654
+ [
655
+ "./Examples/Example2.jpeg",
656
+ 0,
657
+ None,
658
+ "La cabeza de un gato atigrado, en una casa, fotorrealista, 8k, extremadamente detallada",
659
+ "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
660
+ "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
661
+ 1,
662
+ 1024,
663
+ 1,
664
+ 1,
665
+ 200,
666
+ -1,
667
+ 1,
668
+ 7.5,
669
+ False,
670
+ 42,
671
+ 5,
672
+ 1.003,
673
+ "Wavelet",
674
+ "fp16",
675
+ "bf16",
676
+ 1.0,
677
+ True,
678
+ 4,
679
+ False,
680
+ 0.,
681
+ "v0-Q",
682
+ "input",
683
+ 4
684
+ ],
685
+ [
686
+ "./Examples/Example3.webp",
687
+ 0,
688
+ None,
689
+ "A red apple",
690
+ "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
691
+ "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
692
+ 1,
693
+ 1024,
694
+ 1,
695
+ 1,
696
+ 200,
697
+ -1,
698
+ 1,
699
+ 7.5,
700
+ False,
701
+ 42,
702
+ 5,
703
+ 1.003,
704
+ "Wavelet",
705
+ "fp16",
706
+ "bf16",
707
+ 1.0,
708
+ True,
709
+ 4,
710
+ False,
711
+ 0.,
712
+ "v0-Q",
713
+ "input",
714
+ 4
715
+ ],
716
+ [
717
+ "./Examples/Example3.webp",
718
+ 0,
719
+ None,
720
+ "A red marble",
721
+ "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
722
+ "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
723
+ 1,
724
+ 1024,
725
+ 1,
726
+ 1,
727
+ 200,
728
+ -1,
729
+ 1,
730
+ 7.5,
731
+ False,
732
+ 42,
733
+ 5,
734
+ 1.003,
735
+ "Wavelet",
736
+ "fp16",
737
+ "bf16",
738
+ 1.0,
739
+ True,
740
+ 4,
741
+ False,
742
+ 0.,
743
+ "v0-Q",
744
+ "input",
745
+ 4
746
+ ],
747
+ ],
748
+ run_on_click = True,
749
+ fn = stage2_process,
750
+ inputs = [
751
+ input_image,
752
+ rotation,
753
+ denoise_image,
754
+ prompt,
755
+ a_prompt,
756
+ n_prompt,
757
+ num_samples,
758
+ min_size,
759
+ downscale,
760
+ upscale,
761
+ edm_steps,
762
+ s_stage1,
763
+ s_stage2,
764
+ s_cfg,
765
+ randomize_seed,
766
+ seed,
767
+ s_churn,
768
+ s_noise,
769
+ color_fix_type,
770
+ diff_dtype,
771
+ ae_dtype,
772
+ gamma_correction,
773
+ linear_CFG,
774
+ linear_s_stage2,
775
+ spt_linear_CFG,
776
+ spt_linear_s_stage2,
777
+ model_select,
778
+ output_format,
779
+ allocation
780
+ ],
781
+ outputs = [
782
+ result_slider,
783
+ result_gallery,
784
+ restore_information,
785
+ reset_btn
786
+ ],
787
+ cache_examples = False,
788
+ )
789
+
790
+ with gr.Row():
791
+ gr.Markdown(claim_md)
792
+
793
+ input_image.upload(fn = check_upload, inputs = [
794
+ input_image
795
+ ], outputs = [
796
+ rotation
797
+ ], queue = False, show_progress = False)
798
+
799
+ denoise_button.click(fn = check, inputs = [
800
+ input_image
801
+ ], outputs = [], queue = False, show_progress = False).success(fn = stage1_process, inputs = [
802
+ input_image,
803
+ gamma_correction,
804
+ diff_dtype,
805
+ ae_dtype
806
+ ], outputs=[
807
+ denoise_image,
808
+ denoise_information
809
+ ])
810
+
811
+ diffusion_button.click(fn = update_seed, inputs = [
812
+ randomize_seed,
813
+ seed
814
+ ], outputs = [
815
+ seed
816
+ ], queue = False, show_progress = False).then(fn = check, inputs = [
817
+ input_image
818
+ ], outputs = [], queue = False, show_progress = False).success(fn=stage2_process, inputs = [
819
+ input_image,
820
+ rotation,
821
+ denoise_image,
822
+ prompt,
823
+ a_prompt,
824
+ n_prompt,
825
+ num_samples,
826
+ min_size,
827
+ downscale,
828
+ upscale,
829
+ edm_steps,
830
+ s_stage1,
831
+ s_stage2,
832
+ s_cfg,
833
+ randomize_seed,
834
+ seed,
835
+ s_churn,
836
+ s_noise,
837
+ color_fix_type,
838
+ diff_dtype,
839
+ ae_dtype,
840
+ gamma_correction,
841
+ linear_CFG,
842
+ linear_s_stage2,
843
+ spt_linear_CFG,
844
+ spt_linear_s_stage2,
845
+ model_select,
846
+ output_format,
847
+ allocation
848
+ ], outputs = [
849
+ result_slider,
850
+ result_gallery,
851
+ restore_information,
852
+ reset_btn
853
+ ]).success(fn = log_information, inputs = [
854
+ result_gallery
855
+ ], outputs = [], queue = False, show_progress = False)
856
+
857
+ result_gallery.change(on_select_result, [result_slider, result_gallery], result_slider)
858
+ result_gallery.select(on_select_result, [result_slider, result_gallery], result_slider)
859
+
860
+ restart_button.click(fn = load_and_reset, inputs = [
861
+ param_setting
862
+ ], outputs = [
863
+ edm_steps,
864
+ s_cfg,
865
+ s_stage2,
866
+ s_stage1,
867
+ s_churn,
868
+ s_noise,
869
+ a_prompt,
870
+ n_prompt,
871
+ color_fix_type,
872
+ linear_CFG,
873
+ linear_s_stage2,
874
+ spt_linear_CFG,
875
+ spt_linear_s_stage2,
876
+ model_select
877
+ ])
878
+
879
+ reset_btn.click(fn = reset, inputs = [], outputs = [
880
+ input_image,
881
+ rotation,
882
+ denoise_image,
883
+ prompt,
884
+ a_prompt,
885
+ n_prompt,
886
+ num_samples,
887
+ min_size,
888
+ downscale,
889
+ upscale,
890
+ edm_steps,
891
+ s_stage1,
892
+ s_stage2,
893
+ s_cfg,
894
+ randomize_seed,
895
+ seed,
896
+ s_churn,
897
+ s_noise,
898
+ color_fix_type,
899
+ diff_dtype,
900
+ ae_dtype,
901
+ gamma_correction,
902
+ linear_CFG,
903
+ linear_s_stage2,
904
+ spt_linear_CFG,
905
+ spt_linear_s_stage2,
906
+ model_select,
907
+ output_format,
908
+ allocation
909
+ ], queue = False, show_progress = False)
910
+
911
+ interface.queue(10).launch()
configs/clip_vit_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.22.0.dev0",
24
+ "vocab_size": 49408
25
+ }
configs/tokenizer/config.json ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "clip-vit-large-patch14/",
3
+ "architectures": [
4
+ "CLIPModel"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 2.6592,
8
+ "model_type": "clip",
9
+ "projection_dim": 768,
10
+ "text_config": {
11
+ "_name_or_path": "",
12
+ "add_cross_attention": false,
13
+ "architectures": null,
14
+ "attention_dropout": 0.0,
15
+ "bad_words_ids": null,
16
+ "bos_token_id": 0,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "dropout": 0.0,
23
+ "early_stopping": false,
24
+ "encoder_no_repeat_ngram_size": 0,
25
+ "eos_token_id": 2,
26
+ "finetuning_task": null,
27
+ "forced_bos_token_id": null,
28
+ "forced_eos_token_id": null,
29
+ "hidden_act": "quick_gelu",
30
+ "hidden_size": 768,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "initializer_factor": 1.0,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 3072,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_eps": 1e-05,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "max_position_embeddings": 77,
48
+ "min_length": 0,
49
+ "model_type": "clip_text_model",
50
+ "no_repeat_ngram_size": 0,
51
+ "num_attention_heads": 12,
52
+ "num_beam_groups": 1,
53
+ "num_beams": 1,
54
+ "num_hidden_layers": 12,
55
+ "num_return_sequences": 1,
56
+ "output_attentions": false,
57
+ "output_hidden_states": false,
58
+ "output_scores": false,
59
+ "pad_token_id": 1,
60
+ "prefix": null,
61
+ "problem_type": null,
62
+ "projection_dim": 768,
63
+ "pruned_heads": {},
64
+ "remove_invalid_values": false,
65
+ "repetition_penalty": 1.0,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "sep_token_id": null,
69
+ "task_specific_params": null,
70
+ "temperature": 1.0,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": null,
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": null,
77
+ "torchscript": false,
78
+ "transformers_version": "4.16.0.dev0",
79
+ "use_bfloat16": false,
80
+ "vocab_size": 49408
81
+ },
82
+ "text_config_dict": {
83
+ "hidden_size": 768,
84
+ "intermediate_size": 3072,
85
+ "num_attention_heads": 12,
86
+ "num_hidden_layers": 12,
87
+ "projection_dim": 768
88
+ },
89
+ "torch_dtype": "float32",
90
+ "transformers_version": null,
91
+ "vision_config": {
92
+ "_name_or_path": "",
93
+ "add_cross_attention": false,
94
+ "architectures": null,
95
+ "attention_dropout": 0.0,
96
+ "bad_words_ids": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "dropout": 0.0,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "finetuning_task": null,
108
+ "forced_bos_token_id": null,
109
+ "forced_eos_token_id": null,
110
+ "hidden_act": "quick_gelu",
111
+ "hidden_size": 1024,
112
+ "id2label": {
113
+ "0": "LABEL_0",
114
+ "1": "LABEL_1"
115
+ },
116
+ "image_size": 224,
117
+ "initializer_factor": 1.0,
118
+ "initializer_range": 0.02,
119
+ "intermediate_size": 4096,
120
+ "is_decoder": false,
121
+ "is_encoder_decoder": false,
122
+ "label2id": {
123
+ "LABEL_0": 0,
124
+ "LABEL_1": 1
125
+ },
126
+ "layer_norm_eps": 1e-05,
127
+ "length_penalty": 1.0,
128
+ "max_length": 20,
129
+ "min_length": 0,
130
+ "model_type": "clip_vision_model",
131
+ "no_repeat_ngram_size": 0,
132
+ "num_attention_heads": 16,
133
+ "num_beam_groups": 1,
134
+ "num_beams": 1,
135
+ "num_hidden_layers": 24,
136
+ "num_return_sequences": 1,
137
+ "output_attentions": false,
138
+ "output_hidden_states": false,
139
+ "output_scores": false,
140
+ "pad_token_id": null,
141
+ "patch_size": 14,
142
+ "prefix": null,
143
+ "problem_type": null,
144
+ "projection_dim": 768,
145
+ "pruned_heads": {},
146
+ "remove_invalid_values": false,
147
+ "repetition_penalty": 1.0,
148
+ "return_dict": true,
149
+ "return_dict_in_generate": false,
150
+ "sep_token_id": null,
151
+ "task_specific_params": null,
152
+ "temperature": 1.0,
153
+ "tie_encoder_decoder": false,
154
+ "tie_word_embeddings": true,
155
+ "tokenizer_class": null,
156
+ "top_k": 50,
157
+ "top_p": 1.0,
158
+ "torch_dtype": null,
159
+ "torchscript": false,
160
+ "transformers_version": "4.16.0.dev0",
161
+ "use_bfloat16": false
162
+ },
163
+ "vision_config_dict": {
164
+ "hidden_size": 1024,
165
+ "intermediate_size": 4096,
166
+ "num_attention_heads": 16,
167
+ "num_hidden_layers": 24,
168
+ "patch_size": 14,
169
+ "projection_dim": 768
170
+ }
171
+ }
configs/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
configs/tokenizer/preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.48145466,
9
+ 0.4578275,
10
+ 0.40821073
11
+ ],
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "resample": 3,
18
+ "size": 224
19
+ }
configs/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": true
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "single_word": false,
12
+ "lstrip": false,
13
+ "rstrip": false,
14
+ "normalized": true
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": true
22
+ },
23
+ "pad_token": "<|endoftext|>"
24
+ }
configs/tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "unk_token": {
3
+ "content": "<|endoftext|>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": true,
8
+ "__type": "AddedToken"
9
+ },
10
+ "bos_token": {
11
+ "content": "<|startoftext|>",
12
+ "single_word": false,
13
+ "lstrip": false,
14
+ "rstrip": false,
15
+ "normalized": true,
16
+ "__type": "AddedToken"
17
+ },
18
+ "eos_token": {
19
+ "content": "<|endoftext|>",
20
+ "single_word": false,
21
+ "lstrip": false,
22
+ "rstrip": false,
23
+ "normalized": true,
24
+ "__type": "AddedToken"
25
+ },
26
+ "pad_token": "<|endoftext|>",
27
+ "add_prefix_space": false,
28
+ "errors": "replace",
29
+ "do_lower_case": true,
30
+ "name_or_path": "openai/clip-vit-base-patch32",
31
+ "model_max_length": 77,
32
+ "special_tokens_map_file": "./special_tokens_map.json",
33
+ "tokenizer_class": "CLIPTokenizer"
34
+ }
configs/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
options/SUPIR_v0.yaml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: SUPIR.models.SUPIR_model.SUPIRModel
3
+ params:
4
+ ae_dtype: bf16
5
+ diffusion_dtype: fp16
6
+ scale_factor: 0.13025
7
+ disable_first_stage_autocast: True
8
+ network_wrapper: sgm.modules.diffusionmodules.wrappers.ControlWrapper
9
+
10
+ denoiser_config:
11
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiserWithControl
12
+ params:
13
+ num_idx: 1000
14
+ weighting_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
16
+ scaling_config:
17
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
18
+ discretization_config:
19
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
20
+
21
+ control_stage_config:
22
+ target: SUPIR.modules.SUPIR_v0.GLVControl
23
+ params:
24
+ adm_in_channels: 2816
25
+ num_classes: sequential
26
+ use_checkpoint: True
27
+ in_channels: 4
28
+ out_channels: 4
29
+ model_channels: 320
30
+ attention_resolutions: [4, 2]
31
+ num_res_blocks: 2
32
+ channel_mult: [1, 2, 4]
33
+ num_head_channels: 64
34
+ use_spatial_transformer: True
35
+ use_linear_in_transformer: True
36
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
37
+ context_dim: 2048
38
+ spatial_transformer_attn_type: softmax-xformers
39
+ legacy: False
40
+ input_upscale: 1
41
+
42
+ network_config:
43
+ target: SUPIR.modules.SUPIR_v0.LightGLVUNet
44
+ params:
45
+ mode: XL-base
46
+ project_type: ZeroSFT
47
+ project_channel_scale: 2
48
+ adm_in_channels: 2816
49
+ num_classes: sequential
50
+ use_checkpoint: True
51
+ in_channels: 4
52
+ out_channels: 4
53
+ model_channels: 320
54
+ attention_resolutions: [4, 2]
55
+ num_res_blocks: 2
56
+ channel_mult: [1, 2, 4]
57
+ num_head_channels: 64
58
+ use_spatial_transformer: True
59
+ use_linear_in_transformer: True
60
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
61
+ context_dim: 2048
62
+ spatial_transformer_attn_type: softmax-xformers
63
+ legacy: False
64
+
65
+ conditioner_config:
66
+ target: sgm.modules.GeneralConditionerWithControl
67
+ params:
68
+ emb_models:
69
+ # crossattn cond
70
+ - is_trainable: False
71
+ input_key: txt
72
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
73
+ params:
74
+ layer: hidden
75
+ layer_idx: 11
76
+ # crossattn and vector cond
77
+ - is_trainable: False
78
+ input_key: txt
79
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
80
+ params:
81
+ arch: ViT-bigG-14
82
+ version: laion2b_s39b_b160k
83
+ freeze: True
84
+ layer: penultimate
85
+ always_return_pooled: True
86
+ legacy: False
87
+ # vector cond
88
+ - is_trainable: False
89
+ input_key: original_size_as_tuple
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256 # multiplied by two
93
+ # vector cond
94
+ - is_trainable: False
95
+ input_key: crop_coords_top_left
96
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
97
+ params:
98
+ outdim: 256 # multiplied by two
99
+ # vector cond
100
+ - is_trainable: False
101
+ input_key: target_size_as_tuple
102
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
103
+ params:
104
+ outdim: 256 # multiplied by two
105
+
106
+ first_stage_config:
107
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
108
+ params:
109
+ ckpt_path: ~
110
+ embed_dim: 4
111
+ monitor: val/rec_loss
112
+ ddconfig:
113
+ attn_type: vanilla-xformers
114
+ double_z: true
115
+ z_channels: 4
116
+ resolution: 256
117
+ in_channels: 3
118
+ out_ch: 3
119
+ ch: 128
120
+ ch_mult: [1, 2, 4, 4]
121
+ num_res_blocks: 2
122
+ attn_resolutions: []
123
+ dropout: 0.0
124
+ lossconfig:
125
+ target: torch.nn.Identity
126
+
127
+ sampler_config:
128
+ target: sgm.modules.diffusionmodules.sampling.RestoreEDMSampler
129
+ params:
130
+ num_steps: 100
131
+ restore_cfg: 4.0
132
+ s_churn: 0
133
+ s_noise: 1.003
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ guider_config:
137
+ target: sgm.modules.diffusionmodules.guiders.LinearCFG
138
+ params:
139
+ scale: 7.5
140
+ scale_min: 4.0
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
sgm/models/diffusion.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+
10
+ from ..modules import UNCONDITIONAL_CONFIG
11
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
12
+ from ..modules.ema import LitEma
13
+ from ..util import (
14
+ default,
15
+ disabled_train,
16
+ get_obj_from_str,
17
+ instantiate_from_config,
18
+ log_txt_as_img,
19
+ )
20
+
21
+
22
+ class DiffusionEngine(pl.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network_config,
26
+ denoiser_config,
27
+ first_stage_config,
28
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ network_wrapper: Union[None, str] = None,
34
+ ckpt_path: Union[None, str] = None,
35
+ use_ema: bool = False,
36
+ ema_decay_rate: float = 0.9999,
37
+ scale_factor: float = 1.0,
38
+ disable_first_stage_autocast=False,
39
+ input_key: str = "jpg",
40
+ log_keys: Union[List, None] = None,
41
+ no_cond_log: bool = False,
42
+ compile_model: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.log_keys = log_keys
46
+ self.input_key = input_key
47
+ self.optimizer_config = default(
48
+ optimizer_config, {"target": "torch.optim.AdamW"}
49
+ )
50
+ model = instantiate_from_config(network_config)
51
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
52
+ model, compile_model=compile_model
53
+ )
54
+
55
+ self.denoiser = instantiate_from_config(denoiser_config)
56
+ self.sampler = (
57
+ instantiate_from_config(sampler_config)
58
+ if sampler_config is not None
59
+ else None
60
+ )
61
+ self.conditioner = instantiate_from_config(
62
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
63
+ )
64
+ self.scheduler_config = scheduler_config
65
+ self._init_first_stage(first_stage_config)
66
+
67
+ self.loss_fn = (
68
+ instantiate_from_config(loss_fn_config)
69
+ if loss_fn_config is not None
70
+ else None
71
+ )
72
+
73
+ self.use_ema = use_ema
74
+ if self.use_ema:
75
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
76
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
77
+
78
+ self.scale_factor = scale_factor
79
+ self.disable_first_stage_autocast = disable_first_stage_autocast
80
+ self.no_cond_log = no_cond_log
81
+
82
+ if ckpt_path is not None:
83
+ self.init_from_ckpt(ckpt_path)
84
+
85
+ def init_from_ckpt(
86
+ self,
87
+ path: str,
88
+ ) -> None:
89
+ if path.endswith("ckpt"):
90
+ sd = torch.load(path, map_location="cpu")["state_dict"]
91
+ elif path.endswith("safetensors"):
92
+ sd = load_safetensors(path)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ missing, unexpected = self.load_state_dict(sd, strict=False)
97
+ print(
98
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99
+ )
100
+ if len(missing) > 0:
101
+ print(f"Missing Keys: {missing}")
102
+ if len(unexpected) > 0:
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def _init_first_stage(self, config):
106
+ model = instantiate_from_config(config).eval()
107
+ model.train = disabled_train
108
+ for param in model.parameters():
109
+ param.requires_grad = False
110
+ self.first_stage_model = model
111
+
112
+ def get_input(self, batch):
113
+ # assuming unified data format, dataloader returns a dict.
114
+ # image tensors should be scaled to -1 ... 1 and in bchw format
115
+ return batch[self.input_key]
116
+
117
+ @torch.no_grad()
118
+ def decode_first_stage(self, z):
119
+ z = 1.0 / self.scale_factor * z
120
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
121
+ out = self.first_stage_model.decode(z)
122
+ return out
123
+
124
+ @torch.no_grad()
125
+ def encode_first_stage(self, x):
126
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
127
+ z = self.first_stage_model.encode(x)
128
+ z = self.scale_factor * z
129
+ return z
130
+
131
+ def forward(self, x, batch):
132
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
133
+ loss_mean = loss.mean()
134
+ loss_dict = {"loss": loss_mean}
135
+ return loss_mean, loss_dict
136
+
137
+ def shared_step(self, batch: Dict) -> Any:
138
+ x = self.get_input(batch)
139
+ x = self.encode_first_stage(x)
140
+ batch["global_step"] = self.global_step
141
+ loss, loss_dict = self(x, batch)
142
+ return loss, loss_dict
143
+
144
+ def training_step(self, batch, batch_idx):
145
+ loss, loss_dict = self.shared_step(batch)
146
+
147
+ self.log_dict(
148
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
149
+ )
150
+
151
+ self.log(
152
+ "global_step",
153
+ self.global_step,
154
+ prog_bar=True,
155
+ logger=True,
156
+ on_step=True,
157
+ on_epoch=False,
158
+ )
159
+
160
+ # if self.scheduler_config is not None:
161
+ lr = self.optimizers().param_groups[0]["lr"]
162
+ self.log(
163
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
164
+ )
165
+
166
+ return loss
167
+
168
+ def on_train_start(self, *args, **kwargs):
169
+ if self.sampler is None or self.loss_fn is None:
170
+ raise ValueError("Sampler and loss function need to be set for training.")
171
+
172
+ def on_train_batch_end(self, *args, **kwargs):
173
+ if self.use_ema:
174
+ self.model_ema(self.model)
175
+
176
+ @contextmanager
177
+ def ema_scope(self, context=None):
178
+ if self.use_ema:
179
+ self.model_ema.store(self.model.parameters())
180
+ self.model_ema.copy_to(self.model)
181
+ if context is not None:
182
+ print(f"{context}: Switched to EMA weights")
183
+ try:
184
+ yield None
185
+ finally:
186
+ if self.use_ema:
187
+ self.model_ema.restore(self.model.parameters())
188
+ if context is not None:
189
+ print(f"{context}: Restored training weights")
190
+
191
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
192
+ return get_obj_from_str(cfg["target"])(
193
+ params, lr=lr, **cfg.get("params", dict())
194
+ )
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ params = list(self.model.parameters())
199
+ for embedder in self.conditioner.embedders:
200
+ if embedder.is_trainable:
201
+ params = params + list(embedder.parameters())
202
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
203
+ if self.scheduler_config is not None:
204
+ scheduler = instantiate_from_config(self.scheduler_config)
205
+ print("Setting up LambdaLR scheduler...")
206
+ scheduler = [
207
+ {
208
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
209
+ "interval": "step",
210
+ "frequency": 1,
211
+ }
212
+ ]
213
+ return [opt], scheduler
214
+ return opt
215
+
216
+ @torch.no_grad()
217
+ def sample(
218
+ self,
219
+ cond: Dict,
220
+ uc: Union[Dict, None] = None,
221
+ batch_size: int = 16,
222
+ shape: Union[None, Tuple, List] = None,
223
+ **kwargs,
224
+ ):
225
+ randn = torch.randn(batch_size, *shape).to(self.device)
226
+
227
+ denoiser = lambda input, sigma, c: self.denoiser(
228
+ self.model, input, sigma, c, **kwargs
229
+ )
230
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
231
+ return samples
232
+
233
+ @torch.no_grad()
234
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
235
+ """
236
+ Defines heuristics to log different conditionings.
237
+ These can be lists of strings (text-to-image), tensors, ints, ...
238
+ """
239
+ image_h, image_w = batch[self.input_key].shape[2:]
240
+ log = dict()
241
+
242
+ for embedder in self.conditioner.embedders:
243
+ if (
244
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
245
+ ) and not self.no_cond_log:
246
+ x = batch[embedder.input_key][:n]
247
+ if isinstance(x, torch.Tensor):
248
+ if x.dim() == 1:
249
+ # class-conditional, convert integer to string
250
+ x = [str(x[i].item()) for i in range(x.shape[0])]
251
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
252
+ elif x.dim() == 2:
253
+ # size and crop cond and the like
254
+ x = [
255
+ "x".join([str(xx) for xx in x[i].tolist()])
256
+ for i in range(x.shape[0])
257
+ ]
258
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
259
+ else:
260
+ raise NotImplementedError()
261
+ elif isinstance(x, (List, ListConfig)):
262
+ if isinstance(x[0], str):
263
+ # strings
264
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
265
+ else:
266
+ raise NotImplementedError()
267
+ else:
268
+ raise NotImplementedError()
269
+ log[embedder.input_key] = xc
270
+ return log
271
+
272
+ @torch.no_grad()
273
+ def log_images(
274
+ self,
275
+ batch: Dict,
276
+ N: int = 8,
277
+ sample: bool = True,
278
+ ucg_keys: List[str] = None,
279
+ **kwargs,
280
+ ) -> Dict:
281
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
282
+ if ucg_keys:
283
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
284
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
285
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
286
+ )
287
+ else:
288
+ ucg_keys = conditioner_input_keys
289
+ log = dict()
290
+
291
+ x = self.get_input(batch)
292
+
293
+ c, uc = self.conditioner.get_unconditional_conditioning(
294
+ batch,
295
+ force_uc_zero_embeddings=ucg_keys
296
+ if len(self.conditioner.embedders) > 0
297
+ else [],
298
+ )
299
+
300
+ sampling_kwargs = {}
301
+
302
+ N = min(x.shape[0], N)
303
+ x = x.to(self.device)[:N]
304
+ log["inputs"] = x
305
+ z = self.encode_first_stage(x)
306
+ log["reconstructions"] = self.decode_first_stage(z)
307
+ log.update(self.log_conditionings(batch, N))
308
+
309
+ for k in c:
310
+ if isinstance(c[k], torch.Tensor):
311
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
312
+
313
+ if sample:
314
+ with self.ema_scope("Plotting"):
315
+ samples = self.sample(
316
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
317
+ )
318
+ samples = self.decode_first_stage(samples)
319
+ log["samples"] = samples
320
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+ from .encoders.modules import GeneralConditionerWithControl
3
+ from .encoders.modules import PreparedConditioner
4
+
5
+ UNCONDITIONAL_CONFIG = {
6
+ "target": "sgm.modules.GeneralConditioner",
7
+ "params": {"emb_models": []},
8
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ # from einops._torch_specific import allow_ops_in_compiled_graph
8
+ # allow_ops_in_compiled_graph()
9
+ from einops import rearrange, repeat
10
+ from packaging import version
11
+ from torch import nn
12
+
13
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
14
+ SDP_IS_AVAILABLE = True
15
+ from torch.backends.cuda import SDPBackend, sdp_kernel
16
+
17
+ BACKEND_MAP = {
18
+ SDPBackend.MATH: {
19
+ "enable_math": True,
20
+ "enable_flash": False,
21
+ "enable_mem_efficient": False,
22
+ },
23
+ SDPBackend.FLASH_ATTENTION: {
24
+ "enable_math": False,
25
+ "enable_flash": True,
26
+ "enable_mem_efficient": False,
27
+ },
28
+ SDPBackend.EFFICIENT_ATTENTION: {
29
+ "enable_math": False,
30
+ "enable_flash": False,
31
+ "enable_mem_efficient": True,
32
+ },
33
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
34
+ }
35
+ else:
36
+ from contextlib import nullcontext
37
+
38
+ SDP_IS_AVAILABLE = False
39
+ sdp_kernel = nullcontext
40
+ BACKEND_MAP = {}
41
+ print(
42
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
43
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
44
+ )
45
+
46
+ try:
47
+ import xformers
48
+ import xformers.ops
49
+
50
+ XFORMERS_IS_AVAILABLE = True
51
+ except:
52
+ XFORMERS_IS_AVAILABLE = False
53
+ print("no module 'xformers'. Processing without...")
54
+
55
+ from .diffusionmodules.util import checkpoint
56
+
57
+
58
+ def exists(val):
59
+ return val is not None
60
+
61
+
62
+ def uniq(arr):
63
+ return {el: True for el in arr}.keys()
64
+
65
+
66
+ def default(val, d):
67
+ if exists(val):
68
+ return val
69
+ return d() if isfunction(d) else d
70
+
71
+
72
+ def max_neg_value(t):
73
+ return -torch.finfo(t.dtype).max
74
+
75
+
76
+ def init_(tensor):
77
+ dim = tensor.shape[-1]
78
+ std = 1 / math.sqrt(dim)
79
+ tensor.uniform_(-std, std)
80
+ return tensor
81
+
82
+
83
+ # feedforward
84
+ class GEGLU(nn.Module):
85
+ def __init__(self, dim_in, dim_out):
86
+ super().__init__()
87
+ self.proj = nn.Linear(dim_in, dim_out * 2)
88
+
89
+ def forward(self, x):
90
+ x, gate = self.proj(x).chunk(2, dim=-1)
91
+ return x * F.gelu(gate)
92
+
93
+
94
+ class FeedForward(nn.Module):
95
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
96
+ super().__init__()
97
+ inner_dim = int(dim * mult)
98
+ dim_out = default(dim_out, dim)
99
+ project_in = (
100
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
101
+ if not glu
102
+ else GEGLU(dim, inner_dim)
103
+ )
104
+
105
+ self.net = nn.Sequential(
106
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
107
+ )
108
+
109
+ def forward(self, x):
110
+ return self.net(x)
111
+
112
+
113
+ def zero_module(module):
114
+ """
115
+ Zero out the parameters of a module and return it.
116
+ """
117
+ for p in module.parameters():
118
+ p.detach().zero_()
119
+ return module
120
+
121
+
122
+ def Normalize(in_channels):
123
+ return torch.nn.GroupNorm(
124
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
125
+ )
126
+
127
+
128
+ class LinearAttention(nn.Module):
129
+ def __init__(self, dim, heads=4, dim_head=32):
130
+ super().__init__()
131
+ self.heads = heads
132
+ hidden_dim = dim_head * heads
133
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
134
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
135
+
136
+ def forward(self, x):
137
+ b, c, h, w = x.shape
138
+ qkv = self.to_qkv(x)
139
+ q, k, v = rearrange(
140
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
141
+ )
142
+ k = k.softmax(dim=-1)
143
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
144
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
145
+ out = rearrange(
146
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
147
+ )
148
+ return self.to_out(out)
149
+
150
+
151
+ class SpatialSelfAttention(nn.Module):
152
+ def __init__(self, in_channels):
153
+ super().__init__()
154
+ self.in_channels = in_channels
155
+
156
+ self.norm = Normalize(in_channels)
157
+ self.q = torch.nn.Conv2d(
158
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
159
+ )
160
+ self.k = torch.nn.Conv2d(
161
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
162
+ )
163
+ self.v = torch.nn.Conv2d(
164
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
165
+ )
166
+ self.proj_out = torch.nn.Conv2d(
167
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
168
+ )
169
+
170
+ def forward(self, x):
171
+ h_ = x
172
+ h_ = self.norm(h_)
173
+ q = self.q(h_)
174
+ k = self.k(h_)
175
+ v = self.v(h_)
176
+
177
+ # compute attention
178
+ b, c, h, w = q.shape
179
+ q = rearrange(q, "b c h w -> b (h w) c")
180
+ k = rearrange(k, "b c h w -> b c (h w)")
181
+ w_ = torch.einsum("bij,bjk->bik", q, k)
182
+
183
+ w_ = w_ * (int(c) ** (-0.5))
184
+ w_ = torch.nn.functional.softmax(w_, dim=2)
185
+
186
+ # attend to values
187
+ v = rearrange(v, "b c h w -> b c (h w)")
188
+ w_ = rearrange(w_, "b i j -> b j i")
189
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
190
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
191
+ h_ = self.proj_out(h_)
192
+
193
+ return x + h_
194
+
195
+
196
+ class CrossAttention(nn.Module):
197
+ def __init__(
198
+ self,
199
+ query_dim,
200
+ context_dim=None,
201
+ heads=8,
202
+ dim_head=64,
203
+ dropout=0.0,
204
+ backend=None,
205
+ ):
206
+ super().__init__()
207
+ inner_dim = dim_head * heads
208
+ context_dim = default(context_dim, query_dim)
209
+
210
+ self.scale = dim_head**-0.5
211
+ self.heads = heads
212
+
213
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
214
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
215
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
216
+
217
+ self.to_out = nn.Sequential(
218
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
219
+ )
220
+ self.backend = backend
221
+
222
+ def forward(
223
+ self,
224
+ x,
225
+ context=None,
226
+ mask=None,
227
+ additional_tokens=None,
228
+ n_times_crossframe_attn_in_self=0,
229
+ ):
230
+ h = self.heads
231
+
232
+ if additional_tokens is not None:
233
+ # get the number of masked tokens at the beginning of the output sequence
234
+ n_tokens_to_mask = additional_tokens.shape[1]
235
+ # add additional token
236
+ x = torch.cat([additional_tokens, x], dim=1)
237
+
238
+ q = self.to_q(x)
239
+ context = default(context, x)
240
+ k = self.to_k(context)
241
+ v = self.to_v(context)
242
+
243
+ if n_times_crossframe_attn_in_self:
244
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
245
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
246
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
247
+ k = repeat(
248
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
249
+ )
250
+ v = repeat(
251
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
252
+ )
253
+
254
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
255
+
256
+ ## old
257
+ """
258
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
259
+ del q, k
260
+
261
+ if exists(mask):
262
+ mask = rearrange(mask, 'b ... -> b (...)')
263
+ max_neg_value = -torch.finfo(sim.dtype).max
264
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
265
+ sim.masked_fill_(~mask, max_neg_value)
266
+
267
+ # attention, what we cannot get enough of
268
+ sim = sim.softmax(dim=-1)
269
+
270
+ out = einsum('b i j, b j d -> b i d', sim, v)
271
+ """
272
+ ## new
273
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
274
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
275
+ out = F.scaled_dot_product_attention(
276
+ q, k, v, attn_mask=mask
277
+ ) # scale is dim_head ** -0.5 per default
278
+
279
+ del q, k, v
280
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
281
+
282
+ if additional_tokens is not None:
283
+ # remove additional token
284
+ out = out[:, n_tokens_to_mask:]
285
+ return self.to_out(out)
286
+
287
+
288
+ class MemoryEfficientCrossAttention(nn.Module):
289
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
290
+ def __init__(
291
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
292
+ ):
293
+ super().__init__()
294
+ # print(
295
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
296
+ # f"{heads} heads with a dimension of {dim_head}."
297
+ # )
298
+ inner_dim = dim_head * heads
299
+ context_dim = default(context_dim, query_dim)
300
+
301
+ self.heads = heads
302
+ self.dim_head = dim_head
303
+
304
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
305
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
306
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
307
+
308
+ self.to_out = nn.Sequential(
309
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
310
+ )
311
+ self.attention_op: Optional[Any] = None
312
+
313
+ def forward(
314
+ self,
315
+ x,
316
+ context=None,
317
+ mask=None,
318
+ additional_tokens=None,
319
+ n_times_crossframe_attn_in_self=0,
320
+ ):
321
+ if additional_tokens is not None:
322
+ # get the number of masked tokens at the beginning of the output sequence
323
+ n_tokens_to_mask = additional_tokens.shape[1]
324
+ # add additional token
325
+ x = torch.cat([additional_tokens, x], dim=1)
326
+ q = self.to_q(x)
327
+ context = default(context, x)
328
+ k = self.to_k(context)
329
+ v = self.to_v(context)
330
+
331
+ if n_times_crossframe_attn_in_self:
332
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
333
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
334
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
335
+ k = repeat(
336
+ k[::n_times_crossframe_attn_in_self],
337
+ "b ... -> (b n) ...",
338
+ n=n_times_crossframe_attn_in_self,
339
+ )
340
+ v = repeat(
341
+ v[::n_times_crossframe_attn_in_self],
342
+ "b ... -> (b n) ...",
343
+ n=n_times_crossframe_attn_in_self,
344
+ )
345
+
346
+ b, _, _ = q.shape
347
+ q, k, v = map(
348
+ lambda t: t.unsqueeze(3)
349
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
350
+ .permute(0, 2, 1, 3)
351
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
352
+ .contiguous(),
353
+ (q, k, v),
354
+ )
355
+
356
+ # actually compute the attention, what we cannot get enough of
357
+ out = xformers.ops.memory_efficient_attention(
358
+ q, k, v, attn_bias=None, op=self.attention_op
359
+ )
360
+
361
+ # TODO: Use this directly in the attention operation, as a bias
362
+ if exists(mask):
363
+ raise NotImplementedError
364
+ out = (
365
+ out.unsqueeze(0)
366
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
367
+ .permute(0, 2, 1, 3)
368
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
369
+ )
370
+ if additional_tokens is not None:
371
+ # remove additional token
372
+ out = out[:, n_tokens_to_mask:]
373
+ return self.to_out(out)
374
+
375
+
376
+ class BasicTransformerBlock(nn.Module):
377
+ ATTENTION_MODES = {
378
+ "softmax": CrossAttention, # vanilla attention
379
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
380
+ }
381
+
382
+ def __init__(
383
+ self,
384
+ dim,
385
+ n_heads,
386
+ d_head,
387
+ dropout=0.0,
388
+ context_dim=None,
389
+ gated_ff=True,
390
+ checkpoint=False,
391
+ disable_self_attn=False,
392
+ attn_mode="softmax",
393
+ sdp_backend=None,
394
+ ):
395
+ super().__init__()
396
+ checkpoint = False
397
+
398
+ assert attn_mode in self.ATTENTION_MODES
399
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
400
+ print(
401
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
402
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
403
+ )
404
+ attn_mode = "softmax"
405
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
406
+ print(
407
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
408
+ )
409
+ if not XFORMERS_IS_AVAILABLE:
410
+ assert (
411
+ False
412
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
413
+ else:
414
+ print("Falling back to xformers efficient attention.")
415
+ attn_mode = "softmax-xformers"
416
+ attn_cls = self.ATTENTION_MODES[attn_mode]
417
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
418
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
419
+ else:
420
+ assert sdp_backend is None
421
+ self.disable_self_attn = disable_self_attn
422
+ self.attn1 = attn_cls(
423
+ query_dim=dim,
424
+ heads=n_heads,
425
+ dim_head=d_head,
426
+ dropout=dropout,
427
+ context_dim=context_dim if self.disable_self_attn else None,
428
+ backend=sdp_backend,
429
+ ) # is a self-attention if not self.disable_self_attn
430
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
431
+ self.attn2 = attn_cls(
432
+ query_dim=dim,
433
+ context_dim=context_dim,
434
+ heads=n_heads,
435
+ dim_head=d_head,
436
+ dropout=dropout,
437
+ backend=sdp_backend,
438
+ ) # is self-attn if context is none
439
+ self.norm1 = nn.LayerNorm(dim)
440
+ self.norm2 = nn.LayerNorm(dim)
441
+ self.norm3 = nn.LayerNorm(dim)
442
+ self.checkpoint = checkpoint
443
+ if self.checkpoint:
444
+ print(f"{self.__class__.__name__} is using checkpointing")
445
+
446
+ def forward(
447
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
448
+ ):
449
+ kwargs = {"x": x}
450
+
451
+ if context is not None:
452
+ kwargs.update({"context": context})
453
+
454
+ if additional_tokens is not None:
455
+ kwargs.update({"additional_tokens": additional_tokens})
456
+
457
+ if n_times_crossframe_attn_in_self:
458
+ kwargs.update(
459
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
460
+ )
461
+
462
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
463
+ return checkpoint(
464
+ self._forward, (x, context), self.parameters(), self.checkpoint
465
+ )
466
+
467
+ def _forward(
468
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
469
+ ):
470
+ x = (
471
+ self.attn1(
472
+ self.norm1(x),
473
+ context=context if self.disable_self_attn else None,
474
+ additional_tokens=additional_tokens,
475
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
476
+ if not self.disable_self_attn
477
+ else 0,
478
+ )
479
+ + x
480
+ )
481
+ x = (
482
+ self.attn2(
483
+ self.norm2(x), context=context, additional_tokens=additional_tokens
484
+ )
485
+ + x
486
+ )
487
+ x = self.ff(self.norm3(x)) + x
488
+ return x
489
+
490
+
491
+ class BasicTransformerSingleLayerBlock(nn.Module):
492
+ ATTENTION_MODES = {
493
+ "softmax": CrossAttention, # vanilla attention
494
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
495
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
496
+ }
497
+
498
+ def __init__(
499
+ self,
500
+ dim,
501
+ n_heads,
502
+ d_head,
503
+ dropout=0.0,
504
+ context_dim=None,
505
+ gated_ff=True,
506
+ checkpoint=True,
507
+ attn_mode="softmax",
508
+ ):
509
+ super().__init__()
510
+ assert attn_mode in self.ATTENTION_MODES
511
+ attn_cls = self.ATTENTION_MODES[attn_mode]
512
+ self.attn1 = attn_cls(
513
+ query_dim=dim,
514
+ heads=n_heads,
515
+ dim_head=d_head,
516
+ dropout=dropout,
517
+ context_dim=context_dim,
518
+ )
519
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
520
+ self.norm1 = nn.LayerNorm(dim)
521
+ self.norm2 = nn.LayerNorm(dim)
522
+ self.checkpoint = checkpoint
523
+
524
+ def forward(self, x, context=None):
525
+ return checkpoint(
526
+ self._forward, (x, context), self.parameters(), self.checkpoint
527
+ )
528
+
529
+ def _forward(self, x, context=None):
530
+ x = self.attn1(self.norm1(x), context=context) + x
531
+ x = self.ff(self.norm2(x)) + x
532
+ return x
533
+
534
+
535
+ class SpatialTransformer(nn.Module):
536
+ """
537
+ Transformer block for image-like data.
538
+ First, project the input (aka embedding)
539
+ and reshape to b, t, d.
540
+ Then apply standard transformer action.
541
+ Finally, reshape to image
542
+ NEW: use_linear for more efficiency instead of the 1x1 convs
543
+ """
544
+
545
+ def __init__(
546
+ self,
547
+ in_channels,
548
+ n_heads,
549
+ d_head,
550
+ depth=1,
551
+ dropout=0.0,
552
+ context_dim=None,
553
+ disable_self_attn=False,
554
+ use_linear=False,
555
+ attn_type="softmax",
556
+ use_checkpoint=True,
557
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
558
+ sdp_backend=None,
559
+ ):
560
+ super().__init__()
561
+ print(
562
+ f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
563
+ )
564
+ from omegaconf import ListConfig
565
+
566
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
567
+ context_dim = [context_dim]
568
+ if exists(context_dim) and isinstance(context_dim, list):
569
+ if depth != len(context_dim):
570
+ # print(
571
+ # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
572
+ # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
573
+ # )
574
+ # depth does not match context dims.
575
+ assert all(
576
+ map(lambda x: x == context_dim[0], context_dim)
577
+ ), "need homogenous context_dim to match depth automatically"
578
+ context_dim = depth * [context_dim[0]]
579
+ elif context_dim is None:
580
+ context_dim = [None] * depth
581
+ self.in_channels = in_channels
582
+ inner_dim = n_heads * d_head
583
+ self.norm = Normalize(in_channels)
584
+ if not use_linear:
585
+ self.proj_in = nn.Conv2d(
586
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
587
+ )
588
+ else:
589
+ self.proj_in = nn.Linear(in_channels, inner_dim)
590
+
591
+ self.transformer_blocks = nn.ModuleList(
592
+ [
593
+ BasicTransformerBlock(
594
+ inner_dim,
595
+ n_heads,
596
+ d_head,
597
+ dropout=dropout,
598
+ context_dim=context_dim[d],
599
+ disable_self_attn=disable_self_attn,
600
+ attn_mode=attn_type,
601
+ checkpoint=use_checkpoint,
602
+ sdp_backend=sdp_backend,
603
+ )
604
+ for d in range(depth)
605
+ ]
606
+ )
607
+ if not use_linear:
608
+ self.proj_out = zero_module(
609
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
610
+ )
611
+ else:
612
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
613
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
614
+ self.use_linear = use_linear
615
+
616
+ def forward(self, x, context=None):
617
+ # note: if no context is given, cross-attention defaults to self-attention
618
+ if not isinstance(context, list):
619
+ context = [context]
620
+ b, c, h, w = x.shape
621
+ x_in = x
622
+ x = self.norm(x)
623
+ if not self.use_linear:
624
+ x = self.proj_in(x)
625
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
626
+ if self.use_linear:
627
+ x = self.proj_in(x)
628
+ for i, block in enumerate(self.transformer_blocks):
629
+ if i > 0 and len(context) == 1:
630
+ i = 0 # use same context for each block
631
+ x = block(x, context=context[i])
632
+ if self.use_linear:
633
+ x = self.proj_out(x)
634
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
635
+ if not self.use_linear:
636
+ x = self.proj_out(x)
637
+ return x + x_in
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+
7
+ from ....util import default, instantiate_from_config
8
+ from ..lpips.loss.lpips import LPIPS
9
+ from ..lpips.model.model import NLayerDiscriminator, weights_init
10
+ from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
11
+
12
+
13
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
14
+ if global_step < threshold:
15
+ weight = value
16
+ return weight
17
+
18
+
19
+ class LatentLPIPS(nn.Module):
20
+ def __init__(
21
+ self,
22
+ decoder_config,
23
+ perceptual_weight=1.0,
24
+ latent_weight=1.0,
25
+ scale_input_to_tgt_size=False,
26
+ scale_tgt_to_input_size=False,
27
+ perceptual_weight_on_inputs=0.0,
28
+ ):
29
+ super().__init__()
30
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
31
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
32
+ self.init_decoder(decoder_config)
33
+ self.perceptual_loss = LPIPS().eval()
34
+ self.perceptual_weight = perceptual_weight
35
+ self.latent_weight = latent_weight
36
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
37
+
38
+ def init_decoder(self, config):
39
+ self.decoder = instantiate_from_config(config)
40
+ if hasattr(self.decoder, "encoder"):
41
+ del self.decoder.encoder
42
+
43
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
44
+ log = dict()
45
+ loss = (latent_inputs - latent_predictions) ** 2
46
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
47
+ image_reconstructions = None
48
+ if self.perceptual_weight > 0.0:
49
+ image_reconstructions = self.decoder.decode(latent_predictions)
50
+ image_targets = self.decoder.decode(latent_inputs)
51
+ perceptual_loss = self.perceptual_loss(
52
+ image_targets.contiguous(), image_reconstructions.contiguous()
53
+ )
54
+ loss = (
55
+ self.latent_weight * loss.mean()
56
+ + self.perceptual_weight * perceptual_loss.mean()
57
+ )
58
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
59
+
60
+ if self.perceptual_weight_on_inputs > 0.0:
61
+ image_reconstructions = default(
62
+ image_reconstructions, self.decoder.decode(latent_predictions)
63
+ )
64
+ if self.scale_input_to_tgt_size:
65
+ image_inputs = torch.nn.functional.interpolate(
66
+ image_inputs,
67
+ image_reconstructions.shape[2:],
68
+ mode="bicubic",
69
+ antialias=True,
70
+ )
71
+ elif self.scale_tgt_to_input_size:
72
+ image_reconstructions = torch.nn.functional.interpolate(
73
+ image_reconstructions,
74
+ image_inputs.shape[2:],
75
+ mode="bicubic",
76
+ antialias=True,
77
+ )
78
+
79
+ perceptual_loss2 = self.perceptual_loss(
80
+ image_inputs.contiguous(), image_reconstructions.contiguous()
81
+ )
82
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
83
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
84
+ return loss, log
85
+
86
+
87
+ class GeneralLPIPSWithDiscriminator(nn.Module):
88
+ def __init__(
89
+ self,
90
+ disc_start: int,
91
+ logvar_init: float = 0.0,
92
+ pixelloss_weight=1.0,
93
+ disc_num_layers: int = 3,
94
+ disc_in_channels: int = 3,
95
+ disc_factor: float = 1.0,
96
+ disc_weight: float = 1.0,
97
+ perceptual_weight: float = 1.0,
98
+ disc_loss: str = "hinge",
99
+ scale_input_to_tgt_size: bool = False,
100
+ dims: int = 2,
101
+ learn_logvar: bool = False,
102
+ regularization_weights: Union[None, dict] = None,
103
+ ):
104
+ super().__init__()
105
+ self.dims = dims
106
+ if self.dims > 2:
107
+ print(
108
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
109
+ f"the LPIPS loss will be applied to each frame independently. "
110
+ )
111
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
112
+ assert disc_loss in ["hinge", "vanilla"]
113
+ self.pixel_weight = pixelloss_weight
114
+ self.perceptual_loss = LPIPS().eval()
115
+ self.perceptual_weight = perceptual_weight
116
+ # output log variance
117
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
118
+ self.learn_logvar = learn_logvar
119
+
120
+ self.discriminator = NLayerDiscriminator(
121
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
122
+ ).apply(weights_init)
123
+ self.discriminator_iter_start = disc_start
124
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
125
+ self.disc_factor = disc_factor
126
+ self.discriminator_weight = disc_weight
127
+ self.regularization_weights = default(regularization_weights, {})
128
+
129
+ def get_trainable_parameters(self) -> Any:
130
+ return self.discriminator.parameters()
131
+
132
+ def get_trainable_autoencoder_parameters(self) -> Any:
133
+ if self.learn_logvar:
134
+ yield self.logvar
135
+ yield from ()
136
+
137
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
138
+ if last_layer is not None:
139
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
140
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
141
+ else:
142
+ nll_grads = torch.autograd.grad(
143
+ nll_loss, self.last_layer[0], retain_graph=True
144
+ )[0]
145
+ g_grads = torch.autograd.grad(
146
+ g_loss, self.last_layer[0], retain_graph=True
147
+ )[0]
148
+
149
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
150
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
151
+ d_weight = d_weight * self.discriminator_weight
152
+ return d_weight
153
+
154
+ def forward(
155
+ self,
156
+ regularization_log,
157
+ inputs,
158
+ reconstructions,
159
+ optimizer_idx,
160
+ global_step,
161
+ last_layer=None,
162
+ split="train",
163
+ weights=None,
164
+ ):
165
+ if self.scale_input_to_tgt_size:
166
+ inputs = torch.nn.functional.interpolate(
167
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
168
+ )
169
+
170
+ if self.dims > 2:
171
+ inputs, reconstructions = map(
172
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
173
+ (inputs, reconstructions),
174
+ )
175
+
176
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
177
+ if self.perceptual_weight > 0:
178
+ p_loss = self.perceptual_loss(
179
+ inputs.contiguous(), reconstructions.contiguous()
180
+ )
181
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
182
+
183
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
184
+ weighted_nll_loss = nll_loss
185
+ if weights is not None:
186
+ weighted_nll_loss = weights * nll_loss
187
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
188
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
189
+
190
+ # now the GAN part
191
+ if optimizer_idx == 0:
192
+ # generator update
193
+ logits_fake = self.discriminator(reconstructions.contiguous())
194
+ g_loss = -torch.mean(logits_fake)
195
+
196
+ if self.disc_factor > 0.0:
197
+ try:
198
+ d_weight = self.calculate_adaptive_weight(
199
+ nll_loss, g_loss, last_layer=last_layer
200
+ )
201
+ except RuntimeError:
202
+ assert not self.training
203
+ d_weight = torch.tensor(0.0)
204
+ else:
205
+ d_weight = torch.tensor(0.0)
206
+
207
+ disc_factor = adopt_weight(
208
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
209
+ )
210
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
211
+ log = dict()
212
+ for k in regularization_log:
213
+ if k in self.regularization_weights:
214
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
215
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
216
+
217
+ log.update(
218
+ {
219
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
220
+ "{}/logvar".format(split): self.logvar.detach(),
221
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
222
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
223
+ "{}/d_weight".format(split): d_weight.detach(),
224
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
225
+ "{}/g_loss".format(split): g_loss.detach().mean(),
226
+ }
227
+ )
228
+
229
+ return loss, log
230
+
231
+ if optimizer_idx == 1:
232
+ # second pass for discriminator update
233
+ logits_real = self.discriminator(inputs.contiguous().detach())
234
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
235
+
236
+ disc_factor = adopt_weight(
237
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
238
+ )
239
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
240
+
241
+ log = {
242
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
243
+ "{}/logits_real".format(split): logits_real.detach().mean(),
244
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
245
+ }
246
+ return d_loss, log
sgm/modules/autoencoding/lpips/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/loss/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from ..util import get_ckpt_path
10
+
11
+
12
+ class LPIPS(nn.Module):
13
+ # Learned perceptual metric
14
+ def __init__(self, use_dropout=True):
15
+ super().__init__()
16
+ self.scaling_layer = ScalingLayer()
17
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
18
+ self.net = vgg16(pretrained=True, requires_grad=False)
19
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
+ self.load_from_pretrained()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def load_from_pretrained(self, name="vgg_lpips"):
29
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
+ self.load_state_dict(
31
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
+ )
33
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, name="vgg_lpips"):
37
+ if name != "vgg_lpips":
38
+ raise NotImplementedError
39
+ model = cls()
40
+ ckpt = get_ckpt_path(name)
41
+ model.load_state_dict(
42
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
+ )
44
+ return model
45
+
46
+ def forward(self, input, target):
47
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
+ feats0, feats1, diffs = {}, {}, {}
50
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
+ for kk in range(len(self.chns)):
52
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
+ outs1[kk]
54
+ )
55
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
+
57
+ res = [
58
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
+ for kk in range(len(self.chns))
60
+ ]
61
+ val = res[0]
62
+ for l in range(1, len(self.chns)):
63
+ val += res[l]
64
+ return val
65
+
66
+
67
+ class ScalingLayer(nn.Module):
68
+ def __init__(self):
69
+ super(ScalingLayer, self).__init__()
70
+ self.register_buffer(
71
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
+ )
73
+ self.register_buffer(
74
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
+ )
76
+
77
+ def forward(self, inp):
78
+ return (inp - self.shift) / self.scale
79
+
80
+
81
+ class NetLinLayer(nn.Module):
82
+ """A single linear layer which does a 1x1 conv"""
83
+
84
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
+ super(NetLinLayer, self).__init__()
86
+ layers = (
87
+ [
88
+ nn.Dropout(),
89
+ ]
90
+ if (use_dropout)
91
+ else []
92
+ )
93
+ layers += [
94
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
+ ]
96
+ self.model = nn.Sequential(*layers)
97
+
98
+
99
+ class vgg16(torch.nn.Module):
100
+ def __init__(self, requires_grad=False, pretrained=True):
101
+ super(vgg16, self).__init__()
102
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
+ self.slice1 = torch.nn.Sequential()
104
+ self.slice2 = torch.nn.Sequential()
105
+ self.slice3 = torch.nn.Sequential()
106
+ self.slice4 = torch.nn.Sequential()
107
+ self.slice5 = torch.nn.Sequential()
108
+ self.N_slices = 5
109
+ for x in range(4):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(4, 9):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(9, 16):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(16, 23):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(23, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ if not requires_grad:
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, X):
124
+ h = self.slice1(X)
125
+ h_relu1_2 = h
126
+ h = self.slice2(h)
127
+ h_relu2_2 = h
128
+ h = self.slice3(h)
129
+ h_relu3_3 = h
130
+ h = self.slice4(h)
131
+ h_relu4_3 = h
132
+ h = self.slice5(h)
133
+ h_relu5_3 = h
134
+ vgg_outputs = namedtuple(
135
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
+ )
137
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
+ return out
139
+
140
+
141
+ def normalize_tensor(x, eps=1e-10):
142
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
+ return x / (norm_factor + eps)
144
+
145
+
146
+ def spatial_average(x, keepdim=True):
147
+ return x.mean([2, 3], keepdim=keepdim)
sgm/modules/autoencoding/lpips/model/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+
26
+ --------------------------- LICENSE FOR pix2pix --------------------------------
27
+ BSD License
28
+
29
+ For pix2pix software
30
+ Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
+ All rights reserved.
32
+
33
+ Redistribution and use in source and binary forms, with or without
34
+ modification, are permitted provided that the following conditions are met:
35
+
36
+ * Redistributions of source code must retain the above copyright notice, this
37
+ list of conditions and the following disclaimer.
38
+
39
+ * Redistributions in binary form must reproduce the above copyright notice,
40
+ this list of conditions and the following disclaimer in the documentation
41
+ and/or other materials provided with the distribution.
42
+
43
+ ----------------------------- LICENSE FOR DCGAN --------------------------------
44
+ BSD License
45
+
46
+ For dcgan.torch software
47
+
48
+ Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
+
50
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
+
52
+ Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
+
54
+ Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
+
56
+ Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
+
58
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/model/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/model/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find("BatchNorm") != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+
22
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
+ """Construct a PatchGAN discriminator
24
+ Parameters:
25
+ input_nc (int) -- the number of channels in input images
26
+ ndf (int) -- the number of filters in the last conv layer
27
+ n_layers (int) -- the number of conv layers in the discriminator
28
+ norm_layer -- normalization layer
29
+ """
30
+ super(NLayerDiscriminator, self).__init__()
31
+ if not use_actnorm:
32
+ norm_layer = nn.BatchNorm2d
33
+ else:
34
+ norm_layer = ActNorm
35
+ if (
36
+ type(norm_layer) == functools.partial
37
+ ): # no need to use bias as BatchNorm2d has affine parameters
38
+ use_bias = norm_layer.func != nn.BatchNorm2d
39
+ else:
40
+ use_bias = norm_layer != nn.BatchNorm2d
41
+
42
+ kw = 4
43
+ padw = 1
44
+ sequence = [
45
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
+ nn.LeakyReLU(0.2, True),
47
+ ]
48
+ nf_mult = 1
49
+ nf_mult_prev = 1
50
+ for n in range(1, n_layers): # gradually increase the number of filters
51
+ nf_mult_prev = nf_mult
52
+ nf_mult = min(2**n, 8)
53
+ sequence += [
54
+ nn.Conv2d(
55
+ ndf * nf_mult_prev,
56
+ ndf * nf_mult,
57
+ kernel_size=kw,
58
+ stride=2,
59
+ padding=padw,
60
+ bias=use_bias,
61
+ ),
62
+ norm_layer(ndf * nf_mult),
63
+ nn.LeakyReLU(0.2, True),
64
+ ]
65
+
66
+ nf_mult_prev = nf_mult
67
+ nf_mult = min(2**n_layers, 8)
68
+ sequence += [
69
+ nn.Conv2d(
70
+ ndf * nf_mult_prev,
71
+ ndf * nf_mult,
72
+ kernel_size=kw,
73
+ stride=1,
74
+ padding=padw,
75
+ bias=use_bias,
76
+ ),
77
+ norm_layer(ndf * nf_mult),
78
+ nn.LeakyReLU(0.2, True),
79
+ ]
80
+
81
+ sequence += [
82
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
+ ] # output 1 channel prediction map
84
+ self.main = nn.Sequential(*sequence)
85
+
86
+ def forward(self, input):
87
+ """Standard forward."""
88
+ return self.main(input)
sgm/modules/autoencoding/lpips/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
+
11
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
+
13
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
+
15
+
16
+ def download(url, local_path, chunk_size=1024):
17
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
+ with requests.get(url, stream=True) as r:
19
+ total_size = int(r.headers.get("content-length", 0))
20
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
+ with open(local_path, "wb") as f:
22
+ for data in r.iter_content(chunk_size=chunk_size):
23
+ if data:
24
+ f.write(data)
25
+ pbar.update(chunk_size)
26
+
27
+
28
+ def md5_hash(path):
29
+ with open(path, "rb") as f:
30
+ content = f.read()
31
+ return hashlib.md5(content).hexdigest()
32
+
33
+
34
+ def get_ckpt_path(name, root, check=False):
35
+ assert name in URL_MAP
36
+ path = os.path.join(root, CKPT_MAP[name])
37
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
+ download(URL_MAP[name], path)
40
+ md5 = md5_hash(path)
41
+ assert md5 == MD5_MAP[name], md5
42
+ return path
43
+
44
+
45
+ class ActNorm(nn.Module):
46
+ def __init__(
47
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
+ ):
49
+ assert affine
50
+ super().__init__()
51
+ self.logdet = logdet
52
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
+ self.allow_reverse_init = allow_reverse_init
55
+
56
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
+
58
+ def initialize(self, input):
59
+ with torch.no_grad():
60
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
+ mean = (
62
+ flatten.mean(1)
63
+ .unsqueeze(1)
64
+ .unsqueeze(2)
65
+ .unsqueeze(3)
66
+ .permute(1, 0, 2, 3)
67
+ )
68
+ std = (
69
+ flatten.std(1)
70
+ .unsqueeze(1)
71
+ .unsqueeze(2)
72
+ .unsqueeze(3)
73
+ .permute(1, 0, 2, 3)
74
+ )
75
+
76
+ self.loc.data.copy_(-mean)
77
+ self.scale.data.copy_(1 / (std + 1e-6))
78
+
79
+ def forward(self, input, reverse=False):
80
+ if reverse:
81
+ return self.reverse(input)
82
+ if len(input.shape) == 2:
83
+ input = input[:, :, None, None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ _, _, height, width = input.shape
89
+
90
+ if self.training and self.initialized.item() == 0:
91
+ self.initialize(input)
92
+ self.initialized.fill_(1)
93
+
94
+ h = self.scale * (input + self.loc)
95
+
96
+ if squeeze:
97
+ h = h.squeeze(-1).squeeze(-1)
98
+
99
+ if self.logdet:
100
+ log_abs = torch.log(torch.abs(self.scale))
101
+ logdet = height * width * torch.sum(log_abs)
102
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
103
+ return h, logdet
104
+
105
+ return h
106
+
107
+ def reverse(self, output):
108
+ if self.training and self.initialized.item() == 0:
109
+ if not self.allow_reverse_init:
110
+ raise RuntimeError(
111
+ "Initializing ActNorm in reverse direction is "
112
+ "disabled by default. Use allow_reverse_init=True to enable."
113
+ )
114
+ else:
115
+ self.initialize(output)
116
+ self.initialized.fill_(1)
117
+
118
+ if len(output.shape) == 2:
119
+ output = output[:, :, None, None]
120
+ squeeze = True
121
+ else:
122
+ squeeze = False
123
+
124
+ h = output / self.scale - self.loc
125
+
126
+ if squeeze:
127
+ h = h.squeeze(-1).squeeze(-1)
128
+ return h
sgm/modules/autoencoding/lpips/vqperceptual.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def hinge_d_loss(logits_real, logits_fake):
6
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
7
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
+ d_loss = 0.5 * (loss_real + loss_fake)
9
+ return d_loss
10
+
11
+
12
+ def vanilla_d_loss(logits_real, logits_fake):
13
+ d_loss = 0.5 * (
14
+ torch.mean(torch.nn.functional.softplus(-logits_real))
15
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
16
+ )
17
+ return d_loss
sgm/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import DiagonalGaussianDistribution
9
+
10
+
11
+ class AbstractRegularizer(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16
+ raise NotImplementedError()
17
+
18
+ @abstractmethod
19
+ def get_trainable_parameters(self) -> Any:
20
+ raise NotImplementedError()
21
+
22
+
23
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
24
+ def __init__(self, sample: bool = True):
25
+ super().__init__()
26
+ self.sample = sample
27
+
28
+ def get_trainable_parameters(self) -> Any:
29
+ yield from ()
30
+
31
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
32
+ log = dict()
33
+ posterior = DiagonalGaussianDistribution(z)
34
+ if self.sample:
35
+ z = posterior.sample()
36
+ else:
37
+ z = posterior.mode()
38
+ kl_loss = posterior.kl()
39
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
40
+ log["kl_loss"] = kl_loss
41
+ return z, log
42
+
43
+
44
+ def measure_perplexity(predicted_indices, num_centroids):
45
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
46
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
47
+ encodings = (
48
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
49
+ )
50
+ avg_probs = encodings.mean(0)
51
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
52
+ cluster_use = torch.sum(avg_probs > 0)
53
+ return perplexity, cluster_use
sgm/modules/diffusionmodules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .denoiser import Denoiser
2
+ from .discretizer import Discretization
3
+ from .loss import StandardDiffusionLoss
4
+ from .model import Decoder, Encoder, Model
5
+ from .openaimodel import UNetModel
6
+ from .sampling import BaseDiffusionSampler
7
+ from .wrappers import OpenAIWrapper
sgm/modules/diffusionmodules/denoiser.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from ...util import append_dims, instantiate_from_config
4
+
5
+
6
+ class Denoiser(nn.Module):
7
+ def __init__(self, weighting_config, scaling_config):
8
+ super().__init__()
9
+
10
+ self.weighting = instantiate_from_config(weighting_config)
11
+ self.scaling = instantiate_from_config(scaling_config)
12
+
13
+ def possibly_quantize_sigma(self, sigma):
14
+ return sigma
15
+
16
+ def possibly_quantize_c_noise(self, c_noise):
17
+ return c_noise
18
+
19
+ def w(self, sigma):
20
+ return self.weighting(sigma)
21
+
22
+ def __call__(self, network, input, sigma, cond):
23
+ sigma = self.possibly_quantize_sigma(sigma)
24
+ sigma_shape = sigma.shape
25
+ sigma = append_dims(sigma, input.ndim)
26
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
27
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
28
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
29
+
30
+
31
+ class DiscreteDenoiser(Denoiser):
32
+ def __init__(
33
+ self,
34
+ weighting_config,
35
+ scaling_config,
36
+ num_idx,
37
+ discretization_config,
38
+ do_append_zero=False,
39
+ quantize_c_noise=True,
40
+ flip=True,
41
+ ):
42
+ super().__init__(weighting_config, scaling_config)
43
+ sigmas = instantiate_from_config(discretization_config)(
44
+ num_idx, do_append_zero=do_append_zero, flip=flip
45
+ )
46
+ self.register_buffer("sigmas", sigmas)
47
+ self.quantize_c_noise = quantize_c_noise
48
+
49
+ def sigma_to_idx(self, sigma):
50
+ dists = sigma - self.sigmas[:, None]
51
+ return dists.abs().argmin(dim=0).view(sigma.shape)
52
+
53
+ def idx_to_sigma(self, idx):
54
+ return self.sigmas[idx]
55
+
56
+ def possibly_quantize_sigma(self, sigma):
57
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
58
+
59
+ def possibly_quantize_c_noise(self, c_noise):
60
+ if self.quantize_c_noise:
61
+ return self.sigma_to_idx(c_noise)
62
+ else:
63
+ return c_noise
64
+
65
+
66
+ class DiscreteDenoiserWithControl(DiscreteDenoiser):
67
+ def __call__(self, network, input, sigma, cond, control_scale):
68
+ sigma = self.possibly_quantize_sigma(sigma)
69
+ sigma_shape = sigma.shape
70
+ sigma = append_dims(sigma, input.ndim)
71
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
72
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
73
+ return network(input * c_in, c_noise, cond, control_scale) * c_out + input * c_skip
sgm/modules/diffusionmodules/denoiser_scaling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class EDMScaling:
5
+ def __init__(self, sigma_data=0.5):
6
+ self.sigma_data = sigma_data
7
+
8
+ def __call__(self, sigma):
9
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
10
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
11
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
12
+ c_noise = 0.25 * sigma.log()
13
+ return c_skip, c_out, c_in, c_noise
14
+
15
+
16
+ class EpsScaling:
17
+ def __call__(self, sigma):
18
+ c_skip = torch.ones_like(sigma, device=sigma.device)
19
+ c_out = -sigma
20
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
21
+ c_noise = sigma.clone()
22
+ return c_skip, c_out, c_in, c_noise
23
+
24
+
25
+ class VScaling:
26
+ def __call__(self, sigma):
27
+ c_skip = 1.0 / (sigma**2 + 1.0)
28
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
29
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
30
+ c_noise = sigma.clone()
31
+ return c_skip, c_out, c_in, c_noise
sgm/modules/diffusionmodules/denoiser_weighting.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class UnitWeighting:
4
+ def __call__(self, sigma):
5
+ return torch.ones_like(sigma, device=sigma.device)
6
+
7
+
8
+ class EDMWeighting:
9
+ def __init__(self, sigma_data=0.5):
10
+ self.sigma_data = sigma_data
11
+
12
+ def __call__(self, sigma):
13
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
14
+
15
+
16
+ class VWeighting(EDMWeighting):
17
+ def __init__(self):
18
+ super().__init__(sigma_data=1.0)
19
+
20
+
21
+ class EpsWeighting:
22
+ def __call__(self, sigma):
23
+ return sigma**-2
24
+
sgm/modules/diffusionmodules/discretizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...modules.diffusionmodules.util import make_beta_schedule
8
+ from ...util import append_zero
9
+
10
+
11
+ def generate_roughly_equally_spaced_steps(
12
+ num_substeps: int, max_step: int
13
+ ) -> np.ndarray:
14
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15
+
16
+
17
+ class Discretization:
18
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19
+ sigmas = self.get_sigmas(n, device=device)
20
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
21
+ return sigmas if not flip else torch.flip(sigmas, (0,))
22
+
23
+ @abstractmethod
24
+ def get_sigmas(self, n, device):
25
+ pass
26
+
27
+
28
+ class EDMDiscretization(Discretization):
29
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
30
+ self.sigma_min = sigma_min
31
+ self.sigma_max = sigma_max
32
+ self.rho = rho
33
+
34
+ def get_sigmas(self, n, device="cpu"):
35
+ ramp = torch.linspace(0, 1, n, device=device)
36
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
37
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
38
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39
+ return sigmas
40
+
41
+
42
+ class LegacyDDPMDiscretization(Discretization):
43
+ def __init__(
44
+ self,
45
+ linear_start=0.00085,
46
+ linear_end=0.0120,
47
+ num_timesteps=1000,
48
+ ):
49
+ super().__init__()
50
+ self.num_timesteps = num_timesteps
51
+ betas = make_beta_schedule(
52
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53
+ )
54
+ alphas = 1.0 - betas
55
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
56
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
57
+
58
+ def get_sigmas(self, n, device="cpu"):
59
+ if n < self.num_timesteps:
60
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61
+ alphas_cumprod = self.alphas_cumprod[timesteps]
62
+ elif n == self.num_timesteps:
63
+ alphas_cumprod = self.alphas_cumprod
64
+ else:
65
+ raise ValueError
66
+
67
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69
+ return torch.flip(sigmas, (0,))