Kohaku-Blueleaf commited on
Commit
26d4aa7
1 Parent(s): 17c7547
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hf_token.txt
2
+ hf_download/
3
+ results/
4
+ *.csv
5
+ *.onnx
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ venv*/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ .idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
NOTICE ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NOTICE
2
+
3
+ This repository contains files that have been copied or modified from other sources:
4
+
5
+ 1. diffusers_helper/*.py # expect diffusers_helper/attention.py
6
+ Source: https://github.com/lllyasviel/Paints-UNDO/tree/main/diffusers_helper
7
+ License: Apache-2.0
8
+ Modifications: No
9
+
10
+ 2. app.py
11
+ Source: https://github.com/lllyasviel/Paints-UNDO/tree/main/gradio_app.py
12
+ License: Apache-2.0
13
+ Modifications: Modified UI for sketch gen only. Remove vid gen utils. Add clip concat utils.
14
+
15
+ Please see the individual files for their specific licensing terms.
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Modified from https://github.com/lllyasviel/Paints-UNDO/blob/main/gradio_app.py
3
+ '''
4
+ import functools
5
+
6
+ import spaces
7
+ import gradio as gr
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+
12
+ from PIL import Image
13
+ from diffusers import AutoencoderKL, UNet2DConditionModel
14
+ from diffusers.models.attention_processor import AttnProcessor2_0
15
+ from transformers import CLIPTextModel, CLIPTokenizer
16
+ from imgutils.metrics import lpips_difference
17
+ from imgutils.tagging import get_wd14_tags
18
+
19
+ from diffusers_helper.code_cond import unet_add_coded_conds
20
+ from diffusers_helper.cat_cond import unet_add_concat_conds
21
+ from diffusers_helper.k_diffusion import KDiffusionSampler
22
+ from diffusers_helper.attention import AttnProcessor2_0_xformers, XFORMERS_AVAIL
23
+
24
+ from lineart_models import MangaLineExtraction, LineartAnimeDetector, LineartDetector
25
+
26
+
27
+ def resize_and_center_crop(
28
+ image, target_width, target_height=None, interpolation=cv2.INTER_AREA
29
+ ):
30
+ original_height, original_width = image.shape[:2]
31
+ if target_height is None:
32
+ aspect_ratio = original_width / original_height
33
+ target_pixel_count = target_width * target_width
34
+ target_height = (target_pixel_count / aspect_ratio) ** 0.5
35
+ target_width = target_height * aspect_ratio
36
+ target_height = int(target_height)
37
+ target_width = int(target_width)
38
+ print(
39
+ f"original_height={original_height}, "
40
+ f"original_width={original_width}, "
41
+ f"target_height={target_height}, "
42
+ f"target_width={target_width}"
43
+ )
44
+ k = max(target_height / original_height, target_width / original_width)
45
+ new_width = int(round(original_width * k))
46
+ new_height = int(round(original_height * k))
47
+ resized_image = cv2.resize(
48
+ image, (new_width, new_height), interpolation=interpolation
49
+ )
50
+ x_start = (new_width - target_width) // 2
51
+ y_start = (new_height - target_height) // 2
52
+ cropped_image = resized_image[
53
+ y_start : y_start + target_height, x_start : x_start + target_width
54
+ ]
55
+ return cropped_image
56
+
57
+
58
+ class ModifiedUNet(UNet2DConditionModel):
59
+ @classmethod
60
+ def from_config(cls, *args, **kwargs):
61
+ m = super().from_config(*args, **kwargs)
62
+ unet_add_concat_conds(unet=m, new_channels=4)
63
+ unet_add_coded_conds(unet=m, added_number_count=1)
64
+ return m
65
+
66
+
67
+ DEVICE = "cuda"
68
+ torch._dynamo.config.cache_size_limit = 256
69
+
70
+
71
+ lineart_models = []
72
+
73
+ lineart_model = MangaLineExtraction("cuda", "./hf_download")
74
+ lineart_model.load_model()
75
+ lineart_model.model.to(device=DEVICE).eval()
76
+ lineart_models.append(lineart_model)
77
+
78
+ lineart_model = LineartAnimeDetector()
79
+ lineart_model.model.to(device=DEVICE).eval()
80
+ lineart_models.append(lineart_model)
81
+
82
+ lineart_model = LineartDetector()
83
+ lineart_model.model.to(device=DEVICE).eval()
84
+ lineart_models.append(lineart_model)
85
+
86
+
87
+ model_name = "lllyasviel/paints_undo_single_frame"
88
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
89
+ model_name, subfolder="tokenizer"
90
+ )
91
+ text_encoder: CLIPTextModel = (
92
+ CLIPTextModel.from_pretrained(
93
+ model_name,
94
+ subfolder="text_encoder",
95
+ )
96
+ .to(dtype=torch.float16, device=DEVICE)
97
+ .eval()
98
+ )
99
+ vae: AutoencoderKL = (
100
+ AutoencoderKL.from_pretrained(
101
+ model_name,
102
+ subfolder="vae",
103
+ )
104
+ .to(dtype=torch.bfloat16, device=DEVICE)
105
+ .eval()
106
+ )
107
+ unet: ModifiedUNet = (
108
+ ModifiedUNet.from_pretrained(
109
+ model_name,
110
+ subfolder="unet",
111
+ )
112
+ .to(dtype=torch.float16, device=DEVICE)
113
+ .eval()
114
+ )
115
+
116
+ if XFORMERS_AVAIL:
117
+ unet.set_attn_processor(AttnProcessor2_0_xformers())
118
+ vae.set_attn_processor(AttnProcessor2_0_xformers())
119
+ else:
120
+ unet.set_attn_processor(AttnProcessor2_0())
121
+ vae.set_attn_processor(AttnProcessor2_0())
122
+
123
+ # text_encoder = torch.compile(text_encoder, backend="eager", dynamic=True)
124
+ # vae = torch.compile(vae, backend="eager", dynamic=True)
125
+ # unet = torch.compile(unet, mode="reduce-overhead", dynamic=True)
126
+ # for model in lineart_models:
127
+ # model.model = torch.compile(model.model, backend="eager", dynamic=True)
128
+ k_sampler = KDiffusionSampler(
129
+ unet=unet,
130
+ timesteps=1000,
131
+ linear_start=0.00085,
132
+ linear_end=0.020,
133
+ linear=True,
134
+ )
135
+
136
+
137
+ @spaces.GPU
138
+ @torch.inference_mode()
139
+ def encode_cropped_prompt_77tokens(txt: str):
140
+ cond_ids = tokenizer(
141
+ txt,
142
+ padding="max_length",
143
+ max_length=tokenizer.model_max_length,
144
+ truncation=True,
145
+ return_tensors="pt",
146
+ ).input_ids.to(device=text_encoder.device)
147
+ text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
148
+ return text_cond
149
+
150
+
151
+ @spaces.GPU
152
+ @torch.inference_mode()
153
+ def encode_cropped_prompt(txt: str, max_length=225):
154
+ cond_ids = tokenizer(
155
+ txt,
156
+ padding="max_length",
157
+ max_length=max_length + 2,
158
+ truncation=True,
159
+ return_tensors="pt",
160
+ ).input_ids.to(device=text_encoder.device)
161
+ if max_length + 2 > tokenizer.model_max_length:
162
+ input_ids = cond_ids.squeeze(0)
163
+ id_list = list(range(1, max_length + 2 - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2))
164
+ text_cond_list = []
165
+ for i in id_list:
166
+ ids_chunk = (
167
+ input_ids[0].unsqueeze(0),
168
+ input_ids[i : i + tokenizer.model_max_length - 2],
169
+ input_ids[-1].unsqueeze(0),
170
+ )
171
+ if torch.all(ids_chunk[1] == tokenizer.pad_token_id):
172
+ break
173
+ text_cond = text_encoder(torch.concat(ids_chunk).unsqueeze(0)).last_hidden_state
174
+ if text_cond_list == []:
175
+ text_cond_list.append(text_cond[:, :1])
176
+ text_cond_list.append(text_cond[:, 1:tokenizer.model_max_length - 1])
177
+ text_cond_list.append(text_cond[:, -1:])
178
+ text_cond = torch.concat(text_cond_list, dim=1)
179
+ else:
180
+ text_cond = text_encoder(
181
+ cond_ids, attention_mask=None
182
+ ).last_hidden_state
183
+ return text_cond.flatten(0, 1).unsqueeze(0)
184
+
185
+
186
+ @spaces.GPU
187
+ @torch.inference_mode()
188
+ def pytorch2numpy(imgs):
189
+ results = []
190
+ for x in imgs:
191
+ y = x.movedim(0, -1)
192
+ y = y * 127.5 + 127.5
193
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
194
+ results.append(y)
195
+ return results
196
+
197
+
198
+ @spaces.GPU
199
+ @torch.inference_mode()
200
+ def numpy2pytorch(imgs):
201
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
202
+ h = h.movedim(-1, 1)
203
+ return h
204
+
205
+
206
+ @spaces.GPU
207
+ @torch.inference_mode()
208
+ def interrogator_process(x):
209
+ img = Image.fromarray(x)
210
+ rating, features, chars = get_wd14_tags(img, general_threshold=0.25, no_underline=True)
211
+ result = ""
212
+ for char in chars:
213
+ result += char
214
+ result += ", "
215
+ for feature in features:
216
+ result += feature
217
+ result += ", "
218
+ result += max(rating, key=rating.get)
219
+ return result
220
+
221
+
222
+ @spaces.GPU
223
+ @torch.inference_mode()
224
+ def process(
225
+ input_fg,
226
+ prompt,
227
+ input_undo_steps,
228
+ image_width,
229
+ seed,
230
+ steps,
231
+ n_prompt,
232
+ cfg,
233
+ num_sets,
234
+ progress=gr.Progress(),
235
+ ):
236
+ lineart_fg = input_fg
237
+ linearts = []
238
+ for model in lineart_models:
239
+ linearts.append(model(lineart_fg))
240
+ fg = resize_and_center_crop(input_fg, image_width)
241
+ for i, lineart in enumerate(linearts):
242
+ lineart = resize_and_center_crop(lineart, fg.shape[1], fg.shape[0])
243
+ linearts[i] = lineart
244
+
245
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
246
+ concat_conds = (
247
+ vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
248
+ )
249
+
250
+ conds = encode_cropped_prompt(prompt)
251
+ unconds = encode_cropped_prompt_77tokens(n_prompt)
252
+ print(conds.shape, unconds.shape)
253
+ torch.cuda.empty_cache()
254
+
255
+ fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
256
+ initial_latents = torch.zeros_like(concat_conds)
257
+ concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
258
+ latents = []
259
+ rng = torch.Generator(device=DEVICE).manual_seed(int(seed))
260
+ latents = (
261
+ k_sampler(
262
+ initial_latent=initial_latents,
263
+ strength=1.0,
264
+ num_inference_steps=steps,
265
+ guidance_scale=cfg,
266
+ batch_size=len(input_undo_steps) * num_sets,
267
+ generator=rng,
268
+ prompt_embeds=conds,
269
+ negative_prompt_embeds=unconds,
270
+ cross_attention_kwargs={
271
+ "concat_conds": concat_conds,
272
+ "coded_conds": fs,
273
+ },
274
+ same_noise_in_batch=False,
275
+ progress_tqdm=functools.partial(
276
+ progress.tqdm, desc="Generating Key Frames"
277
+ ),
278
+ ).to(vae.dtype)
279
+ / vae.config.scaling_factor
280
+ )
281
+ torch.cuda.empty_cache()
282
+
283
+ pixels = torch.concat(
284
+ [vae.decode(latent.unsqueeze(0)).sample for latent in latents]
285
+ )
286
+ pixels = pytorch2numpy(pixels)
287
+ pixels_with_lpips = []
288
+ lineart_pils = [Image.fromarray(lineart) for lineart in linearts]
289
+ for pixel in pixels:
290
+ pixel_pil = Image.fromarray(pixel)
291
+ pixels_with_lpips.append(
292
+ (
293
+ sum(
294
+ [
295
+ lpips_difference(lineart_pil, pixel_pil)
296
+ for lineart_pil in lineart_pils
297
+ ]
298
+ ),
299
+ pixel,
300
+ )
301
+ )
302
+ pixels = np.stack(
303
+ [i[1] for i in sorted(pixels_with_lpips, key=lambda x: x[0])], axis=0
304
+ )
305
+ torch.cuda.empty_cache()
306
+
307
+ return pixels, np.stack(linearts)
308
+
309
+
310
+ block = gr.Blocks().queue()
311
+ with block:
312
+ gr.Markdown("# Sketch/Lineart extractor")
313
+
314
+ with gr.Row():
315
+ with gr.Column():
316
+ input_fg = gr.Image(
317
+ sources=["upload"], type="numpy", label="Image", height=384
318
+ )
319
+ with gr.Row():
320
+ with gr.Column(scale=2, variant="compact"):
321
+ prompt = gr.Textbox(label="Output Prompt", interactive=True)
322
+ with gr.Column(scale=1, variant="compact", min_width=160):
323
+ n_prompt = gr.Textbox(
324
+ label="Negative Prompt",
325
+ value="lowres, worst quality, bad anatomy, bad hands, text, extra digit, fewer digits, cropped, low quality, jpeg artifacts, signature, watermark, username",
326
+ )
327
+ with gr.Row():
328
+ input_undo_steps = gr.Dropdown(
329
+ label="Operation Steps",
330
+ value=[850, 875, 900, 925, 950, 975],
331
+ choices=list(range(0, 1000, 25)),
332
+ multiselect=True,
333
+ )
334
+ num_sets = gr.Slider(
335
+ label="Num Sets", minimum=1, maximum=10, value=4, step=1
336
+ )
337
+ with gr.Row():
338
+ seed = gr.Slider(
339
+ label="Seed", minimum=0, maximum=50000, step=1, value=37462
340
+ )
341
+ image_width = gr.Slider(
342
+ label="Target size", minimum=512, maximum=1024, value=768, step=32
343
+ )
344
+ steps = gr.Slider(
345
+ label="Steps", minimum=1, maximum=32, value=16, step=1
346
+ )
347
+ cfg = gr.Slider(
348
+ label="CFG Scale", minimum=1.0, maximum=16, value=5, step=0.05
349
+ )
350
+ key_gen_button = gr.Button(value="Generate Sketch", interactive=False)
351
+
352
+ with gr.Column():
353
+ gr.Markdown("#### Sketch Outputs")
354
+ result_gallery = gr.Gallery(
355
+ height=384, object_fit="contain", label="Sketch Outputs", columns=4
356
+ )
357
+ gr.Markdown("#### Line Art Outputs")
358
+ lineart_result = gr.Gallery(
359
+ height=384,
360
+ object_fit="contain",
361
+ label="LineArt outputs",
362
+ )
363
+
364
+ input_fg.change(
365
+ lambda x: [
366
+ interrogator_process(x) if x is not None else "",
367
+ gr.update(interactive=True),
368
+ ],
369
+ inputs=[input_fg],
370
+ outputs=[prompt, key_gen_button],
371
+ )
372
+
373
+ key_gen_button.click(
374
+ fn=process,
375
+ inputs=[
376
+ input_fg,
377
+ prompt,
378
+ input_undo_steps,
379
+ image_width,
380
+ seed,
381
+ steps,
382
+ n_prompt,
383
+ cfg,
384
+ num_sets,
385
+ ],
386
+ outputs=[result_gallery, lineart_result],
387
+ ).then(
388
+ lambda: gr.update(interactive=True),
389
+ outputs=[key_gen_button],
390
+ )
391
+
392
+ block.queue().launch()
diffusers_helper/attention.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from xformers.ops import memory_efficient_attention
7
+ XFORMERS_AVAIL = True
8
+ except ImportError:
9
+ XFORMERS_AVAIL = False
10
+
11
+
12
+ class AttnProcessor2_0_xformers:
13
+ def __call__(
14
+ self,
15
+ attn,
16
+ hidden_states: torch.Tensor,
17
+ encoder_hidden_states: Optional[torch.Tensor] = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ temb: Optional[torch.Tensor] = None,
20
+ *args,
21
+ **kwargs,
22
+ ) -> torch.Tensor:
23
+ residual = hidden_states
24
+ if attn.spatial_norm is not None:
25
+ hidden_states = attn.spatial_norm(hidden_states, temb)
26
+
27
+ input_ndim = hidden_states.ndim
28
+
29
+ if input_ndim == 4:
30
+ batch_size, channel, height, width = hidden_states.shape
31
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
32
+
33
+ batch_size, sequence_length, _ = (
34
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
35
+ )
36
+
37
+ if attention_mask is not None:
38
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
39
+ # scaled_dot_product_attention expects attention_mask shape to be
40
+ # (batch, heads, source_length, target_length)
41
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ inner_dim = key.shape[-1]
57
+ head_dim = inner_dim // attn.heads
58
+
59
+ query = query.view(batch_size, -1, attn.heads, head_dim)
60
+
61
+ key = key.view(batch_size, -1, attn.heads, head_dim)
62
+ value = value.view(batch_size, -1, attn.heads, head_dim)
63
+
64
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
65
+ # TODO: add support for attn.scale when we move to Torch 2.1
66
+ hidden_states = memory_efficient_attention(
67
+ query, key, value, attention_mask, p=0.0
68
+ )
69
+
70
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
71
+ hidden_states = hidden_states.to(query.dtype)
72
+
73
+ # linear proj
74
+ hidden_states = attn.to_out[0](hidden_states)
75
+ # dropout
76
+ hidden_states = attn.to_out[1](hidden_states)
77
+
78
+ if input_ndim == 4:
79
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
80
+
81
+ if attn.residual_connection:
82
+ hidden_states = hidden_states + residual
83
+
84
+ hidden_states = hidden_states / attn.rescale_output_factor
85
+
86
+ return hidden_states
diffusers_helper/cat_cond.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def unet_add_concat_conds(unet, new_channels=4):
5
+ with torch.no_grad():
6
+ new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
7
+ new_conv_in.weight.zero_()
8
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
9
+ new_conv_in.bias = unet.conv_in.bias
10
+ unet.conv_in = new_conv_in
11
+
12
+ unet_original_forward = unet.forward
13
+
14
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
15
+ cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
16
+ c_concat = cross_attention_kwargs.pop('concat_conds')
17
+ kwargs['cross_attention_kwargs'] = cross_attention_kwargs
18
+
19
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample)
20
+ new_sample = torch.cat([sample, c_concat], dim=1)
21
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
22
+
23
+ unet.forward = hooked_unet_forward
24
+ return
diffusers_helper/code_cond.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
4
+
5
+
6
+ def unet_add_coded_conds(unet, added_number_count=1):
7
+ unet.add_time_proj = Timesteps(256, True, 0)
8
+ unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280)
9
+
10
+ def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs):
11
+ coded_conds = added_cond_kwargs.get("coded_conds")
12
+ batch_size = coded_conds.shape[0]
13
+ time_embeds = unet.add_time_proj(coded_conds.flatten())
14
+ time_embeds = time_embeds.reshape((batch_size, -1))
15
+ time_embeds = time_embeds.to(emb)
16
+ aug_emb = unet.add_embedding(time_embeds)
17
+ return aug_emb
18
+
19
+ unet.get_aug_embed = get_aug_embed
20
+
21
+ unet_original_forward = unet.forward
22
+
23
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
24
+ cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
25
+ coded_conds = cross_attention_kwargs.pop('coded_conds')
26
+ kwargs['cross_attention_kwargs'] = cross_attention_kwargs
27
+
28
+ coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device)
29
+ kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds)
30
+ return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs)
31
+
32
+ unet.forward = hooked_unet_forward
33
+
34
+ return
diffusers_helper/k_diffusion.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ @torch.no_grad()
8
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, progress_tqdm=None):
9
+ """DPM-Solver++(2M)."""
10
+ extra_args = {} if extra_args is None else extra_args
11
+ s_in = x.new_ones([x.shape[0]])
12
+ sigma_fn = lambda t: t.neg().exp()
13
+ t_fn = lambda sigma: sigma.log().neg()
14
+ old_denoised = None
15
+
16
+ bar = tqdm if progress_tqdm is None else progress_tqdm
17
+
18
+ for i in bar(range(len(sigmas) - 1)):
19
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
20
+ if callback is not None:
21
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
22
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
23
+ h = t_next - t
24
+ if old_denoised is None or sigmas[i + 1] == 0:
25
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
26
+ else:
27
+ h_last = t - t_fn(sigmas[i - 1])
28
+ r = h_last / h
29
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
30
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
31
+ old_denoised = denoised
32
+ return x
33
+
34
+
35
+ class KModel:
36
+ def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012, linear=False):
37
+ if linear:
38
+ betas = torch.linspace(linear_start, linear_end, timesteps, dtype=torch.float64)
39
+ else:
40
+ betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
41
+
42
+ alphas = 1. - betas
43
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
44
+
45
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
46
+ self.log_sigmas = self.sigmas.log()
47
+ self.sigma_data = 1.0
48
+ self.unet = unet
49
+ return
50
+
51
+ @property
52
+ def sigma_min(self):
53
+ return self.sigmas[0]
54
+
55
+ @property
56
+ def sigma_max(self):
57
+ return self.sigmas[-1]
58
+
59
+ def timestep(self, sigma):
60
+ log_sigma = sigma.log()
61
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
62
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
63
+
64
+ def get_sigmas_karras(self, n, rho=7.):
65
+ ramp = torch.linspace(0, 1, n)
66
+ min_inv_rho = self.sigma_min ** (1 / rho)
67
+ max_inv_rho = self.sigma_max ** (1 / rho)
68
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
69
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
70
+
71
+ def __call__(self, x, sigma, **extra_args):
72
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
73
+ x_ddim_space = x_ddim_space.to(dtype=self.unet.dtype)
74
+ t = self.timestep(sigma)
75
+ cfg_scale = extra_args['cfg_scale']
76
+ eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
77
+ eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
78
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
79
+ return x - noise_pred * sigma[:, None, None, None]
80
+
81
+
82
+ class KDiffusionSampler:
83
+ def __init__(self, unet, **kwargs):
84
+ self.unet = unet
85
+ self.k_model = KModel(unet=unet, **kwargs)
86
+
87
+ @torch.inference_mode()
88
+ def __call__(
89
+ self,
90
+ initial_latent = None,
91
+ strength = 1.0,
92
+ num_inference_steps = 25,
93
+ guidance_scale = 5.0,
94
+ batch_size = 1,
95
+ generator = None,
96
+ prompt_embeds = None,
97
+ negative_prompt_embeds = None,
98
+ cross_attention_kwargs = None,
99
+ same_noise_in_batch = False,
100
+ progress_tqdm = None,
101
+ ):
102
+
103
+ device = self.unet.device
104
+
105
+ # Sigmas
106
+
107
+ sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps/strength))
108
+ sigmas = sigmas[-(num_inference_steps + 1):].to(device)
109
+
110
+ # Initial latents
111
+
112
+ if same_noise_in_batch:
113
+ noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype).repeat(batch_size, 1, 1, 1)
114
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
115
+ else:
116
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
117
+ noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype)
118
+
119
+ latents = initial_latent + noise * sigmas[0].to(initial_latent)
120
+
121
+ # Batch
122
+
123
+ latents = latents.to(device)
124
+ prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device)
125
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device)
126
+
127
+ # Feeds
128
+
129
+ sampler_kwargs = dict(
130
+ cfg_scale=guidance_scale,
131
+ positive=dict(
132
+ encoder_hidden_states=prompt_embeds,
133
+ cross_attention_kwargs=cross_attention_kwargs
134
+ ),
135
+ negative=dict(
136
+ encoder_hidden_states=negative_prompt_embeds,
137
+ cross_attention_kwargs=cross_attention_kwargs,
138
+ )
139
+ )
140
+
141
+ # Sample
142
+
143
+ results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
144
+
145
+ return results
diffusers_helper/utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import glob
5
+ import torch
6
+ import einops
7
+ import torchvision
8
+
9
+ import safetensors.torch as sf
10
+
11
+
12
+ def write_to_json(data, file_path):
13
+ temp_file_path = file_path + ".tmp"
14
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
15
+ json.dump(data, temp_file, indent=4)
16
+ os.replace(temp_file_path, file_path)
17
+ return
18
+
19
+
20
+ def read_from_json(file_path):
21
+ with open(file_path, 'rt', encoding='utf-8') as file:
22
+ data = json.load(file)
23
+ return data
24
+
25
+
26
+ def get_active_parameters(m):
27
+ return {k:v for k, v in m.named_parameters() if v.requires_grad}
28
+
29
+
30
+ def cast_training_params(m, dtype=torch.float32):
31
+ for param in m.parameters():
32
+ if param.requires_grad:
33
+ param.data = param.to(dtype)
34
+ return
35
+
36
+
37
+ def set_attr_recursive(obj, attr, value):
38
+ attrs = attr.split(".")
39
+ for name in attrs[:-1]:
40
+ obj = getattr(obj, name)
41
+ setattr(obj, attrs[-1], value)
42
+ return
43
+
44
+
45
+ @torch.no_grad()
46
+ def batch_mixture(a, b, probability_a=0.5, mask_a=None):
47
+ assert a.shape == b.shape, "Tensors must have the same shape"
48
+ batch_size = a.size(0)
49
+
50
+ if mask_a is None:
51
+ mask_a = torch.rand(batch_size) < probability_a
52
+
53
+ mask_a = mask_a.to(a.device)
54
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
55
+ result = torch.where(mask_a, a, b)
56
+ return result
57
+
58
+
59
+ @torch.no_grad()
60
+ def zero_module(module):
61
+ for p in module.parameters():
62
+ p.detach().zero_()
63
+ return module
64
+
65
+
66
+ def load_last_state(model, folder='accelerator_output'):
67
+ file_pattern = os.path.join(folder, '**', 'model.safetensors')
68
+ files = glob.glob(file_pattern, recursive=True)
69
+
70
+ if not files:
71
+ print("No model.safetensors files found in the specified folder.")
72
+ return
73
+
74
+ newest_file = max(files, key=os.path.getmtime)
75
+ state_dict = sf.load_file(newest_file)
76
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
77
+
78
+ if missing_keys:
79
+ print("Missing keys:", missing_keys)
80
+ if unexpected_keys:
81
+ print("Unexpected keys:", unexpected_keys)
82
+
83
+ print("Loaded model state from:", newest_file)
84
+ return
85
+
86
+
87
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
88
+ tags = tags_str.split(', ')
89
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
90
+ prompt = ', '.join(tags)
91
+ return prompt
92
+
93
+
94
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
95
+ b, c, t, h, w = x.shape
96
+
97
+ per_row = b
98
+ for p in [6, 5, 4, 3, 2]:
99
+ if b % p == 0:
100
+ per_row = p
101
+ break
102
+
103
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
104
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
105
+ x = x.detach().cpu().to(torch.uint8)
106
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
107
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
108
+ return x
109
+
110
+
111
+ def save_bcthw_as_png(x, output_filename):
112
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
113
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
114
+ x = x.detach().cpu().to(torch.uint8)
115
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
116
+ torchvision.io.write_png(x, output_filename)
117
+ return output_filename
118
+
119
+
120
+ def add_tensors_with_padding(tensor1, tensor2):
121
+ if tensor1.shape == tensor2.shape:
122
+ return tensor1 + tensor2
123
+
124
+ shape1 = tensor1.shape
125
+ shape2 = tensor2.shape
126
+
127
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
128
+
129
+ padded_tensor1 = torch.zeros(new_shape)
130
+ padded_tensor2 = torch.zeros(new_shape)
131
+
132
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
133
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
134
+
135
+ result = padded_tensor1 + padded_tensor2
136
+ return result
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.28.0
2
+ transformers==4.41.1
3
+ gradio==4.31.5
4
+ bitsandbytes==0.43.1
5
+ accelerate==0.30.1
6
+ protobuf==3.20
7
+ opencv-python
8
+ tensorboardX
9
+ safetensors
10
+ pillow
11
+ einops
12
+ torch
13
+ torchvision
14
+ dghs-imgutils
15
+ spaces