Spaces:
Runtime error
Runtime error
space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -35
- .gitignore +1 -0
- LICENSE +72 -0
- README.md +5 -13
- SUPIR/__init__.py +0 -0
- SUPIR/models/SUPIR_model.py +176 -0
- SUPIR/models/__init__.py +0 -0
- SUPIR/modules/SUPIR_v0.py +718 -0
- SUPIR/modules/__init__.py +11 -0
- SUPIR/util.py +190 -0
- SUPIR/utils/__init__.py +0 -0
- SUPIR/utils/colorfix.py +120 -0
- SUPIR/utils/devices.py +138 -0
- SUPIR/utils/face_restoration_helper.py +514 -0
- SUPIR/utils/file.py +79 -0
- SUPIR/utils/tilevae.py +971 -0
- app.py +911 -0
- configs/clip_vit_config.json +25 -0
- configs/tokenizer/config.json +171 -0
- configs/tokenizer/merges.txt +0 -0
- configs/tokenizer/preprocessor_config.json +19 -0
- configs/tokenizer/special_tokens_map.json +24 -0
- configs/tokenizer/tokenizer.json +0 -0
- configs/tokenizer/tokenizer_config.json +34 -0
- configs/tokenizer/vocab.json +0 -0
- options/SUPIR_v0.yaml +140 -0
- sgm/__init__.py +4 -0
- sgm/lr_scheduler.py +135 -0
- sgm/models/__init__.py +2 -0
- sgm/models/autoencoder.py +335 -0
- sgm/models/diffusion.py +320 -0
- sgm/modules/__init__.py +8 -0
- sgm/modules/attention.py +637 -0
- sgm/modules/autoencoding/__init__.py +0 -0
- sgm/modules/autoencoding/losses/__init__.py +246 -0
- sgm/modules/autoencoding/lpips/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
- sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
- sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
- sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/model/model.py +88 -0
- sgm/modules/autoencoding/lpips/util.py +128 -0
- sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
- sgm/modules/autoencoding/regularizers/__init__.py +53 -0
- sgm/modules/diffusionmodules/__init__.py +7 -0
- sgm/modules/diffusionmodules/denoiser.py +73 -0
- sgm/modules/diffusionmodules/denoiser_scaling.py +31 -0
- sgm/modules/diffusionmodules/denoiser_weighting.py +24 -0
- sgm/modules/diffusionmodules/discretizer.py +69 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
LICENSE
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SUPIR Software License Agreement
|
2 |
+
|
3 |
+
Copyright (c) 2024 by SupPixel Pty Ltd. All rights reserved.
|
4 |
+
|
5 |
+
This work is jointly owned by SupPixel Pty Ltd, an Australian company (ACN 676 560 320). All rights reserved. The work is created by the author team of SUPIR at SupPixel Pty Ltd.
|
6 |
+
|
7 |
+
This License Agreement ("Agreement") is made and entered into by and between SupPixel Pty Ltd (collectively, the "Licensor") and any person or entity who accesses, uses, or distributes the SUPIR software (the "Licensee").
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
a. Licensed Software: Refers to the SUPIR software, including its source code, associated documentation, and any updates, upgrades, or modifications thereof. The Licensed Software includes both open-source components and proprietary components, as described herein.
|
12 |
+
|
13 |
+
b. Neural Network Components: Includes the weights, biases, and architecture of the proprietary neural network(s) developed and trained by the Licensor. This excludes any pre-existing open-source neural networks used in conjunction with the Licensor's proprietary networks, for which the Licensor does not hold copyright.
|
14 |
+
|
15 |
+
c. Commercial Use: Refers to any use of the Licensed Software or its components, in whole or in part, for the purpose of generating revenue, conducting business operations, or creating products or services for sale or profit. This includes, but is not limited to:
|
16 |
+
|
17 |
+
- Deploying the software on a server to provide a service to users for a fee.
|
18 |
+
- Using the software to process images that are intended for sale.
|
19 |
+
- Integrating the software or its components into a commercial product that is sold to end-users.
|
20 |
+
- Using the software to provide commercial consulting or contracting services, such as image analysis, enhancement, or modification for clients.
|
21 |
+
- Using the software to process images that are subsequently used in training datasets for machine learning models or other types of data analysis. This includes datasets that are used internally, sold, or provided to third parties under a commercial agreement.
|
22 |
+
- Licensing the software or its components to other businesses for commercial purposes.
|
23 |
+
|
24 |
+
2. License Grant
|
25 |
+
|
26 |
+
a. Open-Source Components: The Licensor grants the Licensee a non-exclusive, worldwide, royalty-free license to use, copy, modify, and distribute the source code of the open-source components of the Licensed Software, subject to the terms and conditions of their respective open-source licenses. The specific open-source licenses for each component are indicated in the source code files and documentation of the Licensed Software.
|
27 |
+
|
28 |
+
b. Proprietary Neural Network Components: Notwithstanding the open-source license granted in Section 2(a) for open-source components, the Licensor retains all rights, title, and interest in and to the proprietary Neural Network Components developed and trained by the Licensor. The license granted herein does not include rights to use, copy, modify, distribute, or create derivative works of these proprietary Neural Network Components without obtaining express written permission from the Licensor.
|
29 |
+
|
30 |
+
c. The source code of the Licensed Software is available as open-source on this GitHub repository. However, the open-source license does not grant any rights to the weights, biases, or architecture of the proprietary neural network(s) used in the Licensed Software, which remain the exclusive property of the Licensor.
|
31 |
+
|
32 |
+
3. Commercial Use and Licensing
|
33 |
+
|
34 |
+
a. The Licensor retains all ownership and commercial usage rights in the Licensed Software, including its neural network(s), regardless of the open-source availability of portions of the Licensed Software's source code. Under the open-source license, anyone may freely use, copy, modify, and distribute the source code of the Licensed Software for non-commercial purposes only. Similarly, the proprietary neural network weights, biases, and architecture may only be used for non-commercial purposes. Any commercial use, copying, modification, distribution, or creation of derivative works of the Licensed Software's source code or the proprietary neural network weights, biases, and architecture is strictly prohibited without express written permission from the Licensor.
|
35 |
+
|
36 |
+
b. Any Commercial Use of the Licensed Software, including its source code and neural network(s), requires a separate commercial license agreement with the Licensor. The terms and conditions of such commercial licensing, including the scope of use, fees, and delivery of the relevant components, shall be negotiated and agreed upon between the Licensor and the Licensee in writing. For commercial licensing inquiries, please contact the Licensor at [email protected].
|
37 |
+
|
38 |
+
4. Intellectual Property Rights
|
39 |
+
|
40 |
+
a. The Licensor retains all intellectual property rights in the Licensed Software, including copyrights, patents, trademarks, trade secrets, and any other proprietary rights. The Licensee acknowledges that the Licensor is the sole owner of the Licensed Software and its components, including the proprietary Neural Network Components.
|
41 |
+
|
42 |
+
b. The Licensee agrees not to challenge or contest the Licensor's ownership or intellectual property rights in the Licensed Software, nor to assist or encourage any third party in doing so.
|
43 |
+
|
44 |
+
c. The Licensee shall promptly notify the Licensor of any actual or suspected infringement of the Licensor's intellectual property rights in the Licensed Software by third parties, and shall provide reasonable assistance to the Licensor in enforcing its rights against such infringement.
|
45 |
+
|
46 |
+
5. Warranty Disclaimer
|
47 |
+
|
48 |
+
The Licensed Software is provided "as is", without warranty of any kind, either express or implied, including, but not limited to, the implied warranties of merchantability, fitness for a particular purpose, and non-infringement. The Licensor disclaims all warranties, express or implied, regarding the accuracy, reliability, completeness, or usefulness of the Licensed Software or its components. The Licensor does not warrant that the Licensed Software will be error-free, uninterrupted, or free from defects or vulnerabilities.
|
49 |
+
|
50 |
+
6. Limitation of Liability
|
51 |
+
|
52 |
+
In no event shall the Licensor be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the Licensed Software or the use or other dealings in the Licensed Software. The Licensor shall not be liable for any direct, indirect, incidental, special, consequential, or exemplary damages, including but not limited to damages for loss of profits, goodwill, use, data, or other intangible losses, even if the Licensor has been advised of the possibility of such damages.
|
53 |
+
|
54 |
+
7. Indemnification
|
55 |
+
|
56 |
+
The Licensee agrees to indemnify, defend, and hold harmless the Licensor and its affiliates, officers, directors, employees, and agents from and against any and all claims, liabilities, damages, losses, costs, expenses, or fees (including reasonable attorneys' fees) arising from or relating to the Licensee's use of the Licensed Software or any violation of this Agreement by the Licensee.
|
57 |
+
|
58 |
+
8. Governing Law and Jurisdiction
|
59 |
+
|
60 |
+
This Agreement shall be governed by and construed in accordance with the laws of the State of New South Wales, Australia, without giving effect to any choice or conflict of law provision or rule. Any legal action or proceeding arising out of or relating to this Agreement shall be brought exclusively in the courts of New South Wales, Australia, and each party irrevocably submits to the jurisdiction of such courts.
|
61 |
+
|
62 |
+
9. Entire Agreement
|
63 |
+
|
64 |
+
This Agreement constitutes the entire agreement between the parties concerning the subject matter hereof and supersedes all prior or contemporaneous communications, understandings, and agreements, whether written or oral, between the parties regarding such subject matter.
|
65 |
+
|
66 |
+
By accessing, using, or distributing the Licensed Software, the Licensee agrees to be bound by the terms and conditions of this Agreement. If the Licensee does not agree to the terms of this Agreement, the Licensee must not access, use, or distribute the Licensed Software.
|
67 |
+
|
68 |
+
For inquiries or to obtain permission for Commercial Use, please contact:
|
69 |
+
|
70 |
+
Dr. Jinjin Gu
|
71 |
+
SupPixel Pty Ltd
|
72 |
+
Email: [email protected]
|
README.md
CHANGED
@@ -1,13 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.44.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# SUPIR Backend
|
2 |
+
|
3 |
+
The backend code for [forge-space-SUPIR](https://github.com/Haoming02/forge-space-SUPIR), that gets synced to [Haoming02/SUPIR-Forge](https://huggingface.co/spaces/Haoming02/SUPIR-Forge)
|
4 |
+
|
5 |
+
> There is no point in downloading this, unless you are making a PR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/__init__.py
ADDED
File without changes
|
SUPIR/models/SUPIR_model.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sgm.models.diffusion import DiffusionEngine
|
3 |
+
from sgm.util import instantiate_from_config
|
4 |
+
import copy
|
5 |
+
from sgm.modules.distributions.distributions import DiagonalGaussianDistribution
|
6 |
+
import random
|
7 |
+
from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from torch.nn.functional import interpolate
|
10 |
+
from SUPIR.utils.tilevae import VAEHook
|
11 |
+
|
12 |
+
class SUPIRModel(DiffusionEngine):
|
13 |
+
def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
control_model = instantiate_from_config(control_stage_config)
|
16 |
+
self.model.load_control_model(control_model)
|
17 |
+
self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
|
18 |
+
self.sampler_config = kwargs['sampler_config']
|
19 |
+
|
20 |
+
assert (ae_dtype in ['fp32', 'fp16', 'bf16']) and (diffusion_dtype in ['fp32', 'fp16', 'bf16'])
|
21 |
+
if ae_dtype == 'fp32':
|
22 |
+
ae_dtype = torch.float32
|
23 |
+
elif ae_dtype == 'fp16':
|
24 |
+
raise RuntimeError('fp16 cause NaN in AE')
|
25 |
+
elif ae_dtype == 'bf16':
|
26 |
+
ae_dtype = torch.bfloat16
|
27 |
+
|
28 |
+
if diffusion_dtype == 'fp32':
|
29 |
+
diffusion_dtype = torch.float32
|
30 |
+
elif diffusion_dtype == 'fp16':
|
31 |
+
diffusion_dtype = torch.float16
|
32 |
+
elif diffusion_dtype == 'bf16':
|
33 |
+
diffusion_dtype = torch.bfloat16
|
34 |
+
|
35 |
+
self.ae_dtype = ae_dtype
|
36 |
+
self.model.dtype = diffusion_dtype
|
37 |
+
|
38 |
+
self.p_p = p_p
|
39 |
+
self.n_p = n_p
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def encode_first_stage(self, x):
|
43 |
+
with torch.autocast("cuda", dtype=self.ae_dtype):
|
44 |
+
z = self.first_stage_model.encode(x)
|
45 |
+
z = self.scale_factor * z
|
46 |
+
return z
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
|
50 |
+
with torch.autocast("cuda", dtype=self.ae_dtype):
|
51 |
+
h = self.first_stage_model.denoise_encoder(x)
|
52 |
+
moments = self.first_stage_model.quant_conv(h)
|
53 |
+
posterior = DiagonalGaussianDistribution(moments)
|
54 |
+
if use_sample:
|
55 |
+
z = posterior.sample()
|
56 |
+
else:
|
57 |
+
z = posterior.mode()
|
58 |
+
z = self.scale_factor * z
|
59 |
+
return z
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def decode_first_stage(self, z):
|
63 |
+
z = 1.0 / self.scale_factor * z
|
64 |
+
with torch.autocast("cuda", dtype=self.ae_dtype):
|
65 |
+
out = self.first_stage_model.decode(z)
|
66 |
+
return out.float()
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def batchify_denoise(self, x, is_stage1=False):
|
70 |
+
'''
|
71 |
+
[N, C, H, W], [-1, 1], RGB
|
72 |
+
'''
|
73 |
+
x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
|
74 |
+
return self.decode_first_stage(x)
|
75 |
+
|
76 |
+
@torch.no_grad()
|
77 |
+
def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, restoration_scale=4.0, s_churn=0, s_noise=1.003, cfg_scale=4.0, seed=-1,
|
78 |
+
num_samples=1, control_scale=1, color_fix_type='None', use_linear_CFG=False, use_linear_control_scale=False,
|
79 |
+
cfg_scale_start=1.0, control_scale_start=0.0, **kwargs):
|
80 |
+
'''
|
81 |
+
[N, C], [-1, 1], RGB
|
82 |
+
'''
|
83 |
+
assert len(x) == len(p)
|
84 |
+
assert color_fix_type in ['Wavelet', 'AdaIn', 'None']
|
85 |
+
|
86 |
+
N = len(x)
|
87 |
+
if num_samples > 1:
|
88 |
+
assert N == 1
|
89 |
+
N = num_samples
|
90 |
+
x = x.repeat(N, 1, 1, 1)
|
91 |
+
p = p * N
|
92 |
+
|
93 |
+
if p_p == 'default':
|
94 |
+
p_p = self.p_p
|
95 |
+
if n_p == 'default':
|
96 |
+
n_p = self.n_p
|
97 |
+
|
98 |
+
self.sampler_config.params.num_steps = num_steps
|
99 |
+
if use_linear_CFG:
|
100 |
+
self.sampler_config.params.guider_config.params.scale_min = cfg_scale
|
101 |
+
self.sampler_config.params.guider_config.params.scale = cfg_scale_start
|
102 |
+
else:
|
103 |
+
self.sampler_config.params.guider_config.params.scale_min = cfg_scale
|
104 |
+
self.sampler_config.params.guider_config.params.scale = cfg_scale
|
105 |
+
self.sampler_config.params.restore_cfg = restoration_scale
|
106 |
+
self.sampler_config.params.s_churn = s_churn
|
107 |
+
self.sampler_config.params.s_noise = s_noise
|
108 |
+
self.sampler = instantiate_from_config(self.sampler_config)
|
109 |
+
|
110 |
+
if seed == -1:
|
111 |
+
seed = random.randint(0, 65535)
|
112 |
+
seed_everything(seed)
|
113 |
+
|
114 |
+
_z = self.encode_first_stage_with_denoise(x, use_sample=False)
|
115 |
+
x_stage1 = self.decode_first_stage(_z)
|
116 |
+
z_stage1 = self.encode_first_stage(x_stage1)
|
117 |
+
|
118 |
+
c, uc = self.prepare_condition(_z, p, p_p, n_p, N)
|
119 |
+
|
120 |
+
denoiser = lambda input, sigma, c, control_scale: self.denoiser(
|
121 |
+
self.model, input, sigma, c, control_scale, **kwargs
|
122 |
+
)
|
123 |
+
|
124 |
+
noised_z = torch.randn_like(_z).to(_z.device)
|
125 |
+
|
126 |
+
_samples = self.sampler(denoiser, noised_z, cond=c, uc=uc, x_center=z_stage1, control_scale=control_scale,
|
127 |
+
use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start)
|
128 |
+
samples = self.decode_first_stage(_samples)
|
129 |
+
if color_fix_type == 'Wavelet':
|
130 |
+
samples = wavelet_reconstruction(samples, x_stage1)
|
131 |
+
elif color_fix_type == 'AdaIn':
|
132 |
+
samples = adaptive_instance_normalization(samples, x_stage1)
|
133 |
+
return samples
|
134 |
+
|
135 |
+
def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
|
136 |
+
self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
|
137 |
+
self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
|
138 |
+
self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
|
139 |
+
self.first_stage_model.denoise_encoder.forward = VAEHook(
|
140 |
+
self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
|
141 |
+
fast_encoder=False, color_fix=False, to_gpu=True)
|
142 |
+
self.first_stage_model.encoder.forward = VAEHook(
|
143 |
+
self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
|
144 |
+
fast_encoder=False, color_fix=False, to_gpu=True)
|
145 |
+
self.first_stage_model.decoder.forward = VAEHook(
|
146 |
+
self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
|
147 |
+
fast_encoder=False, color_fix=False, to_gpu=True)
|
148 |
+
|
149 |
+
def prepare_condition(self, _z, p, p_p, n_p, N):
|
150 |
+
batch = {}
|
151 |
+
batch['original_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
|
152 |
+
batch['crop_coords_top_left'] = torch.tensor([0, 0]).repeat(N, 1).to(_z.device)
|
153 |
+
batch['target_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
|
154 |
+
batch['aesthetic_score'] = torch.tensor([9.0]).repeat(N, 1).to(_z.device)
|
155 |
+
batch['control'] = _z
|
156 |
+
|
157 |
+
batch_uc = copy.deepcopy(batch)
|
158 |
+
batch_uc['txt'] = [n_p for _ in p]
|
159 |
+
|
160 |
+
if not isinstance(p[0], list):
|
161 |
+
batch['txt'] = [''.join([_p, p_p]) for _p in p]
|
162 |
+
with torch.cuda.amp.autocast(dtype=self.ae_dtype):
|
163 |
+
c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
|
164 |
+
else:
|
165 |
+
assert len(p) == 1, 'Support bs=1 only for local prompt conditioning.'
|
166 |
+
p_tiles = p[0]
|
167 |
+
c = []
|
168 |
+
for i, p_tile in enumerate(p_tiles):
|
169 |
+
batch['txt'] = [''.join([p_tile, p_p])]
|
170 |
+
with torch.cuda.amp.autocast(dtype=self.ae_dtype):
|
171 |
+
if i == 0:
|
172 |
+
_c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
|
173 |
+
else:
|
174 |
+
_c, _ = self.conditioner.get_unconditional_conditioning(batch, None)
|
175 |
+
c.append(_c)
|
176 |
+
return c, uc
|
SUPIR/models/__init__.py
ADDED
File without changes
|
SUPIR/modules/SUPIR_v0.py
ADDED
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from einops._torch_specific import allow_ops_in_compiled_graph
|
2 |
+
# allow_ops_in_compiled_graph()
|
3 |
+
import einops
|
4 |
+
import torch
|
5 |
+
import torch as th
|
6 |
+
import torch.nn as nn
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
from sgm.modules.diffusionmodules.util import (
|
10 |
+
avg_pool_nd,
|
11 |
+
checkpoint,
|
12 |
+
conv_nd,
|
13 |
+
linear,
|
14 |
+
normalization,
|
15 |
+
timestep_embedding,
|
16 |
+
zero_module,
|
17 |
+
)
|
18 |
+
|
19 |
+
from sgm.modules.diffusionmodules.openaimodel import Downsample, Upsample, UNetModel, Timestep, \
|
20 |
+
TimestepEmbedSequential, ResBlock, AttentionBlock, TimestepBlock
|
21 |
+
from sgm.modules.attention import SpatialTransformer, MemoryEfficientCrossAttention, CrossAttention
|
22 |
+
from sgm.util import default, log_txt_as_img, exists, instantiate_from_config
|
23 |
+
import re
|
24 |
+
import torch
|
25 |
+
from functools import partial
|
26 |
+
|
27 |
+
|
28 |
+
try:
|
29 |
+
import xformers
|
30 |
+
import xformers.ops
|
31 |
+
XFORMERS_IS_AVAILBLE = True
|
32 |
+
except:
|
33 |
+
XFORMERS_IS_AVAILBLE = False
|
34 |
+
|
35 |
+
|
36 |
+
# dummy replace
|
37 |
+
def convert_module_to_f16(x):
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
def convert_module_to_f32(x):
|
42 |
+
pass
|
43 |
+
|
44 |
+
|
45 |
+
class ZeroConv(nn.Module):
|
46 |
+
def __init__(self, label_nc, norm_nc, mask=False):
|
47 |
+
super().__init__()
|
48 |
+
self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
|
49 |
+
self.mask = mask
|
50 |
+
|
51 |
+
def forward(self, c, h, h_ori=None):
|
52 |
+
# with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
53 |
+
if not self.mask:
|
54 |
+
h = h + self.zero_conv(c)
|
55 |
+
else:
|
56 |
+
h = h + self.zero_conv(c) * torch.zeros_like(h)
|
57 |
+
if h_ori is not None:
|
58 |
+
h = th.cat([h_ori, h], dim=1)
|
59 |
+
return h
|
60 |
+
|
61 |
+
|
62 |
+
class ZeroSFT(nn.Module):
|
63 |
+
def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
# param_free_norm_type = str(parsed.group(1))
|
67 |
+
ks = 3
|
68 |
+
pw = ks // 2
|
69 |
+
|
70 |
+
self.norm = norm
|
71 |
+
if self.norm:
|
72 |
+
self.param_free_norm = normalization(norm_nc + concat_channels)
|
73 |
+
else:
|
74 |
+
self.param_free_norm = nn.Identity()
|
75 |
+
|
76 |
+
nhidden = 128
|
77 |
+
|
78 |
+
self.mlp_shared = nn.Sequential(
|
79 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
80 |
+
nn.SiLU()
|
81 |
+
)
|
82 |
+
self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
|
83 |
+
self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
|
84 |
+
# self.zero_mul = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
|
85 |
+
# self.zero_add = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
|
86 |
+
|
87 |
+
self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
|
88 |
+
self.pre_concat = bool(concat_channels != 0)
|
89 |
+
self.mask = mask
|
90 |
+
|
91 |
+
def forward(self, c, h, h_ori=None, control_scale=1):
|
92 |
+
assert self.mask is False
|
93 |
+
if h_ori is not None and self.pre_concat:
|
94 |
+
h_raw = th.cat([h_ori, h], dim=1)
|
95 |
+
else:
|
96 |
+
h_raw = h
|
97 |
+
|
98 |
+
if self.mask:
|
99 |
+
h = h + self.zero_conv(c) * torch.zeros_like(h)
|
100 |
+
else:
|
101 |
+
h = h + self.zero_conv(c)
|
102 |
+
if h_ori is not None and self.pre_concat:
|
103 |
+
h = th.cat([h_ori, h], dim=1)
|
104 |
+
actv = self.mlp_shared(c)
|
105 |
+
gamma = self.zero_mul(actv)
|
106 |
+
beta = self.zero_add(actv)
|
107 |
+
if self.mask:
|
108 |
+
gamma = gamma * torch.zeros_like(gamma)
|
109 |
+
beta = beta * torch.zeros_like(beta)
|
110 |
+
h = self.param_free_norm(h) * (gamma + 1) + beta
|
111 |
+
if h_ori is not None and not self.pre_concat:
|
112 |
+
h = th.cat([h_ori, h], dim=1)
|
113 |
+
return h * control_scale + h_raw * (1 - control_scale)
|
114 |
+
|
115 |
+
|
116 |
+
class ZeroCrossAttn(nn.Module):
|
117 |
+
ATTENTION_MODES = {
|
118 |
+
"softmax": CrossAttention, # vanilla attention
|
119 |
+
"softmax-xformers": MemoryEfficientCrossAttention
|
120 |
+
}
|
121 |
+
|
122 |
+
def __init__(self, context_dim, query_dim, zero_out=True, mask=False):
|
123 |
+
super().__init__()
|
124 |
+
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
125 |
+
assert attn_mode in self.ATTENTION_MODES
|
126 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
127 |
+
self.attn = attn_cls(query_dim=query_dim, context_dim=context_dim, heads=query_dim//64, dim_head=64)
|
128 |
+
self.norm1 = normalization(query_dim)
|
129 |
+
self.norm2 = normalization(context_dim)
|
130 |
+
|
131 |
+
self.mask = mask
|
132 |
+
|
133 |
+
# if zero_out:
|
134 |
+
# # for p in self.attn.to_out.parameters():
|
135 |
+
# # p.detach().zero_()
|
136 |
+
# self.attn.to_out = zero_module(self.attn.to_out)
|
137 |
+
|
138 |
+
def forward(self, context, x, control_scale=1):
|
139 |
+
assert self.mask is False
|
140 |
+
x_in = x
|
141 |
+
x = self.norm1(x)
|
142 |
+
context = self.norm2(context)
|
143 |
+
b, c, h, w = x.shape
|
144 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
145 |
+
context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
|
146 |
+
x = self.attn(x, context)
|
147 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
148 |
+
if self.mask:
|
149 |
+
x = x * torch.zeros_like(x)
|
150 |
+
x = x_in + x * control_scale
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class GLVControl(nn.Module):
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
in_channels,
|
159 |
+
model_channels,
|
160 |
+
out_channels,
|
161 |
+
num_res_blocks,
|
162 |
+
attention_resolutions,
|
163 |
+
dropout=0,
|
164 |
+
channel_mult=(1, 2, 4, 8),
|
165 |
+
conv_resample=True,
|
166 |
+
dims=2,
|
167 |
+
num_classes=None,
|
168 |
+
use_checkpoint=False,
|
169 |
+
use_fp16=False,
|
170 |
+
num_heads=-1,
|
171 |
+
num_head_channels=-1,
|
172 |
+
num_heads_upsample=-1,
|
173 |
+
use_scale_shift_norm=False,
|
174 |
+
resblock_updown=False,
|
175 |
+
use_new_attention_order=False,
|
176 |
+
use_spatial_transformer=False, # custom transformer support
|
177 |
+
transformer_depth=1, # custom transformer support
|
178 |
+
context_dim=None, # custom transformer support
|
179 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
180 |
+
legacy=True,
|
181 |
+
disable_self_attentions=None,
|
182 |
+
num_attention_blocks=None,
|
183 |
+
disable_middle_self_attn=False,
|
184 |
+
use_linear_in_transformer=False,
|
185 |
+
spatial_transformer_attn_type="softmax",
|
186 |
+
adm_in_channels=None,
|
187 |
+
use_fairscale_checkpoint=False,
|
188 |
+
offload_to_cpu=False,
|
189 |
+
transformer_depth_middle=None,
|
190 |
+
input_upscale=1,
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
from omegaconf.listconfig import ListConfig
|
194 |
+
|
195 |
+
if use_spatial_transformer:
|
196 |
+
assert (
|
197 |
+
context_dim is not None
|
198 |
+
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
199 |
+
|
200 |
+
if context_dim is not None:
|
201 |
+
assert (
|
202 |
+
use_spatial_transformer
|
203 |
+
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
204 |
+
if type(context_dim) == ListConfig:
|
205 |
+
context_dim = list(context_dim)
|
206 |
+
|
207 |
+
if num_heads_upsample == -1:
|
208 |
+
num_heads_upsample = num_heads
|
209 |
+
|
210 |
+
if num_heads == -1:
|
211 |
+
assert (
|
212 |
+
num_head_channels != -1
|
213 |
+
), "Either num_heads or num_head_channels has to be set"
|
214 |
+
|
215 |
+
if num_head_channels == -1:
|
216 |
+
assert (
|
217 |
+
num_heads != -1
|
218 |
+
), "Either num_heads or num_head_channels has to be set"
|
219 |
+
|
220 |
+
self.in_channels = in_channels
|
221 |
+
self.model_channels = model_channels
|
222 |
+
self.out_channels = out_channels
|
223 |
+
if isinstance(transformer_depth, int):
|
224 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
225 |
+
elif isinstance(transformer_depth, ListConfig):
|
226 |
+
transformer_depth = list(transformer_depth)
|
227 |
+
transformer_depth_middle = default(
|
228 |
+
transformer_depth_middle, transformer_depth[-1]
|
229 |
+
)
|
230 |
+
|
231 |
+
if isinstance(num_res_blocks, int):
|
232 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
233 |
+
else:
|
234 |
+
if len(num_res_blocks) != len(channel_mult):
|
235 |
+
raise ValueError(
|
236 |
+
"provide num_res_blocks either as an int (globally constant) or "
|
237 |
+
"as a list/tuple (per-level) with the same length as channel_mult"
|
238 |
+
)
|
239 |
+
self.num_res_blocks = num_res_blocks
|
240 |
+
# self.num_res_blocks = num_res_blocks
|
241 |
+
if disable_self_attentions is not None:
|
242 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
243 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
244 |
+
if num_attention_blocks is not None:
|
245 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
246 |
+
assert all(
|
247 |
+
map(
|
248 |
+
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
249 |
+
range(len(num_attention_blocks)),
|
250 |
+
)
|
251 |
+
)
|
252 |
+
print(
|
253 |
+
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
254 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
255 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
256 |
+
f"attention will still not be set."
|
257 |
+
) # todo: convert to warning
|
258 |
+
|
259 |
+
self.attention_resolutions = attention_resolutions
|
260 |
+
self.dropout = dropout
|
261 |
+
self.channel_mult = channel_mult
|
262 |
+
self.conv_resample = conv_resample
|
263 |
+
self.num_classes = num_classes
|
264 |
+
self.use_checkpoint = use_checkpoint
|
265 |
+
if use_fp16:
|
266 |
+
print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
267 |
+
# self.dtype = th.float16 if use_fp16 else th.float32
|
268 |
+
self.num_heads = num_heads
|
269 |
+
self.num_head_channels = num_head_channels
|
270 |
+
self.num_heads_upsample = num_heads_upsample
|
271 |
+
self.predict_codebook_ids = n_embed is not None
|
272 |
+
|
273 |
+
assert use_fairscale_checkpoint != use_checkpoint or not (
|
274 |
+
use_checkpoint or use_fairscale_checkpoint
|
275 |
+
)
|
276 |
+
|
277 |
+
self.use_fairscale_checkpoint = False
|
278 |
+
checkpoint_wrapper_fn = (
|
279 |
+
partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
|
280 |
+
if self.use_fairscale_checkpoint
|
281 |
+
else lambda x: x
|
282 |
+
)
|
283 |
+
|
284 |
+
time_embed_dim = model_channels * 4
|
285 |
+
self.time_embed = checkpoint_wrapper_fn(
|
286 |
+
nn.Sequential(
|
287 |
+
linear(model_channels, time_embed_dim),
|
288 |
+
nn.SiLU(),
|
289 |
+
linear(time_embed_dim, time_embed_dim),
|
290 |
+
)
|
291 |
+
)
|
292 |
+
|
293 |
+
if self.num_classes is not None:
|
294 |
+
if isinstance(self.num_classes, int):
|
295 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
296 |
+
elif self.num_classes == "continuous":
|
297 |
+
print("setting up linear c_adm embedding layer")
|
298 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
299 |
+
elif self.num_classes == "timestep":
|
300 |
+
self.label_emb = checkpoint_wrapper_fn(
|
301 |
+
nn.Sequential(
|
302 |
+
Timestep(model_channels),
|
303 |
+
nn.Sequential(
|
304 |
+
linear(model_channels, time_embed_dim),
|
305 |
+
nn.SiLU(),
|
306 |
+
linear(time_embed_dim, time_embed_dim),
|
307 |
+
),
|
308 |
+
)
|
309 |
+
)
|
310 |
+
elif self.num_classes == "sequential":
|
311 |
+
assert adm_in_channels is not None
|
312 |
+
self.label_emb = nn.Sequential(
|
313 |
+
nn.Sequential(
|
314 |
+
linear(adm_in_channels, time_embed_dim),
|
315 |
+
nn.SiLU(),
|
316 |
+
linear(time_embed_dim, time_embed_dim),
|
317 |
+
)
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
raise ValueError()
|
321 |
+
|
322 |
+
self.input_blocks = nn.ModuleList(
|
323 |
+
[
|
324 |
+
TimestepEmbedSequential(
|
325 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
326 |
+
)
|
327 |
+
]
|
328 |
+
)
|
329 |
+
self._feature_size = model_channels
|
330 |
+
input_block_chans = [model_channels]
|
331 |
+
ch = model_channels
|
332 |
+
ds = 1
|
333 |
+
for level, mult in enumerate(channel_mult):
|
334 |
+
for nr in range(self.num_res_blocks[level]):
|
335 |
+
layers = [
|
336 |
+
checkpoint_wrapper_fn(
|
337 |
+
ResBlock(
|
338 |
+
ch,
|
339 |
+
time_embed_dim,
|
340 |
+
dropout,
|
341 |
+
out_channels=mult * model_channels,
|
342 |
+
dims=dims,
|
343 |
+
use_checkpoint=use_checkpoint,
|
344 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
345 |
+
)
|
346 |
+
)
|
347 |
+
]
|
348 |
+
ch = mult * model_channels
|
349 |
+
if ds in attention_resolutions:
|
350 |
+
if num_head_channels == -1:
|
351 |
+
dim_head = ch // num_heads
|
352 |
+
else:
|
353 |
+
num_heads = ch // num_head_channels
|
354 |
+
dim_head = num_head_channels
|
355 |
+
if legacy:
|
356 |
+
# num_heads = 1
|
357 |
+
dim_head = (
|
358 |
+
ch // num_heads
|
359 |
+
if use_spatial_transformer
|
360 |
+
else num_head_channels
|
361 |
+
)
|
362 |
+
if exists(disable_self_attentions):
|
363 |
+
disabled_sa = disable_self_attentions[level]
|
364 |
+
else:
|
365 |
+
disabled_sa = False
|
366 |
+
|
367 |
+
if (
|
368 |
+
not exists(num_attention_blocks)
|
369 |
+
or nr < num_attention_blocks[level]
|
370 |
+
):
|
371 |
+
layers.append(
|
372 |
+
checkpoint_wrapper_fn(
|
373 |
+
AttentionBlock(
|
374 |
+
ch,
|
375 |
+
use_checkpoint=use_checkpoint,
|
376 |
+
num_heads=num_heads,
|
377 |
+
num_head_channels=dim_head,
|
378 |
+
use_new_attention_order=use_new_attention_order,
|
379 |
+
)
|
380 |
+
)
|
381 |
+
if not use_spatial_transformer
|
382 |
+
else checkpoint_wrapper_fn(
|
383 |
+
SpatialTransformer(
|
384 |
+
ch,
|
385 |
+
num_heads,
|
386 |
+
dim_head,
|
387 |
+
depth=transformer_depth[level],
|
388 |
+
context_dim=context_dim,
|
389 |
+
disable_self_attn=disabled_sa,
|
390 |
+
use_linear=use_linear_in_transformer,
|
391 |
+
attn_type=spatial_transformer_attn_type,
|
392 |
+
use_checkpoint=use_checkpoint,
|
393 |
+
)
|
394 |
+
)
|
395 |
+
)
|
396 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
397 |
+
self._feature_size += ch
|
398 |
+
input_block_chans.append(ch)
|
399 |
+
if level != len(channel_mult) - 1:
|
400 |
+
out_ch = ch
|
401 |
+
self.input_blocks.append(
|
402 |
+
TimestepEmbedSequential(
|
403 |
+
checkpoint_wrapper_fn(
|
404 |
+
ResBlock(
|
405 |
+
ch,
|
406 |
+
time_embed_dim,
|
407 |
+
dropout,
|
408 |
+
out_channels=out_ch,
|
409 |
+
dims=dims,
|
410 |
+
use_checkpoint=use_checkpoint,
|
411 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
412 |
+
down=True,
|
413 |
+
)
|
414 |
+
)
|
415 |
+
if resblock_updown
|
416 |
+
else Downsample(
|
417 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
418 |
+
)
|
419 |
+
)
|
420 |
+
)
|
421 |
+
ch = out_ch
|
422 |
+
input_block_chans.append(ch)
|
423 |
+
ds *= 2
|
424 |
+
self._feature_size += ch
|
425 |
+
|
426 |
+
if num_head_channels == -1:
|
427 |
+
dim_head = ch // num_heads
|
428 |
+
else:
|
429 |
+
num_heads = ch // num_head_channels
|
430 |
+
dim_head = num_head_channels
|
431 |
+
if legacy:
|
432 |
+
# num_heads = 1
|
433 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
434 |
+
self.middle_block = TimestepEmbedSequential(
|
435 |
+
checkpoint_wrapper_fn(
|
436 |
+
ResBlock(
|
437 |
+
ch,
|
438 |
+
time_embed_dim,
|
439 |
+
dropout,
|
440 |
+
dims=dims,
|
441 |
+
use_checkpoint=use_checkpoint,
|
442 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
443 |
+
)
|
444 |
+
),
|
445 |
+
checkpoint_wrapper_fn(
|
446 |
+
AttentionBlock(
|
447 |
+
ch,
|
448 |
+
use_checkpoint=use_checkpoint,
|
449 |
+
num_heads=num_heads,
|
450 |
+
num_head_channels=dim_head,
|
451 |
+
use_new_attention_order=use_new_attention_order,
|
452 |
+
)
|
453 |
+
)
|
454 |
+
if not use_spatial_transformer
|
455 |
+
else checkpoint_wrapper_fn(
|
456 |
+
SpatialTransformer( # always uses a self-attn
|
457 |
+
ch,
|
458 |
+
num_heads,
|
459 |
+
dim_head,
|
460 |
+
depth=transformer_depth_middle,
|
461 |
+
context_dim=context_dim,
|
462 |
+
disable_self_attn=disable_middle_self_attn,
|
463 |
+
use_linear=use_linear_in_transformer,
|
464 |
+
attn_type=spatial_transformer_attn_type,
|
465 |
+
use_checkpoint=use_checkpoint,
|
466 |
+
)
|
467 |
+
),
|
468 |
+
checkpoint_wrapper_fn(
|
469 |
+
ResBlock(
|
470 |
+
ch,
|
471 |
+
time_embed_dim,
|
472 |
+
dropout,
|
473 |
+
dims=dims,
|
474 |
+
use_checkpoint=use_checkpoint,
|
475 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
476 |
+
)
|
477 |
+
),
|
478 |
+
)
|
479 |
+
|
480 |
+
self.input_upscale = input_upscale
|
481 |
+
self.input_hint_block = TimestepEmbedSequential(
|
482 |
+
zero_module(conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
483 |
+
)
|
484 |
+
|
485 |
+
def convert_to_fp16(self):
|
486 |
+
"""
|
487 |
+
Convert the torso of the model to float16.
|
488 |
+
"""
|
489 |
+
self.input_blocks.apply(convert_module_to_f16)
|
490 |
+
self.middle_block.apply(convert_module_to_f16)
|
491 |
+
|
492 |
+
def convert_to_fp32(self):
|
493 |
+
"""
|
494 |
+
Convert the torso of the model to float32.
|
495 |
+
"""
|
496 |
+
self.input_blocks.apply(convert_module_to_f32)
|
497 |
+
self.middle_block.apply(convert_module_to_f32)
|
498 |
+
|
499 |
+
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
500 |
+
# with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
501 |
+
# x = x.to(torch.float32)
|
502 |
+
# timesteps = timesteps.to(torch.float32)
|
503 |
+
# xt = xt.to(torch.float32)
|
504 |
+
# context = context.to(torch.float32)
|
505 |
+
# y = y.to(torch.float32)
|
506 |
+
# print(x.dtype)
|
507 |
+
xt, context, y = xt.to(x.dtype), context.to(x.dtype), y.to(x.dtype)
|
508 |
+
|
509 |
+
if self.input_upscale != 1:
|
510 |
+
x = nn.functional.interpolate(x, scale_factor=self.input_upscale, mode='bilinear', antialias=True)
|
511 |
+
assert (y is not None) == (
|
512 |
+
self.num_classes is not None
|
513 |
+
), "must specify y if and only if the model is class-conditional"
|
514 |
+
hs = []
|
515 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
516 |
+
# import pdb
|
517 |
+
# pdb.set_trace()
|
518 |
+
emb = self.time_embed(t_emb)
|
519 |
+
|
520 |
+
if self.num_classes is not None:
|
521 |
+
assert y.shape[0] == xt.shape[0]
|
522 |
+
emb = emb + self.label_emb(y)
|
523 |
+
|
524 |
+
guided_hint = self.input_hint_block(x, emb, context)
|
525 |
+
|
526 |
+
# h = x.type(self.dtype)
|
527 |
+
h = xt
|
528 |
+
for module in self.input_blocks:
|
529 |
+
if guided_hint is not None:
|
530 |
+
h = module(h, emb, context)
|
531 |
+
h += guided_hint
|
532 |
+
guided_hint = None
|
533 |
+
else:
|
534 |
+
h = module(h, emb, context)
|
535 |
+
hs.append(h)
|
536 |
+
# print(module)
|
537 |
+
# print(h.shape)
|
538 |
+
h = self.middle_block(h, emb, context)
|
539 |
+
hs.append(h)
|
540 |
+
return hs
|
541 |
+
|
542 |
+
|
543 |
+
class LightGLVUNet(UNetModel):
|
544 |
+
def __init__(self, mode='', project_type='ZeroSFT', project_channel_scale=1,
|
545 |
+
*args, **kwargs):
|
546 |
+
super().__init__(*args, **kwargs)
|
547 |
+
if mode == 'XL-base':
|
548 |
+
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
549 |
+
project_channels = [160] * 4 + [320] * 3 + [640] * 3
|
550 |
+
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
551 |
+
cross_attn_insert_idx = [6, 3]
|
552 |
+
self.progressive_mask_nums = [0, 3, 7, 11]
|
553 |
+
elif mode == 'XL-refine':
|
554 |
+
cond_output_channels = [384] * 4 + [768] * 3 + [1536] * 6
|
555 |
+
project_channels = [192] * 4 + [384] * 3 + [768] * 6
|
556 |
+
concat_channels = [384] * 2 + [768] * 3 + [1536] * 7 + [0]
|
557 |
+
cross_attn_insert_idx = [9, 6, 3]
|
558 |
+
self.progressive_mask_nums = [0, 3, 6, 10, 14]
|
559 |
+
else:
|
560 |
+
raise NotImplementedError
|
561 |
+
|
562 |
+
project_channels = [int(c * project_channel_scale) for c in project_channels]
|
563 |
+
|
564 |
+
self.project_modules = nn.ModuleList()
|
565 |
+
for i in range(len(cond_output_channels)):
|
566 |
+
# if i == len(cond_output_channels) - 1:
|
567 |
+
# _project_type = 'ZeroCrossAttn'
|
568 |
+
# else:
|
569 |
+
# _project_type = project_type
|
570 |
+
_project_type = project_type
|
571 |
+
if _project_type == 'ZeroSFT':
|
572 |
+
self.project_modules.append(ZeroSFT(project_channels[i], cond_output_channels[i],
|
573 |
+
concat_channels=concat_channels[i]))
|
574 |
+
elif _project_type == 'ZeroCrossAttn':
|
575 |
+
self.project_modules.append(ZeroCrossAttn(cond_output_channels[i], project_channels[i]))
|
576 |
+
else:
|
577 |
+
raise NotImplementedError
|
578 |
+
|
579 |
+
for i in cross_attn_insert_idx:
|
580 |
+
self.project_modules.insert(i, ZeroCrossAttn(cond_output_channels[i], concat_channels[i]))
|
581 |
+
# print(self.project_modules[i])
|
582 |
+
|
583 |
+
def step_progressive_mask(self):
|
584 |
+
if len(self.progressive_mask_nums) > 0:
|
585 |
+
mask_num = self.progressive_mask_nums.pop()
|
586 |
+
for i in range(len(self.project_modules)):
|
587 |
+
if i < mask_num:
|
588 |
+
self.project_modules[i].mask = True
|
589 |
+
else:
|
590 |
+
self.project_modules[i].mask = False
|
591 |
+
return
|
592 |
+
# print(f'step_progressive_mask, current masked layers: {mask_num}')
|
593 |
+
else:
|
594 |
+
return
|
595 |
+
# print('step_progressive_mask, no more masked layers')
|
596 |
+
# for i in range(len(self.project_modules)):
|
597 |
+
# print(self.project_modules[i].mask)
|
598 |
+
|
599 |
+
|
600 |
+
def forward(self, x, timesteps=None, context=None, y=None, control=None, control_scale=1, **kwargs):
|
601 |
+
"""
|
602 |
+
Apply the model to an input batch.
|
603 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
604 |
+
:param timesteps: a 1-D batch of timesteps.
|
605 |
+
:param context: conditioning plugged in via crossattn
|
606 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
607 |
+
:return: an [N x C x ...] Tensor of outputs.
|
608 |
+
"""
|
609 |
+
assert (y is not None) == (
|
610 |
+
self.num_classes is not None
|
611 |
+
), "must specify y if and only if the model is class-conditional"
|
612 |
+
hs = []
|
613 |
+
|
614 |
+
_dtype = control[0].dtype
|
615 |
+
x, context, y = x.to(_dtype), context.to(_dtype), y.to(_dtype)
|
616 |
+
|
617 |
+
with torch.no_grad():
|
618 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
619 |
+
emb = self.time_embed(t_emb)
|
620 |
+
|
621 |
+
if self.num_classes is not None:
|
622 |
+
assert y.shape[0] == x.shape[0]
|
623 |
+
emb = emb + self.label_emb(y)
|
624 |
+
|
625 |
+
# h = x.type(self.dtype)
|
626 |
+
h = x
|
627 |
+
for module in self.input_blocks:
|
628 |
+
h = module(h, emb, context)
|
629 |
+
hs.append(h)
|
630 |
+
|
631 |
+
adapter_idx = len(self.project_modules) - 1
|
632 |
+
control_idx = len(control) - 1
|
633 |
+
h = self.middle_block(h, emb, context)
|
634 |
+
h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
|
635 |
+
adapter_idx -= 1
|
636 |
+
control_idx -= 1
|
637 |
+
|
638 |
+
for i, module in enumerate(self.output_blocks):
|
639 |
+
_h = hs.pop()
|
640 |
+
h = self.project_modules[adapter_idx](control[control_idx], _h, h, control_scale=control_scale)
|
641 |
+
adapter_idx -= 1
|
642 |
+
# h = th.cat([h, _h], dim=1)
|
643 |
+
if len(module) == 3:
|
644 |
+
assert isinstance(module[2], Upsample)
|
645 |
+
for layer in module[:2]:
|
646 |
+
if isinstance(layer, TimestepBlock):
|
647 |
+
h = layer(h, emb)
|
648 |
+
elif isinstance(layer, SpatialTransformer):
|
649 |
+
h = layer(h, context)
|
650 |
+
else:
|
651 |
+
h = layer(h)
|
652 |
+
# print('cross_attn_here')
|
653 |
+
h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
|
654 |
+
adapter_idx -= 1
|
655 |
+
h = module[2](h)
|
656 |
+
else:
|
657 |
+
h = module(h, emb, context)
|
658 |
+
control_idx -= 1
|
659 |
+
# print(module)
|
660 |
+
# print(h.shape)
|
661 |
+
|
662 |
+
h = h.type(x.dtype)
|
663 |
+
if self.predict_codebook_ids:
|
664 |
+
assert False, "not supported anymore. what the f*** are you doing?"
|
665 |
+
else:
|
666 |
+
return self.out(h)
|
667 |
+
|
668 |
+
if __name__ == '__main__':
|
669 |
+
from omegaconf import OmegaConf
|
670 |
+
|
671 |
+
# refiner
|
672 |
+
# opt = OmegaConf.load('../../options/train/debug_p2_xl.yaml')
|
673 |
+
#
|
674 |
+
# model = instantiate_from_config(opt.model.params.control_stage_config)
|
675 |
+
# hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
|
676 |
+
# hint = [h.cuda() for h in hint]
|
677 |
+
# print(sum(map(lambda hint: hint.numel(), model.parameters())))
|
678 |
+
#
|
679 |
+
# unet = instantiate_from_config(opt.model.params.network_config)
|
680 |
+
# unet = unet.cuda()
|
681 |
+
#
|
682 |
+
# _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
|
683 |
+
# torch.randn([1, 2560]).cuda(), hint)
|
684 |
+
# print(sum(map(lambda _output: _output.numel(), unet.parameters())))
|
685 |
+
|
686 |
+
# base
|
687 |
+
with torch.no_grad():
|
688 |
+
opt = OmegaConf.load('../../options/dev/SUPIR_tmp.yaml')
|
689 |
+
|
690 |
+
model = instantiate_from_config(opt.model.params.control_stage_config)
|
691 |
+
model = model.cuda()
|
692 |
+
|
693 |
+
hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 2048]).cuda(),
|
694 |
+
torch.randn([1, 2816]).cuda())
|
695 |
+
|
696 |
+
# for h in hint:
|
697 |
+
# print(h.shape)
|
698 |
+
|
699 |
+
unet = instantiate_from_config(opt.model.params.network_config)
|
700 |
+
unet = unet.cuda()
|
701 |
+
_output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 2048]).cuda(),
|
702 |
+
torch.randn([1, 2816]).cuda(), hint)
|
703 |
+
|
704 |
+
|
705 |
+
# model = instantiate_from_config(opt.model.params.control_stage_config)
|
706 |
+
# model = model.cuda()
|
707 |
+
# # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
|
708 |
+
# hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 1280]).cuda(),
|
709 |
+
# torch.randn([1, 2560]).cuda())
|
710 |
+
# # hint = [h.cuda() for h in hint]
|
711 |
+
#
|
712 |
+
# for h in hint:
|
713 |
+
# print(h.shape)
|
714 |
+
#
|
715 |
+
# unet = instantiate_from_config(opt.model.params.network_config)
|
716 |
+
# unet = unet.cuda()
|
717 |
+
# _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
|
718 |
+
# torch.randn([1, 2560]).cuda(), hint)
|
SUPIR/modules/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SDXL_BASE_CHANNEL_DICT = {
|
2 |
+
'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3,
|
3 |
+
'project_channels': [160] * 4 + [320] * 3 + [640] * 3,
|
4 |
+
'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
5 |
+
}
|
6 |
+
|
7 |
+
SDXL_REFINE_CHANNEL_DICT = {
|
8 |
+
'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6,
|
9 |
+
'project_channels': [192] * 4 + [384] * 3 + [768] * 6,
|
10 |
+
'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0]
|
11 |
+
}
|
SUPIR/util.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
from torch.nn.functional import interpolate
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from sgm.util import instantiate_from_config
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(d):
|
12 |
+
return d.get("state_dict", d)
|
13 |
+
|
14 |
+
|
15 |
+
def load_state_dict(ckpt_path, location="cpu"):
|
16 |
+
_, extension = os.path.splitext(ckpt_path)
|
17 |
+
if extension.lower() == ".safetensors":
|
18 |
+
import safetensors.torch
|
19 |
+
|
20 |
+
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
21 |
+
else:
|
22 |
+
state_dict = get_state_dict(
|
23 |
+
torch.load(ckpt_path, map_location=torch.device(location))
|
24 |
+
)
|
25 |
+
state_dict = get_state_dict(state_dict)
|
26 |
+
print(f'Loaded state_dict from:\n"{ckpt_path}"')
|
27 |
+
return state_dict
|
28 |
+
|
29 |
+
|
30 |
+
def create_model(config_path):
|
31 |
+
config = OmegaConf.load(config_path)
|
32 |
+
model = instantiate_from_config(config.model).cpu()
|
33 |
+
print(f'Loaded model config from:\n"{config_path}"')
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
def create_SUPIR_model(config_path, SDXL_CKPT, SUPIR_CKPT):
|
38 |
+
config = OmegaConf.load(config_path)
|
39 |
+
model = instantiate_from_config(config.model).cpu()
|
40 |
+
print(f'Loaded model config from:\n"{config_path}"')
|
41 |
+
|
42 |
+
sdxl_sd = load_state_dict(SDXL_CKPT)
|
43 |
+
model.load_state_dict(sdxl_sd, strict=False)
|
44 |
+
model.load_state_dict(load_state_dict(SUPIR_CKPT), strict=False)
|
45 |
+
|
46 |
+
return model, sdxl_sd
|
47 |
+
|
48 |
+
|
49 |
+
def load_QF_ckpt(config_path):
|
50 |
+
config = OmegaConf.load(config_path)
|
51 |
+
ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location="cpu")
|
52 |
+
ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location="cpu")
|
53 |
+
return ckpt_Q, ckpt_F
|
54 |
+
|
55 |
+
|
56 |
+
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
|
57 |
+
"""
|
58 |
+
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
|
59 |
+
"""
|
60 |
+
# size
|
61 |
+
w, h = img.size
|
62 |
+
w *= upsacle
|
63 |
+
h *= upsacle
|
64 |
+
w0, h0 = round(w), round(h)
|
65 |
+
if min(w, h) < min_size:
|
66 |
+
_upsacle = min_size / min(w, h)
|
67 |
+
w *= _upsacle
|
68 |
+
h *= _upsacle
|
69 |
+
if fix_resize is not None:
|
70 |
+
_upsacle = fix_resize / min(w, h)
|
71 |
+
w *= _upsacle
|
72 |
+
h *= _upsacle
|
73 |
+
w0, h0 = round(w), round(h)
|
74 |
+
w = int(np.round(w / 64.0)) * 64
|
75 |
+
h = int(np.round(h / 64.0)) * 64
|
76 |
+
x = img.resize((w, h), Image.BICUBIC)
|
77 |
+
x = np.array(x).round().clip(0, 255).astype(np.uint8)
|
78 |
+
x = x / 255 * 2 - 1
|
79 |
+
x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
|
80 |
+
return x, h0, w0
|
81 |
+
|
82 |
+
|
83 |
+
def Tensor2PIL(x, h0, w0):
|
84 |
+
"""
|
85 |
+
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
|
86 |
+
"""
|
87 |
+
x = x.unsqueeze(0)
|
88 |
+
x = interpolate(x, size=(h0, w0), mode="bicubic")
|
89 |
+
x = (
|
90 |
+
(x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5)
|
91 |
+
.cpu()
|
92 |
+
.numpy()
|
93 |
+
.clip(0, 255)
|
94 |
+
.astype(np.uint8)
|
95 |
+
)
|
96 |
+
return Image.fromarray(x)
|
97 |
+
|
98 |
+
|
99 |
+
def HWC3(x):
|
100 |
+
assert x.dtype == np.uint8
|
101 |
+
if x.ndim == 2:
|
102 |
+
x = x[:, :, None]
|
103 |
+
assert x.ndim == 3
|
104 |
+
H, W, C = x.shape
|
105 |
+
assert C == 1 or C == 3 or C == 4
|
106 |
+
if C == 3:
|
107 |
+
return x
|
108 |
+
if C == 1:
|
109 |
+
return np.concatenate([x, x, x], axis=2)
|
110 |
+
if C == 4:
|
111 |
+
color = x[:, :, 0:3].astype(np.float32)
|
112 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
113 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
114 |
+
y = y.clip(0, 255).astype(np.uint8)
|
115 |
+
return y
|
116 |
+
|
117 |
+
|
118 |
+
def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
|
119 |
+
H, W, C = input_image.shape
|
120 |
+
H = float(H)
|
121 |
+
W = float(W)
|
122 |
+
H *= upscale
|
123 |
+
W *= upscale
|
124 |
+
if min_size is not None:
|
125 |
+
if min(H, W) < min_size:
|
126 |
+
_upsacle = min_size / min(W, H)
|
127 |
+
W *= _upsacle
|
128 |
+
H *= _upsacle
|
129 |
+
H = int(np.round(H / unit_resolution)) * unit_resolution
|
130 |
+
W = int(np.round(W / unit_resolution)) * unit_resolution
|
131 |
+
img = cv2.resize(
|
132 |
+
input_image,
|
133 |
+
(W, H),
|
134 |
+
interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA,
|
135 |
+
)
|
136 |
+
img = img.round().clip(0, 255).astype(np.uint8)
|
137 |
+
return img
|
138 |
+
|
139 |
+
|
140 |
+
def fix_resize(input_image, size=512, unit_resolution=64):
|
141 |
+
H, W, C = input_image.shape
|
142 |
+
H = float(H)
|
143 |
+
W = float(W)
|
144 |
+
upscale = size / min(H, W)
|
145 |
+
H *= upscale
|
146 |
+
W *= upscale
|
147 |
+
H = int(np.round(H / unit_resolution)) * unit_resolution
|
148 |
+
W = int(np.round(W / unit_resolution)) * unit_resolution
|
149 |
+
img = cv2.resize(
|
150 |
+
input_image,
|
151 |
+
(W, H),
|
152 |
+
interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA,
|
153 |
+
)
|
154 |
+
img = img.round().clip(0, 255).astype(np.uint8)
|
155 |
+
return img
|
156 |
+
|
157 |
+
|
158 |
+
def Numpy2Tensor(img):
|
159 |
+
"""
|
160 |
+
np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
|
161 |
+
"""
|
162 |
+
# size
|
163 |
+
img = np.array(img) / 255 * 2 - 1
|
164 |
+
img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
|
165 |
+
return img
|
166 |
+
|
167 |
+
|
168 |
+
def Tensor2Numpy(x, h0=None, w0=None):
|
169 |
+
"""
|
170 |
+
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
|
171 |
+
"""
|
172 |
+
if h0 is not None and w0 is not None:
|
173 |
+
x = x.unsqueeze(0)
|
174 |
+
x = interpolate(x, size=(h0, w0), mode="bicubic")
|
175 |
+
x = x.squeeze(0)
|
176 |
+
x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
177 |
+
return x
|
178 |
+
|
179 |
+
|
180 |
+
def convert_dtype(dtype_str):
|
181 |
+
if dtype_str == "fp32":
|
182 |
+
return torch.float32
|
183 |
+
elif dtype_str == "fp16":
|
184 |
+
return torch.float16
|
185 |
+
elif dtype_str == "bf16":
|
186 |
+
return torch.bfloat16
|
187 |
+
elif dtype_str == "fp8":
|
188 |
+
return torch.float8_e4m3fn
|
189 |
+
else:
|
190 |
+
raise NotImplementedError
|
SUPIR/utils/__init__.py
ADDED
File without changes
|
SUPIR/utils/colorfix.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
# --------------------------------------------------------------------------------
|
3 |
+
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
4 |
+
# --------------------------------------------------------------------------------
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
13 |
+
|
14 |
+
def adain_color_fix(target: Image, source: Image):
|
15 |
+
# Convert images to tensors
|
16 |
+
to_tensor = ToTensor()
|
17 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
18 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
19 |
+
|
20 |
+
# Apply adaptive instance normalization
|
21 |
+
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
22 |
+
|
23 |
+
# Convert tensor back to image
|
24 |
+
to_image = ToPILImage()
|
25 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
26 |
+
|
27 |
+
return result_image
|
28 |
+
|
29 |
+
def wavelet_color_fix(target: Image, source: Image):
|
30 |
+
# Convert images to tensors
|
31 |
+
to_tensor = ToTensor()
|
32 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
33 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
34 |
+
|
35 |
+
# Apply wavelet reconstruction
|
36 |
+
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
37 |
+
|
38 |
+
# Convert tensor back to image
|
39 |
+
to_image = ToPILImage()
|
40 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
41 |
+
|
42 |
+
return result_image
|
43 |
+
|
44 |
+
def calc_mean_std(feat: Tensor, eps=1e-5):
|
45 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
46 |
+
Args:
|
47 |
+
feat (Tensor): 4D tensor.
|
48 |
+
eps (float): A small value added to the variance to avoid
|
49 |
+
divide-by-zero. Default: 1e-5.
|
50 |
+
"""
|
51 |
+
size = feat.size()
|
52 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
53 |
+
b, c = size[:2]
|
54 |
+
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
55 |
+
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
56 |
+
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
57 |
+
return feat_mean, feat_std
|
58 |
+
|
59 |
+
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
60 |
+
"""Adaptive instance normalization.
|
61 |
+
Adjust the reference features to have the similar color and illuminations
|
62 |
+
as those in the degradate features.
|
63 |
+
Args:
|
64 |
+
content_feat (Tensor): The reference feature.
|
65 |
+
style_feat (Tensor): The degradate features.
|
66 |
+
"""
|
67 |
+
size = content_feat.size()
|
68 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
69 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
70 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
71 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
72 |
+
|
73 |
+
def wavelet_blur(image: Tensor, radius: int):
|
74 |
+
"""
|
75 |
+
Apply wavelet blur to the input tensor.
|
76 |
+
"""
|
77 |
+
# input shape: (1, 3, H, W)
|
78 |
+
# convolution kernel
|
79 |
+
kernel_vals = [
|
80 |
+
[0.0625, 0.125, 0.0625],
|
81 |
+
[0.125, 0.25, 0.125],
|
82 |
+
[0.0625, 0.125, 0.0625],
|
83 |
+
]
|
84 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
85 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
86 |
+
kernel = kernel[None, None]
|
87 |
+
# repeat the kernel across all input channels
|
88 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
89 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
90 |
+
# apply convolution
|
91 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
92 |
+
return output
|
93 |
+
|
94 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
95 |
+
"""
|
96 |
+
Apply wavelet decomposition to the input tensor.
|
97 |
+
This function only returns the low frequency & the high frequency.
|
98 |
+
"""
|
99 |
+
high_freq = torch.zeros_like(image)
|
100 |
+
for i in range(levels):
|
101 |
+
radius = 2 ** i
|
102 |
+
low_freq = wavelet_blur(image, radius)
|
103 |
+
high_freq += (image - low_freq)
|
104 |
+
image = low_freq
|
105 |
+
|
106 |
+
return high_freq, low_freq
|
107 |
+
|
108 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
109 |
+
"""
|
110 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
111 |
+
"""
|
112 |
+
# calculate the wavelet decomposition of the content feature
|
113 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
114 |
+
del content_low_freq
|
115 |
+
# calculate the wavelet decomposition of the style feature
|
116 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
117 |
+
del style_high_freq
|
118 |
+
# reconstruct the content feature with the style's high frequency
|
119 |
+
return content_high_freq + style_low_freq
|
120 |
+
|
SUPIR/utils/devices.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import contextlib
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
import torch
|
6 |
+
#from modules import errors
|
7 |
+
|
8 |
+
if sys.platform == "darwin":
|
9 |
+
from modules import mac_specific
|
10 |
+
|
11 |
+
|
12 |
+
def has_mps() -> bool:
|
13 |
+
if sys.platform != "darwin":
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
return mac_specific.has_mps
|
17 |
+
|
18 |
+
|
19 |
+
def get_cuda_device_string():
|
20 |
+
return "cuda"
|
21 |
+
|
22 |
+
|
23 |
+
def get_optimal_device_name():
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
return get_cuda_device_string()
|
26 |
+
|
27 |
+
if has_mps():
|
28 |
+
return "mps"
|
29 |
+
|
30 |
+
return "cpu"
|
31 |
+
|
32 |
+
|
33 |
+
def get_optimal_device():
|
34 |
+
return torch.device(get_optimal_device_name())
|
35 |
+
|
36 |
+
|
37 |
+
def get_device_for(task):
|
38 |
+
return get_optimal_device()
|
39 |
+
|
40 |
+
|
41 |
+
def torch_gc():
|
42 |
+
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
with torch.cuda.device(get_cuda_device_string()):
|
45 |
+
torch.cuda.empty_cache()
|
46 |
+
torch.cuda.ipc_collect()
|
47 |
+
|
48 |
+
if has_mps():
|
49 |
+
mac_specific.torch_mps_gc()
|
50 |
+
|
51 |
+
|
52 |
+
def enable_tf32():
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
|
55 |
+
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
56 |
+
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
57 |
+
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
58 |
+
torch.backends.cudnn.benchmark = True
|
59 |
+
|
60 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
61 |
+
torch.backends.cudnn.allow_tf32 = True
|
62 |
+
|
63 |
+
|
64 |
+
enable_tf32()
|
65 |
+
#errors.run(enable_tf32, "Enabling TF32")
|
66 |
+
|
67 |
+
cpu = torch.device("cpu")
|
68 |
+
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
|
69 |
+
dtype = torch.float16
|
70 |
+
dtype_vae = torch.float16
|
71 |
+
dtype_unet = torch.float16
|
72 |
+
unet_needs_upcast = False
|
73 |
+
|
74 |
+
|
75 |
+
def cond_cast_unet(input):
|
76 |
+
return input.to(dtype_unet) if unet_needs_upcast else input
|
77 |
+
|
78 |
+
|
79 |
+
def cond_cast_float(input):
|
80 |
+
return input.float() if unet_needs_upcast else input
|
81 |
+
|
82 |
+
|
83 |
+
def randn(seed, shape):
|
84 |
+
torch.manual_seed(seed)
|
85 |
+
return torch.randn(shape, device=device)
|
86 |
+
|
87 |
+
|
88 |
+
def randn_without_seed(shape):
|
89 |
+
return torch.randn(shape, device=device)
|
90 |
+
|
91 |
+
|
92 |
+
def autocast(disable=False):
|
93 |
+
if disable:
|
94 |
+
return contextlib.nullcontext()
|
95 |
+
|
96 |
+
return torch.autocast("cuda")
|
97 |
+
|
98 |
+
|
99 |
+
def without_autocast(disable=False):
|
100 |
+
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
101 |
+
|
102 |
+
|
103 |
+
class NansException(Exception):
|
104 |
+
pass
|
105 |
+
|
106 |
+
|
107 |
+
def test_for_nans(x, where):
|
108 |
+
if not torch.all(torch.isnan(x)).item():
|
109 |
+
return
|
110 |
+
|
111 |
+
if where == "unet":
|
112 |
+
message = "A tensor with all NaNs was produced in Unet."
|
113 |
+
|
114 |
+
elif where == "vae":
|
115 |
+
message = "A tensor with all NaNs was produced in VAE."
|
116 |
+
|
117 |
+
else:
|
118 |
+
message = "A tensor with all NaNs was produced."
|
119 |
+
|
120 |
+
message += " Use --disable-nan-check commandline argument to disable this check."
|
121 |
+
|
122 |
+
raise NansException(message)
|
123 |
+
|
124 |
+
|
125 |
+
@lru_cache
|
126 |
+
def first_time_calculation():
|
127 |
+
"""
|
128 |
+
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
129 |
+
spends about 2.7 seconds doing that, at least wih NVidia.
|
130 |
+
"""
|
131 |
+
|
132 |
+
x = torch.zeros((1, 1)).to(device, dtype)
|
133 |
+
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
134 |
+
linear(x)
|
135 |
+
|
136 |
+
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
137 |
+
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
138 |
+
conv2d(x)
|
SUPIR/utils/face_restoration_helper.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from facexlib.detection import init_detection_model
|
8 |
+
from facexlib.parsing import init_parsing_model
|
9 |
+
from facexlib.utils.misc import img2tensor, imwrite
|
10 |
+
|
11 |
+
from .file import load_file_from_url
|
12 |
+
|
13 |
+
|
14 |
+
def get_largest_face(det_faces, h, w):
|
15 |
+
def get_location(val, length):
|
16 |
+
if val < 0:
|
17 |
+
return 0
|
18 |
+
elif val > length:
|
19 |
+
return length
|
20 |
+
else:
|
21 |
+
return val
|
22 |
+
|
23 |
+
face_areas = []
|
24 |
+
for det_face in det_faces:
|
25 |
+
left = get_location(det_face[0], w)
|
26 |
+
right = get_location(det_face[2], w)
|
27 |
+
top = get_location(det_face[1], h)
|
28 |
+
bottom = get_location(det_face[3], h)
|
29 |
+
face_area = (right - left) * (bottom - top)
|
30 |
+
face_areas.append(face_area)
|
31 |
+
largest_idx = face_areas.index(max(face_areas))
|
32 |
+
return det_faces[largest_idx], largest_idx
|
33 |
+
|
34 |
+
|
35 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
36 |
+
if center is not None:
|
37 |
+
center = np.array(center)
|
38 |
+
else:
|
39 |
+
center = np.array([w / 2, h / 2])
|
40 |
+
center_dist = []
|
41 |
+
for det_face in det_faces:
|
42 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
43 |
+
dist = np.linalg.norm(face_center - center)
|
44 |
+
center_dist.append(dist)
|
45 |
+
center_idx = center_dist.index(min(center_dist))
|
46 |
+
return det_faces[center_idx], center_idx
|
47 |
+
|
48 |
+
|
49 |
+
class FaceRestoreHelper(object):
|
50 |
+
"""Helper for the face restoration pipeline (base class)."""
|
51 |
+
|
52 |
+
def __init__(self,
|
53 |
+
upscale_factor,
|
54 |
+
face_size=512,
|
55 |
+
crop_ratio=(1, 1),
|
56 |
+
det_model='retinaface_resnet50',
|
57 |
+
save_ext='png',
|
58 |
+
template_3points=False,
|
59 |
+
pad_blur=False,
|
60 |
+
use_parse=False,
|
61 |
+
device=None):
|
62 |
+
self.template_3points = template_3points # improve robustness
|
63 |
+
self.upscale_factor = int(upscale_factor)
|
64 |
+
# the cropped face ratio based on the square face
|
65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
68 |
+
self.det_model = det_model
|
69 |
+
|
70 |
+
if self.det_model == 'dlib':
|
71 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
72 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
73 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
74 |
+
[513.58415842, 678.5049505]])
|
75 |
+
self.face_template = self.face_template / (1024 // face_size)
|
76 |
+
elif self.template_3points:
|
77 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
78 |
+
else:
|
79 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
80 |
+
# facexlib
|
81 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
82 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
83 |
+
|
84 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
85 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
86 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
87 |
+
|
88 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
89 |
+
if self.crop_ratio[0] > 1:
|
90 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
91 |
+
if self.crop_ratio[1] > 1:
|
92 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
93 |
+
self.save_ext = save_ext
|
94 |
+
self.pad_blur = pad_blur
|
95 |
+
if self.pad_blur is True:
|
96 |
+
self.template_3points = False
|
97 |
+
|
98 |
+
self.all_landmarks_5 = []
|
99 |
+
self.det_faces = []
|
100 |
+
self.affine_matrices = []
|
101 |
+
self.inverse_affine_matrices = []
|
102 |
+
self.cropped_faces = []
|
103 |
+
self.restored_faces = []
|
104 |
+
self.pad_input_imgs = []
|
105 |
+
|
106 |
+
if device is None:
|
107 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
108 |
+
# self.device = get_device()
|
109 |
+
else:
|
110 |
+
self.device = device
|
111 |
+
|
112 |
+
# init face detection model
|
113 |
+
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
114 |
+
|
115 |
+
# init face parsing model
|
116 |
+
self.use_parse = use_parse
|
117 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
|
118 |
+
|
119 |
+
def set_upscale_factor(self, upscale_factor):
|
120 |
+
self.upscale_factor = upscale_factor
|
121 |
+
|
122 |
+
def read_image(self, img):
|
123 |
+
"""img can be image path or cv2 loaded image."""
|
124 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
125 |
+
if isinstance(img, str):
|
126 |
+
img = cv2.imread(img)
|
127 |
+
|
128 |
+
if np.max(img) > 256: # 16-bit image
|
129 |
+
img = img / 65535 * 255
|
130 |
+
if len(img.shape) == 2: # gray image
|
131 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
132 |
+
elif img.shape[2] == 4: # BGRA image with alpha channel
|
133 |
+
img = img[:, :, 0:3]
|
134 |
+
|
135 |
+
self.input_img = img
|
136 |
+
# self.is_gray = is_gray(img, threshold=10)
|
137 |
+
# if self.is_gray:
|
138 |
+
# print('Grayscale input: True')
|
139 |
+
|
140 |
+
if min(self.input_img.shape[:2]) < 512:
|
141 |
+
f = 512.0 / min(self.input_img.shape[:2])
|
142 |
+
self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
143 |
+
|
144 |
+
def init_dlib(self, detection_path, landmark5_path):
|
145 |
+
"""Initialize the dlib detectors and predictors."""
|
146 |
+
try:
|
147 |
+
import dlib
|
148 |
+
except ImportError:
|
149 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
150 |
+
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
151 |
+
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
152 |
+
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
153 |
+
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
154 |
+
return face_detector, shape_predictor_5
|
155 |
+
|
156 |
+
def get_face_landmarks_5_dlib(self,
|
157 |
+
only_keep_largest=False,
|
158 |
+
scale=1):
|
159 |
+
det_faces = self.face_detector(self.input_img, scale)
|
160 |
+
|
161 |
+
if len(det_faces) == 0:
|
162 |
+
print('No face detected. Try to increase upsample_num_times.')
|
163 |
+
return 0
|
164 |
+
else:
|
165 |
+
if only_keep_largest:
|
166 |
+
print('Detect several faces and only keep the largest.')
|
167 |
+
face_areas = []
|
168 |
+
for i in range(len(det_faces)):
|
169 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
170 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
171 |
+
face_areas.append(face_area)
|
172 |
+
largest_idx = face_areas.index(max(face_areas))
|
173 |
+
self.det_faces = [det_faces[largest_idx]]
|
174 |
+
else:
|
175 |
+
self.det_faces = det_faces
|
176 |
+
|
177 |
+
if len(self.det_faces) == 0:
|
178 |
+
return 0
|
179 |
+
|
180 |
+
for face in self.det_faces:
|
181 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
182 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
183 |
+
self.all_landmarks_5.append(landmark)
|
184 |
+
|
185 |
+
return len(self.all_landmarks_5)
|
186 |
+
|
187 |
+
def get_face_landmarks_5(self,
|
188 |
+
only_keep_largest=False,
|
189 |
+
only_center_face=False,
|
190 |
+
resize=None,
|
191 |
+
blur_ratio=0.01,
|
192 |
+
eye_dist_threshold=None):
|
193 |
+
if self.det_model == 'dlib':
|
194 |
+
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
195 |
+
|
196 |
+
if resize is None:
|
197 |
+
scale = 1
|
198 |
+
input_img = self.input_img
|
199 |
+
else:
|
200 |
+
h, w = self.input_img.shape[0:2]
|
201 |
+
scale = resize / min(h, w)
|
202 |
+
scale = max(1, scale) # always scale up
|
203 |
+
h, w = int(h * scale), int(w * scale)
|
204 |
+
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
205 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
209 |
+
|
210 |
+
if bboxes is None or bboxes.shape[0] == 0:
|
211 |
+
return 0
|
212 |
+
else:
|
213 |
+
bboxes = bboxes / scale
|
214 |
+
|
215 |
+
for bbox in bboxes:
|
216 |
+
# remove faces with too small eye distance: side faces or too small faces
|
217 |
+
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
218 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
219 |
+
continue
|
220 |
+
|
221 |
+
if self.template_3points:
|
222 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
223 |
+
else:
|
224 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
225 |
+
self.all_landmarks_5.append(landmark)
|
226 |
+
self.det_faces.append(bbox[0:5])
|
227 |
+
|
228 |
+
if len(self.det_faces) == 0:
|
229 |
+
return 0
|
230 |
+
if only_keep_largest:
|
231 |
+
h, w, _ = self.input_img.shape
|
232 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
233 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
234 |
+
elif only_center_face:
|
235 |
+
h, w, _ = self.input_img.shape
|
236 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
237 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
238 |
+
|
239 |
+
# pad blurry images
|
240 |
+
if self.pad_blur:
|
241 |
+
self.pad_input_imgs = []
|
242 |
+
for landmarks in self.all_landmarks_5:
|
243 |
+
# get landmarks
|
244 |
+
eye_left = landmarks[0, :]
|
245 |
+
eye_right = landmarks[1, :]
|
246 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
247 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
248 |
+
eye_to_eye = eye_right - eye_left
|
249 |
+
eye_to_mouth = mouth_avg - eye_avg
|
250 |
+
|
251 |
+
# Get the oriented crop rectangle
|
252 |
+
# x: half width of the oriented crop rectangle
|
253 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
254 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
255 |
+
# norm with the hypotenuse: get the direction
|
256 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
257 |
+
rect_scale = 1.5
|
258 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
259 |
+
# y: half height of the oriented crop rectangle
|
260 |
+
y = np.flipud(x) * [-1, 1]
|
261 |
+
|
262 |
+
# c: center
|
263 |
+
c = eye_avg + eye_to_mouth * 0.1
|
264 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
265 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
266 |
+
# qsize: side length of the square
|
267 |
+
qsize = np.hypot(*x) * 2
|
268 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
269 |
+
|
270 |
+
# get pad
|
271 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
272 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
273 |
+
int(np.ceil(max(quad[:, 1]))))
|
274 |
+
pad = [
|
275 |
+
max(-pad[0] + border, 1),
|
276 |
+
max(-pad[1] + border, 1),
|
277 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
278 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
279 |
+
]
|
280 |
+
|
281 |
+
if max(pad) > 1:
|
282 |
+
# pad image
|
283 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
284 |
+
# modify landmark coords
|
285 |
+
landmarks[:, 0] += pad[0]
|
286 |
+
landmarks[:, 1] += pad[1]
|
287 |
+
# blur pad images
|
288 |
+
h, w, _ = pad_img.shape
|
289 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
290 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
291 |
+
np.float32(w - 1 - x) / pad[2]),
|
292 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
293 |
+
np.float32(h - 1 - y) / pad[3]))
|
294 |
+
blur = int(qsize * blur_ratio)
|
295 |
+
if blur % 2 == 0:
|
296 |
+
blur += 1
|
297 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
298 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
299 |
+
|
300 |
+
pad_img = pad_img.astype('float32')
|
301 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
302 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
303 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
304 |
+
self.pad_input_imgs.append(pad_img)
|
305 |
+
else:
|
306 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
307 |
+
|
308 |
+
return len(self.all_landmarks_5)
|
309 |
+
|
310 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
311 |
+
"""Align and warp faces with face template.
|
312 |
+
"""
|
313 |
+
if self.pad_blur:
|
314 |
+
assert len(self.pad_input_imgs) == len(
|
315 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
316 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
317 |
+
# use 5 landmarks to get affine matrix
|
318 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
319 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
320 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
321 |
+
self.affine_matrices.append(affine_matrix)
|
322 |
+
# warp and crop faces
|
323 |
+
if border_mode == 'constant':
|
324 |
+
border_mode = cv2.BORDER_CONSTANT
|
325 |
+
elif border_mode == 'reflect101':
|
326 |
+
border_mode = cv2.BORDER_REFLECT101
|
327 |
+
elif border_mode == 'reflect':
|
328 |
+
border_mode = cv2.BORDER_REFLECT
|
329 |
+
if self.pad_blur:
|
330 |
+
input_img = self.pad_input_imgs[idx]
|
331 |
+
else:
|
332 |
+
input_img = self.input_img
|
333 |
+
cropped_face = cv2.warpAffine(
|
334 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
335 |
+
self.cropped_faces.append(cropped_face)
|
336 |
+
# save the cropped face
|
337 |
+
if save_cropped_path is not None:
|
338 |
+
path = os.path.splitext(save_cropped_path)[0]
|
339 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
340 |
+
imwrite(cropped_face, save_path)
|
341 |
+
|
342 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
343 |
+
"""Get inverse affine matrix."""
|
344 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
345 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
346 |
+
inverse_affine *= self.upscale_factor
|
347 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
348 |
+
# save inverse affine matrices
|
349 |
+
if save_inverse_affine_path is not None:
|
350 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
351 |
+
save_path = f'{path}_{idx:02d}.pth'
|
352 |
+
torch.save(inverse_affine, save_path)
|
353 |
+
|
354 |
+
def add_restored_face(self, restored_face, input_face=None):
|
355 |
+
# if self.is_gray:
|
356 |
+
# restored_face = bgr2gray(restored_face) # convert img into grayscale
|
357 |
+
# if input_face is not None:
|
358 |
+
# restored_face = adain_npy(restored_face, input_face) # transfer the color
|
359 |
+
self.restored_faces.append(restored_face)
|
360 |
+
|
361 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
362 |
+
h, w, _ = self.input_img.shape
|
363 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
364 |
+
|
365 |
+
if upsample_img is None:
|
366 |
+
# simply resize the background
|
367 |
+
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
368 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
369 |
+
else:
|
370 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
371 |
+
|
372 |
+
assert len(self.restored_faces) == len(
|
373 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
374 |
+
|
375 |
+
inv_mask_borders = []
|
376 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
377 |
+
if face_upsampler is not None:
|
378 |
+
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
379 |
+
inverse_affine /= self.upscale_factor
|
380 |
+
inverse_affine[:, 2] *= self.upscale_factor
|
381 |
+
face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
|
382 |
+
else:
|
383 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
384 |
+
if self.upscale_factor > 1:
|
385 |
+
extra_offset = 0.5 * self.upscale_factor
|
386 |
+
else:
|
387 |
+
extra_offset = 0
|
388 |
+
inverse_affine[:, 2] += extra_offset
|
389 |
+
face_size = self.face_size
|
390 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
391 |
+
|
392 |
+
# if draw_box or not self.use_parse: # use square parse maps
|
393 |
+
# mask = np.ones(face_size, dtype=np.float32)
|
394 |
+
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
395 |
+
# # remove the black borders
|
396 |
+
# inv_mask_erosion = cv2.erode(
|
397 |
+
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
398 |
+
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
399 |
+
# total_face_area = np.sum(inv_mask_erosion) # // 3
|
400 |
+
# # add border
|
401 |
+
# if draw_box:
|
402 |
+
# h, w = face_size
|
403 |
+
# mask_border = np.ones((h, w, 3), dtype=np.float32)
|
404 |
+
# border = int(1400/np.sqrt(total_face_area))
|
405 |
+
# mask_border[border:h-border, border:w-border,:] = 0
|
406 |
+
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
407 |
+
# inv_mask_borders.append(inv_mask_border)
|
408 |
+
# if not self.use_parse:
|
409 |
+
# # compute the fusion edge based on the area of face
|
410 |
+
# w_edge = int(total_face_area**0.5) // 20
|
411 |
+
# erosion_radius = w_edge * 2
|
412 |
+
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
413 |
+
# blur_size = w_edge * 2
|
414 |
+
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
415 |
+
# if len(upsample_img.shape) == 2: # upsample_img is gray image
|
416 |
+
# upsample_img = upsample_img[:, :, None]
|
417 |
+
# inv_soft_mask = inv_soft_mask[:, :, None]
|
418 |
+
|
419 |
+
# always use square mask
|
420 |
+
mask = np.ones(face_size, dtype=np.float32)
|
421 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
422 |
+
# remove the black borders
|
423 |
+
inv_mask_erosion = cv2.erode(
|
424 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
425 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
426 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
427 |
+
# add border
|
428 |
+
if draw_box:
|
429 |
+
h, w = face_size
|
430 |
+
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
431 |
+
border = int(1400 / np.sqrt(total_face_area))
|
432 |
+
mask_border[border:h - border, border:w - border, :] = 0
|
433 |
+
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
434 |
+
inv_mask_borders.append(inv_mask_border)
|
435 |
+
# compute the fusion edge based on the area of face
|
436 |
+
w_edge = int(total_face_area ** 0.5) // 20
|
437 |
+
erosion_radius = w_edge * 2
|
438 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
439 |
+
blur_size = w_edge * 2
|
440 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
441 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
442 |
+
upsample_img = upsample_img[:, :, None]
|
443 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
444 |
+
|
445 |
+
# parse mask
|
446 |
+
if self.use_parse:
|
447 |
+
# inference
|
448 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
449 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
450 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
451 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
452 |
+
with torch.no_grad():
|
453 |
+
out = self.face_parse(face_input)[0]
|
454 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
455 |
+
|
456 |
+
parse_mask = np.zeros(out.shape)
|
457 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
458 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
459 |
+
parse_mask[out == idx] = color
|
460 |
+
# blur the mask
|
461 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
462 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
463 |
+
# remove the black borders
|
464 |
+
thres = 10
|
465 |
+
parse_mask[:thres, :] = 0
|
466 |
+
parse_mask[-thres:, :] = 0
|
467 |
+
parse_mask[:, :thres] = 0
|
468 |
+
parse_mask[:, -thres:] = 0
|
469 |
+
parse_mask = parse_mask / 255.
|
470 |
+
|
471 |
+
parse_mask = cv2.resize(parse_mask, face_size)
|
472 |
+
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
473 |
+
inv_soft_parse_mask = parse_mask[:, :, None]
|
474 |
+
# pasted_face = inv_restored
|
475 |
+
fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
|
476 |
+
inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
|
477 |
+
|
478 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
479 |
+
alpha = upsample_img[:, :, 3:]
|
480 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
481 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
482 |
+
else:
|
483 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
484 |
+
|
485 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
486 |
+
upsample_img = upsample_img.astype(np.uint16)
|
487 |
+
else:
|
488 |
+
upsample_img = upsample_img.astype(np.uint8)
|
489 |
+
|
490 |
+
# draw bounding box
|
491 |
+
if draw_box:
|
492 |
+
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
493 |
+
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
494 |
+
img_color[:, :, 0] = 0
|
495 |
+
img_color[:, :, 1] = 255
|
496 |
+
img_color[:, :, 2] = 0
|
497 |
+
for inv_mask_border in inv_mask_borders:
|
498 |
+
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
499 |
+
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
500 |
+
|
501 |
+
if save_path is not None:
|
502 |
+
path = os.path.splitext(save_path)[0]
|
503 |
+
save_path = f'{path}.{self.save_ext}'
|
504 |
+
imwrite(upsample_img, save_path)
|
505 |
+
return upsample_img
|
506 |
+
|
507 |
+
def clean_all(self):
|
508 |
+
self.all_landmarks_5 = []
|
509 |
+
self.restored_faces = []
|
510 |
+
self.affine_matrices = []
|
511 |
+
self.cropped_faces = []
|
512 |
+
self.inverse_affine_matrices = []
|
513 |
+
self.det_faces = []
|
514 |
+
self.pad_input_imgs = []
|
SUPIR/utils/file.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
from torch.hub import download_url_to_file, get_dir
|
6 |
+
|
7 |
+
|
8 |
+
def load_file_list(file_list_path: str) -> List[str]:
|
9 |
+
files = []
|
10 |
+
# each line in file list contains a path of an image
|
11 |
+
with open(file_list_path, "r") as fin:
|
12 |
+
for line in fin:
|
13 |
+
path = line.strip()
|
14 |
+
if path:
|
15 |
+
files.append(path)
|
16 |
+
return files
|
17 |
+
|
18 |
+
|
19 |
+
def list_image_files(
|
20 |
+
img_dir: str,
|
21 |
+
exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
|
22 |
+
follow_links: bool=False,
|
23 |
+
log_progress: bool=False,
|
24 |
+
log_every_n_files: int=10000,
|
25 |
+
max_size: int=-1
|
26 |
+
) -> List[str]:
|
27 |
+
files = []
|
28 |
+
for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
|
29 |
+
early_stop = False
|
30 |
+
for file_name in file_names:
|
31 |
+
if os.path.splitext(file_name)[1].lower() in exts:
|
32 |
+
if max_size >= 0 and len(files) >= max_size:
|
33 |
+
early_stop = True
|
34 |
+
break
|
35 |
+
files.append(os.path.join(dir_path, file_name))
|
36 |
+
if log_progress and len(files) % log_every_n_files == 0:
|
37 |
+
print(f"find {len(files)} images in {img_dir}")
|
38 |
+
if early_stop:
|
39 |
+
break
|
40 |
+
return files
|
41 |
+
|
42 |
+
|
43 |
+
def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
|
44 |
+
parent_path, file_name = os.path.split(file_path)
|
45 |
+
stem, ext = os.path.splitext(file_name)
|
46 |
+
return parent_path, stem, ext
|
47 |
+
|
48 |
+
|
49 |
+
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
|
50 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
51 |
+
"""Load file form http url, will download models if necessary.
|
52 |
+
|
53 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
54 |
+
|
55 |
+
Args:
|
56 |
+
url (str): URL to be downloaded.
|
57 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
58 |
+
Default: None.
|
59 |
+
progress (bool): Whether to show the download progress. Default: True.
|
60 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: The path to the downloaded file.
|
64 |
+
"""
|
65 |
+
if model_dir is None: # use the pytorch hub_dir
|
66 |
+
hub_dir = get_dir()
|
67 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
68 |
+
|
69 |
+
os.makedirs(model_dir, exist_ok=True)
|
70 |
+
|
71 |
+
parts = urlparse(url)
|
72 |
+
filename = os.path.basename(parts.path)
|
73 |
+
if file_name is not None:
|
74 |
+
filename = file_name
|
75 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
76 |
+
if not os.path.exists(cached_file):
|
77 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
78 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
79 |
+
return cached_file
|
SUPIR/utils/tilevae.py
ADDED
@@ -0,0 +1,971 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
#
|
3 |
+
# Ultimate VAE Tile Optimization
|
4 |
+
#
|
5 |
+
# Introducing a revolutionary new optimization designed to make
|
6 |
+
# the VAE work with giant images on limited VRAM!
|
7 |
+
# Say goodbye to the frustration of OOM and hello to seamless output!
|
8 |
+
#
|
9 |
+
# ------------------------------------------------------------------------
|
10 |
+
#
|
11 |
+
# This script is a wild hack that splits the image into tiles,
|
12 |
+
# encodes each tile separately, and merges the result back together.
|
13 |
+
#
|
14 |
+
# Advantages:
|
15 |
+
# - The VAE can now work with giant images on limited VRAM
|
16 |
+
# (~10 GB for 8K images!)
|
17 |
+
# - The merged output is completely seamless without any post-processing.
|
18 |
+
#
|
19 |
+
# Drawbacks:
|
20 |
+
# - Giant RAM needed. To store the intermediate results for a 4096x4096
|
21 |
+
# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
|
22 |
+
# you need 128 GB RAM machine (it consumes ~100 GB)
|
23 |
+
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
24 |
+
# You must use --no-half-vae to disable half VAE for that giant image.
|
25 |
+
# - Slow speed. With default tile size, it takes around 50/200 seconds
|
26 |
+
# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
|
27 |
+
# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
|
28 |
+
# - The gradient calculation is not compatible with this hack. It
|
29 |
+
# will break any backward() or torch.autograd.grad() that passes VAE.
|
30 |
+
# (But you can still use the VAE to generate training data.)
|
31 |
+
#
|
32 |
+
# How it works:
|
33 |
+
# 1) The image is split into tiles.
|
34 |
+
# - To ensure perfect results, each tile is padded with 32 pixels
|
35 |
+
# on each side.
|
36 |
+
# - Then the conv2d/silu/upsample/downsample can produce identical
|
37 |
+
# results to the original image without splitting.
|
38 |
+
# 2) The original forward is decomposed into a task queue and a task worker.
|
39 |
+
# - The task queue is a list of functions that will be executed in order.
|
40 |
+
# - The task worker is a loop that executes the tasks in the queue.
|
41 |
+
# 3) The task queue is executed for each tile.
|
42 |
+
# - Current tile is sent to GPU.
|
43 |
+
# - local operations are directly executed.
|
44 |
+
# - Group norm calculation is temporarily suspended until the mean
|
45 |
+
# and var of all tiles are calculated.
|
46 |
+
# - The residual is pre-calculated and stored and addded back later.
|
47 |
+
# - When need to go to the next tile, the current tile is send to cpu.
|
48 |
+
# 4) After all tiles are processed, tiles are merged on cpu and return.
|
49 |
+
#
|
50 |
+
# Enjoy!
|
51 |
+
#
|
52 |
+
# @author: LI YI @ Nanyang Technological University - Singapore
|
53 |
+
# @date: 2023-03-02
|
54 |
+
# @license: MIT License
|
55 |
+
#
|
56 |
+
# Please give me a star if you like this project!
|
57 |
+
#
|
58 |
+
# -------------------------------------------------------------------------
|
59 |
+
|
60 |
+
import gc
|
61 |
+
from time import time
|
62 |
+
import math
|
63 |
+
from tqdm import tqdm
|
64 |
+
|
65 |
+
import torch
|
66 |
+
import torch.version
|
67 |
+
import torch.nn.functional as F
|
68 |
+
from einops import rearrange
|
69 |
+
from diffusers.utils.import_utils import is_xformers_available
|
70 |
+
|
71 |
+
import SUPIR.utils.devices as devices
|
72 |
+
|
73 |
+
try:
|
74 |
+
import xformers
|
75 |
+
import xformers.ops
|
76 |
+
except ImportError:
|
77 |
+
pass
|
78 |
+
|
79 |
+
sd_flag = True
|
80 |
+
|
81 |
+
def get_recommend_encoder_tile_size():
|
82 |
+
if torch.cuda.is_available():
|
83 |
+
total_memory = torch.cuda.get_device_properties(
|
84 |
+
devices.device).total_memory // 2**20
|
85 |
+
if total_memory > 16*1000:
|
86 |
+
ENCODER_TILE_SIZE = 3072
|
87 |
+
elif total_memory > 12*1000:
|
88 |
+
ENCODER_TILE_SIZE = 2048
|
89 |
+
elif total_memory > 8*1000:
|
90 |
+
ENCODER_TILE_SIZE = 1536
|
91 |
+
else:
|
92 |
+
ENCODER_TILE_SIZE = 960
|
93 |
+
else:
|
94 |
+
ENCODER_TILE_SIZE = 512
|
95 |
+
return ENCODER_TILE_SIZE
|
96 |
+
|
97 |
+
|
98 |
+
def get_recommend_decoder_tile_size():
|
99 |
+
if torch.cuda.is_available():
|
100 |
+
total_memory = torch.cuda.get_device_properties(
|
101 |
+
devices.device).total_memory // 2**20
|
102 |
+
if total_memory > 30*1000:
|
103 |
+
DECODER_TILE_SIZE = 256
|
104 |
+
elif total_memory > 16*1000:
|
105 |
+
DECODER_TILE_SIZE = 192
|
106 |
+
elif total_memory > 12*1000:
|
107 |
+
DECODER_TILE_SIZE = 128
|
108 |
+
elif total_memory > 8*1000:
|
109 |
+
DECODER_TILE_SIZE = 96
|
110 |
+
else:
|
111 |
+
DECODER_TILE_SIZE = 64
|
112 |
+
else:
|
113 |
+
DECODER_TILE_SIZE = 64
|
114 |
+
return DECODER_TILE_SIZE
|
115 |
+
|
116 |
+
|
117 |
+
if 'global const':
|
118 |
+
DEFAULT_ENABLED = False
|
119 |
+
DEFAULT_MOVE_TO_GPU = False
|
120 |
+
DEFAULT_FAST_ENCODER = True
|
121 |
+
DEFAULT_FAST_DECODER = True
|
122 |
+
DEFAULT_COLOR_FIX = 0
|
123 |
+
DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
|
124 |
+
DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
|
125 |
+
|
126 |
+
|
127 |
+
# inplace version of silu
|
128 |
+
def inplace_nonlinearity(x):
|
129 |
+
# Test: fix for Nans
|
130 |
+
return F.silu(x, inplace=True)
|
131 |
+
|
132 |
+
# extracted from ldm.modules.diffusionmodules.model
|
133 |
+
|
134 |
+
# from diffusers lib
|
135 |
+
def attn_forward_new(self, h_):
|
136 |
+
batch_size, channel, height, width = h_.shape
|
137 |
+
hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
|
138 |
+
|
139 |
+
attention_mask = None
|
140 |
+
encoder_hidden_states = None
|
141 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
142 |
+
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
143 |
+
|
144 |
+
query = self.to_q(hidden_states)
|
145 |
+
|
146 |
+
if encoder_hidden_states is None:
|
147 |
+
encoder_hidden_states = hidden_states
|
148 |
+
elif self.norm_cross:
|
149 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
150 |
+
|
151 |
+
key = self.to_k(encoder_hidden_states)
|
152 |
+
value = self.to_v(encoder_hidden_states)
|
153 |
+
|
154 |
+
query = self.head_to_batch_dim(query)
|
155 |
+
key = self.head_to_batch_dim(key)
|
156 |
+
value = self.head_to_batch_dim(value)
|
157 |
+
|
158 |
+
attention_probs = self.get_attention_scores(query, key, attention_mask)
|
159 |
+
hidden_states = torch.bmm(attention_probs, value)
|
160 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
161 |
+
|
162 |
+
# linear proj
|
163 |
+
hidden_states = self.to_out[0](hidden_states)
|
164 |
+
# dropout
|
165 |
+
hidden_states = self.to_out[1](hidden_states)
|
166 |
+
|
167 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
168 |
+
|
169 |
+
return hidden_states
|
170 |
+
|
171 |
+
def attn_forward_new_pt2_0(self, hidden_states,):
|
172 |
+
scale = 1
|
173 |
+
attention_mask = None
|
174 |
+
encoder_hidden_states = None
|
175 |
+
|
176 |
+
input_ndim = hidden_states.ndim
|
177 |
+
|
178 |
+
if input_ndim == 4:
|
179 |
+
batch_size, channel, height, width = hidden_states.shape
|
180 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
181 |
+
|
182 |
+
batch_size, sequence_length, _ = (
|
183 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
184 |
+
)
|
185 |
+
|
186 |
+
if attention_mask is not None:
|
187 |
+
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
188 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
189 |
+
# (batch, heads, source_length, target_length)
|
190 |
+
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
|
191 |
+
|
192 |
+
if self.group_norm is not None:
|
193 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
194 |
+
|
195 |
+
query = self.to_q(hidden_states, scale=scale)
|
196 |
+
|
197 |
+
if encoder_hidden_states is None:
|
198 |
+
encoder_hidden_states = hidden_states
|
199 |
+
elif self.norm_cross:
|
200 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
201 |
+
|
202 |
+
key = self.to_k(encoder_hidden_states, scale=scale)
|
203 |
+
value = self.to_v(encoder_hidden_states, scale=scale)
|
204 |
+
|
205 |
+
inner_dim = key.shape[-1]
|
206 |
+
head_dim = inner_dim // self.heads
|
207 |
+
|
208 |
+
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
209 |
+
|
210 |
+
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
211 |
+
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
212 |
+
|
213 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
214 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
215 |
+
hidden_states = F.scaled_dot_product_attention(
|
216 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
217 |
+
)
|
218 |
+
|
219 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
220 |
+
hidden_states = hidden_states.to(query.dtype)
|
221 |
+
|
222 |
+
# linear proj
|
223 |
+
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
224 |
+
# dropout
|
225 |
+
hidden_states = self.to_out[1](hidden_states)
|
226 |
+
|
227 |
+
if input_ndim == 4:
|
228 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
229 |
+
|
230 |
+
return hidden_states
|
231 |
+
|
232 |
+
def attn_forward_new_xformers(self, hidden_states):
|
233 |
+
scale = 1
|
234 |
+
attention_op = None
|
235 |
+
attention_mask = None
|
236 |
+
encoder_hidden_states = None
|
237 |
+
|
238 |
+
input_ndim = hidden_states.ndim
|
239 |
+
|
240 |
+
if input_ndim == 4:
|
241 |
+
batch_size, channel, height, width = hidden_states.shape
|
242 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
243 |
+
|
244 |
+
batch_size, key_tokens, _ = (
|
245 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
246 |
+
)
|
247 |
+
|
248 |
+
attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
249 |
+
if attention_mask is not None:
|
250 |
+
# expand our mask's singleton query_tokens dimension:
|
251 |
+
# [batch*heads, 1, key_tokens] ->
|
252 |
+
# [batch*heads, query_tokens, key_tokens]
|
253 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
254 |
+
# [batch*heads, query_tokens, key_tokens]
|
255 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
256 |
+
_, query_tokens, _ = hidden_states.shape
|
257 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
258 |
+
|
259 |
+
if self.group_norm is not None:
|
260 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
261 |
+
|
262 |
+
query = self.to_q(hidden_states, scale=scale)
|
263 |
+
|
264 |
+
if encoder_hidden_states is None:
|
265 |
+
encoder_hidden_states = hidden_states
|
266 |
+
elif self.norm_cross:
|
267 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
268 |
+
|
269 |
+
key = self.to_k(encoder_hidden_states, scale=scale)
|
270 |
+
value = self.to_v(encoder_hidden_states, scale=scale)
|
271 |
+
|
272 |
+
query = self.head_to_batch_dim(query).contiguous()
|
273 |
+
key = self.head_to_batch_dim(key).contiguous()
|
274 |
+
value = self.head_to_batch_dim(value).contiguous()
|
275 |
+
|
276 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
277 |
+
query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
|
278 |
+
)
|
279 |
+
hidden_states = hidden_states.to(query.dtype)
|
280 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
281 |
+
|
282 |
+
# linear proj
|
283 |
+
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
284 |
+
# dropout
|
285 |
+
hidden_states = self.to_out[1](hidden_states)
|
286 |
+
|
287 |
+
if input_ndim == 4:
|
288 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
289 |
+
|
290 |
+
return hidden_states
|
291 |
+
|
292 |
+
def attn_forward(self, h_):
|
293 |
+
q = self.q(h_)
|
294 |
+
k = self.k(h_)
|
295 |
+
v = self.v(h_)
|
296 |
+
|
297 |
+
# compute attention
|
298 |
+
b, c, h, w = q.shape
|
299 |
+
q = q.reshape(b, c, h*w)
|
300 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
301 |
+
k = k.reshape(b, c, h*w) # b,c,hw
|
302 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
303 |
+
w_ = w_ * (int(c)**(-0.5))
|
304 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
305 |
+
|
306 |
+
# attend to values
|
307 |
+
v = v.reshape(b, c, h*w)
|
308 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
309 |
+
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
310 |
+
h_ = torch.bmm(v, w_)
|
311 |
+
h_ = h_.reshape(b, c, h, w)
|
312 |
+
|
313 |
+
h_ = self.proj_out(h_)
|
314 |
+
|
315 |
+
return h_
|
316 |
+
|
317 |
+
|
318 |
+
def xformer_attn_forward(self, h_):
|
319 |
+
q = self.q(h_)
|
320 |
+
k = self.k(h_)
|
321 |
+
v = self.v(h_)
|
322 |
+
|
323 |
+
# compute attention
|
324 |
+
B, C, H, W = q.shape
|
325 |
+
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
326 |
+
|
327 |
+
q, k, v = map(
|
328 |
+
lambda t: t.unsqueeze(3)
|
329 |
+
.reshape(B, t.shape[1], 1, C)
|
330 |
+
.permute(0, 2, 1, 3)
|
331 |
+
.reshape(B * 1, t.shape[1], C)
|
332 |
+
.contiguous(),
|
333 |
+
(q, k, v),
|
334 |
+
)
|
335 |
+
out = xformers.ops.memory_efficient_attention(
|
336 |
+
q, k, v, attn_bias=None, op=self.attention_op)
|
337 |
+
|
338 |
+
out = (
|
339 |
+
out.unsqueeze(0)
|
340 |
+
.reshape(B, 1, out.shape[1], C)
|
341 |
+
.permute(0, 2, 1, 3)
|
342 |
+
.reshape(B, out.shape[1], C)
|
343 |
+
)
|
344 |
+
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
345 |
+
out = self.proj_out(out)
|
346 |
+
return out
|
347 |
+
|
348 |
+
|
349 |
+
def attn2task(task_queue, net):
|
350 |
+
if False: #isinstance(net, AttnBlock):
|
351 |
+
task_queue.append(('store_res', lambda x: x))
|
352 |
+
task_queue.append(('pre_norm', net.norm))
|
353 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
354 |
+
task_queue.append(['add_res', None])
|
355 |
+
elif False: #isinstance(net, MemoryEfficientAttnBlock):
|
356 |
+
task_queue.append(('store_res', lambda x: x))
|
357 |
+
task_queue.append(('pre_norm', net.norm))
|
358 |
+
task_queue.append(
|
359 |
+
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
360 |
+
task_queue.append(['add_res', None])
|
361 |
+
else:
|
362 |
+
task_queue.append(('store_res', lambda x: x))
|
363 |
+
task_queue.append(('pre_norm', net.norm))
|
364 |
+
if is_xformers_available:
|
365 |
+
# task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
|
366 |
+
task_queue.append(
|
367 |
+
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
368 |
+
elif hasattr(F, "scaled_dot_product_attention"):
|
369 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
|
370 |
+
else:
|
371 |
+
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
|
372 |
+
task_queue.append(['add_res', None])
|
373 |
+
|
374 |
+
def resblock2task(queue, block):
|
375 |
+
"""
|
376 |
+
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
377 |
+
|
378 |
+
@param queue: the target task queue
|
379 |
+
@param block: ResNetBlock
|
380 |
+
|
381 |
+
"""
|
382 |
+
if block.in_channels != block.out_channels:
|
383 |
+
if sd_flag:
|
384 |
+
if block.use_conv_shortcut:
|
385 |
+
queue.append(('store_res', block.conv_shortcut))
|
386 |
+
else:
|
387 |
+
queue.append(('store_res', block.nin_shortcut))
|
388 |
+
else:
|
389 |
+
if block.use_in_shortcut:
|
390 |
+
queue.append(('store_res', block.conv_shortcut))
|
391 |
+
else:
|
392 |
+
queue.append(('store_res', block.nin_shortcut))
|
393 |
+
|
394 |
+
else:
|
395 |
+
queue.append(('store_res', lambda x: x))
|
396 |
+
queue.append(('pre_norm', block.norm1))
|
397 |
+
queue.append(('silu', inplace_nonlinearity))
|
398 |
+
queue.append(('conv1', block.conv1))
|
399 |
+
queue.append(('pre_norm', block.norm2))
|
400 |
+
queue.append(('silu', inplace_nonlinearity))
|
401 |
+
queue.append(('conv2', block.conv2))
|
402 |
+
queue.append(['add_res', None])
|
403 |
+
|
404 |
+
|
405 |
+
def build_sampling(task_queue, net, is_decoder):
|
406 |
+
"""
|
407 |
+
Build the sampling part of a task queue
|
408 |
+
@param task_queue: the target task queue
|
409 |
+
@param net: the network
|
410 |
+
@param is_decoder: currently building decoder or encoder
|
411 |
+
"""
|
412 |
+
if is_decoder:
|
413 |
+
if sd_flag:
|
414 |
+
resblock2task(task_queue, net.mid.block_1)
|
415 |
+
attn2task(task_queue, net.mid.attn_1)
|
416 |
+
# print(task_queue)
|
417 |
+
resblock2task(task_queue, net.mid.block_2)
|
418 |
+
resolution_iter = reversed(range(net.num_resolutions))
|
419 |
+
block_ids = net.num_res_blocks + 1
|
420 |
+
condition = 0
|
421 |
+
module = net.up
|
422 |
+
func_name = 'upsample'
|
423 |
+
else:
|
424 |
+
resblock2task(task_queue, net.mid_block.resnets[0])
|
425 |
+
attn2task(task_queue, net.mid_block.attentions[0])
|
426 |
+
resblock2task(task_queue, net.mid_block.resnets[1])
|
427 |
+
resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
|
428 |
+
block_ids = 2 + 1
|
429 |
+
condition = len(net.up_blocks) - 1
|
430 |
+
module = net.up_blocks
|
431 |
+
func_name = 'upsamplers'
|
432 |
+
else:
|
433 |
+
if sd_flag:
|
434 |
+
resolution_iter = range(net.num_resolutions)
|
435 |
+
block_ids = net.num_res_blocks
|
436 |
+
condition = net.num_resolutions - 1
|
437 |
+
module = net.down
|
438 |
+
func_name = 'downsample'
|
439 |
+
else:
|
440 |
+
resolution_iter = range(len(net.down_blocks))
|
441 |
+
block_ids = 2
|
442 |
+
condition = len(net.down_blocks) - 1
|
443 |
+
module = net.down_blocks
|
444 |
+
func_name = 'downsamplers'
|
445 |
+
|
446 |
+
for i_level in resolution_iter:
|
447 |
+
for i_block in range(block_ids):
|
448 |
+
if sd_flag:
|
449 |
+
resblock2task(task_queue, module[i_level].block[i_block])
|
450 |
+
else:
|
451 |
+
resblock2task(task_queue, module[i_level].resnets[i_block])
|
452 |
+
if i_level != condition:
|
453 |
+
if sd_flag:
|
454 |
+
task_queue.append((func_name, getattr(module[i_level], func_name)))
|
455 |
+
else:
|
456 |
+
if is_decoder:
|
457 |
+
task_queue.append((func_name, module[i_level].upsamplers[0]))
|
458 |
+
else:
|
459 |
+
task_queue.append((func_name, module[i_level].downsamplers[0]))
|
460 |
+
|
461 |
+
if not is_decoder:
|
462 |
+
if sd_flag:
|
463 |
+
resblock2task(task_queue, net.mid.block_1)
|
464 |
+
attn2task(task_queue, net.mid.attn_1)
|
465 |
+
resblock2task(task_queue, net.mid.block_2)
|
466 |
+
else:
|
467 |
+
resblock2task(task_queue, net.mid_block.resnets[0])
|
468 |
+
attn2task(task_queue, net.mid_block.attentions[0])
|
469 |
+
resblock2task(task_queue, net.mid_block.resnets[1])
|
470 |
+
|
471 |
+
|
472 |
+
def build_task_queue(net, is_decoder):
|
473 |
+
"""
|
474 |
+
Build a single task queue for the encoder or decoder
|
475 |
+
@param net: the VAE decoder or encoder network
|
476 |
+
@param is_decoder: currently building decoder or encoder
|
477 |
+
@return: the task queue
|
478 |
+
"""
|
479 |
+
task_queue = []
|
480 |
+
task_queue.append(('conv_in', net.conv_in))
|
481 |
+
|
482 |
+
# construct the sampling part of the task queue
|
483 |
+
# because encoder and decoder share the same architecture, we extract the sampling part
|
484 |
+
build_sampling(task_queue, net, is_decoder)
|
485 |
+
if is_decoder and not sd_flag:
|
486 |
+
net.give_pre_end = False
|
487 |
+
net.tanh_out = False
|
488 |
+
|
489 |
+
if not is_decoder or not net.give_pre_end:
|
490 |
+
if sd_flag:
|
491 |
+
task_queue.append(('pre_norm', net.norm_out))
|
492 |
+
else:
|
493 |
+
task_queue.append(('pre_norm', net.conv_norm_out))
|
494 |
+
task_queue.append(('silu', inplace_nonlinearity))
|
495 |
+
task_queue.append(('conv_out', net.conv_out))
|
496 |
+
if is_decoder and net.tanh_out:
|
497 |
+
task_queue.append(('tanh', torch.tanh))
|
498 |
+
|
499 |
+
return task_queue
|
500 |
+
|
501 |
+
|
502 |
+
def clone_task_queue(task_queue):
|
503 |
+
"""
|
504 |
+
Clone a task queue
|
505 |
+
@param task_queue: the task queue to be cloned
|
506 |
+
@return: the cloned task queue
|
507 |
+
"""
|
508 |
+
return [[item for item in task] for task in task_queue]
|
509 |
+
|
510 |
+
|
511 |
+
def get_var_mean(input, num_groups, eps=1e-6):
|
512 |
+
"""
|
513 |
+
Get mean and var for group norm
|
514 |
+
"""
|
515 |
+
b, c = input.size(0), input.size(1)
|
516 |
+
channel_in_group = int(c/num_groups)
|
517 |
+
input_reshaped = input.contiguous().view(
|
518 |
+
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
519 |
+
var, mean = torch.var_mean(
|
520 |
+
input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
521 |
+
return var, mean
|
522 |
+
|
523 |
+
|
524 |
+
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
525 |
+
"""
|
526 |
+
Custom group norm with fixed mean and var
|
527 |
+
|
528 |
+
@param input: input tensor
|
529 |
+
@param num_groups: number of groups. by default, num_groups = 32
|
530 |
+
@param mean: mean, must be pre-calculated by get_var_mean
|
531 |
+
@param var: var, must be pre-calculated by get_var_mean
|
532 |
+
@param weight: weight, should be fetched from the original group norm
|
533 |
+
@param bias: bias, should be fetched from the original group norm
|
534 |
+
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
535 |
+
|
536 |
+
@return: normalized tensor
|
537 |
+
"""
|
538 |
+
b, c = input.size(0), input.size(1)
|
539 |
+
channel_in_group = int(c/num_groups)
|
540 |
+
input_reshaped = input.contiguous().view(
|
541 |
+
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
542 |
+
|
543 |
+
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
|
544 |
+
training=False, momentum=0, eps=eps)
|
545 |
+
|
546 |
+
out = out.view(b, c, *input.size()[2:])
|
547 |
+
|
548 |
+
# post affine transform
|
549 |
+
if weight is not None:
|
550 |
+
out *= weight.view(1, -1, 1, 1)
|
551 |
+
if bias is not None:
|
552 |
+
out += bias.view(1, -1, 1, 1)
|
553 |
+
return out
|
554 |
+
|
555 |
+
|
556 |
+
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
557 |
+
"""
|
558 |
+
Crop the valid region from the tile
|
559 |
+
@param x: input tile
|
560 |
+
@param input_bbox: original input bounding box
|
561 |
+
@param target_bbox: output bounding box
|
562 |
+
@param scale: scale factor
|
563 |
+
@return: cropped tile
|
564 |
+
"""
|
565 |
+
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
566 |
+
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
567 |
+
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
568 |
+
|
569 |
+
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
570 |
+
|
571 |
+
|
572 |
+
def perfcount(fn):
|
573 |
+
def wrapper(*args, **kwargs):
|
574 |
+
ts = time()
|
575 |
+
|
576 |
+
if torch.cuda.is_available():
|
577 |
+
torch.cuda.reset_peak_memory_stats(devices.device)
|
578 |
+
devices.torch_gc()
|
579 |
+
gc.collect()
|
580 |
+
|
581 |
+
ret = fn(*args, **kwargs)
|
582 |
+
|
583 |
+
devices.torch_gc()
|
584 |
+
gc.collect()
|
585 |
+
if torch.cuda.is_available():
|
586 |
+
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
587 |
+
torch.cuda.reset_peak_memory_stats(devices.device)
|
588 |
+
print(
|
589 |
+
f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
590 |
+
else:
|
591 |
+
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
592 |
+
|
593 |
+
return ret
|
594 |
+
return wrapper
|
595 |
+
|
596 |
+
# copy end :)
|
597 |
+
|
598 |
+
|
599 |
+
class GroupNormParam:
|
600 |
+
def __init__(self):
|
601 |
+
self.var_list = []
|
602 |
+
self.mean_list = []
|
603 |
+
self.pixel_list = []
|
604 |
+
self.weight = None
|
605 |
+
self.bias = None
|
606 |
+
|
607 |
+
def add_tile(self, tile, layer):
|
608 |
+
var, mean = get_var_mean(tile, 32)
|
609 |
+
# For giant images, the variance can be larger than max float16
|
610 |
+
# In this case we create a copy to float32
|
611 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
612 |
+
fp32_tile = tile.float()
|
613 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
614 |
+
# ============= DEBUG: test for infinite =============
|
615 |
+
# if torch.isinf(var).any():
|
616 |
+
# print('var: ', var)
|
617 |
+
# ====================================================
|
618 |
+
self.var_list.append(var)
|
619 |
+
self.mean_list.append(mean)
|
620 |
+
self.pixel_list.append(
|
621 |
+
tile.shape[2]*tile.shape[3])
|
622 |
+
if hasattr(layer, 'weight'):
|
623 |
+
self.weight = layer.weight
|
624 |
+
self.bias = layer.bias
|
625 |
+
else:
|
626 |
+
self.weight = None
|
627 |
+
self.bias = None
|
628 |
+
|
629 |
+
def summary(self):
|
630 |
+
"""
|
631 |
+
summarize the mean and var and return a function
|
632 |
+
that apply group norm on each tile
|
633 |
+
"""
|
634 |
+
if len(self.var_list) == 0:
|
635 |
+
return None
|
636 |
+
var = torch.vstack(self.var_list)
|
637 |
+
mean = torch.vstack(self.mean_list)
|
638 |
+
max_value = max(self.pixel_list)
|
639 |
+
pixels = torch.tensor(
|
640 |
+
self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
641 |
+
sum_pixels = torch.sum(pixels)
|
642 |
+
pixels = pixels.unsqueeze(
|
643 |
+
1) / sum_pixels
|
644 |
+
var = torch.sum(
|
645 |
+
var * pixels, dim=0)
|
646 |
+
mean = torch.sum(
|
647 |
+
mean * pixels, dim=0)
|
648 |
+
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
649 |
+
|
650 |
+
@staticmethod
|
651 |
+
def from_tile(tile, norm):
|
652 |
+
"""
|
653 |
+
create a function from a single tile without summary
|
654 |
+
"""
|
655 |
+
var, mean = get_var_mean(tile, 32)
|
656 |
+
if var.dtype == torch.float16 and var.isinf().any():
|
657 |
+
fp32_tile = tile.float()
|
658 |
+
var, mean = get_var_mean(fp32_tile, 32)
|
659 |
+
# if it is a macbook, we need to convert back to float16
|
660 |
+
if var.device.type == 'mps':
|
661 |
+
# clamp to avoid overflow
|
662 |
+
var = torch.clamp(var, 0, 60000)
|
663 |
+
var = var.half()
|
664 |
+
mean = mean.half()
|
665 |
+
if hasattr(norm, 'weight'):
|
666 |
+
weight = norm.weight
|
667 |
+
bias = norm.bias
|
668 |
+
else:
|
669 |
+
weight = None
|
670 |
+
bias = None
|
671 |
+
|
672 |
+
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
673 |
+
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
674 |
+
return group_norm_func
|
675 |
+
|
676 |
+
|
677 |
+
class VAEHook:
|
678 |
+
def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
|
679 |
+
self.net = net # encoder | decoder
|
680 |
+
self.tile_size = tile_size
|
681 |
+
self.is_decoder = is_decoder
|
682 |
+
self.fast_mode = (fast_encoder and not is_decoder) or (
|
683 |
+
fast_decoder and is_decoder)
|
684 |
+
self.color_fix = color_fix and not is_decoder
|
685 |
+
self.to_gpu = to_gpu
|
686 |
+
self.pad = 11 if is_decoder else 32
|
687 |
+
|
688 |
+
def __call__(self, x):
|
689 |
+
B, C, H, W = x.shape
|
690 |
+
original_device = next(self.net.parameters()).device
|
691 |
+
try:
|
692 |
+
if self.to_gpu:
|
693 |
+
self.net.to(devices.get_optimal_device())
|
694 |
+
if max(H, W) <= self.pad * 2 + self.tile_size:
|
695 |
+
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
|
696 |
+
return self.net.original_forward(x)
|
697 |
+
else:
|
698 |
+
return self.vae_tile_forward(x)
|
699 |
+
finally:
|
700 |
+
self.net.to(original_device)
|
701 |
+
|
702 |
+
def get_best_tile_size(self, lowerbound, upperbound):
|
703 |
+
"""
|
704 |
+
Get the best tile size for GPU memory
|
705 |
+
"""
|
706 |
+
divider = 32
|
707 |
+
while divider >= 2:
|
708 |
+
remainer = lowerbound % divider
|
709 |
+
if remainer == 0:
|
710 |
+
return lowerbound
|
711 |
+
candidate = lowerbound - remainer + divider
|
712 |
+
if candidate <= upperbound:
|
713 |
+
return candidate
|
714 |
+
divider //= 2
|
715 |
+
return lowerbound
|
716 |
+
|
717 |
+
def split_tiles(self, h, w):
|
718 |
+
"""
|
719 |
+
Tool function to split the image into tiles
|
720 |
+
@param h: height of the image
|
721 |
+
@param w: width of the image
|
722 |
+
@return: tile_input_bboxes, tile_output_bboxes
|
723 |
+
"""
|
724 |
+
tile_input_bboxes, tile_output_bboxes = [], []
|
725 |
+
tile_size = self.tile_size
|
726 |
+
pad = self.pad
|
727 |
+
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
728 |
+
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
729 |
+
# If any of the numbers are 0, we let it be 1
|
730 |
+
# This is to deal with long and thin images
|
731 |
+
num_height_tiles = max(num_height_tiles, 1)
|
732 |
+
num_width_tiles = max(num_width_tiles, 1)
|
733 |
+
|
734 |
+
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
735 |
+
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
736 |
+
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
737 |
+
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
738 |
+
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
739 |
+
|
740 |
+
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
741 |
+
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
742 |
+
|
743 |
+
for i in range(num_height_tiles):
|
744 |
+
for j in range(num_width_tiles):
|
745 |
+
# bbox: [x1, x2, y1, y2]
|
746 |
+
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
747 |
+
input_bbox = [
|
748 |
+
pad + j * real_tile_width,
|
749 |
+
min(pad + (j + 1) * real_tile_width, w),
|
750 |
+
pad + i * real_tile_height,
|
751 |
+
min(pad + (i + 1) * real_tile_height, h),
|
752 |
+
]
|
753 |
+
|
754 |
+
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
755 |
+
output_bbox = [
|
756 |
+
input_bbox[0] if input_bbox[0] > pad else 0,
|
757 |
+
input_bbox[1] if input_bbox[1] < w - pad else w,
|
758 |
+
input_bbox[2] if input_bbox[2] > pad else 0,
|
759 |
+
input_bbox[3] if input_bbox[3] < h - pad else h,
|
760 |
+
]
|
761 |
+
|
762 |
+
# scale to get the final output bbox
|
763 |
+
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
764 |
+
tile_output_bboxes.append(output_bbox)
|
765 |
+
|
766 |
+
# indistinguishable expand the input bbox by pad pixels
|
767 |
+
tile_input_bboxes.append([
|
768 |
+
max(0, input_bbox[0] - pad),
|
769 |
+
min(w, input_bbox[1] + pad),
|
770 |
+
max(0, input_bbox[2] - pad),
|
771 |
+
min(h, input_bbox[3] + pad),
|
772 |
+
])
|
773 |
+
|
774 |
+
return tile_input_bboxes, tile_output_bboxes
|
775 |
+
|
776 |
+
@torch.no_grad()
|
777 |
+
def estimate_group_norm(self, z, task_queue, color_fix):
|
778 |
+
device = z.device
|
779 |
+
tile = z
|
780 |
+
last_id = len(task_queue) - 1
|
781 |
+
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
782 |
+
last_id -= 1
|
783 |
+
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
784 |
+
raise ValueError('No group norm found in the task queue')
|
785 |
+
# estimate until the last group norm
|
786 |
+
for i in range(last_id + 1):
|
787 |
+
task = task_queue[i]
|
788 |
+
if task[0] == 'pre_norm':
|
789 |
+
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
790 |
+
task_queue[i] = ('apply_norm', group_norm_func)
|
791 |
+
if i == last_id:
|
792 |
+
return True
|
793 |
+
tile = group_norm_func(tile)
|
794 |
+
elif task[0] == 'store_res':
|
795 |
+
task_id = i + 1
|
796 |
+
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
797 |
+
task_id += 1
|
798 |
+
if task_id >= last_id:
|
799 |
+
continue
|
800 |
+
task_queue[task_id][1] = task[1](tile)
|
801 |
+
elif task[0] == 'add_res':
|
802 |
+
tile += task[1].to(device)
|
803 |
+
task[1] = None
|
804 |
+
elif color_fix and task[0] == 'downsample':
|
805 |
+
for j in range(i, last_id + 1):
|
806 |
+
if task_queue[j][0] == 'store_res':
|
807 |
+
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
808 |
+
return True
|
809 |
+
else:
|
810 |
+
tile = task[1](tile)
|
811 |
+
try:
|
812 |
+
devices.test_for_nans(tile, "vae")
|
813 |
+
except:
|
814 |
+
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
815 |
+
return False
|
816 |
+
|
817 |
+
raise IndexError('Should not reach here')
|
818 |
+
|
819 |
+
@perfcount
|
820 |
+
@torch.no_grad()
|
821 |
+
def vae_tile_forward(self, z):
|
822 |
+
"""
|
823 |
+
Decode a latent vector z into an image in a tiled manner.
|
824 |
+
@param z: latent vector
|
825 |
+
@return: image
|
826 |
+
"""
|
827 |
+
device = next(self.net.parameters()).device
|
828 |
+
dtype = z.dtype
|
829 |
+
net = self.net
|
830 |
+
tile_size = self.tile_size
|
831 |
+
is_decoder = self.is_decoder
|
832 |
+
|
833 |
+
z = z.detach() # detach the input to avoid backprop
|
834 |
+
|
835 |
+
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
836 |
+
net.last_z_shape = z.shape
|
837 |
+
|
838 |
+
# Split the input into tiles and build a task queue for each tile
|
839 |
+
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
840 |
+
|
841 |
+
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
842 |
+
|
843 |
+
# Prepare tiles by split the input latents
|
844 |
+
tiles = []
|
845 |
+
for input_bbox in in_bboxes:
|
846 |
+
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
847 |
+
tiles.append(tile)
|
848 |
+
|
849 |
+
num_tiles = len(tiles)
|
850 |
+
num_completed = 0
|
851 |
+
|
852 |
+
# Build task queues
|
853 |
+
single_task_queue = build_task_queue(net, is_decoder)
|
854 |
+
#print(single_task_queue)
|
855 |
+
if self.fast_mode:
|
856 |
+
# Fast mode: downsample the input image to the tile size,
|
857 |
+
# then estimate the group norm parameters on the downsampled image
|
858 |
+
scale_factor = tile_size / max(height, width)
|
859 |
+
z = z.to(device)
|
860 |
+
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
861 |
+
# use nearest-exact to keep statictics as close as possible
|
862 |
+
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
863 |
+
|
864 |
+
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
865 |
+
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
866 |
+
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
867 |
+
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
868 |
+
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
869 |
+
del std_old, mean_old, std_new, mean_new
|
870 |
+
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
871 |
+
# so we need to clamp it to max z's range.
|
872 |
+
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
873 |
+
estimate_task_queue = clone_task_queue(single_task_queue)
|
874 |
+
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
875 |
+
single_task_queue = estimate_task_queue
|
876 |
+
del downsampled_z
|
877 |
+
|
878 |
+
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
879 |
+
|
880 |
+
# Dummy result
|
881 |
+
result = None
|
882 |
+
result_approx = None
|
883 |
+
#try:
|
884 |
+
# with devices.autocast():
|
885 |
+
# result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
|
886 |
+
#except: pass
|
887 |
+
# Free memory of input latent tensor
|
888 |
+
del z
|
889 |
+
|
890 |
+
# Task queue execution
|
891 |
+
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
892 |
+
|
893 |
+
# execute the task back and forth when switch tiles so that we always
|
894 |
+
# keep one tile on the GPU to reduce unnecessary data transfer
|
895 |
+
forward = True
|
896 |
+
interrupted = False
|
897 |
+
#state.interrupted = interrupted
|
898 |
+
while True:
|
899 |
+
#if state.interrupted: interrupted = True ; break
|
900 |
+
|
901 |
+
group_norm_param = GroupNormParam()
|
902 |
+
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
903 |
+
#if state.interrupted: interrupted = True ; break
|
904 |
+
|
905 |
+
tile = tiles[i].to(device)
|
906 |
+
input_bbox = in_bboxes[i]
|
907 |
+
task_queue = task_queues[i]
|
908 |
+
|
909 |
+
interrupted = False
|
910 |
+
while len(task_queue) > 0:
|
911 |
+
#if state.interrupted: interrupted = True ; break
|
912 |
+
|
913 |
+
# DEBUG: current task
|
914 |
+
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
|
915 |
+
task = task_queue.pop(0)
|
916 |
+
if task[0] == 'pre_norm':
|
917 |
+
group_norm_param.add_tile(tile, task[1])
|
918 |
+
break
|
919 |
+
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
920 |
+
task_id = 0
|
921 |
+
res = task[1](tile)
|
922 |
+
if not self.fast_mode or task[0] == 'store_res_cpu':
|
923 |
+
res = res.cpu()
|
924 |
+
while task_queue[task_id][0] != 'add_res':
|
925 |
+
task_id += 1
|
926 |
+
task_queue[task_id][1] = res
|
927 |
+
elif task[0] == 'add_res':
|
928 |
+
tile += task[1].to(device)
|
929 |
+
task[1] = None
|
930 |
+
else:
|
931 |
+
tile = task[1](tile)
|
932 |
+
#print(tiles[i].shape, tile.shape, task)
|
933 |
+
pbar.update(1)
|
934 |
+
|
935 |
+
if interrupted: break
|
936 |
+
|
937 |
+
# check for NaNs in the tile.
|
938 |
+
# If there are NaNs, we abort the process to save user's time
|
939 |
+
#devices.test_for_nans(tile, "vae")
|
940 |
+
|
941 |
+
#print(tiles[i].shape, tile.shape, i, num_tiles)
|
942 |
+
if len(task_queue) == 0:
|
943 |
+
tiles[i] = None
|
944 |
+
num_completed += 1
|
945 |
+
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
946 |
+
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
947 |
+
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
948 |
+
del tile
|
949 |
+
elif i == num_tiles - 1 and forward:
|
950 |
+
forward = False
|
951 |
+
tiles[i] = tile
|
952 |
+
elif i == 0 and not forward:
|
953 |
+
forward = True
|
954 |
+
tiles[i] = tile
|
955 |
+
else:
|
956 |
+
tiles[i] = tile.cpu()
|
957 |
+
del tile
|
958 |
+
|
959 |
+
if interrupted: break
|
960 |
+
if num_completed == num_tiles: break
|
961 |
+
|
962 |
+
# insert the group norm task to the head of each task queue
|
963 |
+
group_norm_func = group_norm_param.summary()
|
964 |
+
if group_norm_func is not None:
|
965 |
+
for i in range(num_tiles):
|
966 |
+
task_queue = task_queues[i]
|
967 |
+
task_queue.insert(0, ('apply_norm', group_norm_func))
|
968 |
+
|
969 |
+
# Done!
|
970 |
+
pbar.close()
|
971 |
+
return result.to(dtype) if result is not None else result_approx.to(device)
|
app.py
ADDED
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
raise NotImplementedError(
|
2 |
+
"This Space is meant for SD-Webui-Forge\nhttps://github.com/lllyasviel/stable-diffusion-webui-forge"
|
3 |
+
)
|
4 |
+
|
5 |
+
import os
|
6 |
+
import gradio as gr
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import einops
|
11 |
+
import copy
|
12 |
+
import math
|
13 |
+
import time
|
14 |
+
import random
|
15 |
+
import spaces
|
16 |
+
import re
|
17 |
+
import uuid
|
18 |
+
|
19 |
+
from gradio_imageslider import ImageSlider
|
20 |
+
from PIL import Image
|
21 |
+
from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype, create_SUPIR_model, load_QF_ckpt
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
from pillow_heif import register_heif_opener
|
24 |
+
|
25 |
+
register_heif_opener()
|
26 |
+
|
27 |
+
max_64_bit_int = np.iinfo(np.int32).max
|
28 |
+
|
29 |
+
hf_hub_download(repo_id="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", filename="open_clip_pytorch_model.bin", local_dir="laion_CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
30 |
+
hf_hub_download(repo_id="camenduru/SUPIR", filename="sd_xl_base_1.0_0.9vae.safetensors", local_dir="yushan777_SUPIR")
|
31 |
+
hf_hub_download(repo_id="camenduru/SUPIR", filename="SUPIR-v0F.ckpt", local_dir="yushan777_SUPIR")
|
32 |
+
hf_hub_download(repo_id="camenduru/SUPIR", filename="SUPIR-v0Q.ckpt", local_dir="yushan777_SUPIR")
|
33 |
+
hf_hub_download(repo_id="RunDiffusion/Juggernaut-XL-Lightning", filename="Juggernaut_RunDiffusionPhoto2_Lightning_4Steps.safetensors", local_dir="RunDiffusion_Juggernaut-XL-Lightning")
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument("--opt", type=str, default='options/SUPIR_v0.yaml')
|
37 |
+
parser.add_argument("--ip", type=str, default='127.0.0.1')
|
38 |
+
parser.add_argument("--port", type=int, default='6688')
|
39 |
+
parser.add_argument("--no_llava", action='store_true', default=True)#False
|
40 |
+
parser.add_argument("--use_image_slider", action='store_true', default=False)#False
|
41 |
+
parser.add_argument("--log_history", action='store_true', default=False)
|
42 |
+
parser.add_argument("--loading_half_params", action='store_true', default=False)#False
|
43 |
+
parser.add_argument("--use_tile_vae", action='store_true', default=True)#False
|
44 |
+
parser.add_argument("--encoder_tile_size", type=int, default=512)
|
45 |
+
parser.add_argument("--decoder_tile_size", type=int, default=64)
|
46 |
+
parser.add_argument("--load_8bit_llava", action='store_true', default=False)
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
if torch.cuda.device_count() > 0:
|
50 |
+
SUPIR_device = 'cuda:0'
|
51 |
+
|
52 |
+
# Load SUPIR
|
53 |
+
model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
|
54 |
+
if args.loading_half_params:
|
55 |
+
model = model.half()
|
56 |
+
if args.use_tile_vae:
|
57 |
+
model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
|
58 |
+
model = model.to(SUPIR_device)
|
59 |
+
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
|
60 |
+
model.current_model = 'v0-Q'
|
61 |
+
ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
|
62 |
+
|
63 |
+
def check_upload(input_image):
|
64 |
+
if input_image is None:
|
65 |
+
raise gr.Error("Please provide an image to restore.")
|
66 |
+
return gr.update(visible = True)
|
67 |
+
|
68 |
+
def update_seed(is_randomize_seed, seed):
|
69 |
+
if is_randomize_seed:
|
70 |
+
return random.randint(0, max_64_bit_int)
|
71 |
+
return seed
|
72 |
+
|
73 |
+
def reset():
|
74 |
+
return [
|
75 |
+
None,
|
76 |
+
0,
|
77 |
+
None,
|
78 |
+
None,
|
79 |
+
"Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
|
80 |
+
"painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
|
81 |
+
1,
|
82 |
+
1024,
|
83 |
+
1,
|
84 |
+
2,
|
85 |
+
50,
|
86 |
+
-1.0,
|
87 |
+
1.,
|
88 |
+
default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0,
|
89 |
+
True,
|
90 |
+
random.randint(0, max_64_bit_int),
|
91 |
+
5,
|
92 |
+
1.003,
|
93 |
+
"Wavelet",
|
94 |
+
"fp32",
|
95 |
+
"fp32",
|
96 |
+
1.0,
|
97 |
+
True,
|
98 |
+
False,
|
99 |
+
default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0,
|
100 |
+
0.,
|
101 |
+
"v0-Q",
|
102 |
+
"input",
|
103 |
+
6
|
104 |
+
]
|
105 |
+
|
106 |
+
def check(input_image):
|
107 |
+
if input_image is None:
|
108 |
+
raise gr.Error("Please provide an image to restore.")
|
109 |
+
|
110 |
+
@spaces.GPU(duration=420)
|
111 |
+
def stage1_process(
|
112 |
+
input_image,
|
113 |
+
gamma_correction,
|
114 |
+
diff_dtype,
|
115 |
+
ae_dtype
|
116 |
+
):
|
117 |
+
print('stage1_process ==>>')
|
118 |
+
if torch.cuda.device_count() == 0:
|
119 |
+
gr.Warning('Set this space to GPU config to make it work.')
|
120 |
+
return None, None
|
121 |
+
torch.cuda.set_device(SUPIR_device)
|
122 |
+
LQ = HWC3(np.array(Image.open(input_image)))
|
123 |
+
LQ = fix_resize(LQ, 512)
|
124 |
+
# stage1
|
125 |
+
LQ = np.array(LQ) / 255 * 2 - 1
|
126 |
+
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
|
127 |
+
|
128 |
+
model.ae_dtype = convert_dtype(ae_dtype)
|
129 |
+
model.model.dtype = convert_dtype(diff_dtype)
|
130 |
+
|
131 |
+
LQ = model.batchify_denoise(LQ, is_stage1=True)
|
132 |
+
LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
|
133 |
+
# gamma correction
|
134 |
+
LQ = LQ / 255.0
|
135 |
+
LQ = np.power(LQ, gamma_correction)
|
136 |
+
LQ *= 255.0
|
137 |
+
LQ = LQ.round().clip(0, 255).astype(np.uint8)
|
138 |
+
print('<<== stage1_process')
|
139 |
+
return LQ, gr.update(visible = True)
|
140 |
+
|
141 |
+
def stage2_process(*args, **kwargs):
|
142 |
+
try:
|
143 |
+
return restore_in_Xmin(*args, **kwargs)
|
144 |
+
except Exception as e:
|
145 |
+
# NO_GPU_MESSAGE_INQUEUE
|
146 |
+
print("gradio.exceptions.Error 'No GPU is currently available for you after 60s'")
|
147 |
+
print('str(type(e)): ' + str(type(e))) # <class 'gradio.exceptions.Error'>
|
148 |
+
print('str(e): ' + str(e)) # You have exceeded your GPU quota...
|
149 |
+
try:
|
150 |
+
print('e.message: ' + e.message) # No GPU is currently available for you after 60s
|
151 |
+
except Exception as e2:
|
152 |
+
print('Failure')
|
153 |
+
if str(e).startswith("No GPU is currently available for you after 60s"):
|
154 |
+
print('Exception identified!!!')
|
155 |
+
#if str(type(e)) == "<class 'gradio.exceptions.Error'>":
|
156 |
+
#print('Exception of name ' + type(e).__name__)
|
157 |
+
raise e
|
158 |
+
|
159 |
+
def restore_in_Xmin(
|
160 |
+
noisy_image,
|
161 |
+
rotation,
|
162 |
+
denoise_image,
|
163 |
+
prompt,
|
164 |
+
a_prompt,
|
165 |
+
n_prompt,
|
166 |
+
num_samples,
|
167 |
+
min_size,
|
168 |
+
downscale,
|
169 |
+
upscale,
|
170 |
+
edm_steps,
|
171 |
+
s_stage1,
|
172 |
+
s_stage2,
|
173 |
+
s_cfg,
|
174 |
+
randomize_seed,
|
175 |
+
seed,
|
176 |
+
s_churn,
|
177 |
+
s_noise,
|
178 |
+
color_fix_type,
|
179 |
+
diff_dtype,
|
180 |
+
ae_dtype,
|
181 |
+
gamma_correction,
|
182 |
+
linear_CFG,
|
183 |
+
linear_s_stage2,
|
184 |
+
spt_linear_CFG,
|
185 |
+
spt_linear_s_stage2,
|
186 |
+
model_select,
|
187 |
+
output_format,
|
188 |
+
allocation
|
189 |
+
):
|
190 |
+
print("noisy_image:\n" + str(noisy_image))
|
191 |
+
print("denoise_image:\n" + str(denoise_image))
|
192 |
+
print("rotation: " + str(rotation))
|
193 |
+
print("prompt: " + str(prompt))
|
194 |
+
print("a_prompt: " + str(a_prompt))
|
195 |
+
print("n_prompt: " + str(n_prompt))
|
196 |
+
print("num_samples: " + str(num_samples))
|
197 |
+
print("min_size: " + str(min_size))
|
198 |
+
print("downscale: " + str(downscale))
|
199 |
+
print("upscale: " + str(upscale))
|
200 |
+
print("edm_steps: " + str(edm_steps))
|
201 |
+
print("s_stage1: " + str(s_stage1))
|
202 |
+
print("s_stage2: " + str(s_stage2))
|
203 |
+
print("s_cfg: " + str(s_cfg))
|
204 |
+
print("randomize_seed: " + str(randomize_seed))
|
205 |
+
print("seed: " + str(seed))
|
206 |
+
print("s_churn: " + str(s_churn))
|
207 |
+
print("s_noise: " + str(s_noise))
|
208 |
+
print("color_fix_type: " + str(color_fix_type))
|
209 |
+
print("diff_dtype: " + str(diff_dtype))
|
210 |
+
print("ae_dtype: " + str(ae_dtype))
|
211 |
+
print("gamma_correction: " + str(gamma_correction))
|
212 |
+
print("linear_CFG: " + str(linear_CFG))
|
213 |
+
print("linear_s_stage2: " + str(linear_s_stage2))
|
214 |
+
print("spt_linear_CFG: " + str(spt_linear_CFG))
|
215 |
+
print("spt_linear_s_stage2: " + str(spt_linear_s_stage2))
|
216 |
+
print("model_select: " + str(model_select))
|
217 |
+
print("GPU time allocation: " + str(allocation) + " min")
|
218 |
+
print("output_format: " + str(output_format))
|
219 |
+
|
220 |
+
input_format = re.sub(r"^.*\.([^\.]+)$", r"\1", noisy_image)
|
221 |
+
|
222 |
+
if input_format not in ['png', 'webp', 'jpg', 'jpeg', 'gif', 'bmp', 'heic']:
|
223 |
+
gr.Warning('Invalid image format. Please first convert into *.png, *.webp, *.jpg, *.jpeg, *.gif, *.bmp or *.heic.')
|
224 |
+
return None, None, None, None
|
225 |
+
|
226 |
+
if output_format == "input":
|
227 |
+
if noisy_image is None:
|
228 |
+
output_format = "png"
|
229 |
+
else:
|
230 |
+
output_format = input_format
|
231 |
+
print("final output_format: " + str(output_format))
|
232 |
+
|
233 |
+
if prompt is None:
|
234 |
+
prompt = ""
|
235 |
+
|
236 |
+
if a_prompt is None:
|
237 |
+
a_prompt = ""
|
238 |
+
|
239 |
+
if n_prompt is None:
|
240 |
+
n_prompt = ""
|
241 |
+
|
242 |
+
if prompt != "" and a_prompt != "":
|
243 |
+
a_prompt = prompt + ", " + a_prompt
|
244 |
+
else:
|
245 |
+
a_prompt = prompt + a_prompt
|
246 |
+
print("Final prompt: " + str(a_prompt))
|
247 |
+
|
248 |
+
denoise_image = np.array(Image.open(noisy_image if denoise_image is None else denoise_image))
|
249 |
+
|
250 |
+
if rotation == 90:
|
251 |
+
denoise_image = np.array(list(zip(*denoise_image[::-1])))
|
252 |
+
elif rotation == 180:
|
253 |
+
denoise_image = np.array(list(zip(*denoise_image[::-1])))
|
254 |
+
denoise_image = np.array(list(zip(*denoise_image[::-1])))
|
255 |
+
elif rotation == -90:
|
256 |
+
denoise_image = np.array(list(zip(*denoise_image))[::-1])
|
257 |
+
|
258 |
+
if 1 < downscale:
|
259 |
+
input_height, input_width, input_channel = denoise_image.shape
|
260 |
+
denoise_image = np.array(Image.fromarray(denoise_image).resize((input_width // downscale, input_height // downscale), Image.LANCZOS))
|
261 |
+
|
262 |
+
denoise_image = HWC3(denoise_image)
|
263 |
+
|
264 |
+
if torch.cuda.device_count() == 0:
|
265 |
+
gr.Warning('Set this space to GPU config to make it work.')
|
266 |
+
return [noisy_image, denoise_image], gr.update(label="Downloadable results in *." + output_format + " format", format = output_format, value = [denoise_image]), None, gr.update(visible=True)
|
267 |
+
|
268 |
+
if model_select != model.current_model:
|
269 |
+
print('load ' + model_select)
|
270 |
+
if model_select == 'v0-Q':
|
271 |
+
model.load_state_dict(ckpt_Q, strict=False)
|
272 |
+
elif model_select == 'v0-F':
|
273 |
+
model.load_state_dict(ckpt_F, strict=False)
|
274 |
+
model.current_model = model_select
|
275 |
+
|
276 |
+
model.ae_dtype = convert_dtype(ae_dtype)
|
277 |
+
model.model.dtype = convert_dtype(diff_dtype)
|
278 |
+
|
279 |
+
# Allocation
|
280 |
+
if allocation == 1:
|
281 |
+
return restore_in_1min(
|
282 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
283 |
+
)
|
284 |
+
if allocation == 2:
|
285 |
+
return restore_in_2min(
|
286 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
287 |
+
)
|
288 |
+
if allocation == 3:
|
289 |
+
return restore_in_3min(
|
290 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
291 |
+
)
|
292 |
+
if allocation == 4:
|
293 |
+
return restore_in_4min(
|
294 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
295 |
+
)
|
296 |
+
if allocation == 5:
|
297 |
+
return restore_in_5min(
|
298 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
299 |
+
)
|
300 |
+
if allocation == 7:
|
301 |
+
return restore_in_7min(
|
302 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
303 |
+
)
|
304 |
+
if allocation == 8:
|
305 |
+
return restore_in_8min(
|
306 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
307 |
+
)
|
308 |
+
if allocation == 9:
|
309 |
+
return restore_in_9min(
|
310 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
311 |
+
)
|
312 |
+
if allocation == 10:
|
313 |
+
return restore_in_10min(
|
314 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
return restore_in_6min(
|
318 |
+
noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
|
319 |
+
)
|
320 |
+
|
321 |
+
@spaces.GPU(duration=59)
|
322 |
+
def restore_in_1min(*args, **kwargs):
|
323 |
+
return restore_on_gpu(*args, **kwargs)
|
324 |
+
|
325 |
+
@spaces.GPU(duration=119)
|
326 |
+
def restore_in_2min(*args, **kwargs):
|
327 |
+
return restore_on_gpu(*args, **kwargs)
|
328 |
+
|
329 |
+
@spaces.GPU(duration=179)
|
330 |
+
def restore_in_3min(*args, **kwargs):
|
331 |
+
return restore_on_gpu(*args, **kwargs)
|
332 |
+
|
333 |
+
@spaces.GPU(duration=239)
|
334 |
+
def restore_in_4min(*args, **kwargs):
|
335 |
+
return restore_on_gpu(*args, **kwargs)
|
336 |
+
|
337 |
+
@spaces.GPU(duration=299)
|
338 |
+
def restore_in_5min(*args, **kwargs):
|
339 |
+
return restore_on_gpu(*args, **kwargs)
|
340 |
+
|
341 |
+
@spaces.GPU(duration=359)
|
342 |
+
def restore_in_6min(*args, **kwargs):
|
343 |
+
return restore_on_gpu(*args, **kwargs)
|
344 |
+
|
345 |
+
@spaces.GPU(duration=419)
|
346 |
+
def restore_in_7min(*args, **kwargs):
|
347 |
+
return restore_on_gpu(*args, **kwargs)
|
348 |
+
|
349 |
+
@spaces.GPU(duration=479)
|
350 |
+
def restore_in_8min(*args, **kwargs):
|
351 |
+
return restore_on_gpu(*args, **kwargs)
|
352 |
+
|
353 |
+
@spaces.GPU(duration=539)
|
354 |
+
def restore_in_9min(*args, **kwargs):
|
355 |
+
return restore_on_gpu(*args, **kwargs)
|
356 |
+
|
357 |
+
@spaces.GPU(duration=599)
|
358 |
+
def restore_in_10min(*args, **kwargs):
|
359 |
+
return restore_on_gpu(*args, **kwargs)
|
360 |
+
|
361 |
+
def restore_on_gpu(
|
362 |
+
noisy_image,
|
363 |
+
input_image,
|
364 |
+
prompt,
|
365 |
+
a_prompt,
|
366 |
+
n_prompt,
|
367 |
+
num_samples,
|
368 |
+
min_size,
|
369 |
+
downscale,
|
370 |
+
upscale,
|
371 |
+
edm_steps,
|
372 |
+
s_stage1,
|
373 |
+
s_stage2,
|
374 |
+
s_cfg,
|
375 |
+
randomize_seed,
|
376 |
+
seed,
|
377 |
+
s_churn,
|
378 |
+
s_noise,
|
379 |
+
color_fix_type,
|
380 |
+
diff_dtype,
|
381 |
+
ae_dtype,
|
382 |
+
gamma_correction,
|
383 |
+
linear_CFG,
|
384 |
+
linear_s_stage2,
|
385 |
+
spt_linear_CFG,
|
386 |
+
spt_linear_s_stage2,
|
387 |
+
model_select,
|
388 |
+
output_format,
|
389 |
+
allocation
|
390 |
+
):
|
391 |
+
start = time.time()
|
392 |
+
print('restore ==>>')
|
393 |
+
|
394 |
+
torch.cuda.set_device(SUPIR_device)
|
395 |
+
|
396 |
+
with torch.no_grad():
|
397 |
+
input_image = upscale_image(input_image, upscale, unit_resolution=32, min_size=min_size)
|
398 |
+
LQ = np.array(input_image) / 255.0
|
399 |
+
LQ = np.power(LQ, gamma_correction)
|
400 |
+
LQ *= 255.0
|
401 |
+
LQ = LQ.round().clip(0, 255).astype(np.uint8)
|
402 |
+
LQ = LQ / 255 * 2 - 1
|
403 |
+
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
|
404 |
+
captions = ['']
|
405 |
+
|
406 |
+
samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
|
407 |
+
s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
|
408 |
+
num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
|
409 |
+
use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
|
410 |
+
cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
|
411 |
+
|
412 |
+
x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
|
413 |
+
0, 255).astype(np.uint8)
|
414 |
+
results = [x_samples[i] for i in range(num_samples)]
|
415 |
+
torch.cuda.empty_cache()
|
416 |
+
|
417 |
+
# All the results have the same size
|
418 |
+
input_height, input_width, input_channel = np.array(input_image).shape
|
419 |
+
result_height, result_width, result_channel = np.array(results[0]).shape
|
420 |
+
|
421 |
+
print('<<== restore')
|
422 |
+
end = time.time()
|
423 |
+
secondes = int(end - start)
|
424 |
+
minutes = math.floor(secondes / 60)
|
425 |
+
secondes = secondes - (minutes * 60)
|
426 |
+
hours = math.floor(minutes / 60)
|
427 |
+
minutes = minutes - (hours * 60)
|
428 |
+
information = ("Start the process again if you want a different result. " if randomize_seed else "") + \
|
429 |
+
"If you don't get the image you wanted, add more details in the « Image description ». " + \
|
430 |
+
"Wait " + str(allocation) + " min before a new run to avoid quota penalty or use another computer. " + \
|
431 |
+
"The image" + (" has" if len(results) == 1 else "s have") + " been generated in " + \
|
432 |
+
((str(hours) + " h, ") if hours != 0 else "") + \
|
433 |
+
((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
|
434 |
+
str(secondes) + " sec. " + \
|
435 |
+
"The new image resolution is " + str(result_width) + \
|
436 |
+
" pixels large and " + str(result_height) + \
|
437 |
+
" pixels high, so a resolution of " + f'{result_width * result_height:,}' + " pixels."
|
438 |
+
print(information)
|
439 |
+
try:
|
440 |
+
print("Initial resolution: " + f'{input_width * input_height:,}')
|
441 |
+
print("Final resolution: " + f'{result_width * result_height:,}')
|
442 |
+
print("edm_steps: " + str(edm_steps))
|
443 |
+
print("num_samples: " + str(num_samples))
|
444 |
+
print("downscale: " + str(downscale))
|
445 |
+
print("Estimated minutes: " + f'{(((result_width * result_height**(1/1.75)) * input_width * input_height * (edm_steps**(1/2)) * (num_samples**(1/2.5)))**(1/2.5)) / 25000:,}')
|
446 |
+
except Exception as e:
|
447 |
+
print('Exception of Estimation')
|
448 |
+
|
449 |
+
# Only one image can be shown in the slider
|
450 |
+
return [noisy_image] + [results[0]], gr.update(label="Downloadable results in *." + output_format + " format", format = output_format, value = results), gr.update(value = information, visible = True), gr.update(visible=True)
|
451 |
+
|
452 |
+
def load_and_reset(param_setting):
|
453 |
+
print('load_and_reset ==>>')
|
454 |
+
if torch.cuda.device_count() == 0:
|
455 |
+
gr.Warning('Set this space to GPU config to make it work.')
|
456 |
+
return None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
457 |
+
edm_steps = default_setting.edm_steps
|
458 |
+
s_stage2 = 1.0
|
459 |
+
s_stage1 = -1.0
|
460 |
+
s_churn = 5
|
461 |
+
s_noise = 1.003
|
462 |
+
a_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - ' \
|
463 |
+
'realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore ' \
|
464 |
+
'detailing, hyper sharpness, perfect without deformations.'
|
465 |
+
n_prompt = 'painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, ' \
|
466 |
+
'3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, ' \
|
467 |
+
'signature, jpeg artifacts, deformed, lowres, over-smooth'
|
468 |
+
color_fix_type = 'Wavelet'
|
469 |
+
spt_linear_s_stage2 = 0.0
|
470 |
+
linear_s_stage2 = False
|
471 |
+
linear_CFG = True
|
472 |
+
if param_setting == "Quality":
|
473 |
+
s_cfg = default_setting.s_cfg_Quality
|
474 |
+
spt_linear_CFG = default_setting.spt_linear_CFG_Quality
|
475 |
+
model_select = "v0-Q"
|
476 |
+
elif param_setting == "Fidelity":
|
477 |
+
s_cfg = default_setting.s_cfg_Fidelity
|
478 |
+
spt_linear_CFG = default_setting.spt_linear_CFG_Fidelity
|
479 |
+
model_select = "v0-F"
|
480 |
+
else:
|
481 |
+
raise NotImplementedError
|
482 |
+
gr.Info('The parameters are reset.')
|
483 |
+
print('<<== load_and_reset')
|
484 |
+
return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
|
485 |
+
linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select
|
486 |
+
|
487 |
+
def log_information(result_gallery):
|
488 |
+
print('log_information')
|
489 |
+
if result_gallery is not None:
|
490 |
+
for i, result in enumerate(result_gallery):
|
491 |
+
print(result[0])
|
492 |
+
|
493 |
+
def on_select_result(result_slider, result_gallery, evt: gr.SelectData):
|
494 |
+
print('on_select_result')
|
495 |
+
if result_gallery is not None:
|
496 |
+
for i, result in enumerate(result_gallery):
|
497 |
+
print(result[0])
|
498 |
+
return [result_slider[0], result_gallery[evt.index][0]]
|
499 |
+
|
500 |
+
title_html = """
|
501 |
+
<h1><center>SUPIR</center></h1>
|
502 |
+
<big><center>Upscale your images up to x10 freely, without account, without watermark and download it</center></big>
|
503 |
+
<center><big><big>🤸<big><big><big><big><big><big>🤸</big></big></big></big></big></big></big></big></center>
|
504 |
+
|
505 |
+
<p>This is an online demo of SUPIR, a practicing model scaling for photo-realistic image restoration.
|
506 |
+
The content added by SUPIR is <b><u>imagination, not real-world information</u></b>.
|
507 |
+
SUPIR is for beauty and illustration only.
|
508 |
+
Most of the processes last few minutes.
|
509 |
+
If you want to upscale AI-generated images, be noticed that <i>PixArt Sigma</i> space can directly generate 5984x5984 images.
|
510 |
+
Due to Gradio issues, the generated image is slightly less satured than the original.
|
511 |
+
Please leave a <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">message in discussion</a> if you encounter issues.
|
512 |
+
You can also use <a href="https://huggingface.co/spaces/gokaygokay/AuraSR">AuraSR</a> to upscale x4.
|
513 |
+
|
514 |
+
<p><center><a href="https://arxiv.org/abs/2401.13627">Paper</a>   <a href="http://supir.xpixel.group/">Project Page</a>   <a href="https://huggingface.co/blog/MonsterMMORPG/supir-sota-image-upscale-better-than-magnific-ai">Local Install Guide</a></center></p>
|
515 |
+
<p><center><a style="display:inline-block" href='https://github.com/Fanghua-Yu/SUPIR'><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/Fanghua-Yu/SUPIR?style=social"></a></center></p>
|
516 |
+
"""
|
517 |
+
|
518 |
+
|
519 |
+
claim_md = """
|
520 |
+
## **Piracy**
|
521 |
+
The images are not stored but the logs are saved during a month.
|
522 |
+
## **How to get SUPIR**
|
523 |
+
You can get SUPIR on HuggingFace by [duplicating this space](https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true) and set GPU.
|
524 |
+
You can also install SUPIR on your computer following [this tutorial](https://huggingface.co/blog/MonsterMMORPG/supir-sota-image-upscale-better-than-magnific-ai).
|
525 |
+
You can install _Pinokio_ on your computer and then install _SUPIR_ into it. It should be quite easy if you have an Nvidia GPU.
|
526 |
+
## **Terms of use**
|
527 |
+
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
528 |
+
## **License**
|
529 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
|
530 |
+
"""
|
531 |
+
|
532 |
+
# Gradio interface
|
533 |
+
with gr.Blocks() as interface:
|
534 |
+
if torch.cuda.device_count() == 0:
|
535 |
+
with gr.Row():
|
536 |
+
gr.HTML("""
|
537 |
+
<p style="background-color: red;"><big><big><big><b>⚠️To use SUPIR, <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true">duplicate this space</a> and set a GPU with 30 GB VRAM.</b>
|
538 |
+
|
539 |
+
You can't use SUPIR directly here because this space runs on a CPU, which is not enough for SUPIR. Please provide <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">feedback</a> if you have issues.
|
540 |
+
</big></big></big></p>
|
541 |
+
""")
|
542 |
+
gr.HTML(title_html)
|
543 |
+
|
544 |
+
input_image = gr.Image(label="Input (*.png, *.webp, *.jpeg, *.jpg, *.gif, *.bmp, *.heic)", show_label=True, type="filepath", height=600, elem_id="image-input")
|
545 |
+
rotation = gr.Radio([["No rotation", 0], ["⤵ Rotate +90°", 90], ["↩ Return 180°", 180], ["⤴ Rotate -90°", -90]], label="Orientation correction", info="Will apply the following rotation before restoring the image; the AI needs a good orientation to understand the content", value=0, interactive=True, visible=False)
|
546 |
+
with gr.Group():
|
547 |
+
prompt = gr.Textbox(label="Image description", info="Help the AI understand what the image represents; describe as much as possible, especially the details we can't see on the original image; you can write in any language", value="", placeholder="A 33 years old man, walking, in the street, Santiago, morning, Summer, photorealistic", lines=3)
|
548 |
+
prompt_hint = gr.HTML("You can use a <a href='"'https://huggingface.co/spaces/MaziyarPanahi/llava-llama-3-8b'"'>LlaVa space</a> to auto-generate the description of your image.")
|
549 |
+
upscale = gr.Radio([["x1", 1], ["x2", 2], ["x3", 3], ["x4", 4], ["x5", 5], ["x6", 6], ["x7", 7], ["x8", 8], ["x9", 9], ["x10", 10]], label="Upscale factor", info="Resolution x1 to x10", value=2, interactive=True)
|
550 |
+
output_format = gr.Radio([["As input", "input"], ["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="input", interactive=True)
|
551 |
+
allocation = gr.Radio([["1 min", 1], ["2 min", 2], ["3 min", 3], ["4 min", 4], ["5 min", 5], ["6 min", 6], ["7 min (discouraged)", 7], ["8 min (discouraged)", 8], ["9 min (discouraged)", 9], ["10 min (discouraged)", 10]], label="GPU allocation time", info="lower=May abort run, higher=Quota penalty for next runs", value=5, interactive=True)
|
552 |
+
|
553 |
+
with gr.Accordion("Pre-denoising (optional)", open=False):
|
554 |
+
gamma_correction = gr.Slider(label="Gamma Correction", info = "lower=lighter, higher=darker", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
|
555 |
+
denoise_button = gr.Button(value="Pre-denoise")
|
556 |
+
denoise_image = gr.Image(label="Denoised image", show_label=True, type="filepath", sources=[], interactive = False, height=600, elem_id="image-s1")
|
557 |
+
denoise_information = gr.HTML(value="If present, the denoised image will be used for the restoration instead of the input image.", visible=False)
|
558 |
+
|
559 |
+
with gr.Accordion("Advanced options", open=False):
|
560 |
+
a_prompt = gr.Textbox(label="Additional image description",
|
561 |
+
info="Completes the main image description",
|
562 |
+
value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
|
563 |
+
'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
|
564 |
+
'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
|
565 |
+
'hyper sharpness, perfect without deformations.',
|
566 |
+
lines=3)
|
567 |
+
n_prompt = gr.Textbox(label="Negative image description",
|
568 |
+
info="Disambiguate by listing what the image does NOT represent",
|
569 |
+
value='painting, oil painting, illustration, drawing, art, sketch, anime, '
|
570 |
+
'cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, '
|
571 |
+
'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
|
572 |
+
'deformed, lowres, over-smooth',
|
573 |
+
lines=3)
|
574 |
+
edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
|
575 |
+
num_samples = gr.Slider(label="Num Samples", info="Number of generated results", minimum=1, maximum=4 if not args.use_image_slider else 1
|
576 |
+
, value=1, step=1)
|
577 |
+
min_size = gr.Slider(label="Minimum size", info="Minimum height, minimum width of the result", minimum=32, maximum=4096, value=1024, step=32)
|
578 |
+
downscale = gr.Radio([["/1", 1], ["/2", 2], ["/3", 3], ["/4", 4], ["/5", 5], ["/6", 6], ["/7", 7], ["/8", 8], ["/9", 9], ["/10", 10]], label="Pre-downscale factor", info="Reducing blurred image reduce the process time", value=1, interactive=True)
|
579 |
+
with gr.Row():
|
580 |
+
with gr.Column():
|
581 |
+
model_select = gr.Radio([["💃 Quality (v0-Q)", "v0-Q"], ["🎯 Fidelity (v0-F)", "v0-F"]], label="Model Selection", info="Pretrained model", value="v0-Q",
|
582 |
+
interactive=True)
|
583 |
+
with gr.Column():
|
584 |
+
color_fix_type = gr.Radio([["None", "None"], ["AdaIn (improve as a photo)", "AdaIn"], ["Wavelet (for JPEG artifacts)", "Wavelet"]], label="Color-Fix Type", info="AdaIn=Improve following a style, Wavelet=For JPEG artifacts", value="AdaIn",
|
585 |
+
interactive=True)
|
586 |
+
s_cfg = gr.Slider(label="Text Guidance Scale", info="lower=follow the image, higher=follow the prompt", minimum=1.0, maximum=15.0,
|
587 |
+
value=default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.1)
|
588 |
+
s_stage2 = gr.Slider(label="Restoring Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
|
589 |
+
s_stage1 = gr.Slider(label="Pre-denoising Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
|
590 |
+
s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
|
591 |
+
s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
|
592 |
+
with gr.Row():
|
593 |
+
with gr.Column():
|
594 |
+
linear_CFG = gr.Checkbox(label="Linear CFG", value=True)
|
595 |
+
spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
|
596 |
+
maximum=9.0, value=default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.5)
|
597 |
+
with gr.Column():
|
598 |
+
linear_s_stage2 = gr.Checkbox(label="Linear Restoring Guidance", value=False)
|
599 |
+
spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
|
600 |
+
maximum=1., value=0., step=0.05)
|
601 |
+
with gr.Column():
|
602 |
+
diff_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["fp16 (medium)", "fp16"], ["bf16 (speed)", "bf16"]], label="Diffusion Data Type", value="fp32",
|
603 |
+
interactive=True)
|
604 |
+
with gr.Column():
|
605 |
+
ae_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["bf16 (speed)", "bf16"]], label="Auto-Encoder Data Type", value="fp32",
|
606 |
+
interactive=True)
|
607 |
+
randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different")
|
608 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=max_64_bit_int, step=1, randomize=True)
|
609 |
+
with gr.Group():
|
610 |
+
param_setting = gr.Radio(["Quality", "Fidelity"], interactive=True, label="Presetting", value = "Quality")
|
611 |
+
restart_button = gr.Button(value="Apply presetting")
|
612 |
+
|
613 |
+
with gr.Column():
|
614 |
+
diffusion_button = gr.Button(value="🚀 Upscale/Restore", variant = "primary", elem_id = "process_button")
|
615 |
+
reset_btn = gr.Button(value="🧹 Reinit page", variant="stop", elem_id="reset_button", visible = False)
|
616 |
+
|
617 |
+
restore_information = gr.HTML(value = "Restart the process to get another result.", visible = False)
|
618 |
+
result_slider = ImageSlider(label = 'Comparator', show_label = False, interactive = False, elem_id = "slider1", show_download_button = False)
|
619 |
+
result_gallery = gr.Gallery(label = 'Downloadable results', show_label = True, interactive = False, elem_id = "gallery1")
|
620 |
+
|
621 |
+
gr.Examples(
|
622 |
+
examples = [
|
623 |
+
[
|
624 |
+
"./Examples/Example1.png",
|
625 |
+
0,
|
626 |
+
None,
|
627 |
+
"Group of people, walking, happy, in the street, photorealistic, 8k, extremely detailled",
|
628 |
+
"Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
|
629 |
+
"painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
|
630 |
+
2,
|
631 |
+
1024,
|
632 |
+
1,
|
633 |
+
8,
|
634 |
+
200,
|
635 |
+
-1,
|
636 |
+
1,
|
637 |
+
7.5,
|
638 |
+
False,
|
639 |
+
42,
|
640 |
+
5,
|
641 |
+
1.003,
|
642 |
+
"AdaIn",
|
643 |
+
"fp16",
|
644 |
+
"bf16",
|
645 |
+
1.0,
|
646 |
+
True,
|
647 |
+
4,
|
648 |
+
False,
|
649 |
+
0.,
|
650 |
+
"v0-Q",
|
651 |
+
"input",
|
652 |
+
5
|
653 |
+
],
|
654 |
+
[
|
655 |
+
"./Examples/Example2.jpeg",
|
656 |
+
0,
|
657 |
+
None,
|
658 |
+
"La cabeza de un gato atigrado, en una casa, fotorrealista, 8k, extremadamente detallada",
|
659 |
+
"Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
|
660 |
+
"painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
|
661 |
+
1,
|
662 |
+
1024,
|
663 |
+
1,
|
664 |
+
1,
|
665 |
+
200,
|
666 |
+
-1,
|
667 |
+
1,
|
668 |
+
7.5,
|
669 |
+
False,
|
670 |
+
42,
|
671 |
+
5,
|
672 |
+
1.003,
|
673 |
+
"Wavelet",
|
674 |
+
"fp16",
|
675 |
+
"bf16",
|
676 |
+
1.0,
|
677 |
+
True,
|
678 |
+
4,
|
679 |
+
False,
|
680 |
+
0.,
|
681 |
+
"v0-Q",
|
682 |
+
"input",
|
683 |
+
4
|
684 |
+
],
|
685 |
+
[
|
686 |
+
"./Examples/Example3.webp",
|
687 |
+
0,
|
688 |
+
None,
|
689 |
+
"A red apple",
|
690 |
+
"Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
|
691 |
+
"painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
|
692 |
+
1,
|
693 |
+
1024,
|
694 |
+
1,
|
695 |
+
1,
|
696 |
+
200,
|
697 |
+
-1,
|
698 |
+
1,
|
699 |
+
7.5,
|
700 |
+
False,
|
701 |
+
42,
|
702 |
+
5,
|
703 |
+
1.003,
|
704 |
+
"Wavelet",
|
705 |
+
"fp16",
|
706 |
+
"bf16",
|
707 |
+
1.0,
|
708 |
+
True,
|
709 |
+
4,
|
710 |
+
False,
|
711 |
+
0.,
|
712 |
+
"v0-Q",
|
713 |
+
"input",
|
714 |
+
4
|
715 |
+
],
|
716 |
+
[
|
717 |
+
"./Examples/Example3.webp",
|
718 |
+
0,
|
719 |
+
None,
|
720 |
+
"A red marble",
|
721 |
+
"Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
|
722 |
+
"painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
|
723 |
+
1,
|
724 |
+
1024,
|
725 |
+
1,
|
726 |
+
1,
|
727 |
+
200,
|
728 |
+
-1,
|
729 |
+
1,
|
730 |
+
7.5,
|
731 |
+
False,
|
732 |
+
42,
|
733 |
+
5,
|
734 |
+
1.003,
|
735 |
+
"Wavelet",
|
736 |
+
"fp16",
|
737 |
+
"bf16",
|
738 |
+
1.0,
|
739 |
+
True,
|
740 |
+
4,
|
741 |
+
False,
|
742 |
+
0.,
|
743 |
+
"v0-Q",
|
744 |
+
"input",
|
745 |
+
4
|
746 |
+
],
|
747 |
+
],
|
748 |
+
run_on_click = True,
|
749 |
+
fn = stage2_process,
|
750 |
+
inputs = [
|
751 |
+
input_image,
|
752 |
+
rotation,
|
753 |
+
denoise_image,
|
754 |
+
prompt,
|
755 |
+
a_prompt,
|
756 |
+
n_prompt,
|
757 |
+
num_samples,
|
758 |
+
min_size,
|
759 |
+
downscale,
|
760 |
+
upscale,
|
761 |
+
edm_steps,
|
762 |
+
s_stage1,
|
763 |
+
s_stage2,
|
764 |
+
s_cfg,
|
765 |
+
randomize_seed,
|
766 |
+
seed,
|
767 |
+
s_churn,
|
768 |
+
s_noise,
|
769 |
+
color_fix_type,
|
770 |
+
diff_dtype,
|
771 |
+
ae_dtype,
|
772 |
+
gamma_correction,
|
773 |
+
linear_CFG,
|
774 |
+
linear_s_stage2,
|
775 |
+
spt_linear_CFG,
|
776 |
+
spt_linear_s_stage2,
|
777 |
+
model_select,
|
778 |
+
output_format,
|
779 |
+
allocation
|
780 |
+
],
|
781 |
+
outputs = [
|
782 |
+
result_slider,
|
783 |
+
result_gallery,
|
784 |
+
restore_information,
|
785 |
+
reset_btn
|
786 |
+
],
|
787 |
+
cache_examples = False,
|
788 |
+
)
|
789 |
+
|
790 |
+
with gr.Row():
|
791 |
+
gr.Markdown(claim_md)
|
792 |
+
|
793 |
+
input_image.upload(fn = check_upload, inputs = [
|
794 |
+
input_image
|
795 |
+
], outputs = [
|
796 |
+
rotation
|
797 |
+
], queue = False, show_progress = False)
|
798 |
+
|
799 |
+
denoise_button.click(fn = check, inputs = [
|
800 |
+
input_image
|
801 |
+
], outputs = [], queue = False, show_progress = False).success(fn = stage1_process, inputs = [
|
802 |
+
input_image,
|
803 |
+
gamma_correction,
|
804 |
+
diff_dtype,
|
805 |
+
ae_dtype
|
806 |
+
], outputs=[
|
807 |
+
denoise_image,
|
808 |
+
denoise_information
|
809 |
+
])
|
810 |
+
|
811 |
+
diffusion_button.click(fn = update_seed, inputs = [
|
812 |
+
randomize_seed,
|
813 |
+
seed
|
814 |
+
], outputs = [
|
815 |
+
seed
|
816 |
+
], queue = False, show_progress = False).then(fn = check, inputs = [
|
817 |
+
input_image
|
818 |
+
], outputs = [], queue = False, show_progress = False).success(fn=stage2_process, inputs = [
|
819 |
+
input_image,
|
820 |
+
rotation,
|
821 |
+
denoise_image,
|
822 |
+
prompt,
|
823 |
+
a_prompt,
|
824 |
+
n_prompt,
|
825 |
+
num_samples,
|
826 |
+
min_size,
|
827 |
+
downscale,
|
828 |
+
upscale,
|
829 |
+
edm_steps,
|
830 |
+
s_stage1,
|
831 |
+
s_stage2,
|
832 |
+
s_cfg,
|
833 |
+
randomize_seed,
|
834 |
+
seed,
|
835 |
+
s_churn,
|
836 |
+
s_noise,
|
837 |
+
color_fix_type,
|
838 |
+
diff_dtype,
|
839 |
+
ae_dtype,
|
840 |
+
gamma_correction,
|
841 |
+
linear_CFG,
|
842 |
+
linear_s_stage2,
|
843 |
+
spt_linear_CFG,
|
844 |
+
spt_linear_s_stage2,
|
845 |
+
model_select,
|
846 |
+
output_format,
|
847 |
+
allocation
|
848 |
+
], outputs = [
|
849 |
+
result_slider,
|
850 |
+
result_gallery,
|
851 |
+
restore_information,
|
852 |
+
reset_btn
|
853 |
+
]).success(fn = log_information, inputs = [
|
854 |
+
result_gallery
|
855 |
+
], outputs = [], queue = False, show_progress = False)
|
856 |
+
|
857 |
+
result_gallery.change(on_select_result, [result_slider, result_gallery], result_slider)
|
858 |
+
result_gallery.select(on_select_result, [result_slider, result_gallery], result_slider)
|
859 |
+
|
860 |
+
restart_button.click(fn = load_and_reset, inputs = [
|
861 |
+
param_setting
|
862 |
+
], outputs = [
|
863 |
+
edm_steps,
|
864 |
+
s_cfg,
|
865 |
+
s_stage2,
|
866 |
+
s_stage1,
|
867 |
+
s_churn,
|
868 |
+
s_noise,
|
869 |
+
a_prompt,
|
870 |
+
n_prompt,
|
871 |
+
color_fix_type,
|
872 |
+
linear_CFG,
|
873 |
+
linear_s_stage2,
|
874 |
+
spt_linear_CFG,
|
875 |
+
spt_linear_s_stage2,
|
876 |
+
model_select
|
877 |
+
])
|
878 |
+
|
879 |
+
reset_btn.click(fn = reset, inputs = [], outputs = [
|
880 |
+
input_image,
|
881 |
+
rotation,
|
882 |
+
denoise_image,
|
883 |
+
prompt,
|
884 |
+
a_prompt,
|
885 |
+
n_prompt,
|
886 |
+
num_samples,
|
887 |
+
min_size,
|
888 |
+
downscale,
|
889 |
+
upscale,
|
890 |
+
edm_steps,
|
891 |
+
s_stage1,
|
892 |
+
s_stage2,
|
893 |
+
s_cfg,
|
894 |
+
randomize_seed,
|
895 |
+
seed,
|
896 |
+
s_churn,
|
897 |
+
s_noise,
|
898 |
+
color_fix_type,
|
899 |
+
diff_dtype,
|
900 |
+
ae_dtype,
|
901 |
+
gamma_correction,
|
902 |
+
linear_CFG,
|
903 |
+
linear_s_stage2,
|
904 |
+
spt_linear_CFG,
|
905 |
+
spt_linear_s_stage2,
|
906 |
+
model_select,
|
907 |
+
output_format,
|
908 |
+
allocation
|
909 |
+
], queue = False, show_progress = False)
|
910 |
+
|
911 |
+
interface.queue(10).launch()
|
configs/clip_vit_config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPTextModel"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"dropout": 0.0,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "quick_gelu",
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 3072,
|
15 |
+
"layer_norm_eps": 1e-05,
|
16 |
+
"max_position_embeddings": 77,
|
17 |
+
"model_type": "clip_text_model",
|
18 |
+
"num_attention_heads": 12,
|
19 |
+
"num_hidden_layers": 12,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"projection_dim": 768,
|
22 |
+
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.22.0.dev0",
|
24 |
+
"vocab_size": 49408
|
25 |
+
}
|
configs/tokenizer/config.json
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "clip-vit-large-patch14/",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPModel"
|
5 |
+
],
|
6 |
+
"initializer_factor": 1.0,
|
7 |
+
"logit_scale_init_value": 2.6592,
|
8 |
+
"model_type": "clip",
|
9 |
+
"projection_dim": 768,
|
10 |
+
"text_config": {
|
11 |
+
"_name_or_path": "",
|
12 |
+
"add_cross_attention": false,
|
13 |
+
"architectures": null,
|
14 |
+
"attention_dropout": 0.0,
|
15 |
+
"bad_words_ids": null,
|
16 |
+
"bos_token_id": 0,
|
17 |
+
"chunk_size_feed_forward": 0,
|
18 |
+
"cross_attention_hidden_size": null,
|
19 |
+
"decoder_start_token_id": null,
|
20 |
+
"diversity_penalty": 0.0,
|
21 |
+
"do_sample": false,
|
22 |
+
"dropout": 0.0,
|
23 |
+
"early_stopping": false,
|
24 |
+
"encoder_no_repeat_ngram_size": 0,
|
25 |
+
"eos_token_id": 2,
|
26 |
+
"finetuning_task": null,
|
27 |
+
"forced_bos_token_id": null,
|
28 |
+
"forced_eos_token_id": null,
|
29 |
+
"hidden_act": "quick_gelu",
|
30 |
+
"hidden_size": 768,
|
31 |
+
"id2label": {
|
32 |
+
"0": "LABEL_0",
|
33 |
+
"1": "LABEL_1"
|
34 |
+
},
|
35 |
+
"initializer_factor": 1.0,
|
36 |
+
"initializer_range": 0.02,
|
37 |
+
"intermediate_size": 3072,
|
38 |
+
"is_decoder": false,
|
39 |
+
"is_encoder_decoder": false,
|
40 |
+
"label2id": {
|
41 |
+
"LABEL_0": 0,
|
42 |
+
"LABEL_1": 1
|
43 |
+
},
|
44 |
+
"layer_norm_eps": 1e-05,
|
45 |
+
"length_penalty": 1.0,
|
46 |
+
"max_length": 20,
|
47 |
+
"max_position_embeddings": 77,
|
48 |
+
"min_length": 0,
|
49 |
+
"model_type": "clip_text_model",
|
50 |
+
"no_repeat_ngram_size": 0,
|
51 |
+
"num_attention_heads": 12,
|
52 |
+
"num_beam_groups": 1,
|
53 |
+
"num_beams": 1,
|
54 |
+
"num_hidden_layers": 12,
|
55 |
+
"num_return_sequences": 1,
|
56 |
+
"output_attentions": false,
|
57 |
+
"output_hidden_states": false,
|
58 |
+
"output_scores": false,
|
59 |
+
"pad_token_id": 1,
|
60 |
+
"prefix": null,
|
61 |
+
"problem_type": null,
|
62 |
+
"projection_dim": 768,
|
63 |
+
"pruned_heads": {},
|
64 |
+
"remove_invalid_values": false,
|
65 |
+
"repetition_penalty": 1.0,
|
66 |
+
"return_dict": true,
|
67 |
+
"return_dict_in_generate": false,
|
68 |
+
"sep_token_id": null,
|
69 |
+
"task_specific_params": null,
|
70 |
+
"temperature": 1.0,
|
71 |
+
"tie_encoder_decoder": false,
|
72 |
+
"tie_word_embeddings": true,
|
73 |
+
"tokenizer_class": null,
|
74 |
+
"top_k": 50,
|
75 |
+
"top_p": 1.0,
|
76 |
+
"torch_dtype": null,
|
77 |
+
"torchscript": false,
|
78 |
+
"transformers_version": "4.16.0.dev0",
|
79 |
+
"use_bfloat16": false,
|
80 |
+
"vocab_size": 49408
|
81 |
+
},
|
82 |
+
"text_config_dict": {
|
83 |
+
"hidden_size": 768,
|
84 |
+
"intermediate_size": 3072,
|
85 |
+
"num_attention_heads": 12,
|
86 |
+
"num_hidden_layers": 12,
|
87 |
+
"projection_dim": 768
|
88 |
+
},
|
89 |
+
"torch_dtype": "float32",
|
90 |
+
"transformers_version": null,
|
91 |
+
"vision_config": {
|
92 |
+
"_name_or_path": "",
|
93 |
+
"add_cross_attention": false,
|
94 |
+
"architectures": null,
|
95 |
+
"attention_dropout": 0.0,
|
96 |
+
"bad_words_ids": null,
|
97 |
+
"bos_token_id": null,
|
98 |
+
"chunk_size_feed_forward": 0,
|
99 |
+
"cross_attention_hidden_size": null,
|
100 |
+
"decoder_start_token_id": null,
|
101 |
+
"diversity_penalty": 0.0,
|
102 |
+
"do_sample": false,
|
103 |
+
"dropout": 0.0,
|
104 |
+
"early_stopping": false,
|
105 |
+
"encoder_no_repeat_ngram_size": 0,
|
106 |
+
"eos_token_id": null,
|
107 |
+
"finetuning_task": null,
|
108 |
+
"forced_bos_token_id": null,
|
109 |
+
"forced_eos_token_id": null,
|
110 |
+
"hidden_act": "quick_gelu",
|
111 |
+
"hidden_size": 1024,
|
112 |
+
"id2label": {
|
113 |
+
"0": "LABEL_0",
|
114 |
+
"1": "LABEL_1"
|
115 |
+
},
|
116 |
+
"image_size": 224,
|
117 |
+
"initializer_factor": 1.0,
|
118 |
+
"initializer_range": 0.02,
|
119 |
+
"intermediate_size": 4096,
|
120 |
+
"is_decoder": false,
|
121 |
+
"is_encoder_decoder": false,
|
122 |
+
"label2id": {
|
123 |
+
"LABEL_0": 0,
|
124 |
+
"LABEL_1": 1
|
125 |
+
},
|
126 |
+
"layer_norm_eps": 1e-05,
|
127 |
+
"length_penalty": 1.0,
|
128 |
+
"max_length": 20,
|
129 |
+
"min_length": 0,
|
130 |
+
"model_type": "clip_vision_model",
|
131 |
+
"no_repeat_ngram_size": 0,
|
132 |
+
"num_attention_heads": 16,
|
133 |
+
"num_beam_groups": 1,
|
134 |
+
"num_beams": 1,
|
135 |
+
"num_hidden_layers": 24,
|
136 |
+
"num_return_sequences": 1,
|
137 |
+
"output_attentions": false,
|
138 |
+
"output_hidden_states": false,
|
139 |
+
"output_scores": false,
|
140 |
+
"pad_token_id": null,
|
141 |
+
"patch_size": 14,
|
142 |
+
"prefix": null,
|
143 |
+
"problem_type": null,
|
144 |
+
"projection_dim": 768,
|
145 |
+
"pruned_heads": {},
|
146 |
+
"remove_invalid_values": false,
|
147 |
+
"repetition_penalty": 1.0,
|
148 |
+
"return_dict": true,
|
149 |
+
"return_dict_in_generate": false,
|
150 |
+
"sep_token_id": null,
|
151 |
+
"task_specific_params": null,
|
152 |
+
"temperature": 1.0,
|
153 |
+
"tie_encoder_decoder": false,
|
154 |
+
"tie_word_embeddings": true,
|
155 |
+
"tokenizer_class": null,
|
156 |
+
"top_k": 50,
|
157 |
+
"top_p": 1.0,
|
158 |
+
"torch_dtype": null,
|
159 |
+
"torchscript": false,
|
160 |
+
"transformers_version": "4.16.0.dev0",
|
161 |
+
"use_bfloat16": false
|
162 |
+
},
|
163 |
+
"vision_config_dict": {
|
164 |
+
"hidden_size": 1024,
|
165 |
+
"intermediate_size": 4096,
|
166 |
+
"num_attention_heads": 16,
|
167 |
+
"num_hidden_layers": 24,
|
168 |
+
"patch_size": 14,
|
169 |
+
"projection_dim": 768
|
170 |
+
}
|
171 |
+
}
|
configs/tokenizer/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/tokenizer/preprocessor_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 224,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_normalize": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
7 |
+
"image_mean": [
|
8 |
+
0.48145466,
|
9 |
+
0.4578275,
|
10 |
+
0.40821073
|
11 |
+
],
|
12 |
+
"image_std": [
|
13 |
+
0.26862954,
|
14 |
+
0.26130258,
|
15 |
+
0.27577711
|
16 |
+
],
|
17 |
+
"resample": 3,
|
18 |
+
"size": 224
|
19 |
+
}
|
configs/tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"single_word": false,
|
5 |
+
"lstrip": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"normalized": true
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"single_word": false,
|
12 |
+
"lstrip": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"normalized": true
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"single_word": false,
|
19 |
+
"lstrip": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"normalized": true
|
22 |
+
},
|
23 |
+
"pad_token": "<|endoftext|>"
|
24 |
+
}
|
configs/tokenizer/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"unk_token": {
|
3 |
+
"content": "<|endoftext|>",
|
4 |
+
"single_word": false,
|
5 |
+
"lstrip": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"__type": "AddedToken"
|
9 |
+
},
|
10 |
+
"bos_token": {
|
11 |
+
"content": "<|startoftext|>",
|
12 |
+
"single_word": false,
|
13 |
+
"lstrip": false,
|
14 |
+
"rstrip": false,
|
15 |
+
"normalized": true,
|
16 |
+
"__type": "AddedToken"
|
17 |
+
},
|
18 |
+
"eos_token": {
|
19 |
+
"content": "<|endoftext|>",
|
20 |
+
"single_word": false,
|
21 |
+
"lstrip": false,
|
22 |
+
"rstrip": false,
|
23 |
+
"normalized": true,
|
24 |
+
"__type": "AddedToken"
|
25 |
+
},
|
26 |
+
"pad_token": "<|endoftext|>",
|
27 |
+
"add_prefix_space": false,
|
28 |
+
"errors": "replace",
|
29 |
+
"do_lower_case": true,
|
30 |
+
"name_or_path": "openai/clip-vit-base-patch32",
|
31 |
+
"model_max_length": 77,
|
32 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
33 |
+
"tokenizer_class": "CLIPTokenizer"
|
34 |
+
}
|
configs/tokenizer/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
options/SUPIR_v0.yaml
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: SUPIR.models.SUPIR_model.SUPIRModel
|
3 |
+
params:
|
4 |
+
ae_dtype: bf16
|
5 |
+
diffusion_dtype: fp16
|
6 |
+
scale_factor: 0.13025
|
7 |
+
disable_first_stage_autocast: True
|
8 |
+
network_wrapper: sgm.modules.diffusionmodules.wrappers.ControlWrapper
|
9 |
+
|
10 |
+
denoiser_config:
|
11 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiserWithControl
|
12 |
+
params:
|
13 |
+
num_idx: 1000
|
14 |
+
weighting_config:
|
15 |
+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
16 |
+
scaling_config:
|
17 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
18 |
+
discretization_config:
|
19 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
20 |
+
|
21 |
+
control_stage_config:
|
22 |
+
target: SUPIR.modules.SUPIR_v0.GLVControl
|
23 |
+
params:
|
24 |
+
adm_in_channels: 2816
|
25 |
+
num_classes: sequential
|
26 |
+
use_checkpoint: True
|
27 |
+
in_channels: 4
|
28 |
+
out_channels: 4
|
29 |
+
model_channels: 320
|
30 |
+
attention_resolutions: [4, 2]
|
31 |
+
num_res_blocks: 2
|
32 |
+
channel_mult: [1, 2, 4]
|
33 |
+
num_head_channels: 64
|
34 |
+
use_spatial_transformer: True
|
35 |
+
use_linear_in_transformer: True
|
36 |
+
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
37 |
+
context_dim: 2048
|
38 |
+
spatial_transformer_attn_type: softmax-xformers
|
39 |
+
legacy: False
|
40 |
+
input_upscale: 1
|
41 |
+
|
42 |
+
network_config:
|
43 |
+
target: SUPIR.modules.SUPIR_v0.LightGLVUNet
|
44 |
+
params:
|
45 |
+
mode: XL-base
|
46 |
+
project_type: ZeroSFT
|
47 |
+
project_channel_scale: 2
|
48 |
+
adm_in_channels: 2816
|
49 |
+
num_classes: sequential
|
50 |
+
use_checkpoint: True
|
51 |
+
in_channels: 4
|
52 |
+
out_channels: 4
|
53 |
+
model_channels: 320
|
54 |
+
attention_resolutions: [4, 2]
|
55 |
+
num_res_blocks: 2
|
56 |
+
channel_mult: [1, 2, 4]
|
57 |
+
num_head_channels: 64
|
58 |
+
use_spatial_transformer: True
|
59 |
+
use_linear_in_transformer: True
|
60 |
+
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
61 |
+
context_dim: 2048
|
62 |
+
spatial_transformer_attn_type: softmax-xformers
|
63 |
+
legacy: False
|
64 |
+
|
65 |
+
conditioner_config:
|
66 |
+
target: sgm.modules.GeneralConditionerWithControl
|
67 |
+
params:
|
68 |
+
emb_models:
|
69 |
+
# crossattn cond
|
70 |
+
- is_trainable: False
|
71 |
+
input_key: txt
|
72 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
73 |
+
params:
|
74 |
+
layer: hidden
|
75 |
+
layer_idx: 11
|
76 |
+
# crossattn and vector cond
|
77 |
+
- is_trainable: False
|
78 |
+
input_key: txt
|
79 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
80 |
+
params:
|
81 |
+
arch: ViT-bigG-14
|
82 |
+
version: laion2b_s39b_b160k
|
83 |
+
freeze: True
|
84 |
+
layer: penultimate
|
85 |
+
always_return_pooled: True
|
86 |
+
legacy: False
|
87 |
+
# vector cond
|
88 |
+
- is_trainable: False
|
89 |
+
input_key: original_size_as_tuple
|
90 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
91 |
+
params:
|
92 |
+
outdim: 256 # multiplied by two
|
93 |
+
# vector cond
|
94 |
+
- is_trainable: False
|
95 |
+
input_key: crop_coords_top_left
|
96 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
97 |
+
params:
|
98 |
+
outdim: 256 # multiplied by two
|
99 |
+
# vector cond
|
100 |
+
- is_trainable: False
|
101 |
+
input_key: target_size_as_tuple
|
102 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
103 |
+
params:
|
104 |
+
outdim: 256 # multiplied by two
|
105 |
+
|
106 |
+
first_stage_config:
|
107 |
+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
108 |
+
params:
|
109 |
+
ckpt_path: ~
|
110 |
+
embed_dim: 4
|
111 |
+
monitor: val/rec_loss
|
112 |
+
ddconfig:
|
113 |
+
attn_type: vanilla-xformers
|
114 |
+
double_z: true
|
115 |
+
z_channels: 4
|
116 |
+
resolution: 256
|
117 |
+
in_channels: 3
|
118 |
+
out_ch: 3
|
119 |
+
ch: 128
|
120 |
+
ch_mult: [1, 2, 4, 4]
|
121 |
+
num_res_blocks: 2
|
122 |
+
attn_resolutions: []
|
123 |
+
dropout: 0.0
|
124 |
+
lossconfig:
|
125 |
+
target: torch.nn.Identity
|
126 |
+
|
127 |
+
sampler_config:
|
128 |
+
target: sgm.modules.diffusionmodules.sampling.RestoreEDMSampler
|
129 |
+
params:
|
130 |
+
num_steps: 100
|
131 |
+
restore_cfg: 4.0
|
132 |
+
s_churn: 0
|
133 |
+
s_noise: 1.003
|
134 |
+
discretization_config:
|
135 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
136 |
+
guider_config:
|
137 |
+
target: sgm.modules.diffusionmodules.guiders.LinearCFG
|
138 |
+
params:
|
139 |
+
scale: 7.5
|
140 |
+
scale_min: 4.0
|
sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (
|
32 |
+
self.lr_max - self.lr_start
|
33 |
+
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
+
self.last_lr = lr
|
35 |
+
return lr
|
36 |
+
else:
|
37 |
+
t = (n - self.lr_warm_up_steps) / (
|
38 |
+
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
+
)
|
40 |
+
t = min(t, 1.0)
|
41 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
+
1 + np.cos(t * np.pi)
|
43 |
+
)
|
44 |
+
self.last_lr = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def __call__(self, n, **kwargs):
|
48 |
+
return self.schedule(n, **kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
class LambdaWarmUpCosineScheduler2:
|
52 |
+
"""
|
53 |
+
supports repeated iterations, configurable via lists
|
54 |
+
note: use with a base_lr of 1.0.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
+
):
|
60 |
+
assert (
|
61 |
+
len(warm_up_steps)
|
62 |
+
== len(f_min)
|
63 |
+
== len(f_max)
|
64 |
+
== len(f_start)
|
65 |
+
== len(cycle_lengths)
|
66 |
+
)
|
67 |
+
self.lr_warm_up_steps = warm_up_steps
|
68 |
+
self.f_start = f_start
|
69 |
+
self.f_min = f_min
|
70 |
+
self.f_max = f_max
|
71 |
+
self.cycle_lengths = cycle_lengths
|
72 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
+
self.last_f = 0.0
|
74 |
+
self.verbosity_interval = verbosity_interval
|
75 |
+
|
76 |
+
def find_in_interval(self, n):
|
77 |
+
interval = 0
|
78 |
+
for cl in self.cum_cycles[1:]:
|
79 |
+
if n <= cl:
|
80 |
+
return interval
|
81 |
+
interval += 1
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0:
|
88 |
+
print(
|
89 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}"
|
91 |
+
)
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
+
cycle
|
95 |
+
] * n + self.f_start[cycle]
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
else:
|
99 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
+
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
+
)
|
102 |
+
t = min(t, 1.0)
|
103 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
+
1 + np.cos(t * np.pi)
|
105 |
+
)
|
106 |
+
self.last_f = f
|
107 |
+
return f
|
108 |
+
|
109 |
+
def __call__(self, n, **kwargs):
|
110 |
+
return self.schedule(n, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
+
def schedule(self, n, **kwargs):
|
115 |
+
cycle = self.find_in_interval(n)
|
116 |
+
n = n - self.cum_cycles[cycle]
|
117 |
+
if self.verbosity_interval > 0:
|
118 |
+
if n % self.verbosity_interval == 0:
|
119 |
+
print(
|
120 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
+
f"current cycle {cycle}"
|
122 |
+
)
|
123 |
+
|
124 |
+
if n < self.lr_warm_up_steps[cycle]:
|
125 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
+
cycle
|
127 |
+
] * n + self.f_start[cycle]
|
128 |
+
self.last_f = f
|
129 |
+
return f
|
130 |
+
else:
|
131 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
+
self.cycle_lengths[cycle] - n
|
133 |
+
) / (self.cycle_lengths[cycle])
|
134 |
+
self.last_f = f
|
135 |
+
return f
|
sgm/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .autoencoder import AutoencodingEngine
|
2 |
+
from .diffusion import DiffusionEngine
|
sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from abc import abstractmethod
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Any, Dict, Tuple, Union
|
5 |
+
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from omegaconf import ListConfig
|
9 |
+
from packaging import version
|
10 |
+
from safetensors.torch import load_file as load_safetensors
|
11 |
+
|
12 |
+
from ..modules.diffusionmodules.model import Decoder, Encoder
|
13 |
+
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
14 |
+
from ..modules.ema import LitEma
|
15 |
+
from ..util import default, get_obj_from_str, instantiate_from_config
|
16 |
+
|
17 |
+
|
18 |
+
class AbstractAutoencoder(pl.LightningModule):
|
19 |
+
"""
|
20 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
21 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
22 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
ema_decay: Union[None, float] = None,
|
28 |
+
monitor: Union[None, str] = None,
|
29 |
+
input_key: str = "jpg",
|
30 |
+
ckpt_path: Union[None, str] = None,
|
31 |
+
ignore_keys: Union[Tuple, list, ListConfig] = (),
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.input_key = input_key
|
35 |
+
self.use_ema = ema_decay is not None
|
36 |
+
if monitor is not None:
|
37 |
+
self.monitor = monitor
|
38 |
+
|
39 |
+
if self.use_ema:
|
40 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
41 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def init_from_ckpt(
|
50 |
+
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
|
51 |
+
) -> None:
|
52 |
+
if path.endswith("ckpt"):
|
53 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
54 |
+
elif path.endswith("safetensors"):
|
55 |
+
sd = load_safetensors(path)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
keys = list(sd.keys())
|
60 |
+
for k in keys:
|
61 |
+
for ik in ignore_keys:
|
62 |
+
if re.match(ik, k):
|
63 |
+
print("Deleting key {} from state_dict.".format(k))
|
64 |
+
del sd[k]
|
65 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
66 |
+
print(
|
67 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
68 |
+
)
|
69 |
+
if len(missing) > 0:
|
70 |
+
print(f"Missing Keys: {missing}")
|
71 |
+
if len(unexpected) > 0:
|
72 |
+
print(f"Unexpected Keys: {unexpected}")
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def get_input(self, batch) -> Any:
|
76 |
+
raise NotImplementedError()
|
77 |
+
|
78 |
+
def on_train_batch_end(self, *args, **kwargs):
|
79 |
+
# for EMA computation
|
80 |
+
if self.use_ema:
|
81 |
+
self.model_ema(self)
|
82 |
+
|
83 |
+
@contextmanager
|
84 |
+
def ema_scope(self, context=None):
|
85 |
+
if self.use_ema:
|
86 |
+
self.model_ema.store(self.parameters())
|
87 |
+
self.model_ema.copy_to(self)
|
88 |
+
if context is not None:
|
89 |
+
print(f"{context}: Switched to EMA weights")
|
90 |
+
try:
|
91 |
+
yield None
|
92 |
+
finally:
|
93 |
+
if self.use_ema:
|
94 |
+
self.model_ema.restore(self.parameters())
|
95 |
+
if context is not None:
|
96 |
+
print(f"{context}: Restored training weights")
|
97 |
+
|
98 |
+
@abstractmethod
|
99 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
100 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
101 |
+
|
102 |
+
@abstractmethod
|
103 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
104 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
105 |
+
|
106 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
107 |
+
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
108 |
+
return get_obj_from_str(cfg["target"])(
|
109 |
+
params, lr=lr, **cfg.get("params", dict())
|
110 |
+
)
|
111 |
+
|
112 |
+
def configure_optimizers(self) -> Any:
|
113 |
+
raise NotImplementedError()
|
114 |
+
|
115 |
+
|
116 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
117 |
+
"""
|
118 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
119 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
120 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
*args,
|
126 |
+
encoder_config: Dict,
|
127 |
+
decoder_config: Dict,
|
128 |
+
loss_config: Dict,
|
129 |
+
regularizer_config: Dict,
|
130 |
+
optimizer_config: Union[Dict, None] = None,
|
131 |
+
lr_g_factor: float = 1.0,
|
132 |
+
**kwargs,
|
133 |
+
):
|
134 |
+
super().__init__(*args, **kwargs)
|
135 |
+
# todo: add options to freeze encoder/decoder
|
136 |
+
self.encoder = instantiate_from_config(encoder_config)
|
137 |
+
self.decoder = instantiate_from_config(decoder_config)
|
138 |
+
self.loss = instantiate_from_config(loss_config)
|
139 |
+
self.regularization = instantiate_from_config(regularizer_config)
|
140 |
+
self.optimizer_config = default(
|
141 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
142 |
+
)
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
|
145 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
146 |
+
# assuming unified data format, dataloader returns a dict.
|
147 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
|
148 |
+
return batch[self.input_key]
|
149 |
+
|
150 |
+
def get_autoencoder_params(self) -> list:
|
151 |
+
params = (
|
152 |
+
list(self.encoder.parameters())
|
153 |
+
+ list(self.decoder.parameters())
|
154 |
+
+ list(self.regularization.get_trainable_parameters())
|
155 |
+
+ list(self.loss.get_trainable_autoencoder_parameters())
|
156 |
+
)
|
157 |
+
return params
|
158 |
+
|
159 |
+
def get_discriminator_params(self) -> list:
|
160 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
161 |
+
return params
|
162 |
+
|
163 |
+
def get_last_layer(self):
|
164 |
+
return self.decoder.get_last_layer()
|
165 |
+
|
166 |
+
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
|
167 |
+
z = self.encoder(x)
|
168 |
+
z, reg_log = self.regularization(z)
|
169 |
+
if return_reg_log:
|
170 |
+
return z, reg_log
|
171 |
+
return z
|
172 |
+
|
173 |
+
def decode(self, z: Any) -> torch.Tensor:
|
174 |
+
x = self.decoder(z)
|
175 |
+
return x
|
176 |
+
|
177 |
+
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
178 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
179 |
+
dec = self.decode(z)
|
180 |
+
return z, dec, reg_log
|
181 |
+
|
182 |
+
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
|
183 |
+
x = self.get_input(batch)
|
184 |
+
z, xrec, regularization_log = self(x)
|
185 |
+
|
186 |
+
if optimizer_idx == 0:
|
187 |
+
# autoencode
|
188 |
+
aeloss, log_dict_ae = self.loss(
|
189 |
+
regularization_log,
|
190 |
+
x,
|
191 |
+
xrec,
|
192 |
+
optimizer_idx,
|
193 |
+
self.global_step,
|
194 |
+
last_layer=self.get_last_layer(),
|
195 |
+
split="train",
|
196 |
+
)
|
197 |
+
|
198 |
+
self.log_dict(
|
199 |
+
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
200 |
+
)
|
201 |
+
return aeloss
|
202 |
+
|
203 |
+
if optimizer_idx == 1:
|
204 |
+
# discriminator
|
205 |
+
discloss, log_dict_disc = self.loss(
|
206 |
+
regularization_log,
|
207 |
+
x,
|
208 |
+
xrec,
|
209 |
+
optimizer_idx,
|
210 |
+
self.global_step,
|
211 |
+
last_layer=self.get_last_layer(),
|
212 |
+
split="train",
|
213 |
+
)
|
214 |
+
self.log_dict(
|
215 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
216 |
+
)
|
217 |
+
return discloss
|
218 |
+
|
219 |
+
def validation_step(self, batch, batch_idx) -> Dict:
|
220 |
+
log_dict = self._validation_step(batch, batch_idx)
|
221 |
+
with self.ema_scope():
|
222 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
223 |
+
log_dict.update(log_dict_ema)
|
224 |
+
return log_dict
|
225 |
+
|
226 |
+
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
|
227 |
+
x = self.get_input(batch)
|
228 |
+
|
229 |
+
z, xrec, regularization_log = self(x)
|
230 |
+
aeloss, log_dict_ae = self.loss(
|
231 |
+
regularization_log,
|
232 |
+
x,
|
233 |
+
xrec,
|
234 |
+
0,
|
235 |
+
self.global_step,
|
236 |
+
last_layer=self.get_last_layer(),
|
237 |
+
split="val" + postfix,
|
238 |
+
)
|
239 |
+
|
240 |
+
discloss, log_dict_disc = self.loss(
|
241 |
+
regularization_log,
|
242 |
+
x,
|
243 |
+
xrec,
|
244 |
+
1,
|
245 |
+
self.global_step,
|
246 |
+
last_layer=self.get_last_layer(),
|
247 |
+
split="val" + postfix,
|
248 |
+
)
|
249 |
+
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
250 |
+
log_dict_ae.update(log_dict_disc)
|
251 |
+
self.log_dict(log_dict_ae)
|
252 |
+
return log_dict_ae
|
253 |
+
|
254 |
+
def configure_optimizers(self) -> Any:
|
255 |
+
ae_params = self.get_autoencoder_params()
|
256 |
+
disc_params = self.get_discriminator_params()
|
257 |
+
|
258 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
259 |
+
ae_params,
|
260 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
261 |
+
self.optimizer_config,
|
262 |
+
)
|
263 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
264 |
+
disc_params, self.learning_rate, self.optimizer_config
|
265 |
+
)
|
266 |
+
|
267 |
+
return [opt_ae, opt_disc], []
|
268 |
+
|
269 |
+
@torch.no_grad()
|
270 |
+
def log_images(self, batch: Dict, **kwargs) -> Dict:
|
271 |
+
log = dict()
|
272 |
+
x = self.get_input(batch)
|
273 |
+
_, xrec, _ = self(x)
|
274 |
+
log["inputs"] = x
|
275 |
+
log["reconstructions"] = xrec
|
276 |
+
with self.ema_scope():
|
277 |
+
_, xrec_ema, _ = self(x)
|
278 |
+
log["reconstructions_ema"] = xrec_ema
|
279 |
+
return log
|
280 |
+
|
281 |
+
|
282 |
+
class AutoencoderKL(AutoencodingEngine):
|
283 |
+
def __init__(self, embed_dim: int, **kwargs):
|
284 |
+
ddconfig = kwargs.pop("ddconfig")
|
285 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
286 |
+
ignore_keys = kwargs.pop("ignore_keys", ())
|
287 |
+
super().__init__(
|
288 |
+
encoder_config={"target": "torch.nn.Identity"},
|
289 |
+
decoder_config={"target": "torch.nn.Identity"},
|
290 |
+
regularizer_config={"target": "torch.nn.Identity"},
|
291 |
+
loss_config=kwargs.pop("lossconfig"),
|
292 |
+
**kwargs,
|
293 |
+
)
|
294 |
+
assert ddconfig["double_z"]
|
295 |
+
self.encoder = Encoder(**ddconfig)
|
296 |
+
self.decoder = Decoder(**ddconfig)
|
297 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
298 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
299 |
+
self.embed_dim = embed_dim
|
300 |
+
|
301 |
+
if ckpt_path is not None:
|
302 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
303 |
+
|
304 |
+
def encode(self, x):
|
305 |
+
assert (
|
306 |
+
not self.training
|
307 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
308 |
+
h = self.encoder(x)
|
309 |
+
moments = self.quant_conv(h)
|
310 |
+
posterior = DiagonalGaussianDistribution(moments)
|
311 |
+
return posterior
|
312 |
+
|
313 |
+
def decode(self, z, **decoder_kwargs):
|
314 |
+
z = self.post_quant_conv(z)
|
315 |
+
dec = self.decoder(z, **decoder_kwargs)
|
316 |
+
return dec
|
317 |
+
|
318 |
+
|
319 |
+
class AutoencoderKLInferenceWrapper(AutoencoderKL):
|
320 |
+
def encode(self, x):
|
321 |
+
return super().encode(x).sample()
|
322 |
+
|
323 |
+
|
324 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
325 |
+
def __init__(self, *args, **kwargs):
|
326 |
+
super().__init__(*args, **kwargs)
|
327 |
+
|
328 |
+
def get_input(self, x: Any) -> Any:
|
329 |
+
return x
|
330 |
+
|
331 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
332 |
+
return x
|
333 |
+
|
334 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
335 |
+
return x
|
sgm/models/diffusion.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from typing import Any, Dict, List, Tuple, Union
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
from omegaconf import ListConfig, OmegaConf
|
7 |
+
from safetensors.torch import load_file as load_safetensors
|
8 |
+
from torch.optim.lr_scheduler import LambdaLR
|
9 |
+
|
10 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
11 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
12 |
+
from ..modules.ema import LitEma
|
13 |
+
from ..util import (
|
14 |
+
default,
|
15 |
+
disabled_train,
|
16 |
+
get_obj_from_str,
|
17 |
+
instantiate_from_config,
|
18 |
+
log_txt_as_img,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class DiffusionEngine(pl.LightningModule):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
network_config,
|
26 |
+
denoiser_config,
|
27 |
+
first_stage_config,
|
28 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
29 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
30 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
31 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
32 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
33 |
+
network_wrapper: Union[None, str] = None,
|
34 |
+
ckpt_path: Union[None, str] = None,
|
35 |
+
use_ema: bool = False,
|
36 |
+
ema_decay_rate: float = 0.9999,
|
37 |
+
scale_factor: float = 1.0,
|
38 |
+
disable_first_stage_autocast=False,
|
39 |
+
input_key: str = "jpg",
|
40 |
+
log_keys: Union[List, None] = None,
|
41 |
+
no_cond_log: bool = False,
|
42 |
+
compile_model: bool = False,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.log_keys = log_keys
|
46 |
+
self.input_key = input_key
|
47 |
+
self.optimizer_config = default(
|
48 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
49 |
+
)
|
50 |
+
model = instantiate_from_config(network_config)
|
51 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
52 |
+
model, compile_model=compile_model
|
53 |
+
)
|
54 |
+
|
55 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
56 |
+
self.sampler = (
|
57 |
+
instantiate_from_config(sampler_config)
|
58 |
+
if sampler_config is not None
|
59 |
+
else None
|
60 |
+
)
|
61 |
+
self.conditioner = instantiate_from_config(
|
62 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
63 |
+
)
|
64 |
+
self.scheduler_config = scheduler_config
|
65 |
+
self._init_first_stage(first_stage_config)
|
66 |
+
|
67 |
+
self.loss_fn = (
|
68 |
+
instantiate_from_config(loss_fn_config)
|
69 |
+
if loss_fn_config is not None
|
70 |
+
else None
|
71 |
+
)
|
72 |
+
|
73 |
+
self.use_ema = use_ema
|
74 |
+
if self.use_ema:
|
75 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
76 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
77 |
+
|
78 |
+
self.scale_factor = scale_factor
|
79 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
80 |
+
self.no_cond_log = no_cond_log
|
81 |
+
|
82 |
+
if ckpt_path is not None:
|
83 |
+
self.init_from_ckpt(ckpt_path)
|
84 |
+
|
85 |
+
def init_from_ckpt(
|
86 |
+
self,
|
87 |
+
path: str,
|
88 |
+
) -> None:
|
89 |
+
if path.endswith("ckpt"):
|
90 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
91 |
+
elif path.endswith("safetensors"):
|
92 |
+
sd = load_safetensors(path)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError
|
95 |
+
|
96 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
97 |
+
print(
|
98 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
99 |
+
)
|
100 |
+
if len(missing) > 0:
|
101 |
+
print(f"Missing Keys: {missing}")
|
102 |
+
if len(unexpected) > 0:
|
103 |
+
print(f"Unexpected Keys: {unexpected}")
|
104 |
+
|
105 |
+
def _init_first_stage(self, config):
|
106 |
+
model = instantiate_from_config(config).eval()
|
107 |
+
model.train = disabled_train
|
108 |
+
for param in model.parameters():
|
109 |
+
param.requires_grad = False
|
110 |
+
self.first_stage_model = model
|
111 |
+
|
112 |
+
def get_input(self, batch):
|
113 |
+
# assuming unified data format, dataloader returns a dict.
|
114 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
115 |
+
return batch[self.input_key]
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def decode_first_stage(self, z):
|
119 |
+
z = 1.0 / self.scale_factor * z
|
120 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
121 |
+
out = self.first_stage_model.decode(z)
|
122 |
+
return out
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
def encode_first_stage(self, x):
|
126 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
127 |
+
z = self.first_stage_model.encode(x)
|
128 |
+
z = self.scale_factor * z
|
129 |
+
return z
|
130 |
+
|
131 |
+
def forward(self, x, batch):
|
132 |
+
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
133 |
+
loss_mean = loss.mean()
|
134 |
+
loss_dict = {"loss": loss_mean}
|
135 |
+
return loss_mean, loss_dict
|
136 |
+
|
137 |
+
def shared_step(self, batch: Dict) -> Any:
|
138 |
+
x = self.get_input(batch)
|
139 |
+
x = self.encode_first_stage(x)
|
140 |
+
batch["global_step"] = self.global_step
|
141 |
+
loss, loss_dict = self(x, batch)
|
142 |
+
return loss, loss_dict
|
143 |
+
|
144 |
+
def training_step(self, batch, batch_idx):
|
145 |
+
loss, loss_dict = self.shared_step(batch)
|
146 |
+
|
147 |
+
self.log_dict(
|
148 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
149 |
+
)
|
150 |
+
|
151 |
+
self.log(
|
152 |
+
"global_step",
|
153 |
+
self.global_step,
|
154 |
+
prog_bar=True,
|
155 |
+
logger=True,
|
156 |
+
on_step=True,
|
157 |
+
on_epoch=False,
|
158 |
+
)
|
159 |
+
|
160 |
+
# if self.scheduler_config is not None:
|
161 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
162 |
+
self.log(
|
163 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
164 |
+
)
|
165 |
+
|
166 |
+
return loss
|
167 |
+
|
168 |
+
def on_train_start(self, *args, **kwargs):
|
169 |
+
if self.sampler is None or self.loss_fn is None:
|
170 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
171 |
+
|
172 |
+
def on_train_batch_end(self, *args, **kwargs):
|
173 |
+
if self.use_ema:
|
174 |
+
self.model_ema(self.model)
|
175 |
+
|
176 |
+
@contextmanager
|
177 |
+
def ema_scope(self, context=None):
|
178 |
+
if self.use_ema:
|
179 |
+
self.model_ema.store(self.model.parameters())
|
180 |
+
self.model_ema.copy_to(self.model)
|
181 |
+
if context is not None:
|
182 |
+
print(f"{context}: Switched to EMA weights")
|
183 |
+
try:
|
184 |
+
yield None
|
185 |
+
finally:
|
186 |
+
if self.use_ema:
|
187 |
+
self.model_ema.restore(self.model.parameters())
|
188 |
+
if context is not None:
|
189 |
+
print(f"{context}: Restored training weights")
|
190 |
+
|
191 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
192 |
+
return get_obj_from_str(cfg["target"])(
|
193 |
+
params, lr=lr, **cfg.get("params", dict())
|
194 |
+
)
|
195 |
+
|
196 |
+
def configure_optimizers(self):
|
197 |
+
lr = self.learning_rate
|
198 |
+
params = list(self.model.parameters())
|
199 |
+
for embedder in self.conditioner.embedders:
|
200 |
+
if embedder.is_trainable:
|
201 |
+
params = params + list(embedder.parameters())
|
202 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
203 |
+
if self.scheduler_config is not None:
|
204 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
205 |
+
print("Setting up LambdaLR scheduler...")
|
206 |
+
scheduler = [
|
207 |
+
{
|
208 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
209 |
+
"interval": "step",
|
210 |
+
"frequency": 1,
|
211 |
+
}
|
212 |
+
]
|
213 |
+
return [opt], scheduler
|
214 |
+
return opt
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def sample(
|
218 |
+
self,
|
219 |
+
cond: Dict,
|
220 |
+
uc: Union[Dict, None] = None,
|
221 |
+
batch_size: int = 16,
|
222 |
+
shape: Union[None, Tuple, List] = None,
|
223 |
+
**kwargs,
|
224 |
+
):
|
225 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
226 |
+
|
227 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
228 |
+
self.model, input, sigma, c, **kwargs
|
229 |
+
)
|
230 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
231 |
+
return samples
|
232 |
+
|
233 |
+
@torch.no_grad()
|
234 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
235 |
+
"""
|
236 |
+
Defines heuristics to log different conditionings.
|
237 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
238 |
+
"""
|
239 |
+
image_h, image_w = batch[self.input_key].shape[2:]
|
240 |
+
log = dict()
|
241 |
+
|
242 |
+
for embedder in self.conditioner.embedders:
|
243 |
+
if (
|
244 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
245 |
+
) and not self.no_cond_log:
|
246 |
+
x = batch[embedder.input_key][:n]
|
247 |
+
if isinstance(x, torch.Tensor):
|
248 |
+
if x.dim() == 1:
|
249 |
+
# class-conditional, convert integer to string
|
250 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
251 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
252 |
+
elif x.dim() == 2:
|
253 |
+
# size and crop cond and the like
|
254 |
+
x = [
|
255 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
256 |
+
for i in range(x.shape[0])
|
257 |
+
]
|
258 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
259 |
+
else:
|
260 |
+
raise NotImplementedError()
|
261 |
+
elif isinstance(x, (List, ListConfig)):
|
262 |
+
if isinstance(x[0], str):
|
263 |
+
# strings
|
264 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
265 |
+
else:
|
266 |
+
raise NotImplementedError()
|
267 |
+
else:
|
268 |
+
raise NotImplementedError()
|
269 |
+
log[embedder.input_key] = xc
|
270 |
+
return log
|
271 |
+
|
272 |
+
@torch.no_grad()
|
273 |
+
def log_images(
|
274 |
+
self,
|
275 |
+
batch: Dict,
|
276 |
+
N: int = 8,
|
277 |
+
sample: bool = True,
|
278 |
+
ucg_keys: List[str] = None,
|
279 |
+
**kwargs,
|
280 |
+
) -> Dict:
|
281 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
282 |
+
if ucg_keys:
|
283 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
284 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
285 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
ucg_keys = conditioner_input_keys
|
289 |
+
log = dict()
|
290 |
+
|
291 |
+
x = self.get_input(batch)
|
292 |
+
|
293 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
294 |
+
batch,
|
295 |
+
force_uc_zero_embeddings=ucg_keys
|
296 |
+
if len(self.conditioner.embedders) > 0
|
297 |
+
else [],
|
298 |
+
)
|
299 |
+
|
300 |
+
sampling_kwargs = {}
|
301 |
+
|
302 |
+
N = min(x.shape[0], N)
|
303 |
+
x = x.to(self.device)[:N]
|
304 |
+
log["inputs"] = x
|
305 |
+
z = self.encode_first_stage(x)
|
306 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
307 |
+
log.update(self.log_conditionings(batch, N))
|
308 |
+
|
309 |
+
for k in c:
|
310 |
+
if isinstance(c[k], torch.Tensor):
|
311 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
312 |
+
|
313 |
+
if sample:
|
314 |
+
with self.ema_scope("Plotting"):
|
315 |
+
samples = self.sample(
|
316 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
317 |
+
)
|
318 |
+
samples = self.decode_first_stage(samples)
|
319 |
+
log["samples"] = samples
|
320 |
+
return log
|
sgm/modules/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner
|
2 |
+
from .encoders.modules import GeneralConditionerWithControl
|
3 |
+
from .encoders.modules import PreparedConditioner
|
4 |
+
|
5 |
+
UNCONDITIONAL_CONFIG = {
|
6 |
+
"target": "sgm.modules.GeneralConditioner",
|
7 |
+
"params": {"emb_models": []},
|
8 |
+
}
|
sgm/modules/attention.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from inspect import isfunction
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
# from einops._torch_specific import allow_ops_in_compiled_graph
|
8 |
+
# allow_ops_in_compiled_graph()
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from packaging import version
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
14 |
+
SDP_IS_AVAILABLE = True
|
15 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
16 |
+
|
17 |
+
BACKEND_MAP = {
|
18 |
+
SDPBackend.MATH: {
|
19 |
+
"enable_math": True,
|
20 |
+
"enable_flash": False,
|
21 |
+
"enable_mem_efficient": False,
|
22 |
+
},
|
23 |
+
SDPBackend.FLASH_ATTENTION: {
|
24 |
+
"enable_math": False,
|
25 |
+
"enable_flash": True,
|
26 |
+
"enable_mem_efficient": False,
|
27 |
+
},
|
28 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
29 |
+
"enable_math": False,
|
30 |
+
"enable_flash": False,
|
31 |
+
"enable_mem_efficient": True,
|
32 |
+
},
|
33 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
34 |
+
}
|
35 |
+
else:
|
36 |
+
from contextlib import nullcontext
|
37 |
+
|
38 |
+
SDP_IS_AVAILABLE = False
|
39 |
+
sdp_kernel = nullcontext
|
40 |
+
BACKEND_MAP = {}
|
41 |
+
print(
|
42 |
+
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
43 |
+
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
44 |
+
)
|
45 |
+
|
46 |
+
try:
|
47 |
+
import xformers
|
48 |
+
import xformers.ops
|
49 |
+
|
50 |
+
XFORMERS_IS_AVAILABLE = True
|
51 |
+
except:
|
52 |
+
XFORMERS_IS_AVAILABLE = False
|
53 |
+
print("no module 'xformers'. Processing without...")
|
54 |
+
|
55 |
+
from .diffusionmodules.util import checkpoint
|
56 |
+
|
57 |
+
|
58 |
+
def exists(val):
|
59 |
+
return val is not None
|
60 |
+
|
61 |
+
|
62 |
+
def uniq(arr):
|
63 |
+
return {el: True for el in arr}.keys()
|
64 |
+
|
65 |
+
|
66 |
+
def default(val, d):
|
67 |
+
if exists(val):
|
68 |
+
return val
|
69 |
+
return d() if isfunction(d) else d
|
70 |
+
|
71 |
+
|
72 |
+
def max_neg_value(t):
|
73 |
+
return -torch.finfo(t.dtype).max
|
74 |
+
|
75 |
+
|
76 |
+
def init_(tensor):
|
77 |
+
dim = tensor.shape[-1]
|
78 |
+
std = 1 / math.sqrt(dim)
|
79 |
+
tensor.uniform_(-std, std)
|
80 |
+
return tensor
|
81 |
+
|
82 |
+
|
83 |
+
# feedforward
|
84 |
+
class GEGLU(nn.Module):
|
85 |
+
def __init__(self, dim_in, dim_out):
|
86 |
+
super().__init__()
|
87 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
91 |
+
return x * F.gelu(gate)
|
92 |
+
|
93 |
+
|
94 |
+
class FeedForward(nn.Module):
|
95 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
96 |
+
super().__init__()
|
97 |
+
inner_dim = int(dim * mult)
|
98 |
+
dim_out = default(dim_out, dim)
|
99 |
+
project_in = (
|
100 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
101 |
+
if not glu
|
102 |
+
else GEGLU(dim, inner_dim)
|
103 |
+
)
|
104 |
+
|
105 |
+
self.net = nn.Sequential(
|
106 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
return self.net(x)
|
111 |
+
|
112 |
+
|
113 |
+
def zero_module(module):
|
114 |
+
"""
|
115 |
+
Zero out the parameters of a module and return it.
|
116 |
+
"""
|
117 |
+
for p in module.parameters():
|
118 |
+
p.detach().zero_()
|
119 |
+
return module
|
120 |
+
|
121 |
+
|
122 |
+
def Normalize(in_channels):
|
123 |
+
return torch.nn.GroupNorm(
|
124 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
class LinearAttention(nn.Module):
|
129 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
130 |
+
super().__init__()
|
131 |
+
self.heads = heads
|
132 |
+
hidden_dim = dim_head * heads
|
133 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
134 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
b, c, h, w = x.shape
|
138 |
+
qkv = self.to_qkv(x)
|
139 |
+
q, k, v = rearrange(
|
140 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
141 |
+
)
|
142 |
+
k = k.softmax(dim=-1)
|
143 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
144 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
145 |
+
out = rearrange(
|
146 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
147 |
+
)
|
148 |
+
return self.to_out(out)
|
149 |
+
|
150 |
+
|
151 |
+
class SpatialSelfAttention(nn.Module):
|
152 |
+
def __init__(self, in_channels):
|
153 |
+
super().__init__()
|
154 |
+
self.in_channels = in_channels
|
155 |
+
|
156 |
+
self.norm = Normalize(in_channels)
|
157 |
+
self.q = torch.nn.Conv2d(
|
158 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
159 |
+
)
|
160 |
+
self.k = torch.nn.Conv2d(
|
161 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
162 |
+
)
|
163 |
+
self.v = torch.nn.Conv2d(
|
164 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
165 |
+
)
|
166 |
+
self.proj_out = torch.nn.Conv2d(
|
167 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
h_ = x
|
172 |
+
h_ = self.norm(h_)
|
173 |
+
q = self.q(h_)
|
174 |
+
k = self.k(h_)
|
175 |
+
v = self.v(h_)
|
176 |
+
|
177 |
+
# compute attention
|
178 |
+
b, c, h, w = q.shape
|
179 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
180 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
181 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
182 |
+
|
183 |
+
w_ = w_ * (int(c) ** (-0.5))
|
184 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
185 |
+
|
186 |
+
# attend to values
|
187 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
188 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
189 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
190 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
191 |
+
h_ = self.proj_out(h_)
|
192 |
+
|
193 |
+
return x + h_
|
194 |
+
|
195 |
+
|
196 |
+
class CrossAttention(nn.Module):
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
query_dim,
|
200 |
+
context_dim=None,
|
201 |
+
heads=8,
|
202 |
+
dim_head=64,
|
203 |
+
dropout=0.0,
|
204 |
+
backend=None,
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
inner_dim = dim_head * heads
|
208 |
+
context_dim = default(context_dim, query_dim)
|
209 |
+
|
210 |
+
self.scale = dim_head**-0.5
|
211 |
+
self.heads = heads
|
212 |
+
|
213 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
214 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
215 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
216 |
+
|
217 |
+
self.to_out = nn.Sequential(
|
218 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
219 |
+
)
|
220 |
+
self.backend = backend
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
x,
|
225 |
+
context=None,
|
226 |
+
mask=None,
|
227 |
+
additional_tokens=None,
|
228 |
+
n_times_crossframe_attn_in_self=0,
|
229 |
+
):
|
230 |
+
h = self.heads
|
231 |
+
|
232 |
+
if additional_tokens is not None:
|
233 |
+
# get the number of masked tokens at the beginning of the output sequence
|
234 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
235 |
+
# add additional token
|
236 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
237 |
+
|
238 |
+
q = self.to_q(x)
|
239 |
+
context = default(context, x)
|
240 |
+
k = self.to_k(context)
|
241 |
+
v = self.to_v(context)
|
242 |
+
|
243 |
+
if n_times_crossframe_attn_in_self:
|
244 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
245 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
246 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
247 |
+
k = repeat(
|
248 |
+
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
249 |
+
)
|
250 |
+
v = repeat(
|
251 |
+
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
252 |
+
)
|
253 |
+
|
254 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
255 |
+
|
256 |
+
## old
|
257 |
+
"""
|
258 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
259 |
+
del q, k
|
260 |
+
|
261 |
+
if exists(mask):
|
262 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
263 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
264 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
265 |
+
sim.masked_fill_(~mask, max_neg_value)
|
266 |
+
|
267 |
+
# attention, what we cannot get enough of
|
268 |
+
sim = sim.softmax(dim=-1)
|
269 |
+
|
270 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
271 |
+
"""
|
272 |
+
## new
|
273 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
274 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
275 |
+
out = F.scaled_dot_product_attention(
|
276 |
+
q, k, v, attn_mask=mask
|
277 |
+
) # scale is dim_head ** -0.5 per default
|
278 |
+
|
279 |
+
del q, k, v
|
280 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
281 |
+
|
282 |
+
if additional_tokens is not None:
|
283 |
+
# remove additional token
|
284 |
+
out = out[:, n_tokens_to_mask:]
|
285 |
+
return self.to_out(out)
|
286 |
+
|
287 |
+
|
288 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
289 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
290 |
+
def __init__(
|
291 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
292 |
+
):
|
293 |
+
super().__init__()
|
294 |
+
# print(
|
295 |
+
# f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
296 |
+
# f"{heads} heads with a dimension of {dim_head}."
|
297 |
+
# )
|
298 |
+
inner_dim = dim_head * heads
|
299 |
+
context_dim = default(context_dim, query_dim)
|
300 |
+
|
301 |
+
self.heads = heads
|
302 |
+
self.dim_head = dim_head
|
303 |
+
|
304 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
305 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
306 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
307 |
+
|
308 |
+
self.to_out = nn.Sequential(
|
309 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
310 |
+
)
|
311 |
+
self.attention_op: Optional[Any] = None
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
x,
|
316 |
+
context=None,
|
317 |
+
mask=None,
|
318 |
+
additional_tokens=None,
|
319 |
+
n_times_crossframe_attn_in_self=0,
|
320 |
+
):
|
321 |
+
if additional_tokens is not None:
|
322 |
+
# get the number of masked tokens at the beginning of the output sequence
|
323 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
324 |
+
# add additional token
|
325 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
326 |
+
q = self.to_q(x)
|
327 |
+
context = default(context, x)
|
328 |
+
k = self.to_k(context)
|
329 |
+
v = self.to_v(context)
|
330 |
+
|
331 |
+
if n_times_crossframe_attn_in_self:
|
332 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
333 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
334 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
335 |
+
k = repeat(
|
336 |
+
k[::n_times_crossframe_attn_in_self],
|
337 |
+
"b ... -> (b n) ...",
|
338 |
+
n=n_times_crossframe_attn_in_self,
|
339 |
+
)
|
340 |
+
v = repeat(
|
341 |
+
v[::n_times_crossframe_attn_in_self],
|
342 |
+
"b ... -> (b n) ...",
|
343 |
+
n=n_times_crossframe_attn_in_self,
|
344 |
+
)
|
345 |
+
|
346 |
+
b, _, _ = q.shape
|
347 |
+
q, k, v = map(
|
348 |
+
lambda t: t.unsqueeze(3)
|
349 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
350 |
+
.permute(0, 2, 1, 3)
|
351 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
352 |
+
.contiguous(),
|
353 |
+
(q, k, v),
|
354 |
+
)
|
355 |
+
|
356 |
+
# actually compute the attention, what we cannot get enough of
|
357 |
+
out = xformers.ops.memory_efficient_attention(
|
358 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
359 |
+
)
|
360 |
+
|
361 |
+
# TODO: Use this directly in the attention operation, as a bias
|
362 |
+
if exists(mask):
|
363 |
+
raise NotImplementedError
|
364 |
+
out = (
|
365 |
+
out.unsqueeze(0)
|
366 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
367 |
+
.permute(0, 2, 1, 3)
|
368 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
369 |
+
)
|
370 |
+
if additional_tokens is not None:
|
371 |
+
# remove additional token
|
372 |
+
out = out[:, n_tokens_to_mask:]
|
373 |
+
return self.to_out(out)
|
374 |
+
|
375 |
+
|
376 |
+
class BasicTransformerBlock(nn.Module):
|
377 |
+
ATTENTION_MODES = {
|
378 |
+
"softmax": CrossAttention, # vanilla attention
|
379 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
380 |
+
}
|
381 |
+
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
dim,
|
385 |
+
n_heads,
|
386 |
+
d_head,
|
387 |
+
dropout=0.0,
|
388 |
+
context_dim=None,
|
389 |
+
gated_ff=True,
|
390 |
+
checkpoint=False,
|
391 |
+
disable_self_attn=False,
|
392 |
+
attn_mode="softmax",
|
393 |
+
sdp_backend=None,
|
394 |
+
):
|
395 |
+
super().__init__()
|
396 |
+
checkpoint = False
|
397 |
+
|
398 |
+
assert attn_mode in self.ATTENTION_MODES
|
399 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
400 |
+
print(
|
401 |
+
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
402 |
+
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
403 |
+
)
|
404 |
+
attn_mode = "softmax"
|
405 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
406 |
+
print(
|
407 |
+
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
408 |
+
)
|
409 |
+
if not XFORMERS_IS_AVAILABLE:
|
410 |
+
assert (
|
411 |
+
False
|
412 |
+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
413 |
+
else:
|
414 |
+
print("Falling back to xformers efficient attention.")
|
415 |
+
attn_mode = "softmax-xformers"
|
416 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
417 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
418 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
419 |
+
else:
|
420 |
+
assert sdp_backend is None
|
421 |
+
self.disable_self_attn = disable_self_attn
|
422 |
+
self.attn1 = attn_cls(
|
423 |
+
query_dim=dim,
|
424 |
+
heads=n_heads,
|
425 |
+
dim_head=d_head,
|
426 |
+
dropout=dropout,
|
427 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
428 |
+
backend=sdp_backend,
|
429 |
+
) # is a self-attention if not self.disable_self_attn
|
430 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
431 |
+
self.attn2 = attn_cls(
|
432 |
+
query_dim=dim,
|
433 |
+
context_dim=context_dim,
|
434 |
+
heads=n_heads,
|
435 |
+
dim_head=d_head,
|
436 |
+
dropout=dropout,
|
437 |
+
backend=sdp_backend,
|
438 |
+
) # is self-attn if context is none
|
439 |
+
self.norm1 = nn.LayerNorm(dim)
|
440 |
+
self.norm2 = nn.LayerNorm(dim)
|
441 |
+
self.norm3 = nn.LayerNorm(dim)
|
442 |
+
self.checkpoint = checkpoint
|
443 |
+
if self.checkpoint:
|
444 |
+
print(f"{self.__class__.__name__} is using checkpointing")
|
445 |
+
|
446 |
+
def forward(
|
447 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
448 |
+
):
|
449 |
+
kwargs = {"x": x}
|
450 |
+
|
451 |
+
if context is not None:
|
452 |
+
kwargs.update({"context": context})
|
453 |
+
|
454 |
+
if additional_tokens is not None:
|
455 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
456 |
+
|
457 |
+
if n_times_crossframe_attn_in_self:
|
458 |
+
kwargs.update(
|
459 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
460 |
+
)
|
461 |
+
|
462 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
463 |
+
return checkpoint(
|
464 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
465 |
+
)
|
466 |
+
|
467 |
+
def _forward(
|
468 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
469 |
+
):
|
470 |
+
x = (
|
471 |
+
self.attn1(
|
472 |
+
self.norm1(x),
|
473 |
+
context=context if self.disable_self_attn else None,
|
474 |
+
additional_tokens=additional_tokens,
|
475 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
476 |
+
if not self.disable_self_attn
|
477 |
+
else 0,
|
478 |
+
)
|
479 |
+
+ x
|
480 |
+
)
|
481 |
+
x = (
|
482 |
+
self.attn2(
|
483 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens
|
484 |
+
)
|
485 |
+
+ x
|
486 |
+
)
|
487 |
+
x = self.ff(self.norm3(x)) + x
|
488 |
+
return x
|
489 |
+
|
490 |
+
|
491 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
492 |
+
ATTENTION_MODES = {
|
493 |
+
"softmax": CrossAttention, # vanilla attention
|
494 |
+
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
495 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
496 |
+
}
|
497 |
+
|
498 |
+
def __init__(
|
499 |
+
self,
|
500 |
+
dim,
|
501 |
+
n_heads,
|
502 |
+
d_head,
|
503 |
+
dropout=0.0,
|
504 |
+
context_dim=None,
|
505 |
+
gated_ff=True,
|
506 |
+
checkpoint=True,
|
507 |
+
attn_mode="softmax",
|
508 |
+
):
|
509 |
+
super().__init__()
|
510 |
+
assert attn_mode in self.ATTENTION_MODES
|
511 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
512 |
+
self.attn1 = attn_cls(
|
513 |
+
query_dim=dim,
|
514 |
+
heads=n_heads,
|
515 |
+
dim_head=d_head,
|
516 |
+
dropout=dropout,
|
517 |
+
context_dim=context_dim,
|
518 |
+
)
|
519 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
520 |
+
self.norm1 = nn.LayerNorm(dim)
|
521 |
+
self.norm2 = nn.LayerNorm(dim)
|
522 |
+
self.checkpoint = checkpoint
|
523 |
+
|
524 |
+
def forward(self, x, context=None):
|
525 |
+
return checkpoint(
|
526 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
527 |
+
)
|
528 |
+
|
529 |
+
def _forward(self, x, context=None):
|
530 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
531 |
+
x = self.ff(self.norm2(x)) + x
|
532 |
+
return x
|
533 |
+
|
534 |
+
|
535 |
+
class SpatialTransformer(nn.Module):
|
536 |
+
"""
|
537 |
+
Transformer block for image-like data.
|
538 |
+
First, project the input (aka embedding)
|
539 |
+
and reshape to b, t, d.
|
540 |
+
Then apply standard transformer action.
|
541 |
+
Finally, reshape to image
|
542 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
543 |
+
"""
|
544 |
+
|
545 |
+
def __init__(
|
546 |
+
self,
|
547 |
+
in_channels,
|
548 |
+
n_heads,
|
549 |
+
d_head,
|
550 |
+
depth=1,
|
551 |
+
dropout=0.0,
|
552 |
+
context_dim=None,
|
553 |
+
disable_self_attn=False,
|
554 |
+
use_linear=False,
|
555 |
+
attn_type="softmax",
|
556 |
+
use_checkpoint=True,
|
557 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
558 |
+
sdp_backend=None,
|
559 |
+
):
|
560 |
+
super().__init__()
|
561 |
+
print(
|
562 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
563 |
+
)
|
564 |
+
from omegaconf import ListConfig
|
565 |
+
|
566 |
+
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
567 |
+
context_dim = [context_dim]
|
568 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
569 |
+
if depth != len(context_dim):
|
570 |
+
# print(
|
571 |
+
# f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
572 |
+
# f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
573 |
+
# )
|
574 |
+
# depth does not match context dims.
|
575 |
+
assert all(
|
576 |
+
map(lambda x: x == context_dim[0], context_dim)
|
577 |
+
), "need homogenous context_dim to match depth automatically"
|
578 |
+
context_dim = depth * [context_dim[0]]
|
579 |
+
elif context_dim is None:
|
580 |
+
context_dim = [None] * depth
|
581 |
+
self.in_channels = in_channels
|
582 |
+
inner_dim = n_heads * d_head
|
583 |
+
self.norm = Normalize(in_channels)
|
584 |
+
if not use_linear:
|
585 |
+
self.proj_in = nn.Conv2d(
|
586 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
587 |
+
)
|
588 |
+
else:
|
589 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
590 |
+
|
591 |
+
self.transformer_blocks = nn.ModuleList(
|
592 |
+
[
|
593 |
+
BasicTransformerBlock(
|
594 |
+
inner_dim,
|
595 |
+
n_heads,
|
596 |
+
d_head,
|
597 |
+
dropout=dropout,
|
598 |
+
context_dim=context_dim[d],
|
599 |
+
disable_self_attn=disable_self_attn,
|
600 |
+
attn_mode=attn_type,
|
601 |
+
checkpoint=use_checkpoint,
|
602 |
+
sdp_backend=sdp_backend,
|
603 |
+
)
|
604 |
+
for d in range(depth)
|
605 |
+
]
|
606 |
+
)
|
607 |
+
if not use_linear:
|
608 |
+
self.proj_out = zero_module(
|
609 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
610 |
+
)
|
611 |
+
else:
|
612 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
613 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
614 |
+
self.use_linear = use_linear
|
615 |
+
|
616 |
+
def forward(self, x, context=None):
|
617 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
618 |
+
if not isinstance(context, list):
|
619 |
+
context = [context]
|
620 |
+
b, c, h, w = x.shape
|
621 |
+
x_in = x
|
622 |
+
x = self.norm(x)
|
623 |
+
if not self.use_linear:
|
624 |
+
x = self.proj_in(x)
|
625 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
626 |
+
if self.use_linear:
|
627 |
+
x = self.proj_in(x)
|
628 |
+
for i, block in enumerate(self.transformer_blocks):
|
629 |
+
if i > 0 and len(context) == 1:
|
630 |
+
i = 0 # use same context for each block
|
631 |
+
x = block(x, context=context[i])
|
632 |
+
if self.use_linear:
|
633 |
+
x = self.proj_out(x)
|
634 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
635 |
+
if not self.use_linear:
|
636 |
+
x = self.proj_out(x)
|
637 |
+
return x + x_in
|
sgm/modules/autoencoding/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/losses/__init__.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from ....util import default, instantiate_from_config
|
8 |
+
from ..lpips.loss.lpips import LPIPS
|
9 |
+
from ..lpips.model.model import NLayerDiscriminator, weights_init
|
10 |
+
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
11 |
+
|
12 |
+
|
13 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
14 |
+
if global_step < threshold:
|
15 |
+
weight = value
|
16 |
+
return weight
|
17 |
+
|
18 |
+
|
19 |
+
class LatentLPIPS(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
decoder_config,
|
23 |
+
perceptual_weight=1.0,
|
24 |
+
latent_weight=1.0,
|
25 |
+
scale_input_to_tgt_size=False,
|
26 |
+
scale_tgt_to_input_size=False,
|
27 |
+
perceptual_weight_on_inputs=0.0,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
31 |
+
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
32 |
+
self.init_decoder(decoder_config)
|
33 |
+
self.perceptual_loss = LPIPS().eval()
|
34 |
+
self.perceptual_weight = perceptual_weight
|
35 |
+
self.latent_weight = latent_weight
|
36 |
+
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
37 |
+
|
38 |
+
def init_decoder(self, config):
|
39 |
+
self.decoder = instantiate_from_config(config)
|
40 |
+
if hasattr(self.decoder, "encoder"):
|
41 |
+
del self.decoder.encoder
|
42 |
+
|
43 |
+
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
44 |
+
log = dict()
|
45 |
+
loss = (latent_inputs - latent_predictions) ** 2
|
46 |
+
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
47 |
+
image_reconstructions = None
|
48 |
+
if self.perceptual_weight > 0.0:
|
49 |
+
image_reconstructions = self.decoder.decode(latent_predictions)
|
50 |
+
image_targets = self.decoder.decode(latent_inputs)
|
51 |
+
perceptual_loss = self.perceptual_loss(
|
52 |
+
image_targets.contiguous(), image_reconstructions.contiguous()
|
53 |
+
)
|
54 |
+
loss = (
|
55 |
+
self.latent_weight * loss.mean()
|
56 |
+
+ self.perceptual_weight * perceptual_loss.mean()
|
57 |
+
)
|
58 |
+
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
59 |
+
|
60 |
+
if self.perceptual_weight_on_inputs > 0.0:
|
61 |
+
image_reconstructions = default(
|
62 |
+
image_reconstructions, self.decoder.decode(latent_predictions)
|
63 |
+
)
|
64 |
+
if self.scale_input_to_tgt_size:
|
65 |
+
image_inputs = torch.nn.functional.interpolate(
|
66 |
+
image_inputs,
|
67 |
+
image_reconstructions.shape[2:],
|
68 |
+
mode="bicubic",
|
69 |
+
antialias=True,
|
70 |
+
)
|
71 |
+
elif self.scale_tgt_to_input_size:
|
72 |
+
image_reconstructions = torch.nn.functional.interpolate(
|
73 |
+
image_reconstructions,
|
74 |
+
image_inputs.shape[2:],
|
75 |
+
mode="bicubic",
|
76 |
+
antialias=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
perceptual_loss2 = self.perceptual_loss(
|
80 |
+
image_inputs.contiguous(), image_reconstructions.contiguous()
|
81 |
+
)
|
82 |
+
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
83 |
+
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
84 |
+
return loss, log
|
85 |
+
|
86 |
+
|
87 |
+
class GeneralLPIPSWithDiscriminator(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
disc_start: int,
|
91 |
+
logvar_init: float = 0.0,
|
92 |
+
pixelloss_weight=1.0,
|
93 |
+
disc_num_layers: int = 3,
|
94 |
+
disc_in_channels: int = 3,
|
95 |
+
disc_factor: float = 1.0,
|
96 |
+
disc_weight: float = 1.0,
|
97 |
+
perceptual_weight: float = 1.0,
|
98 |
+
disc_loss: str = "hinge",
|
99 |
+
scale_input_to_tgt_size: bool = False,
|
100 |
+
dims: int = 2,
|
101 |
+
learn_logvar: bool = False,
|
102 |
+
regularization_weights: Union[None, dict] = None,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
self.dims = dims
|
106 |
+
if self.dims > 2:
|
107 |
+
print(
|
108 |
+
f"running with dims={dims}. This means that for perceptual loss calculation, "
|
109 |
+
f"the LPIPS loss will be applied to each frame independently. "
|
110 |
+
)
|
111 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
112 |
+
assert disc_loss in ["hinge", "vanilla"]
|
113 |
+
self.pixel_weight = pixelloss_weight
|
114 |
+
self.perceptual_loss = LPIPS().eval()
|
115 |
+
self.perceptual_weight = perceptual_weight
|
116 |
+
# output log variance
|
117 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
118 |
+
self.learn_logvar = learn_logvar
|
119 |
+
|
120 |
+
self.discriminator = NLayerDiscriminator(
|
121 |
+
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
|
122 |
+
).apply(weights_init)
|
123 |
+
self.discriminator_iter_start = disc_start
|
124 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
125 |
+
self.disc_factor = disc_factor
|
126 |
+
self.discriminator_weight = disc_weight
|
127 |
+
self.regularization_weights = default(regularization_weights, {})
|
128 |
+
|
129 |
+
def get_trainable_parameters(self) -> Any:
|
130 |
+
return self.discriminator.parameters()
|
131 |
+
|
132 |
+
def get_trainable_autoencoder_parameters(self) -> Any:
|
133 |
+
if self.learn_logvar:
|
134 |
+
yield self.logvar
|
135 |
+
yield from ()
|
136 |
+
|
137 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
138 |
+
if last_layer is not None:
|
139 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
140 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
141 |
+
else:
|
142 |
+
nll_grads = torch.autograd.grad(
|
143 |
+
nll_loss, self.last_layer[0], retain_graph=True
|
144 |
+
)[0]
|
145 |
+
g_grads = torch.autograd.grad(
|
146 |
+
g_loss, self.last_layer[0], retain_graph=True
|
147 |
+
)[0]
|
148 |
+
|
149 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
150 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
151 |
+
d_weight = d_weight * self.discriminator_weight
|
152 |
+
return d_weight
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self,
|
156 |
+
regularization_log,
|
157 |
+
inputs,
|
158 |
+
reconstructions,
|
159 |
+
optimizer_idx,
|
160 |
+
global_step,
|
161 |
+
last_layer=None,
|
162 |
+
split="train",
|
163 |
+
weights=None,
|
164 |
+
):
|
165 |
+
if self.scale_input_to_tgt_size:
|
166 |
+
inputs = torch.nn.functional.interpolate(
|
167 |
+
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
168 |
+
)
|
169 |
+
|
170 |
+
if self.dims > 2:
|
171 |
+
inputs, reconstructions = map(
|
172 |
+
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
173 |
+
(inputs, reconstructions),
|
174 |
+
)
|
175 |
+
|
176 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
177 |
+
if self.perceptual_weight > 0:
|
178 |
+
p_loss = self.perceptual_loss(
|
179 |
+
inputs.contiguous(), reconstructions.contiguous()
|
180 |
+
)
|
181 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
182 |
+
|
183 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
184 |
+
weighted_nll_loss = nll_loss
|
185 |
+
if weights is not None:
|
186 |
+
weighted_nll_loss = weights * nll_loss
|
187 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
188 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
189 |
+
|
190 |
+
# now the GAN part
|
191 |
+
if optimizer_idx == 0:
|
192 |
+
# generator update
|
193 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
194 |
+
g_loss = -torch.mean(logits_fake)
|
195 |
+
|
196 |
+
if self.disc_factor > 0.0:
|
197 |
+
try:
|
198 |
+
d_weight = self.calculate_adaptive_weight(
|
199 |
+
nll_loss, g_loss, last_layer=last_layer
|
200 |
+
)
|
201 |
+
except RuntimeError:
|
202 |
+
assert not self.training
|
203 |
+
d_weight = torch.tensor(0.0)
|
204 |
+
else:
|
205 |
+
d_weight = torch.tensor(0.0)
|
206 |
+
|
207 |
+
disc_factor = adopt_weight(
|
208 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
209 |
+
)
|
210 |
+
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
|
211 |
+
log = dict()
|
212 |
+
for k in regularization_log:
|
213 |
+
if k in self.regularization_weights:
|
214 |
+
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
215 |
+
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
|
216 |
+
|
217 |
+
log.update(
|
218 |
+
{
|
219 |
+
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
220 |
+
"{}/logvar".format(split): self.logvar.detach(),
|
221 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
222 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
223 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
224 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
225 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
226 |
+
}
|
227 |
+
)
|
228 |
+
|
229 |
+
return loss, log
|
230 |
+
|
231 |
+
if optimizer_idx == 1:
|
232 |
+
# second pass for discriminator update
|
233 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
234 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
235 |
+
|
236 |
+
disc_factor = adopt_weight(
|
237 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
238 |
+
)
|
239 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
240 |
+
|
241 |
+
log = {
|
242 |
+
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
243 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
244 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean(),
|
245 |
+
}
|
246 |
+
return d_loss, log
|
sgm/modules/autoencoding/lpips/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/loss/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/lpips.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from ..util import get_ckpt_path
|
10 |
+
|
11 |
+
|
12 |
+
class LPIPS(nn.Module):
|
13 |
+
# Learned perceptual metric
|
14 |
+
def __init__(self, use_dropout=True):
|
15 |
+
super().__init__()
|
16 |
+
self.scaling_layer = ScalingLayer()
|
17 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
18 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
19 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
20 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
21 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
22 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
23 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
24 |
+
self.load_from_pretrained()
|
25 |
+
for param in self.parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
|
28 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
29 |
+
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
30 |
+
self.load_state_dict(
|
31 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
32 |
+
)
|
33 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
37 |
+
if name != "vgg_lpips":
|
38 |
+
raise NotImplementedError
|
39 |
+
model = cls()
|
40 |
+
ckpt = get_ckpt_path(name)
|
41 |
+
model.load_state_dict(
|
42 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
43 |
+
)
|
44 |
+
return model
|
45 |
+
|
46 |
+
def forward(self, input, target):
|
47 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
48 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
49 |
+
feats0, feats1, diffs = {}, {}, {}
|
50 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
51 |
+
for kk in range(len(self.chns)):
|
52 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
53 |
+
outs1[kk]
|
54 |
+
)
|
55 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
56 |
+
|
57 |
+
res = [
|
58 |
+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
59 |
+
for kk in range(len(self.chns))
|
60 |
+
]
|
61 |
+
val = res[0]
|
62 |
+
for l in range(1, len(self.chns)):
|
63 |
+
val += res[l]
|
64 |
+
return val
|
65 |
+
|
66 |
+
|
67 |
+
class ScalingLayer(nn.Module):
|
68 |
+
def __init__(self):
|
69 |
+
super(ScalingLayer, self).__init__()
|
70 |
+
self.register_buffer(
|
71 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
72 |
+
)
|
73 |
+
self.register_buffer(
|
74 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, inp):
|
78 |
+
return (inp - self.shift) / self.scale
|
79 |
+
|
80 |
+
|
81 |
+
class NetLinLayer(nn.Module):
|
82 |
+
"""A single linear layer which does a 1x1 conv"""
|
83 |
+
|
84 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
85 |
+
super(NetLinLayer, self).__init__()
|
86 |
+
layers = (
|
87 |
+
[
|
88 |
+
nn.Dropout(),
|
89 |
+
]
|
90 |
+
if (use_dropout)
|
91 |
+
else []
|
92 |
+
)
|
93 |
+
layers += [
|
94 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
95 |
+
]
|
96 |
+
self.model = nn.Sequential(*layers)
|
97 |
+
|
98 |
+
|
99 |
+
class vgg16(torch.nn.Module):
|
100 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
101 |
+
super(vgg16, self).__init__()
|
102 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
103 |
+
self.slice1 = torch.nn.Sequential()
|
104 |
+
self.slice2 = torch.nn.Sequential()
|
105 |
+
self.slice3 = torch.nn.Sequential()
|
106 |
+
self.slice4 = torch.nn.Sequential()
|
107 |
+
self.slice5 = torch.nn.Sequential()
|
108 |
+
self.N_slices = 5
|
109 |
+
for x in range(4):
|
110 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(4, 9):
|
112 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(9, 16):
|
114 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(16, 23):
|
116 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(23, 30):
|
118 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
if not requires_grad:
|
120 |
+
for param in self.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
def forward(self, X):
|
124 |
+
h = self.slice1(X)
|
125 |
+
h_relu1_2 = h
|
126 |
+
h = self.slice2(h)
|
127 |
+
h_relu2_2 = h
|
128 |
+
h = self.slice3(h)
|
129 |
+
h_relu3_3 = h
|
130 |
+
h = self.slice4(h)
|
131 |
+
h_relu4_3 = h
|
132 |
+
h = self.slice5(h)
|
133 |
+
h_relu5_3 = h
|
134 |
+
vgg_outputs = namedtuple(
|
135 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
136 |
+
)
|
137 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def normalize_tensor(x, eps=1e-10):
|
142 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
143 |
+
return x / (norm_factor + eps)
|
144 |
+
|
145 |
+
|
146 |
+
def spatial_average(x, keepdim=True):
|
147 |
+
return x.mean([2, 3], keepdim=keepdim)
|
sgm/modules/autoencoding/lpips/model/LICENSE
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
25 |
+
|
26 |
+
--------------------------- LICENSE FOR pix2pix --------------------------------
|
27 |
+
BSD License
|
28 |
+
|
29 |
+
For pix2pix software
|
30 |
+
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
31 |
+
All rights reserved.
|
32 |
+
|
33 |
+
Redistribution and use in source and binary forms, with or without
|
34 |
+
modification, are permitted provided that the following conditions are met:
|
35 |
+
|
36 |
+
* Redistributions of source code must retain the above copyright notice, this
|
37 |
+
list of conditions and the following disclaimer.
|
38 |
+
|
39 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
40 |
+
this list of conditions and the following disclaimer in the documentation
|
41 |
+
and/or other materials provided with the distribution.
|
42 |
+
|
43 |
+
----------------------------- LICENSE FOR DCGAN --------------------------------
|
44 |
+
BSD License
|
45 |
+
|
46 |
+
For dcgan.torch software
|
47 |
+
|
48 |
+
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
49 |
+
|
50 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
51 |
+
|
52 |
+
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
53 |
+
|
54 |
+
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
55 |
+
|
56 |
+
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
57 |
+
|
58 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/model/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/model/model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ..util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find("BatchNorm") != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
23 |
+
"""Construct a PatchGAN discriminator
|
24 |
+
Parameters:
|
25 |
+
input_nc (int) -- the number of channels in input images
|
26 |
+
ndf (int) -- the number of filters in the last conv layer
|
27 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
28 |
+
norm_layer -- normalization layer
|
29 |
+
"""
|
30 |
+
super(NLayerDiscriminator, self).__init__()
|
31 |
+
if not use_actnorm:
|
32 |
+
norm_layer = nn.BatchNorm2d
|
33 |
+
else:
|
34 |
+
norm_layer = ActNorm
|
35 |
+
if (
|
36 |
+
type(norm_layer) == functools.partial
|
37 |
+
): # no need to use bias as BatchNorm2d has affine parameters
|
38 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
39 |
+
else:
|
40 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
41 |
+
|
42 |
+
kw = 4
|
43 |
+
padw = 1
|
44 |
+
sequence = [
|
45 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
46 |
+
nn.LeakyReLU(0.2, True),
|
47 |
+
]
|
48 |
+
nf_mult = 1
|
49 |
+
nf_mult_prev = 1
|
50 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
51 |
+
nf_mult_prev = nf_mult
|
52 |
+
nf_mult = min(2**n, 8)
|
53 |
+
sequence += [
|
54 |
+
nn.Conv2d(
|
55 |
+
ndf * nf_mult_prev,
|
56 |
+
ndf * nf_mult,
|
57 |
+
kernel_size=kw,
|
58 |
+
stride=2,
|
59 |
+
padding=padw,
|
60 |
+
bias=use_bias,
|
61 |
+
),
|
62 |
+
norm_layer(ndf * nf_mult),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
]
|
65 |
+
|
66 |
+
nf_mult_prev = nf_mult
|
67 |
+
nf_mult = min(2**n_layers, 8)
|
68 |
+
sequence += [
|
69 |
+
nn.Conv2d(
|
70 |
+
ndf * nf_mult_prev,
|
71 |
+
ndf * nf_mult,
|
72 |
+
kernel_size=kw,
|
73 |
+
stride=1,
|
74 |
+
padding=padw,
|
75 |
+
bias=use_bias,
|
76 |
+
),
|
77 |
+
norm_layer(ndf * nf_mult),
|
78 |
+
nn.LeakyReLU(0.2, True),
|
79 |
+
]
|
80 |
+
|
81 |
+
sequence += [
|
82 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
83 |
+
] # output 1 channel prediction map
|
84 |
+
self.main = nn.Sequential(*sequence)
|
85 |
+
|
86 |
+
def forward(self, input):
|
87 |
+
"""Standard forward."""
|
88 |
+
return self.main(input)
|
sgm/modules/autoencoding/lpips/util.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
10 |
+
|
11 |
+
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
12 |
+
|
13 |
+
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
14 |
+
|
15 |
+
|
16 |
+
def download(url, local_path, chunk_size=1024):
|
17 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
18 |
+
with requests.get(url, stream=True) as r:
|
19 |
+
total_size = int(r.headers.get("content-length", 0))
|
20 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
21 |
+
with open(local_path, "wb") as f:
|
22 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
23 |
+
if data:
|
24 |
+
f.write(data)
|
25 |
+
pbar.update(chunk_size)
|
26 |
+
|
27 |
+
|
28 |
+
def md5_hash(path):
|
29 |
+
with open(path, "rb") as f:
|
30 |
+
content = f.read()
|
31 |
+
return hashlib.md5(content).hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def get_ckpt_path(name, root, check=False):
|
35 |
+
assert name in URL_MAP
|
36 |
+
path = os.path.join(root, CKPT_MAP[name])
|
37 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
38 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
39 |
+
download(URL_MAP[name], path)
|
40 |
+
md5 = md5_hash(path)
|
41 |
+
assert md5 == MD5_MAP[name], md5
|
42 |
+
return path
|
43 |
+
|
44 |
+
|
45 |
+
class ActNorm(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
48 |
+
):
|
49 |
+
assert affine
|
50 |
+
super().__init__()
|
51 |
+
self.logdet = logdet
|
52 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
53 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
54 |
+
self.allow_reverse_init = allow_reverse_init
|
55 |
+
|
56 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
57 |
+
|
58 |
+
def initialize(self, input):
|
59 |
+
with torch.no_grad():
|
60 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
61 |
+
mean = (
|
62 |
+
flatten.mean(1)
|
63 |
+
.unsqueeze(1)
|
64 |
+
.unsqueeze(2)
|
65 |
+
.unsqueeze(3)
|
66 |
+
.permute(1, 0, 2, 3)
|
67 |
+
)
|
68 |
+
std = (
|
69 |
+
flatten.std(1)
|
70 |
+
.unsqueeze(1)
|
71 |
+
.unsqueeze(2)
|
72 |
+
.unsqueeze(3)
|
73 |
+
.permute(1, 0, 2, 3)
|
74 |
+
)
|
75 |
+
|
76 |
+
self.loc.data.copy_(-mean)
|
77 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
78 |
+
|
79 |
+
def forward(self, input, reverse=False):
|
80 |
+
if reverse:
|
81 |
+
return self.reverse(input)
|
82 |
+
if len(input.shape) == 2:
|
83 |
+
input = input[:, :, None, None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
_, _, height, width = input.shape
|
89 |
+
|
90 |
+
if self.training and self.initialized.item() == 0:
|
91 |
+
self.initialize(input)
|
92 |
+
self.initialized.fill_(1)
|
93 |
+
|
94 |
+
h = self.scale * (input + self.loc)
|
95 |
+
|
96 |
+
if squeeze:
|
97 |
+
h = h.squeeze(-1).squeeze(-1)
|
98 |
+
|
99 |
+
if self.logdet:
|
100 |
+
log_abs = torch.log(torch.abs(self.scale))
|
101 |
+
logdet = height * width * torch.sum(log_abs)
|
102 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
103 |
+
return h, logdet
|
104 |
+
|
105 |
+
return h
|
106 |
+
|
107 |
+
def reverse(self, output):
|
108 |
+
if self.training and self.initialized.item() == 0:
|
109 |
+
if not self.allow_reverse_init:
|
110 |
+
raise RuntimeError(
|
111 |
+
"Initializing ActNorm in reverse direction is "
|
112 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
self.initialize(output)
|
116 |
+
self.initialized.fill_(1)
|
117 |
+
|
118 |
+
if len(output.shape) == 2:
|
119 |
+
output = output[:, :, None, None]
|
120 |
+
squeeze = True
|
121 |
+
else:
|
122 |
+
squeeze = False
|
123 |
+
|
124 |
+
h = output / self.scale - self.loc
|
125 |
+
|
126 |
+
if squeeze:
|
127 |
+
h = h.squeeze(-1).squeeze(-1)
|
128 |
+
return h
|
sgm/modules/autoencoding/lpips/vqperceptual.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def hinge_d_loss(logits_real, logits_fake):
|
6 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
7 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
8 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
9 |
+
return d_loss
|
10 |
+
|
11 |
+
|
12 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
13 |
+
d_loss = 0.5 * (
|
14 |
+
torch.mean(torch.nn.functional.softplus(-logits_real))
|
15 |
+
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
16 |
+
)
|
17 |
+
return d_loss
|
sgm/modules/autoencoding/regularizers/__init__.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from ....modules.distributions.distributions import DiagonalGaussianDistribution
|
9 |
+
|
10 |
+
|
11 |
+
class AbstractRegularizer(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
16 |
+
raise NotImplementedError()
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def get_trainable_parameters(self) -> Any:
|
20 |
+
raise NotImplementedError()
|
21 |
+
|
22 |
+
|
23 |
+
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
24 |
+
def __init__(self, sample: bool = True):
|
25 |
+
super().__init__()
|
26 |
+
self.sample = sample
|
27 |
+
|
28 |
+
def get_trainable_parameters(self) -> Any:
|
29 |
+
yield from ()
|
30 |
+
|
31 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
32 |
+
log = dict()
|
33 |
+
posterior = DiagonalGaussianDistribution(z)
|
34 |
+
if self.sample:
|
35 |
+
z = posterior.sample()
|
36 |
+
else:
|
37 |
+
z = posterior.mode()
|
38 |
+
kl_loss = posterior.kl()
|
39 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
40 |
+
log["kl_loss"] = kl_loss
|
41 |
+
return z, log
|
42 |
+
|
43 |
+
|
44 |
+
def measure_perplexity(predicted_indices, num_centroids):
|
45 |
+
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
46 |
+
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
47 |
+
encodings = (
|
48 |
+
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
49 |
+
)
|
50 |
+
avg_probs = encodings.mean(0)
|
51 |
+
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
52 |
+
cluster_use = torch.sum(avg_probs > 0)
|
53 |
+
return perplexity, cluster_use
|
sgm/modules/diffusionmodules/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .denoiser import Denoiser
|
2 |
+
from .discretizer import Discretization
|
3 |
+
from .loss import StandardDiffusionLoss
|
4 |
+
from .model import Decoder, Encoder, Model
|
5 |
+
from .openaimodel import UNetModel
|
6 |
+
from .sampling import BaseDiffusionSampler
|
7 |
+
from .wrappers import OpenAIWrapper
|
sgm/modules/diffusionmodules/denoiser.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from ...util import append_dims, instantiate_from_config
|
4 |
+
|
5 |
+
|
6 |
+
class Denoiser(nn.Module):
|
7 |
+
def __init__(self, weighting_config, scaling_config):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.weighting = instantiate_from_config(weighting_config)
|
11 |
+
self.scaling = instantiate_from_config(scaling_config)
|
12 |
+
|
13 |
+
def possibly_quantize_sigma(self, sigma):
|
14 |
+
return sigma
|
15 |
+
|
16 |
+
def possibly_quantize_c_noise(self, c_noise):
|
17 |
+
return c_noise
|
18 |
+
|
19 |
+
def w(self, sigma):
|
20 |
+
return self.weighting(sigma)
|
21 |
+
|
22 |
+
def __call__(self, network, input, sigma, cond):
|
23 |
+
sigma = self.possibly_quantize_sigma(sigma)
|
24 |
+
sigma_shape = sigma.shape
|
25 |
+
sigma = append_dims(sigma, input.ndim)
|
26 |
+
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
27 |
+
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
28 |
+
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
|
29 |
+
|
30 |
+
|
31 |
+
class DiscreteDenoiser(Denoiser):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
weighting_config,
|
35 |
+
scaling_config,
|
36 |
+
num_idx,
|
37 |
+
discretization_config,
|
38 |
+
do_append_zero=False,
|
39 |
+
quantize_c_noise=True,
|
40 |
+
flip=True,
|
41 |
+
):
|
42 |
+
super().__init__(weighting_config, scaling_config)
|
43 |
+
sigmas = instantiate_from_config(discretization_config)(
|
44 |
+
num_idx, do_append_zero=do_append_zero, flip=flip
|
45 |
+
)
|
46 |
+
self.register_buffer("sigmas", sigmas)
|
47 |
+
self.quantize_c_noise = quantize_c_noise
|
48 |
+
|
49 |
+
def sigma_to_idx(self, sigma):
|
50 |
+
dists = sigma - self.sigmas[:, None]
|
51 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
52 |
+
|
53 |
+
def idx_to_sigma(self, idx):
|
54 |
+
return self.sigmas[idx]
|
55 |
+
|
56 |
+
def possibly_quantize_sigma(self, sigma):
|
57 |
+
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
58 |
+
|
59 |
+
def possibly_quantize_c_noise(self, c_noise):
|
60 |
+
if self.quantize_c_noise:
|
61 |
+
return self.sigma_to_idx(c_noise)
|
62 |
+
else:
|
63 |
+
return c_noise
|
64 |
+
|
65 |
+
|
66 |
+
class DiscreteDenoiserWithControl(DiscreteDenoiser):
|
67 |
+
def __call__(self, network, input, sigma, cond, control_scale):
|
68 |
+
sigma = self.possibly_quantize_sigma(sigma)
|
69 |
+
sigma_shape = sigma.shape
|
70 |
+
sigma = append_dims(sigma, input.ndim)
|
71 |
+
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
72 |
+
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
73 |
+
return network(input * c_in, c_noise, cond, control_scale) * c_out + input * c_skip
|
sgm/modules/diffusionmodules/denoiser_scaling.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class EDMScaling:
|
5 |
+
def __init__(self, sigma_data=0.5):
|
6 |
+
self.sigma_data = sigma_data
|
7 |
+
|
8 |
+
def __call__(self, sigma):
|
9 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
10 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
11 |
+
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
12 |
+
c_noise = 0.25 * sigma.log()
|
13 |
+
return c_skip, c_out, c_in, c_noise
|
14 |
+
|
15 |
+
|
16 |
+
class EpsScaling:
|
17 |
+
def __call__(self, sigma):
|
18 |
+
c_skip = torch.ones_like(sigma, device=sigma.device)
|
19 |
+
c_out = -sigma
|
20 |
+
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
21 |
+
c_noise = sigma.clone()
|
22 |
+
return c_skip, c_out, c_in, c_noise
|
23 |
+
|
24 |
+
|
25 |
+
class VScaling:
|
26 |
+
def __call__(self, sigma):
|
27 |
+
c_skip = 1.0 / (sigma**2 + 1.0)
|
28 |
+
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
29 |
+
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
30 |
+
c_noise = sigma.clone()
|
31 |
+
return c_skip, c_out, c_in, c_noise
|
sgm/modules/diffusionmodules/denoiser_weighting.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class UnitWeighting:
|
4 |
+
def __call__(self, sigma):
|
5 |
+
return torch.ones_like(sigma, device=sigma.device)
|
6 |
+
|
7 |
+
|
8 |
+
class EDMWeighting:
|
9 |
+
def __init__(self, sigma_data=0.5):
|
10 |
+
self.sigma_data = sigma_data
|
11 |
+
|
12 |
+
def __call__(self, sigma):
|
13 |
+
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
14 |
+
|
15 |
+
|
16 |
+
class VWeighting(EDMWeighting):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__(sigma_data=1.0)
|
19 |
+
|
20 |
+
|
21 |
+
class EpsWeighting:
|
22 |
+
def __call__(self, sigma):
|
23 |
+
return sigma**-2
|
24 |
+
|
sgm/modules/diffusionmodules/discretizer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ...modules.diffusionmodules.util import make_beta_schedule
|
8 |
+
from ...util import append_zero
|
9 |
+
|
10 |
+
|
11 |
+
def generate_roughly_equally_spaced_steps(
|
12 |
+
num_substeps: int, max_step: int
|
13 |
+
) -> np.ndarray:
|
14 |
+
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
|
15 |
+
|
16 |
+
|
17 |
+
class Discretization:
|
18 |
+
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
|
19 |
+
sigmas = self.get_sigmas(n, device=device)
|
20 |
+
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
21 |
+
return sigmas if not flip else torch.flip(sigmas, (0,))
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def get_sigmas(self, n, device):
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
class EDMDiscretization(Discretization):
|
29 |
+
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
30 |
+
self.sigma_min = sigma_min
|
31 |
+
self.sigma_max = sigma_max
|
32 |
+
self.rho = rho
|
33 |
+
|
34 |
+
def get_sigmas(self, n, device="cpu"):
|
35 |
+
ramp = torch.linspace(0, 1, n, device=device)
|
36 |
+
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
37 |
+
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
38 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
39 |
+
return sigmas
|
40 |
+
|
41 |
+
|
42 |
+
class LegacyDDPMDiscretization(Discretization):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
linear_start=0.00085,
|
46 |
+
linear_end=0.0120,
|
47 |
+
num_timesteps=1000,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.num_timesteps = num_timesteps
|
51 |
+
betas = make_beta_schedule(
|
52 |
+
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
53 |
+
)
|
54 |
+
alphas = 1.0 - betas
|
55 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
56 |
+
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
57 |
+
|
58 |
+
def get_sigmas(self, n, device="cpu"):
|
59 |
+
if n < self.num_timesteps:
|
60 |
+
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
61 |
+
alphas_cumprod = self.alphas_cumprod[timesteps]
|
62 |
+
elif n == self.num_timesteps:
|
63 |
+
alphas_cumprod = self.alphas_cumprod
|
64 |
+
else:
|
65 |
+
raise ValueError
|
66 |
+
|
67 |
+
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
68 |
+
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
69 |
+
return torch.flip(sigmas, (0,))
|