Zafaflahfksdf commited on
Commit
da3eeba
1 Parent(s): ca1233a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +142 -0
  2. LICENSE +201 -0
  3. README.md +149 -12
  4. README_DEV.md +61 -0
  5. fast_sam/__init__.py +9 -0
  6. fast_sam/fast_sam_wrapper.py +90 -0
  7. ia_check_versions.py +74 -0
  8. ia_config.py +115 -0
  9. ia_devices.py +10 -0
  10. ia_file_manager.py +71 -0
  11. ia_get_dataset_colormap.py +416 -0
  12. ia_logging.py +14 -0
  13. ia_sam_manager.py +182 -0
  14. ia_threading.py +55 -0
  15. ia_ui_gradio.py +30 -0
  16. ia_ui_items.py +110 -0
  17. iasam_app.py +809 -0
  18. images/inpaint_anything_explanation_image_1.png +0 -0
  19. images/inpaint_anything_ui_image_1.png +0 -0
  20. images/sample_input_image.png +0 -0
  21. images/sample_mask_image.png +0 -0
  22. images/sample_seg_color_image.png +0 -0
  23. inpalib/__init__.py +18 -0
  24. inpalib/masklib.py +106 -0
  25. inpalib/samlib.py +256 -0
  26. javascript/inpaint-anything.js +458 -0
  27. lama_cleaner/__init__.py +19 -0
  28. lama_cleaner/benchmark.py +109 -0
  29. lama_cleaner/const.py +173 -0
  30. lama_cleaner/file_manager/__init__.py +1 -0
  31. lama_cleaner/file_manager/file_manager.py +265 -0
  32. lama_cleaner/file_manager/storage_backends.py +46 -0
  33. lama_cleaner/file_manager/utils.py +67 -0
  34. lama_cleaner/helper.py +292 -0
  35. lama_cleaner/installer.py +12 -0
  36. lama_cleaner/model/__init__.py +0 -0
  37. lama_cleaner/model/base.py +298 -0
  38. lama_cleaner/model/controlnet.py +289 -0
  39. lama_cleaner/model/ddim_sampler.py +193 -0
  40. lama_cleaner/model/fcf.py +1733 -0
  41. lama_cleaner/model/instruct_pix2pix.py +83 -0
  42. lama_cleaner/model/lama.py +51 -0
  43. lama_cleaner/model/ldm.py +333 -0
  44. lama_cleaner/model/manga.py +91 -0
  45. lama_cleaner/model/mat.py +1935 -0
  46. lama_cleaner/model/opencv2.py +28 -0
  47. lama_cleaner/model/paint_by_example.py +79 -0
  48. lama_cleaner/model/pipeline/__init__.py +3 -0
  49. lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py +585 -0
  50. lama_cleaner/model/plms_sampler.py +225 -0
.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pth
2
+ *.pt
3
+ *.pyc
4
+ src/
5
+ outputs/
6
+ models/
7
+ models
8
+ .DS_Store
9
+ ia_config.ini
10
+ .eslintrc
11
+ .eslintrc.json
12
+ pyproject.toml
13
+
14
+ # Byte-compiled / optimized / DLL files
15
+ __pycache__/
16
+ *.py[cod]
17
+ *$py.class
18
+
19
+ # C extensions
20
+ *.so
21
+
22
+ # Distribution / packaging
23
+ .Python
24
+ build/
25
+ develop-eggs/
26
+ dist/
27
+ downloads/
28
+ eggs/
29
+ .eggs/
30
+ lib/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ pip-wheel-metadata/
37
+ share/python-wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # PyInstaller
44
+ # Usually these files are written by a python script from a template
45
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
+ *.manifest
47
+ *.spec
48
+
49
+ # Installer logs
50
+ pip-log.txt
51
+ pip-delete-this-directory.txt
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .nox/
57
+ .coverage
58
+ .coverage.*
59
+ .cache
60
+ nosetests.xml
61
+ coverage.xml
62
+ *.cover
63
+ *.py,cover
64
+ .hypothesis/
65
+ .pytest_cache/
66
+
67
+ # Translations
68
+ *.mo
69
+ *.pot
70
+
71
+ # Django stuff:
72
+ *.log
73
+ local_settings.py
74
+ db.sqlite3
75
+ db.sqlite3-journal
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ target/
89
+
90
+ # Jupyter Notebook
91
+ .ipynb_checkpoints
92
+
93
+ # IPython
94
+ profile_default/
95
+ ipython_config.py
96
+
97
+ # pyenv
98
+ .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108
+ __pypackages__/
109
+
110
+ # Celery stuff
111
+ celerybeat-schedule
112
+ celerybeat.pid
113
+
114
+ # SageMath parsed files
115
+ *.sage.py
116
+
117
+ # Environments
118
+ .env
119
+ .venv
120
+ env/
121
+ venv/
122
+ ENV/
123
+ env.bak/
124
+ venv.bak/
125
+
126
+ # Spyder project settings
127
+ .spyderproject
128
+ .spyproject
129
+
130
+ # Rope project settings
131
+ .ropeproject
132
+
133
+ # mkdocs documentation
134
+ /site
135
+
136
+ # mypy
137
+ .mypy_cache/
138
+ .dmypy.json
139
+ dmypy.json
140
+
141
+ # Pyre type checker
142
+ .pyre/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,149 @@
1
- ---
2
- title: ' '
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.40.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: _
3
+ app_file: iasam_app.py
4
+ sdk: gradio
5
+ sdk_version: 3.50.2
6
+ ---
7
+ # Inpaint Anything (Inpainting with Segment Anything)
8
+
9
+ Inpaint Anything performs stable diffusion inpainting on a browser UI using any mask selected from the output of [Segment Anything](https://github.com/facebookresearch/segment-anything).
10
+
11
+
12
+ Using Segment Anything enables users to specify masks by simply pointing to the desired areas, instead of manually filling them in. This can increase the efficiency and accuracy of the mask creation process, leading to potentially higher-quality inpainting results while saving time and effort.
13
+
14
+ [Extension version for AUTOMATIC1111's Web UI](https://github.com/Uminosachi/sd-webui-inpaint-anything)
15
+
16
+ ![Explanation image](images/inpaint_anything_explanation_image_1.png)
17
+
18
+ ## Installation
19
+
20
+ Please follow these steps to install the software:
21
+
22
+ * Create a new conda environment:
23
+
24
+ ```bash
25
+ conda create -n inpaint python=3.10
26
+ conda activate inpaint
27
+ ```
28
+
29
+ * Clone the software repository:
30
+
31
+ ```bash
32
+ git clone https://github.com/Uminosachi/inpaint-anything.git
33
+ cd inpaint-anything
34
+ ```
35
+
36
+ * For the CUDA environment, install the following packages:
37
+
38
+ ```bash
39
+ pip install -r requirements.txt
40
+ ```
41
+
42
+ * If you are using macOS, please install the package from the following file instead:
43
+
44
+ ```bash
45
+ pip install -r requirements_mac.txt
46
+ ```
47
+
48
+ ## Running the application
49
+
50
+ ```bash
51
+ python iasam_app.py
52
+ ```
53
+
54
+ * Open http://127.0.0.1:7860/ in your browser.
55
+ * Note: If you have a privacy protection extension enabled in your web browser, such as DuckDuckGo, you may not be able to retrieve the mask from your sketch.
56
+
57
+ ### Options
58
+
59
+ * `--save-seg`: Save the segmentation image generated by SAM.
60
+ * `--offline`: Execute inpainting using an offline network.
61
+ * `--sam-cpu`: Perform the Segment Anything operation on CPU.
62
+
63
+ ## Downloading the Model
64
+
65
+ * Launch this application.
66
+ * Click on the `Download model` button, located next to the [Segment Anything Model ID](https://github.com/facebookresearch/segment-anything#model-checkpoints). This includes the [SAM 2](https://github.com/facebookresearch/segment-anything-2), [Segment Anything in High Quality Model ID](https://github.com/SysCV/sam-hq), [Fast Segment Anything](https://github.com/CASIA-IVA-Lab/FastSAM), and [Faster Segment Anything (MobileSAM)](https://github.com/ChaoningZhang/MobileSAM).
67
+ * Please note that the SAM is available in three sizes: Base, Large, and Huge. Remember, larger sizes consume more VRAM.
68
+ * Wait for the download to complete.
69
+ * The downloaded model file will be stored in the `models` directory of this application's repository.
70
+
71
+ ## Usage
72
+
73
+ * Drag and drop your image onto the input image area.
74
+ * Outpainting can be achieved by the `Padding options`, configuring the scale and balance, and then clicking on the `Run Padding` button.
75
+ * The `Anime Style` checkbox enhances segmentation mask detection, particularly in anime style images, at the expense of a slight reduction in mask quality.
76
+ * Click on the `Run Segment Anything` button.
77
+ * Use sketching to point the area you want to inpaint. You can undo and adjust the pen size.
78
+ * Hover over either the SAM image or the mask image and press the `S` key for Fullscreen mode, or the `R` key to Reset zoom.
79
+ * Click on the `Create mask` button. The mask will appear in the selected mask image area.
80
+
81
+ ### Mask Adjustment
82
+
83
+ * `Expand mask region` button: Use this to slightly expand the area of the mask for broader coverage.
84
+ * `Trim mask by sketch` button: Clicking this will exclude the sketched area from the mask.
85
+ * `Add mask by sketch` button: Clicking this will add the sketched area to the mask.
86
+
87
+ ### Inpainting Tab
88
+
89
+ * Enter your desired Prompt and Negative Prompt, then choose the Inpainting Model ID.
90
+ * Click on the `Run Inpainting` button (**Please note that it may take some time to download the model for the first time**).
91
+ * In the Advanced options, you can adjust the Sampler, Sampling Steps, Guidance Scale, and Seed.
92
+ * If you enable the `Mask area Only` option, modifications will be confined to the designated mask area only.
93
+ * Adjust the iteration slider to perform inpainting multiple times with different seeds.
94
+ * The inpainting process is powered by [diffusers](https://github.com/huggingface/diffusers).
95
+
96
+ #### Tips
97
+
98
+ * You can directly drag and drop the inpainted image into the input image field on the Web UI. (useful with Chrome and Edge browsers)
99
+
100
+ #### Model Cache
101
+ * The inpainting model, which is saved in HuggingFace's cache and includes `inpaint` (case-insensitive) in its repo_id, will also be added to the Inpainting Model ID dropdown list.
102
+ * If there's a specific model you'd like to use, you can cache it in advance using the following Python commands:
103
+ ```bash
104
+ python
105
+ ```
106
+ ```python
107
+ from diffusers import StableDiffusionInpaintPipeline
108
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("Uminosachi/dreamshaper_5-inpainting")
109
+ exit()
110
+ ```
111
+ * The model diffusers downloaded is typically stored in your home directory. You can find it at `/home/username/.cache/huggingface/hub` for Linux and MacOS users, or at `C:\Users\username\.cache\huggingface\hub` for Windows users.
112
+ * When executing inpainting, if the following error is output to the console, try deleting the corresponding model from the cache folder mentioned above:
113
+ ```
114
+ An error occurred while trying to fetch model name...
115
+ ```
116
+
117
+ ### Cleaner Tab
118
+
119
+ * Choose the Cleaner Model ID.
120
+ * Click on the `Run Cleaner` button (**Please note that it may take some time to download the model for the first time**).
121
+ * Cleaner process is performed using [Lama Cleaner](https://github.com/Sanster/lama-cleaner).
122
+
123
+ ### Mask only Tab
124
+
125
+ * Gives ability to just save mask without any other processing, so it's then possible to use the mask in other graphic applications.
126
+ * `Get mask as alpha of image` button: Save the mask as RGBA image, with the mask put into the alpha channel of the input image.
127
+ * `Get mask` button: Save the mask as RGB image.
128
+
129
+ ![UI image](images/inpaint_anything_ui_image_1.png)
130
+
131
+ ## Auto-saving images
132
+
133
+ * The inpainted image will be automatically saved in the folder that matches the current date within the `outputs` directory.
134
+
135
+ ## Development
136
+
137
+ With the [Inpaint Anything library](README_DEV.md), you can perform segmentation and create masks using sketches from other applications.
138
+
139
+ ## License
140
+
141
+ The source code is licensed under the [Apache 2.0 license](LICENSE).
142
+
143
+ ## References
144
+
145
+ * Ravi, N., Gabeur, V., Hu, Y.-T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädel, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K. V., Carion, N., Wu, C.-Y., Girshick, R., Dollár, P., & Feichtenhofer, C. (2024). [SAM 2: Segment Anything in Images and Videos](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/). arXiv preprint.
146
+ * Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T., Whitehead, S., Berg, A. C., Lo, W-Y., Dollár, P., & Girshick, R. (2023). [Segment Anything](https://arxiv.org/abs/2304.02643). arXiv:2304.02643.
147
+ * Ke, L., Ye, M., Danelljan, M., Liu, Y., Tai, Y-W., Tang, C-K., & Yu, F. (2023). [Segment Anything in High Quality](https://arxiv.org/abs/2306.01567). arXiv:2306.01567.
148
+ * Zhao, X., Ding, W., An, Y., Du, Y., Yu, T., Li, M., Tang, M., & Wang, J. (2023). [Fast Segment Anything](https://arxiv.org/abs/2306.12156). arXiv:2306.12156 [cs.CV].
149
+ * Zhang, C., Han, D., Qiao, Y., Kim, J. U., Bae, S-H., Lee, S., & Hong, C. S. (2023). [Faster Segment Anything: Towards Lightweight SAM for Mobile Applications](https://arxiv.org/abs/2306.14289). arXiv:2306.14289.
README_DEV.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage of Inpaint Anything Library
2
+
3
+ ## Introduction
4
+
5
+ The `inpalib` from the `inpaint-anything` package lets you segment images and create masks using sketches from other applications.
6
+
7
+ ## Code Breakdown
8
+
9
+ ### Imports and Module Initialization
10
+
11
+ ```python
12
+ import importlib
13
+
14
+ import numpy as np
15
+ from PIL import Image, ImageDraw
16
+
17
+ inpalib = importlib.import_module("inpaint-anything.inpalib")
18
+ ```
19
+
20
+ ### Fetch Model IDs
21
+
22
+ ```python
23
+ available_sam_ids = inpalib.get_available_sam_ids()
24
+
25
+ use_sam_id = "sam_hq_vit_l.pth"
26
+ # assert use_sam_id in available_sam_ids, f"Invalid SAM ID: {use_sam_id}"
27
+ ```
28
+
29
+ Note: Only the models downloaded via the Inpaint Anything are available.
30
+
31
+ ### Generate Segments Image
32
+
33
+ ```python
34
+ input_image = np.array(Image.open("/path/to/image.png"))
35
+
36
+ sam_masks = inpalib.generate_sam_masks(input_image, use_sam_id, anime_style_chk=False)
37
+ sam_masks = inpalib.sort_masks_by_area(sam_masks)
38
+
39
+ seg_color_image = inpalib.create_seg_color_image(input_image, sam_masks)
40
+
41
+ Image.fromarray(seg_color_image).save("/path/to/seg_color_image.png")
42
+ ```
43
+
44
+ <img src="images/sample_input_image.png" alt="drawing" width="256"/> <img src="images/sample_seg_color_image.png" alt="drawing" width="256"/>
45
+
46
+ ### Create Mask from Sketch
47
+
48
+ ```python
49
+ sketch_image = Image.fromarray(np.zeros_like(input_image))
50
+
51
+ draw = ImageDraw.Draw(sketch_image)
52
+ draw.point((input_image.shape[1] // 2, input_image.shape[0] // 2), fill=(255, 255, 255))
53
+
54
+ mask_image = inpalib.create_mask_image(np.array(sketch_image), sam_masks, ignore_black_chk=True)
55
+
56
+ Image.fromarray(mask_image).save("/path/to/mask_image.png")
57
+ ```
58
+
59
+ <img src="images/sample_mask_image.png" alt="drawing" width="256"/>
60
+
61
+ Note: Ensure you adjust the file paths before executing the code.
fast_sam/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .fast_sam_wrapper import FastSAM
2
+ from .fast_sam_wrapper import FastSamAutomaticMaskGenerator
3
+
4
+ fast_sam_model_registry = {
5
+ "FastSAM-x": FastSAM,
6
+ "FastSAM-s": FastSAM,
7
+ }
8
+
9
+ __all__ = ["FastSAM", "FastSamAutomaticMaskGenerator", "fast_sam_model_registry"]
fast_sam/fast_sam_wrapper.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Dict, List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import ultralytics
9
+
10
+ if hasattr(ultralytics, "FastSAM"):
11
+ from ultralytics import FastSAM as YOLO
12
+ else:
13
+ from ultralytics import YOLO
14
+
15
+
16
+ class FastSAM:
17
+ def __init__(
18
+ self,
19
+ checkpoint: str,
20
+ ) -> None:
21
+ self.model_path = checkpoint
22
+ self.model = YOLO(self.model_path)
23
+
24
+ if not hasattr(torch.nn.Upsample, "recompute_scale_factor"):
25
+ torch.nn.Upsample.recompute_scale_factor = None
26
+
27
+ def to(self, device) -> None:
28
+ self.model.to(device)
29
+
30
+ @property
31
+ def device(self) -> Any:
32
+ return self.model.device
33
+
34
+ def __call__(self, source=None, stream=False, **kwargs) -> Any:
35
+ return self.model(source=source, stream=stream, **kwargs)
36
+
37
+
38
+ class FastSamAutomaticMaskGenerator:
39
+ def __init__(
40
+ self,
41
+ model: FastSAM,
42
+ points_per_batch: int = None,
43
+ pred_iou_thresh: float = None,
44
+ stability_score_thresh: float = None,
45
+ ) -> None:
46
+ self.model = model
47
+ self.points_per_batch = points_per_batch
48
+ self.pred_iou_thresh = pred_iou_thresh
49
+ self.stability_score_thresh = stability_score_thresh
50
+ self.conf = 0.25 if stability_score_thresh >= 0.95 else 0.15
51
+
52
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
53
+ height, width = image.shape[:2]
54
+ new_height = math.ceil(height / 32) * 32
55
+ new_width = math.ceil(width / 32) * 32
56
+ resize_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
57
+
58
+ backup_nn_dict = {}
59
+ for key, _ in torch.nn.__dict__.copy().items():
60
+ if not inspect.isclass(torch.nn.__dict__.get(key)) and "Norm" in key:
61
+ backup_nn_dict[key] = torch.nn.__dict__.pop(key)
62
+
63
+ results = self.model(
64
+ source=resize_image,
65
+ stream=False,
66
+ imgsz=max(new_height, new_width),
67
+ device=self.model.device,
68
+ retina_masks=True,
69
+ iou=0.7,
70
+ conf=self.conf,
71
+ max_det=256)
72
+
73
+ for key, value in backup_nn_dict.items():
74
+ setattr(torch.nn, key, value)
75
+ # assert backup_nn_dict[key] == torch.nn.__dict__[key]
76
+
77
+ annotations = results[0].masks.data
78
+
79
+ if isinstance(annotations[0], torch.Tensor):
80
+ annotations = np.array(annotations.cpu())
81
+
82
+ annotations_list = []
83
+ for mask in annotations:
84
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
85
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((7, 7), np.uint8))
86
+ mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_AREA)
87
+
88
+ annotations_list.append(dict(segmentation=mask.astype(bool)))
89
+
90
+ return annotations_list
ia_check_versions.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from importlib.metadata import version
3
+ from importlib.util import find_spec
4
+
5
+ import torch
6
+ from packaging.version import parse
7
+
8
+
9
+ def get_module_version(module_name):
10
+ try:
11
+ module_version = version(module_name)
12
+ except Exception:
13
+ module_version = None
14
+ return module_version
15
+
16
+
17
+ def compare_version(version1, version2):
18
+ if not isinstance(version1, str) or not isinstance(version2, str):
19
+ return None
20
+
21
+ if parse(version1) > parse(version2):
22
+ return 1
23
+ elif parse(version1) < parse(version2):
24
+ return -1
25
+ else:
26
+ return 0
27
+
28
+
29
+ def compare_module_version(module_name, version_string):
30
+ module_version = get_module_version(module_name)
31
+
32
+ result = compare_version(module_version, version_string)
33
+ return result if result is not None else -2
34
+
35
+
36
+ class IACheckVersions:
37
+ @cached_property
38
+ def diffusers_enable_cpu_offload(self):
39
+ if (find_spec("diffusers") is not None and compare_module_version("diffusers", "0.15.0") >= 0 and
40
+ find_spec("accelerate") is not None and compare_module_version("accelerate", "0.17.0") >= 0 and
41
+ torch.cuda.is_available()):
42
+ return True
43
+ else:
44
+ return False
45
+
46
+ @cached_property
47
+ def torch_mps_is_available(self):
48
+ if compare_module_version("torch", "2.0.1") < 0:
49
+ if not getattr(torch, "has_mps", False):
50
+ return False
51
+ try:
52
+ torch.zeros(1).to(torch.device("mps"))
53
+ return True
54
+ except Exception:
55
+ return False
56
+ else:
57
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
58
+
59
+ @cached_property
60
+ def torch_on_amd_rocm(self):
61
+ if find_spec("torch") is not None and "rocm" in version("torch"):
62
+ return True
63
+ else:
64
+ return False
65
+
66
+ @cached_property
67
+ def gradio_version_is_old(self):
68
+ if find_spec("gradio") is not None and compare_module_version("gradio", "3.34.0") <= 0:
69
+ return True
70
+ else:
71
+ return False
72
+
73
+
74
+ ia_check_versions = IACheckVersions()
ia_config.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configparser
2
+ # import json
3
+ import os
4
+ from types import SimpleNamespace
5
+
6
+ from ia_ui_items import get_inp_model_ids, get_sam_model_ids
7
+
8
+
9
+ class IAConfig:
10
+ SECTIONS = SimpleNamespace(
11
+ DEFAULT=configparser.DEFAULTSECT,
12
+ USER="USER",
13
+ )
14
+
15
+ KEYS = SimpleNamespace(
16
+ SAM_MODEL_ID="sam_model_id",
17
+ INP_MODEL_ID="inp_model_id",
18
+ )
19
+
20
+ PATHS = SimpleNamespace(
21
+ INI=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ia_config.ini"),
22
+ )
23
+
24
+ global_args = {}
25
+
26
+ def __init__(self):
27
+ self.ids_dict = {}
28
+ self.ids_dict[IAConfig.KEYS.SAM_MODEL_ID] = {
29
+ "list": get_sam_model_ids(),
30
+ "index": 1,
31
+ }
32
+ self.ids_dict[IAConfig.KEYS.INP_MODEL_ID] = {
33
+ "list": get_inp_model_ids(),
34
+ "index": 0,
35
+ }
36
+
37
+
38
+ ia_config = IAConfig()
39
+
40
+
41
+ def setup_ia_config_ini():
42
+ ia_config_ini = configparser.ConfigParser(defaults={})
43
+ if os.path.isfile(IAConfig.PATHS.INI):
44
+ ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
45
+
46
+ changed = False
47
+ for key, ids_info in ia_config.ids_dict.items():
48
+ if not ia_config_ini.has_option(IAConfig.SECTIONS.DEFAULT, key):
49
+ if len(ids_info["list"]) > ids_info["index"]:
50
+ ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]]
51
+ changed = True
52
+ else:
53
+ if len(ids_info["list"]) > ids_info["index"] and ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] != ids_info["list"][ids_info["index"]]:
54
+ ia_config_ini[IAConfig.SECTIONS.DEFAULT][key] = ids_info["list"][ids_info["index"]]
55
+ changed = True
56
+
57
+ if changed:
58
+ with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f:
59
+ ia_config_ini.write(f)
60
+
61
+
62
+ def get_ia_config(key, section=IAConfig.SECTIONS.DEFAULT):
63
+ setup_ia_config_ini()
64
+
65
+ ia_config_ini = configparser.ConfigParser(defaults={})
66
+ ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
67
+
68
+ if ia_config_ini.has_option(section, key):
69
+ return ia_config_ini[section][key]
70
+
71
+ section = IAConfig.SECTIONS.DEFAULT
72
+ if ia_config_ini.has_option(section, key):
73
+ return ia_config_ini[section][key]
74
+
75
+ return None
76
+
77
+
78
+ def get_ia_config_index(key, section=IAConfig.SECTIONS.DEFAULT):
79
+ value = get_ia_config(key, section)
80
+
81
+ ids_dict = ia_config.ids_dict
82
+ if value is None:
83
+ if key in ids_dict.keys():
84
+ ids_info = ids_dict[key]
85
+ return ids_info["index"]
86
+ else:
87
+ return 0
88
+ else:
89
+ if key in ids_dict.keys():
90
+ ids_info = ids_dict[key]
91
+ return ids_info["list"].index(value) if value in ids_info["list"] else ids_info["index"]
92
+ else:
93
+ return 0
94
+
95
+
96
+ def set_ia_config(key, value, section=IAConfig.SECTIONS.DEFAULT):
97
+ setup_ia_config_ini()
98
+
99
+ ia_config_ini = configparser.ConfigParser(defaults={})
100
+ ia_config_ini.read(IAConfig.PATHS.INI, encoding="utf-8")
101
+
102
+ if ia_config_ini.has_option(section, key) and ia_config_ini[section][key] == value:
103
+ return
104
+
105
+ if section != IAConfig.SECTIONS.DEFAULT and not ia_config_ini.has_section(section):
106
+ ia_config_ini[section] = {}
107
+
108
+ try:
109
+ ia_config_ini[section][key] = value
110
+ except Exception:
111
+ ia_config_ini[section] = {}
112
+ ia_config_ini[section][key] = value
113
+
114
+ with open(IAConfig.PATHS.INI, "w", encoding="utf-8") as f:
115
+ ia_config_ini.write(f)
ia_devices.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class TorchDevices:
5
+ def __init__(self):
6
+ self.cpu = torch.device("cpu")
7
+ self.device = torch.device("cuda") if torch.cuda.is_available() else self.cpu
8
+
9
+
10
+ devices = TorchDevices()
ia_file_manager.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from huggingface_hub import snapshot_download
4
+ from ia_logging import ia_logging
5
+
6
+
7
+ class IAFileManager:
8
+ DOWNLOAD_COMPLETE = "Download complete"
9
+
10
+ def __init__(self) -> None:
11
+ self._ia_outputs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
12
+ "outputs",
13
+ datetime.now().strftime("%Y-%m-%d"))
14
+
15
+ self._ia_models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
16
+
17
+ @property
18
+ def outputs_dir(self) -> str:
19
+ """Get inpaint-anything outputs directory.
20
+
21
+ Returns:
22
+ str: inpaint-anything outputs directory
23
+ """
24
+ if not os.path.isdir(self._ia_outputs_dir):
25
+ os.makedirs(self._ia_outputs_dir, exist_ok=True)
26
+ return self._ia_outputs_dir
27
+
28
+ @property
29
+ def models_dir(self) -> str:
30
+ """Get inpaint-anything models directory.
31
+
32
+ Returns:
33
+ str: inpaint-anything models directory
34
+ """
35
+ if not os.path.isdir(self._ia_models_dir):
36
+ os.makedirs(self._ia_models_dir, exist_ok=True)
37
+ return self._ia_models_dir
38
+
39
+ @property
40
+ def savename_prefix(self) -> str:
41
+ """Get inpaint-anything savename prefix.
42
+
43
+ Returns:
44
+ str: inpaint-anything savename prefix
45
+ """
46
+ return datetime.now().strftime("%Y%m%d-%H%M%S")
47
+
48
+
49
+ ia_file_manager = IAFileManager()
50
+
51
+
52
+ def download_model_from_hf(hf_model_id, local_files_only=False):
53
+ """Download model from HuggingFace Hub.
54
+
55
+ Args:
56
+ sam_model_id (str): HuggingFace model id
57
+ local_files_only (bool, optional): If True, use only local files. Defaults to False.
58
+
59
+ Returns:
60
+ str: download status
61
+ """
62
+ if not local_files_only:
63
+ ia_logging.info(f"Downloading {hf_model_id}")
64
+ try:
65
+ snapshot_download(repo_id=hf_model_id, local_files_only=local_files_only)
66
+ except FileNotFoundError:
67
+ return f"{hf_model_id} not found, please download"
68
+ except Exception as e:
69
+ return str(e)
70
+
71
+ return IAFileManager.DOWNLOAD_COMPLETE
ia_get_dataset_colormap.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lint as: python2, python3
2
+ # Copyright 2018 The TensorFlow Authors All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Visualizes the segmentation results via specified color map.
17
+
18
+ Visualizes the semantic segmentation results by the color map
19
+ defined by the different datasets. Supported colormaps are:
20
+
21
+ * ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/).
22
+
23
+ * Cityscapes dataset (https://www.cityscapes-dataset.com).
24
+
25
+ * Mapillary Vistas (https://research.mapillary.com).
26
+
27
+ * PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/).
28
+ """
29
+
30
+ from __future__ import absolute_import, division, print_function
31
+
32
+ import numpy as np
33
+
34
+ # from six.moves import range
35
+
36
+ # Dataset names.
37
+ _ADE20K = 'ade20k'
38
+ _CITYSCAPES = 'cityscapes'
39
+ _MAPILLARY_VISTAS = 'mapillary_vistas'
40
+ _PASCAL = 'pascal'
41
+
42
+ # Max number of entries in the colormap for each dataset.
43
+ _DATASET_MAX_ENTRIES = {
44
+ _ADE20K: 151,
45
+ _CITYSCAPES: 256,
46
+ _MAPILLARY_VISTAS: 66,
47
+ _PASCAL: 512,
48
+ }
49
+
50
+
51
+ def create_ade20k_label_colormap():
52
+ """Creates a label colormap used in ADE20K segmentation benchmark.
53
+
54
+ Returns:
55
+ A colormap for visualizing segmentation results.
56
+ """
57
+ return np.asarray([
58
+ [0, 0, 0],
59
+ [120, 120, 120],
60
+ [180, 120, 120],
61
+ [6, 230, 230],
62
+ [80, 50, 50],
63
+ [4, 200, 3],
64
+ [120, 120, 80],
65
+ [140, 140, 140],
66
+ [204, 5, 255],
67
+ [230, 230, 230],
68
+ [4, 250, 7],
69
+ [224, 5, 255],
70
+ [235, 255, 7],
71
+ [150, 5, 61],
72
+ [120, 120, 70],
73
+ [8, 255, 51],
74
+ [255, 6, 82],
75
+ [143, 255, 140],
76
+ [204, 255, 4],
77
+ [255, 51, 7],
78
+ [204, 70, 3],
79
+ [0, 102, 200],
80
+ [61, 230, 250],
81
+ [255, 6, 51],
82
+ [11, 102, 255],
83
+ [255, 7, 71],
84
+ [255, 9, 224],
85
+ [9, 7, 230],
86
+ [220, 220, 220],
87
+ [255, 9, 92],
88
+ [112, 9, 255],
89
+ [8, 255, 214],
90
+ [7, 255, 224],
91
+ [255, 184, 6],
92
+ [10, 255, 71],
93
+ [255, 41, 10],
94
+ [7, 255, 255],
95
+ [224, 255, 8],
96
+ [102, 8, 255],
97
+ [255, 61, 6],
98
+ [255, 194, 7],
99
+ [255, 122, 8],
100
+ [0, 255, 20],
101
+ [255, 8, 41],
102
+ [255, 5, 153],
103
+ [6, 51, 255],
104
+ [235, 12, 255],
105
+ [160, 150, 20],
106
+ [0, 163, 255],
107
+ [140, 140, 140],
108
+ [250, 10, 15],
109
+ [20, 255, 0],
110
+ [31, 255, 0],
111
+ [255, 31, 0],
112
+ [255, 224, 0],
113
+ [153, 255, 0],
114
+ [0, 0, 255],
115
+ [255, 71, 0],
116
+ [0, 235, 255],
117
+ [0, 173, 255],
118
+ [31, 0, 255],
119
+ [11, 200, 200],
120
+ [255, 82, 0],
121
+ [0, 255, 245],
122
+ [0, 61, 255],
123
+ [0, 255, 112],
124
+ [0, 255, 133],
125
+ [255, 0, 0],
126
+ [255, 163, 0],
127
+ [255, 102, 0],
128
+ [194, 255, 0],
129
+ [0, 143, 255],
130
+ [51, 255, 0],
131
+ [0, 82, 255],
132
+ [0, 255, 41],
133
+ [0, 255, 173],
134
+ [10, 0, 255],
135
+ [173, 255, 0],
136
+ [0, 255, 153],
137
+ [255, 92, 0],
138
+ [255, 0, 255],
139
+ [255, 0, 245],
140
+ [255, 0, 102],
141
+ [255, 173, 0],
142
+ [255, 0, 20],
143
+ [255, 184, 184],
144
+ [0, 31, 255],
145
+ [0, 255, 61],
146
+ [0, 71, 255],
147
+ [255, 0, 204],
148
+ [0, 255, 194],
149
+ [0, 255, 82],
150
+ [0, 10, 255],
151
+ [0, 112, 255],
152
+ [51, 0, 255],
153
+ [0, 194, 255],
154
+ [0, 122, 255],
155
+ [0, 255, 163],
156
+ [255, 153, 0],
157
+ [0, 255, 10],
158
+ [255, 112, 0],
159
+ [143, 255, 0],
160
+ [82, 0, 255],
161
+ [163, 255, 0],
162
+ [255, 235, 0],
163
+ [8, 184, 170],
164
+ [133, 0, 255],
165
+ [0, 255, 92],
166
+ [184, 0, 255],
167
+ [255, 0, 31],
168
+ [0, 184, 255],
169
+ [0, 214, 255],
170
+ [255, 0, 112],
171
+ [92, 255, 0],
172
+ [0, 224, 255],
173
+ [112, 224, 255],
174
+ [70, 184, 160],
175
+ [163, 0, 255],
176
+ [153, 0, 255],
177
+ [71, 255, 0],
178
+ [255, 0, 163],
179
+ [255, 204, 0],
180
+ [255, 0, 143],
181
+ [0, 255, 235],
182
+ [133, 255, 0],
183
+ [255, 0, 235],
184
+ [245, 0, 255],
185
+ [255, 0, 122],
186
+ [255, 245, 0],
187
+ [10, 190, 212],
188
+ [214, 255, 0],
189
+ [0, 204, 255],
190
+ [20, 0, 255],
191
+ [255, 255, 0],
192
+ [0, 153, 255],
193
+ [0, 41, 255],
194
+ [0, 255, 204],
195
+ [41, 0, 255],
196
+ [41, 255, 0],
197
+ [173, 0, 255],
198
+ [0, 245, 255],
199
+ [71, 0, 255],
200
+ [122, 0, 255],
201
+ [0, 255, 184],
202
+ [0, 92, 255],
203
+ [184, 255, 0],
204
+ [0, 133, 255],
205
+ [255, 214, 0],
206
+ [25, 194, 194],
207
+ [102, 255, 0],
208
+ [92, 0, 255],
209
+ ])
210
+
211
+
212
+ def create_cityscapes_label_colormap():
213
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
214
+
215
+ Returns:
216
+ A colormap for visualizing segmentation results.
217
+ """
218
+ colormap = np.zeros((256, 3), dtype=np.uint8)
219
+ colormap[0] = [128, 64, 128]
220
+ colormap[1] = [244, 35, 232]
221
+ colormap[2] = [70, 70, 70]
222
+ colormap[3] = [102, 102, 156]
223
+ colormap[4] = [190, 153, 153]
224
+ colormap[5] = [153, 153, 153]
225
+ colormap[6] = [250, 170, 30]
226
+ colormap[7] = [220, 220, 0]
227
+ colormap[8] = [107, 142, 35]
228
+ colormap[9] = [152, 251, 152]
229
+ colormap[10] = [70, 130, 180]
230
+ colormap[11] = [220, 20, 60]
231
+ colormap[12] = [255, 0, 0]
232
+ colormap[13] = [0, 0, 142]
233
+ colormap[14] = [0, 0, 70]
234
+ colormap[15] = [0, 60, 100]
235
+ colormap[16] = [0, 80, 100]
236
+ colormap[17] = [0, 0, 230]
237
+ colormap[18] = [119, 11, 32]
238
+ return colormap
239
+
240
+
241
+ def create_mapillary_vistas_label_colormap():
242
+ """Creates a label colormap used in Mapillary Vistas segmentation benchmark.
243
+
244
+ Returns:
245
+ A colormap for visualizing segmentation results.
246
+ """
247
+ return np.asarray([
248
+ [165, 42, 42],
249
+ [0, 192, 0],
250
+ [196, 196, 196],
251
+ [190, 153, 153],
252
+ [180, 165, 180],
253
+ [102, 102, 156],
254
+ [102, 102, 156],
255
+ [128, 64, 255],
256
+ [140, 140, 200],
257
+ [170, 170, 170],
258
+ [250, 170, 160],
259
+ [96, 96, 96],
260
+ [230, 150, 140],
261
+ [128, 64, 128],
262
+ [110, 110, 110],
263
+ [244, 35, 232],
264
+ [150, 100, 100],
265
+ [70, 70, 70],
266
+ [150, 120, 90],
267
+ [220, 20, 60],
268
+ [255, 0, 0],
269
+ [255, 0, 0],
270
+ [255, 0, 0],
271
+ [200, 128, 128],
272
+ [255, 255, 255],
273
+ [64, 170, 64],
274
+ [128, 64, 64],
275
+ [70, 130, 180],
276
+ [255, 255, 255],
277
+ [152, 251, 152],
278
+ [107, 142, 35],
279
+ [0, 170, 30],
280
+ [255, 255, 128],
281
+ [250, 0, 30],
282
+ [0, 0, 0],
283
+ [220, 220, 220],
284
+ [170, 170, 170],
285
+ [222, 40, 40],
286
+ [100, 170, 30],
287
+ [40, 40, 40],
288
+ [33, 33, 33],
289
+ [170, 170, 170],
290
+ [0, 0, 142],
291
+ [170, 170, 170],
292
+ [210, 170, 100],
293
+ [153, 153, 153],
294
+ [128, 128, 128],
295
+ [0, 0, 142],
296
+ [250, 170, 30],
297
+ [192, 192, 192],
298
+ [220, 220, 0],
299
+ [180, 165, 180],
300
+ [119, 11, 32],
301
+ [0, 0, 142],
302
+ [0, 60, 100],
303
+ [0, 0, 142],
304
+ [0, 0, 90],
305
+ [0, 0, 230],
306
+ [0, 80, 100],
307
+ [128, 64, 64],
308
+ [0, 0, 110],
309
+ [0, 0, 70],
310
+ [0, 0, 192],
311
+ [32, 32, 32],
312
+ [0, 0, 0],
313
+ [0, 0, 0],
314
+ ])
315
+
316
+
317
+ def create_pascal_label_colormap():
318
+ """Creates a label colormap used in PASCAL VOC segmentation benchmark.
319
+
320
+ Returns:
321
+ A colormap for visualizing segmentation results.
322
+ """
323
+ colormap = np.zeros((_DATASET_MAX_ENTRIES[_PASCAL], 3), dtype=int)
324
+ ind = np.arange(_DATASET_MAX_ENTRIES[_PASCAL], dtype=int)
325
+
326
+ for shift in reversed(list(range(8))):
327
+ for channel in range(3):
328
+ colormap[:, channel] |= bit_get(ind, channel) << shift
329
+ ind >>= 3
330
+
331
+ return colormap
332
+
333
+
334
+ def get_ade20k_name():
335
+ return _ADE20K
336
+
337
+
338
+ def get_cityscapes_name():
339
+ return _CITYSCAPES
340
+
341
+
342
+ def get_mapillary_vistas_name():
343
+ return _MAPILLARY_VISTAS
344
+
345
+
346
+ def get_pascal_name():
347
+ return _PASCAL
348
+
349
+
350
+ def bit_get(val, idx):
351
+ """Gets the bit value.
352
+
353
+ Args:
354
+ val: Input value, int or numpy int array.
355
+ idx: Which bit of the input val.
356
+
357
+ Returns:
358
+ The "idx"-th bit of input val.
359
+ """
360
+ return (val >> idx) & 1
361
+
362
+
363
+ def create_label_colormap(dataset=_PASCAL):
364
+ """Creates a label colormap for the specified dataset.
365
+
366
+ Args:
367
+ dataset: The colormap used in the dataset.
368
+
369
+ Returns:
370
+ A numpy array of the dataset colormap.
371
+
372
+ Raises:
373
+ ValueError: If the dataset is not supported.
374
+ """
375
+ if dataset == _ADE20K:
376
+ return create_ade20k_label_colormap()
377
+ elif dataset == _CITYSCAPES:
378
+ return create_cityscapes_label_colormap()
379
+ elif dataset == _MAPILLARY_VISTAS:
380
+ return create_mapillary_vistas_label_colormap()
381
+ elif dataset == _PASCAL:
382
+ return create_pascal_label_colormap()
383
+ else:
384
+ raise ValueError('Unsupported dataset.')
385
+
386
+
387
+ def label_to_color_image(label, dataset=_PASCAL):
388
+ """Adds color defined by the dataset colormap to the label.
389
+
390
+ Args:
391
+ label: A 2D array with integer type, storing the segmentation label.
392
+ dataset: The colormap used in the dataset.
393
+
394
+ Returns:
395
+ result: A 2D array with floating type. The element of the array
396
+ is the color indexed by the corresponding element in the input label
397
+ to the dataset color map.
398
+
399
+ Raises:
400
+ ValueError: If label is not of rank 2 or its value is larger than color
401
+ map maximum entry.
402
+ """
403
+ if label.ndim != 2:
404
+ raise ValueError('Expect 2-D input label. Got {}'.format(label.shape))
405
+
406
+ if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]:
407
+ raise ValueError(
408
+ 'label value too large: {} >= {}.'.format(
409
+ np.max(label), _DATASET_MAX_ENTRIES[dataset]))
410
+
411
+ colormap = create_label_colormap(dataset)
412
+ return colormap[label]
413
+
414
+
415
+ def get_dataset_colormap_max_entries(dataset):
416
+ return _DATASET_MAX_ENTRIES[dataset]
ia_logging.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
5
+ warnings.filterwarnings(action="ignore", category=FutureWarning, module="huggingface_hub")
6
+
7
+ ia_logging = logging.getLogger("Inpaint Anything")
8
+ ia_logging.setLevel(logging.INFO)
9
+ ia_logging.propagate = False
10
+
11
+ ia_logging_sh = logging.StreamHandler()
12
+ ia_logging_sh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
13
+ ia_logging_sh.setLevel(logging.INFO)
14
+ ia_logging.addHandler(ia_logging_sh)
ia_sam_manager.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ from functools import partial
4
+
5
+ import torch
6
+
7
+ from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
8
+ from ia_check_versions import ia_check_versions
9
+ from ia_config import IAConfig
10
+ from ia_devices import devices
11
+ from ia_logging import ia_logging
12
+ from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
13
+ from mobile_sam import SamPredictor as SamPredictorMobile
14
+ from mobile_sam import sam_model_registry as sam_model_registry_mobile
15
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
16
+ from sam2.build_sam import build_sam2
17
+ from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
18
+ from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
19
+ from segment_anything_hq import SamPredictor as SamPredictorHQ
20
+ from segment_anything_hq import sam_model_registry as sam_model_registry_hq
21
+
22
+
23
+ def check_bfloat16_support() -> bool:
24
+ if torch.cuda.is_available():
25
+ compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
26
+ if compute_capability[0] >= 8:
27
+ ia_logging.debug("The CUDA device supports bfloat16")
28
+ return True
29
+ else:
30
+ ia_logging.debug("The CUDA device does not support bfloat16")
31
+ return False
32
+ else:
33
+ ia_logging.debug("CUDA is not available")
34
+ return False
35
+
36
+
37
+ def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
38
+ def wrapper(*args, **kwargs):
39
+ updated_kwargs = {**fixed_kwargs, **kwargs}
40
+ return func(*args, *fixed_args, **updated_kwargs)
41
+ return wrapper
42
+
43
+
44
+ def rename_args(func, arg_map):
45
+ def wrapper(*args, **kwargs):
46
+ new_kwargs = {arg_map.get(k, k): v for k, v in kwargs.items()}
47
+ return func(*args, **new_kwargs)
48
+ return wrapper
49
+
50
+
51
+ arg_map = {"checkpoint": "ckpt_path"}
52
+ rename_build_sam2 = rename_args(build_sam2, arg_map)
53
+ end_kwargs = dict(device="cpu", mode="eval", hydra_overrides_extra=[], apply_postprocessing=False)
54
+ sam2_model_registry = {
55
+ "sam2_hiera_large": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_l.yaml"),
56
+ "sam2_hiera_base_plus": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_b+.yaml"),
57
+ "sam2_hiera_small": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_s.yaml"),
58
+ "sam2_hiera_tiny": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_t.yaml"),
59
+ }
60
+
61
+
62
+ def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
63
+ """Get SAM mask generator.
64
+
65
+ Args:
66
+ sam_checkpoint (str): SAM checkpoint path
67
+
68
+ Returns:
69
+ SamAutomaticMaskGenerator or None: SAM mask generator
70
+ """
71
+ points_per_batch = 64
72
+ if "_hq_" in os.path.basename(sam_checkpoint):
73
+ model_type = os.path.basename(sam_checkpoint)[7:12]
74
+ sam_model_registry_local = sam_model_registry_hq
75
+ SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
76
+ points_per_batch = 32
77
+ elif "FastSAM" in os.path.basename(sam_checkpoint):
78
+ model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
79
+ sam_model_registry_local = fast_sam_model_registry
80
+ SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
81
+ points_per_batch = None
82
+ elif "mobile_sam" in os.path.basename(sam_checkpoint):
83
+ model_type = "vit_t"
84
+ sam_model_registry_local = sam_model_registry_mobile
85
+ SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
86
+ points_per_batch = 64
87
+ elif "sam2_" in os.path.basename(sam_checkpoint):
88
+ model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
89
+ sam_model_registry_local = sam2_model_registry
90
+ SamAutomaticMaskGeneratorLocal = SAM2AutomaticMaskGenerator
91
+ points_per_batch = 128
92
+ else:
93
+ model_type = os.path.basename(sam_checkpoint)[4:9]
94
+ sam_model_registry_local = sam_model_registry
95
+ SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
96
+ points_per_batch = 64
97
+
98
+ pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
99
+ stability_score_thresh = 0.95 if not anime_style_chk else 0.9
100
+
101
+ if "sam2_" in model_type:
102
+ pred_iou_thresh = round(pred_iou_thresh - 0.18, 2)
103
+ stability_score_thresh = round(stability_score_thresh - 0.03, 2)
104
+ sam2_gen_kwargs = dict(
105
+ points_per_side=64,
106
+ points_per_batch=points_per_batch,
107
+ pred_iou_thresh=pred_iou_thresh,
108
+ stability_score_thresh=stability_score_thresh,
109
+ stability_score_offset=0.7,
110
+ crop_n_layers=1,
111
+ box_nms_thresh=0.7,
112
+ crop_n_points_downscale_factor=2)
113
+ if platform.system() == "Darwin":
114
+ sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1))
115
+
116
+ if os.path.isfile(sam_checkpoint):
117
+ sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
118
+ if platform.system() == "Darwin":
119
+ if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
120
+ sam.to(device=torch.device("cpu"))
121
+ else:
122
+ sam.to(device=torch.device("mps"))
123
+ else:
124
+ if IAConfig.global_args.get("sam_cpu", False):
125
+ ia_logging.info("SAM is running on CPU... (the option has been selected)")
126
+ sam.to(device=devices.cpu)
127
+ else:
128
+ sam.to(device=devices.device)
129
+ sam_gen_kwargs = dict(
130
+ model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
131
+ if "sam2_" in model_type:
132
+ sam_gen_kwargs.update(sam2_gen_kwargs)
133
+ sam_mask_generator = SamAutomaticMaskGeneratorLocal(**sam_gen_kwargs)
134
+ else:
135
+ sam_mask_generator = None
136
+
137
+ return sam_mask_generator
138
+
139
+
140
+ def get_sam_predictor(sam_checkpoint):
141
+ """Get SAM predictor.
142
+
143
+ Args:
144
+ sam_checkpoint (str): SAM checkpoint path
145
+
146
+ Returns:
147
+ SamPredictor or None: SAM predictor
148
+ """
149
+ # model_type = "vit_h"
150
+ if "_hq_" in os.path.basename(sam_checkpoint):
151
+ model_type = os.path.basename(sam_checkpoint)[7:12]
152
+ sam_model_registry_local = sam_model_registry_hq
153
+ SamPredictorLocal = SamPredictorHQ
154
+ elif "FastSAM" in os.path.basename(sam_checkpoint):
155
+ raise NotImplementedError("FastSAM predictor is not implemented yet.")
156
+ elif "mobile_sam" in os.path.basename(sam_checkpoint):
157
+ model_type = "vit_t"
158
+ sam_model_registry_local = sam_model_registry_mobile
159
+ SamPredictorLocal = SamPredictorMobile
160
+ else:
161
+ model_type = os.path.basename(sam_checkpoint)[4:9]
162
+ sam_model_registry_local = sam_model_registry
163
+ SamPredictorLocal = SamPredictor
164
+
165
+ if os.path.isfile(sam_checkpoint):
166
+ sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
167
+ if platform.system() == "Darwin":
168
+ if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
169
+ sam.to(device=torch.device("cpu"))
170
+ else:
171
+ sam.to(device=torch.device("mps"))
172
+ else:
173
+ if IAConfig.global_args.get("sam_cpu", False):
174
+ ia_logging.info("SAM is running on CPU... (the option has been selected)")
175
+ sam.to(device=devices.cpu)
176
+ else:
177
+ sam.to(device=devices.device)
178
+ sam_predictor = SamPredictorLocal(sam)
179
+ else:
180
+ sam_predictor = None
181
+
182
+ return sam_predictor
ia_threading.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import inspect
3
+ import threading
4
+ from functools import wraps
5
+
6
+ import torch
7
+
8
+ from ia_check_versions import ia_check_versions
9
+
10
+ model_access_sem = threading.Semaphore(1)
11
+
12
+
13
+ def torch_gc():
14
+ if torch.cuda.is_available():
15
+ torch.cuda.empty_cache()
16
+ torch.cuda.ipc_collect()
17
+ if ia_check_versions.torch_mps_is_available:
18
+ if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
19
+ torch.mps.empty_cache()
20
+
21
+
22
+ def clear_cache():
23
+ gc.collect()
24
+ torch_gc()
25
+
26
+
27
+ def post_clear_cache(sem):
28
+ with sem:
29
+ gc.collect()
30
+ torch_gc()
31
+
32
+
33
+ def async_post_clear_cache():
34
+ thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,))
35
+ thread.start()
36
+
37
+
38
+ def clear_cache_decorator(func):
39
+ @wraps(func)
40
+ def yield_wrapper(*args, **kwargs):
41
+ clear_cache()
42
+ yield from func(*args, **kwargs)
43
+ clear_cache()
44
+
45
+ @wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ clear_cache()
48
+ res = func(*args, **kwargs)
49
+ clear_cache()
50
+ return res
51
+
52
+ if inspect.isgeneratorfunction(func):
53
+ return yield_wrapper
54
+ else:
55
+ return wrapper
ia_ui_gradio.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
6
+
7
+
8
+ def webpath(fn):
9
+ web_path = os.path.realpath(fn)
10
+
11
+ return f'file={web_path}?{os.path.getmtime(fn)}'
12
+
13
+
14
+ def javascript_html():
15
+ script_path = os.path.join(os.path.dirname(__file__), "javascript", "inpaint-anything.js")
16
+ head = f'<script type="text/javascript" src="{webpath(script_path)}"></script>\n'
17
+
18
+ return head
19
+
20
+
21
+ def reload_javascript():
22
+ js = javascript_html()
23
+
24
+ def template_response(*args, **kwargs):
25
+ res = GradioTemplateResponseOriginal(*args, **kwargs)
26
+ res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
27
+ res.init_headers()
28
+ return res
29
+
30
+ gr.routes.templates.TemplateResponse = template_response
ia_ui_items.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import scan_cache_dir
2
+
3
+
4
+ def get_sampler_names():
5
+ """Get sampler name list.
6
+
7
+ Returns:
8
+ list: sampler name list
9
+ """
10
+ sampler_names = [
11
+ "DDIM",
12
+ "Euler",
13
+ "Euler a",
14
+ "DPM2 Karras",
15
+ "DPM2 a Karras",
16
+ ]
17
+ return sampler_names
18
+
19
+
20
+ def get_sam_model_ids():
21
+ """Get SAM model ids list.
22
+
23
+ Returns:
24
+ list: SAM model ids list
25
+ """
26
+ sam_model_ids = [
27
+ "sam2_hiera_large.pt",
28
+ "sam2_hiera_base_plus.pt",
29
+ "sam2_hiera_small.pt",
30
+ "sam2_hiera_tiny.pt",
31
+ "sam_vit_h_4b8939.pth",
32
+ "sam_vit_l_0b3195.pth",
33
+ "sam_vit_b_01ec64.pth",
34
+ "sam_hq_vit_h.pth",
35
+ "sam_hq_vit_l.pth",
36
+ "sam_hq_vit_b.pth",
37
+ "FastSAM-x.pt",
38
+ "FastSAM-s.pt",
39
+ "mobile_sam.pt",
40
+ ]
41
+ return sam_model_ids
42
+
43
+
44
+ inp_list_from_cache = None
45
+
46
+
47
+ def get_inp_model_ids():
48
+ """Get inpainting model ids list.
49
+
50
+ Returns:
51
+ list: model ids list
52
+ """
53
+ global inp_list_from_cache
54
+ model_ids = [
55
+ "stabilityai/stable-diffusion-2-inpainting",
56
+ "Uminosachi/dreamshaper_8Inpainting",
57
+ "Uminosachi/deliberate_v3-inpainting",
58
+ "Uminosachi/realisticVisionV51_v51VAE-inpainting",
59
+ "Uminosachi/revAnimated_v121Inp-inpainting",
60
+ "runwayml/stable-diffusion-inpainting",
61
+ ]
62
+ if inp_list_from_cache is not None and isinstance(inp_list_from_cache, list):
63
+ model_ids.extend(inp_list_from_cache)
64
+ return model_ids
65
+ try:
66
+ hf_cache_info = scan_cache_dir()
67
+ inpaint_repos = []
68
+ for repo in hf_cache_info.repos:
69
+ if repo.repo_type == "model" and "inpaint" in repo.repo_id.lower() and repo.repo_id not in model_ids:
70
+ inpaint_repos.append(repo.repo_id)
71
+ inp_list_from_cache = sorted(inpaint_repos, reverse=True, key=lambda x: x.split("/")[-1])
72
+ model_ids.extend(inp_list_from_cache)
73
+ return model_ids
74
+ except Exception:
75
+ return model_ids
76
+
77
+
78
+ def get_cleaner_model_ids():
79
+ """Get cleaner model ids list.
80
+
81
+ Returns:
82
+ list: model ids list
83
+ """
84
+ model_ids = [
85
+ "lama",
86
+ "ldm",
87
+ "zits",
88
+ "mat",
89
+ "fcf",
90
+ "manga",
91
+ ]
92
+ return model_ids
93
+
94
+
95
+ def get_padding_mode_names():
96
+ """Get padding mode name list.
97
+
98
+ Returns:
99
+ list: padding mode name list
100
+ """
101
+ padding_mode_names = [
102
+ "constant",
103
+ "edge",
104
+ "reflect",
105
+ "mean",
106
+ "median",
107
+ "maximum",
108
+ "minimum",
109
+ ]
110
+ return padding_mode_names
iasam_app.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ # import math
3
+ import gc
4
+ import os
5
+ import platform
6
+
7
+ if platform.system() == "Darwin":
8
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
9
+
10
+ if platform.system() == "Windows":
11
+ os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
12
+
13
+ import random
14
+ import traceback
15
+ from importlib.util import find_spec
16
+
17
+ import cv2
18
+ import gradio as gr
19
+ import numpy as np
20
+ import torch
21
+ from diffusers import (DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
22
+ KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler,
23
+ StableDiffusionInpaintPipeline)
24
+ from PIL import Image, ImageFilter
25
+ from PIL.PngImagePlugin import PngInfo
26
+ from torch.hub import download_url_to_file
27
+ from torchvision import transforms
28
+
29
+ import inpalib
30
+ from ia_check_versions import ia_check_versions
31
+ from ia_config import IAConfig, get_ia_config_index, set_ia_config, setup_ia_config_ini
32
+ from ia_devices import devices
33
+ from ia_file_manager import IAFileManager, download_model_from_hf, ia_file_manager
34
+ from ia_logging import ia_logging
35
+ from ia_threading import clear_cache_decorator
36
+ from ia_ui_gradio import reload_javascript
37
+ from ia_ui_items import (get_cleaner_model_ids, get_inp_model_ids, get_padding_mode_names,
38
+ get_sam_model_ids, get_sampler_names)
39
+ from lama_cleaner.model_manager import ModelManager
40
+ from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
41
+
42
+ print("platform:", platform.system())
43
+
44
+ reload_javascript()
45
+
46
+ if find_spec("xformers") is not None:
47
+ xformers_available = True
48
+ else:
49
+ xformers_available = False
50
+
51
+ parser = argparse.ArgumentParser(description="Inpaint Anything")
52
+ parser.add_argument("--save-seg", action="store_true", help="Save the segmentation image generated by SAM.")
53
+ parser.add_argument("--offline", action="store_true", help="Execute inpainting using an offline network.")
54
+ parser.add_argument("--sam-cpu", action="store_true", help="Perform the Segment Anything operation on CPU.")
55
+ args = parser.parse_args()
56
+ IAConfig.global_args.update(args.__dict__)
57
+
58
+
59
+ @clear_cache_decorator
60
+ def download_model(sam_model_id):
61
+ """Download SAM model.
62
+
63
+ Args:
64
+ sam_model_id (str): SAM model id
65
+
66
+ Returns:
67
+ str: download status
68
+ """
69
+ if "_hq_" in sam_model_id:
70
+ url_sam = "https://huggingface.co/Uminosachi/sam-hq/resolve/main/" + sam_model_id
71
+ elif "FastSAM" in sam_model_id:
72
+ url_sam = "https://huggingface.co/Uminosachi/FastSAM/resolve/main/" + sam_model_id
73
+ elif "mobile_sam" in sam_model_id:
74
+ url_sam = "https://huggingface.co/Uminosachi/MobileSAM/resolve/main/" + sam_model_id
75
+ elif "sam2_" in sam_model_id:
76
+ url_sam = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" + sam_model_id
77
+ else:
78
+ url_sam = "https://dl.fbaipublicfiles.com/segment_anything/" + sam_model_id
79
+
80
+ sam_checkpoint = os.path.join(ia_file_manager.models_dir, sam_model_id)
81
+ if not os.path.isfile(sam_checkpoint):
82
+ try:
83
+ download_url_to_file(url_sam, sam_checkpoint)
84
+ except Exception as e:
85
+ ia_logging.error(str(e))
86
+ return str(e)
87
+
88
+ return IAFileManager.DOWNLOAD_COMPLETE
89
+ else:
90
+ return "Model already exists"
91
+
92
+
93
+ sam_dict = dict(sam_masks=None, mask_image=None, cnet=None, orig_image=None, pad_mask=None)
94
+
95
+
96
+ def save_mask_image(mask_image, save_mask_chk=False):
97
+ """Save mask image.
98
+
99
+ Args:
100
+ mask_image (np.ndarray): mask image
101
+ save_mask_chk (bool, optional): If True, save mask image. Defaults to False.
102
+
103
+ Returns:
104
+ None
105
+ """
106
+ if save_mask_chk:
107
+ save_name = "_".join([ia_file_manager.savename_prefix, "created_mask"]) + ".png"
108
+ save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
109
+ Image.fromarray(mask_image).save(save_name)
110
+
111
+
112
+ @clear_cache_decorator
113
+ def input_image_upload(input_image, sam_image, sel_mask):
114
+ global sam_dict
115
+ sam_dict["orig_image"] = input_image
116
+ sam_dict["pad_mask"] = None
117
+
118
+ if (sam_dict["mask_image"] is None or not isinstance(sam_dict["mask_image"], np.ndarray) or
119
+ sam_dict["mask_image"].shape != input_image.shape):
120
+ sam_dict["mask_image"] = np.zeros_like(input_image, dtype=np.uint8)
121
+
122
+ ret_sel_image = cv2.addWeighted(input_image, 0.5, sam_dict["mask_image"], 0.5, 0)
123
+
124
+ if sam_image is None or not isinstance(sam_image, dict) or "image" not in sam_image:
125
+ sam_dict["sam_masks"] = None
126
+ ret_sam_image = np.zeros_like(input_image, dtype=np.uint8)
127
+ elif sam_image["image"].shape == input_image.shape:
128
+ ret_sam_image = gr.update()
129
+ else:
130
+ sam_dict["sam_masks"] = None
131
+ ret_sam_image = gr.update(value=np.zeros_like(input_image, dtype=np.uint8))
132
+
133
+ if sel_mask is None or not isinstance(sel_mask, dict) or "image" not in sel_mask:
134
+ ret_sel_mask = ret_sel_image
135
+ elif sel_mask["image"].shape == ret_sel_image.shape and np.all(sel_mask["image"] == ret_sel_image):
136
+ ret_sel_mask = gr.update()
137
+ else:
138
+ ret_sel_mask = gr.update(value=ret_sel_image)
139
+
140
+ return ret_sam_image, ret_sel_mask, gr.update(interactive=True)
141
+
142
+
143
+ @clear_cache_decorator
144
+ def run_padding(input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode="edge"):
145
+ global sam_dict
146
+ if input_image is None or sam_dict["orig_image"] is None:
147
+ sam_dict["orig_image"] = None
148
+ sam_dict["pad_mask"] = None
149
+ return None, "Input image not found"
150
+
151
+ orig_image = sam_dict["orig_image"]
152
+
153
+ height, width = orig_image.shape[:2]
154
+ pad_width, pad_height = (int(width * pad_scale_width), int(height * pad_scale_height))
155
+ ia_logging.info(f"resize by padding: ({height}, {width}) -> ({pad_height}, {pad_width})")
156
+
157
+ pad_size_w, pad_size_h = (pad_width - width, pad_height - height)
158
+ pad_size_l = int(pad_size_w * pad_lr_barance)
159
+ pad_size_r = pad_size_w - pad_size_l
160
+ pad_size_t = int(pad_size_h * pad_tb_barance)
161
+ pad_size_b = pad_size_h - pad_size_t
162
+
163
+ pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r), (0, 0)]
164
+ if padding_mode == "constant":
165
+ fill_value = 127
166
+ pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode, constant_values=fill_value)
167
+ else:
168
+ pad_image = np.pad(orig_image, pad_width=pad_width, mode=padding_mode)
169
+
170
+ mask_pad_width = [(pad_size_t, pad_size_b), (pad_size_l, pad_size_r)]
171
+ pad_mask = np.zeros((height, width), dtype=np.uint8)
172
+ pad_mask = np.pad(pad_mask, pad_width=mask_pad_width, mode="constant", constant_values=255)
173
+ sam_dict["pad_mask"] = dict(segmentation=pad_mask.astype(bool))
174
+
175
+ return pad_image, "Padding done"
176
+
177
+
178
+ @clear_cache_decorator
179
+ def run_sam(input_image, sam_model_id, sam_image, anime_style_chk=False):
180
+ global sam_dict
181
+ if not inpalib.sam_file_exists(sam_model_id):
182
+ ret_sam_image = None if sam_image is None else gr.update()
183
+ return ret_sam_image, f"{sam_model_id} not found, please download"
184
+
185
+ if input_image is None:
186
+ ret_sam_image = None if sam_image is None else gr.update()
187
+ return ret_sam_image, "Input image not found"
188
+
189
+ set_ia_config(IAConfig.KEYS.SAM_MODEL_ID, sam_model_id, IAConfig.SECTIONS.USER)
190
+
191
+ if sam_dict["sam_masks"] is not None:
192
+ sam_dict["sam_masks"] = None
193
+ gc.collect()
194
+
195
+ ia_logging.info(f"input_image: {input_image.shape} {input_image.dtype}")
196
+
197
+ try:
198
+ sam_masks = inpalib.generate_sam_masks(input_image, sam_model_id, anime_style_chk)
199
+ sam_masks = inpalib.sort_masks_by_area(sam_masks)
200
+ sam_masks = inpalib.insert_mask_to_sam_masks(sam_masks, sam_dict["pad_mask"])
201
+
202
+ seg_image = inpalib.create_seg_color_image(input_image, sam_masks)
203
+
204
+ sam_dict["sam_masks"] = sam_masks
205
+
206
+ except Exception as e:
207
+ print(traceback.format_exc())
208
+ ia_logging.error(str(e))
209
+ ret_sam_image = None if sam_image is None else gr.update()
210
+ return ret_sam_image, "Segment Anything failed"
211
+
212
+ if IAConfig.global_args.get("save_seg", False):
213
+ save_name = "_".join([ia_file_manager.savename_prefix, os.path.splitext(sam_model_id)[0]]) + ".png"
214
+ save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
215
+ Image.fromarray(seg_image).save(save_name)
216
+
217
+ if sam_image is None:
218
+ return seg_image, "Segment Anything complete"
219
+ else:
220
+ if sam_image["image"].shape == seg_image.shape and np.all(sam_image["image"] == seg_image):
221
+ return gr.update(), "Segment Anything complete"
222
+ else:
223
+ return gr.update(value=seg_image), "Segment Anything complete"
224
+
225
+
226
+ @clear_cache_decorator
227
+ def select_mask(input_image, sam_image, invert_chk, ignore_black_chk, sel_mask):
228
+ global sam_dict
229
+ if sam_dict["sam_masks"] is None or sam_image is None:
230
+ ret_sel_mask = None if sel_mask is None else gr.update()
231
+ return ret_sel_mask
232
+ sam_masks = sam_dict["sam_masks"]
233
+
234
+ # image = sam_image["image"]
235
+ mask = sam_image["mask"][:, :, 0:1]
236
+
237
+ try:
238
+ seg_image = inpalib.create_mask_image(mask, sam_masks, ignore_black_chk)
239
+ if invert_chk:
240
+ seg_image = inpalib.invert_mask(seg_image)
241
+
242
+ sam_dict["mask_image"] = seg_image
243
+
244
+ except Exception as e:
245
+ print(traceback.format_exc())
246
+ ia_logging.error(str(e))
247
+ ret_sel_mask = None if sel_mask is None else gr.update()
248
+ return ret_sel_mask
249
+
250
+ if input_image is not None and input_image.shape == seg_image.shape:
251
+ ret_image = cv2.addWeighted(input_image, 0.5, seg_image, 0.5, 0)
252
+ else:
253
+ ret_image = seg_image
254
+
255
+ if sel_mask is None:
256
+ return ret_image
257
+ else:
258
+ if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
259
+ return gr.update()
260
+ else:
261
+ return gr.update(value=ret_image)
262
+
263
+
264
+ @clear_cache_decorator
265
+ def expand_mask(input_image, sel_mask, expand_iteration=1):
266
+ global sam_dict
267
+ if sam_dict["mask_image"] is None or sel_mask is None:
268
+ return None
269
+
270
+ new_sel_mask = sam_dict["mask_image"]
271
+
272
+ expand_iteration = int(np.clip(expand_iteration, 1, 100))
273
+
274
+ new_sel_mask = cv2.dilate(new_sel_mask, np.ones((3, 3), dtype=np.uint8), iterations=expand_iteration)
275
+
276
+ sam_dict["mask_image"] = new_sel_mask
277
+
278
+ if input_image is not None and input_image.shape == new_sel_mask.shape:
279
+ ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
280
+ else:
281
+ ret_image = new_sel_mask
282
+
283
+ if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
284
+ return gr.update()
285
+ else:
286
+ return gr.update(value=ret_image)
287
+
288
+
289
+ @clear_cache_decorator
290
+ def apply_mask(input_image, sel_mask):
291
+ global sam_dict
292
+ if sam_dict["mask_image"] is None or sel_mask is None:
293
+ return None
294
+
295
+ sel_mask_image = sam_dict["mask_image"]
296
+ sel_mask_mask = np.logical_not(sel_mask["mask"][:, :, 0:3].astype(bool)).astype(np.uint8)
297
+ new_sel_mask = sel_mask_image * sel_mask_mask
298
+
299
+ sam_dict["mask_image"] = new_sel_mask
300
+
301
+ if input_image is not None and input_image.shape == new_sel_mask.shape:
302
+ ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
303
+ else:
304
+ ret_image = new_sel_mask
305
+
306
+ if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
307
+ return gr.update()
308
+ else:
309
+ return gr.update(value=ret_image)
310
+
311
+
312
+ @clear_cache_decorator
313
+ def add_mask(input_image, sel_mask):
314
+ global sam_dict
315
+ if sam_dict["mask_image"] is None or sel_mask is None:
316
+ return None
317
+
318
+ sel_mask_image = sam_dict["mask_image"]
319
+ sel_mask_mask = sel_mask["mask"][:, :, 0:3].astype(bool).astype(np.uint8)
320
+ new_sel_mask = sel_mask_image + (sel_mask_mask * np.invert(sel_mask_image, dtype=np.uint8))
321
+
322
+ sam_dict["mask_image"] = new_sel_mask
323
+
324
+ if input_image is not None and input_image.shape == new_sel_mask.shape:
325
+ ret_image = cv2.addWeighted(input_image, 0.5, new_sel_mask, 0.5, 0)
326
+ else:
327
+ ret_image = new_sel_mask
328
+
329
+ if sel_mask["image"].shape == ret_image.shape and np.all(sel_mask["image"] == ret_image):
330
+ return gr.update()
331
+ else:
332
+ return gr.update(value=ret_image)
333
+
334
+
335
+ def auto_resize_to_pil(input_image, mask_image):
336
+ init_image = Image.fromarray(input_image).convert("RGB")
337
+ mask_image = Image.fromarray(mask_image).convert("RGB")
338
+ assert init_image.size == mask_image.size, "The sizes of the image and mask do not match"
339
+ width, height = init_image.size
340
+
341
+ new_height = (height // 8) * 8
342
+ new_width = (width // 8) * 8
343
+ if new_width < width or new_height < height:
344
+ if (new_width / width) < (new_height / height):
345
+ scale = new_height / height
346
+ else:
347
+ scale = new_width / width
348
+ resize_height = int(height*scale+0.5)
349
+ resize_width = int(width*scale+0.5)
350
+ if height != resize_height or width != resize_width:
351
+ ia_logging.info(f"resize: ({height}, {width}) -> ({resize_height}, {resize_width})")
352
+ init_image = transforms.functional.resize(init_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS)
353
+ mask_image = transforms.functional.resize(mask_image, (resize_height, resize_width), transforms.InterpolationMode.LANCZOS)
354
+ if resize_height != new_height or resize_width != new_width:
355
+ ia_logging.info(f"center_crop: ({resize_height}, {resize_width}) -> ({new_height}, {new_width})")
356
+ init_image = transforms.functional.center_crop(init_image, (new_height, new_width))
357
+ mask_image = transforms.functional.center_crop(mask_image, (new_height, new_width))
358
+
359
+ return init_image, mask_image
360
+
361
+
362
+ @clear_cache_decorator
363
+ def run_inpaint(input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, inp_model_id, save_mask_chk, composite_chk,
364
+ sampler_name="DDIM", iteration_count=1):
365
+ global sam_dict
366
+ if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
367
+ ia_logging.error("The image or mask does not exist")
368
+ return
369
+
370
+ mask_image = sam_dict["mask_image"]
371
+ if input_image.shape != mask_image.shape:
372
+ ia_logging.error("The sizes of the image and mask do not match")
373
+ return
374
+
375
+ set_ia_config(IAConfig.KEYS.INP_MODEL_ID, inp_model_id, IAConfig.SECTIONS.USER)
376
+
377
+ save_mask_image(mask_image, save_mask_chk)
378
+
379
+ ia_logging.info(f"Loading model {inp_model_id}")
380
+ config_offline_inpainting = IAConfig.global_args.get("offline", False)
381
+ if config_offline_inpainting:
382
+ ia_logging.info("Run Inpainting on offline network: {}".format(str(config_offline_inpainting)))
383
+ local_files_only = False
384
+ local_file_status = download_model_from_hf(inp_model_id, local_files_only=True)
385
+ if local_file_status != IAFileManager.DOWNLOAD_COMPLETE:
386
+ if config_offline_inpainting:
387
+ ia_logging.warning(local_file_status)
388
+ return
389
+ else:
390
+ local_files_only = True
391
+ ia_logging.info("local_files_only: {}".format(str(local_files_only)))
392
+
393
+ if platform.system() == "Darwin" or devices.device == devices.cpu or ia_check_versions.torch_on_amd_rocm:
394
+ torch_dtype = torch.float32
395
+ else:
396
+ torch_dtype = torch.float16
397
+
398
+ try:
399
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
400
+ inp_model_id, torch_dtype=torch_dtype, local_files_only=local_files_only, use_safetensors=True)
401
+ except Exception as e:
402
+ ia_logging.error(str(e))
403
+ if not config_offline_inpainting:
404
+ try:
405
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
406
+ inp_model_id, torch_dtype=torch_dtype, use_safetensors=True)
407
+ except Exception as e:
408
+ ia_logging.error(str(e))
409
+ try:
410
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
411
+ inp_model_id, torch_dtype=torch_dtype, force_download=True, use_safetensors=True)
412
+ except Exception as e:
413
+ ia_logging.error(str(e))
414
+ return
415
+ else:
416
+ return
417
+ pipe.safety_checker = None
418
+
419
+ ia_logging.info(f"Using sampler {sampler_name}")
420
+ if sampler_name == "DDIM":
421
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
422
+ elif sampler_name == "Euler":
423
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
424
+ elif sampler_name == "Euler a":
425
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
426
+ elif sampler_name == "DPM2 Karras":
427
+ pipe.scheduler = KDPM2DiscreteScheduler.from_config(pipe.scheduler.config)
428
+ elif sampler_name == "DPM2 a Karras":
429
+ pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config)
430
+ else:
431
+ ia_logging.info("Sampler fallback to DDIM")
432
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
433
+
434
+ if platform.system() == "Darwin":
435
+ pipe = pipe.to("mps" if ia_check_versions.torch_mps_is_available else "cpu")
436
+ pipe.enable_attention_slicing()
437
+ torch_generator = torch.Generator(devices.cpu)
438
+ else:
439
+ if ia_check_versions.diffusers_enable_cpu_offload and devices.device != devices.cpu:
440
+ ia_logging.info("Enable model cpu offload")
441
+ pipe.enable_model_cpu_offload()
442
+ else:
443
+ pipe = pipe.to(devices.device)
444
+ if xformers_available:
445
+ ia_logging.info("Enable xformers memory efficient attention")
446
+ pipe.enable_xformers_memory_efficient_attention()
447
+ else:
448
+ ia_logging.info("Enable attention slicing")
449
+ pipe.enable_attention_slicing()
450
+ if "privateuseone" in str(getattr(devices.device, "type", "")):
451
+ torch_generator = torch.Generator(devices.cpu)
452
+ else:
453
+ torch_generator = torch.Generator(devices.device)
454
+
455
+ init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
456
+ width, height = init_image.size
457
+
458
+ output_list = []
459
+ iteration_count = iteration_count if iteration_count is not None else 1
460
+ for count in range(int(iteration_count)):
461
+ gc.collect()
462
+ if seed < 0 or count > 0:
463
+ seed = random.randint(0, 2147483647)
464
+
465
+ generator = torch_generator.manual_seed(seed)
466
+
467
+ pipe_args_dict = {
468
+ "prompt": prompt,
469
+ "image": init_image,
470
+ "width": width,
471
+ "height": height,
472
+ "mask_image": mask_image,
473
+ "num_inference_steps": ddim_steps,
474
+ "guidance_scale": cfg_scale,
475
+ "negative_prompt": n_prompt,
476
+ "generator": generator,
477
+ }
478
+
479
+ output_image = pipe(**pipe_args_dict).images[0]
480
+
481
+ if composite_chk:
482
+ dilate_mask_image = Image.fromarray(cv2.dilate(np.array(mask_image), np.ones((3, 3), dtype=np.uint8), iterations=4))
483
+ output_image = Image.composite(output_image, init_image, dilate_mask_image.convert("L").filter(ImageFilter.GaussianBlur(3)))
484
+
485
+ generation_params = {
486
+ "Steps": ddim_steps,
487
+ "Sampler": sampler_name,
488
+ "CFG scale": cfg_scale,
489
+ "Seed": seed,
490
+ "Size": f"{width}x{height}",
491
+ "Model": inp_model_id,
492
+ }
493
+
494
+ generation_params_text = ", ".join([k if k == v else f"{k}: {v}" for k, v in generation_params.items() if v is not None])
495
+ prompt_text = prompt if prompt else ""
496
+ negative_prompt_text = "\nNegative prompt: " + n_prompt if n_prompt else ""
497
+ infotext = f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
498
+
499
+ metadata = PngInfo()
500
+ metadata.add_text("parameters", infotext)
501
+
502
+ save_name = "_".join([ia_file_manager.savename_prefix, os.path.basename(inp_model_id), str(seed)]) + ".png"
503
+ save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
504
+ output_image.save(save_name, pnginfo=metadata)
505
+
506
+ output_list.append(output_image)
507
+
508
+ yield output_list, max([1, iteration_count - (count + 1)])
509
+
510
+
511
+ @clear_cache_decorator
512
+ def run_cleaner(input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk):
513
+ global sam_dict
514
+ if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
515
+ ia_logging.error("The image or mask does not exist")
516
+ return None
517
+
518
+ mask_image = sam_dict["mask_image"]
519
+ if input_image.shape != mask_image.shape:
520
+ ia_logging.error("The sizes of the image and mask do not match")
521
+ return None
522
+
523
+ save_mask_image(mask_image, cleaner_save_mask_chk)
524
+
525
+ ia_logging.info(f"Loading model {cleaner_model_id}")
526
+ if platform.system() == "Darwin":
527
+ model = ModelManager(name=cleaner_model_id, device=devices.cpu)
528
+ else:
529
+ model = ModelManager(name=cleaner_model_id, device=devices.device)
530
+
531
+ init_image, mask_image = auto_resize_to_pil(input_image, mask_image)
532
+ width, height = init_image.size
533
+
534
+ init_image = np.array(init_image)
535
+ mask_image = np.array(mask_image.convert("L"))
536
+
537
+ config = Config(
538
+ ldm_steps=20,
539
+ ldm_sampler=LDMSampler.ddim,
540
+ hd_strategy=HDStrategy.ORIGINAL,
541
+ hd_strategy_crop_margin=32,
542
+ hd_strategy_crop_trigger_size=512,
543
+ hd_strategy_resize_limit=512,
544
+ prompt="",
545
+ sd_steps=20,
546
+ sd_sampler=SDSampler.ddim
547
+ )
548
+
549
+ output_image = model(image=init_image, mask=mask_image, config=config)
550
+ output_image = cv2.cvtColor(output_image.astype(np.uint8), cv2.COLOR_BGR2RGB)
551
+ output_image = Image.fromarray(output_image)
552
+
553
+ save_name = "_".join([ia_file_manager.savename_prefix, os.path.basename(cleaner_model_id)]) + ".png"
554
+ save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
555
+ output_image.save(save_name)
556
+
557
+ del model
558
+ return [output_image]
559
+
560
+
561
+ @clear_cache_decorator
562
+ def run_get_alpha_image(input_image, sel_mask):
563
+ global sam_dict
564
+ if input_image is None or sam_dict["mask_image"] is None or sel_mask is None:
565
+ ia_logging.error("The image or mask does not exist")
566
+ return None, ""
567
+
568
+ mask_image = sam_dict["mask_image"]
569
+ if input_image.shape != mask_image.shape:
570
+ ia_logging.error("The sizes of the image and mask do not match")
571
+ return None, ""
572
+
573
+ alpha_image = Image.fromarray(input_image).convert("RGBA")
574
+ mask_image = Image.fromarray(mask_image).convert("L")
575
+
576
+ alpha_image.putalpha(mask_image)
577
+
578
+ save_name = "_".join([ia_file_manager.savename_prefix, "rgba_image"]) + ".png"
579
+ save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
580
+ alpha_image.save(save_name)
581
+
582
+ return alpha_image, f"saved: {save_name}"
583
+
584
+
585
+ @clear_cache_decorator
586
+ def run_get_mask(sel_mask):
587
+ global sam_dict
588
+ if sam_dict["mask_image"] is None or sel_mask is None:
589
+ return None
590
+
591
+ mask_image = sam_dict["mask_image"]
592
+
593
+ save_name = "_".join([ia_file_manager.savename_prefix, "created_mask"]) + ".png"
594
+ save_name = os.path.join(ia_file_manager.outputs_dir, save_name)
595
+ Image.fromarray(mask_image).save(save_name)
596
+
597
+ return mask_image
598
+
599
+
600
+ def on_ui_tabs():
601
+ setup_ia_config_ini()
602
+ sampler_names = get_sampler_names()
603
+ sam_model_ids = get_sam_model_ids()
604
+ sam_model_index = get_ia_config_index(IAConfig.KEYS.SAM_MODEL_ID, IAConfig.SECTIONS.USER)
605
+ inp_model_ids = get_inp_model_ids()
606
+ inp_model_index = get_ia_config_index(IAConfig.KEYS.INP_MODEL_ID, IAConfig.SECTIONS.USER)
607
+ cleaner_model_ids = get_cleaner_model_ids()
608
+ padding_mode_names = get_padding_mode_names()
609
+
610
+ out_gallery_kwargs = dict(columns=2, height=520, object_fit="contain", preview=True)
611
+
612
+ block = gr.Blocks(analytics_enabled=False).queue()
613
+ block.title = "Inpaint Anything"
614
+ with block as inpaint_anything_interface:
615
+ with gr.Row():
616
+ gr.Markdown("## Inpainting with Segment Anything")
617
+ with gr.Row():
618
+ with gr.Column():
619
+ with gr.Row():
620
+ with gr.Column():
621
+ sam_model_id = gr.Dropdown(label="Segment Anything Model ID", elem_id="sam_model_id", choices=sam_model_ids,
622
+ value=sam_model_ids[sam_model_index], show_label=True)
623
+ with gr.Column():
624
+ with gr.Row():
625
+ load_model_btn = gr.Button("Download model", elem_id="load_model_btn")
626
+ with gr.Row():
627
+ status_text = gr.Textbox(label="", elem_id="status_text", max_lines=1, show_label=False, interactive=False)
628
+ with gr.Row():
629
+ input_image = gr.Image(label="Input image", elem_id="ia_input_image", source="upload", type="numpy", interactive=True)
630
+
631
+ with gr.Row():
632
+ with gr.Accordion("Padding options", elem_id="padding_options", open=False):
633
+ with gr.Row():
634
+ with gr.Column():
635
+ pad_scale_width = gr.Slider(label="Scale Width", elem_id="pad_scale_width", minimum=1.0, maximum=1.5, value=1.0, step=0.01)
636
+ with gr.Column():
637
+ pad_lr_barance = gr.Slider(label="Left/Right Balance", elem_id="pad_lr_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
638
+ with gr.Row():
639
+ with gr.Column():
640
+ pad_scale_height = gr.Slider(label="Scale Height", elem_id="pad_scale_height", minimum=1.0, maximum=1.5, value=1.0, step=0.01)
641
+ with gr.Column():
642
+ pad_tb_barance = gr.Slider(label="Top/Bottom Balance", elem_id="pad_tb_barance", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
643
+ with gr.Row():
644
+ with gr.Column():
645
+ padding_mode = gr.Dropdown(label="Padding Mode", elem_id="padding_mode", choices=padding_mode_names, value="edge")
646
+ with gr.Column():
647
+ padding_btn = gr.Button("Run Padding", elem_id="padding_btn")
648
+
649
+ with gr.Row():
650
+ with gr.Column():
651
+ anime_style_chk = gr.Checkbox(label="Anime Style (Up Detection, Down mask Quality)", elem_id="anime_style_chk",
652
+ show_label=True, interactive=True)
653
+ with gr.Column():
654
+ sam_btn = gr.Button("Run Segment Anything", elem_id="sam_btn", variant="primary", interactive=False)
655
+
656
+ with gr.Tab("Inpainting", elem_id="inpainting_tab"):
657
+ prompt = gr.Textbox(label="Inpainting Prompt", elem_id="sd_prompt")
658
+ n_prompt = gr.Textbox(label="Negative Prompt", elem_id="sd_n_prompt")
659
+ with gr.Accordion("Advanced options", elem_id="inp_advanced_options", open=False):
660
+ composite_chk = gr.Checkbox(label="Mask area Only", elem_id="composite_chk", value=True, show_label=True, interactive=True)
661
+ with gr.Row():
662
+ with gr.Column():
663
+ sampler_name = gr.Dropdown(label="Sampler", elem_id="sampler_name", choices=sampler_names,
664
+ value=sampler_names[0], show_label=True)
665
+ with gr.Column():
666
+ ddim_steps = gr.Slider(label="Sampling Steps", elem_id="ddim_steps", minimum=1, maximum=100, value=20, step=1)
667
+ cfg_scale = gr.Slider(label="Guidance Scale", elem_id="cfg_scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
668
+ seed = gr.Slider(
669
+ label="Seed",
670
+ elem_id="sd_seed",
671
+ minimum=-1,
672
+ maximum=2147483647,
673
+ step=1,
674
+ value=-1,
675
+ )
676
+ with gr.Row():
677
+ with gr.Column():
678
+ inp_model_id = gr.Dropdown(label="Inpainting Model ID", elem_id="inp_model_id",
679
+ choices=inp_model_ids, value=inp_model_ids[inp_model_index], show_label=True)
680
+ with gr.Column():
681
+ with gr.Row():
682
+ inpaint_btn = gr.Button("Run Inpainting", elem_id="inpaint_btn", variant="primary")
683
+ with gr.Row():
684
+ save_mask_chk = gr.Checkbox(label="Save mask", elem_id="save_mask_chk",
685
+ value=False, show_label=False, interactive=False, visible=False)
686
+ iteration_count = gr.Slider(label="Iterations", elem_id="iteration_count", minimum=1, maximum=10, value=1, step=1)
687
+
688
+ with gr.Row():
689
+ if ia_check_versions.gradio_version_is_old:
690
+ out_image = gr.Gallery(label="Inpainted image", elem_id="ia_out_image", show_label=False
691
+ ).style(**out_gallery_kwargs)
692
+ else:
693
+ out_image = gr.Gallery(label="Inpainted image", elem_id="ia_out_image", show_label=False,
694
+ **out_gallery_kwargs)
695
+
696
+ with gr.Tab("Cleaner", elem_id="cleaner_tab"):
697
+ with gr.Row():
698
+ with gr.Column():
699
+ cleaner_model_id = gr.Dropdown(label="Cleaner Model ID", elem_id="cleaner_model_id",
700
+ choices=cleaner_model_ids, value=cleaner_model_ids[0], show_label=True)
701
+ with gr.Column():
702
+ with gr.Row():
703
+ cleaner_btn = gr.Button("Run Cleaner", elem_id="cleaner_btn", variant="primary")
704
+ with gr.Row():
705
+ cleaner_save_mask_chk = gr.Checkbox(label="Save mask", elem_id="cleaner_save_mask_chk",
706
+ value=False, show_label=False, interactive=False, visible=False)
707
+
708
+ with gr.Row():
709
+ if ia_check_versions.gradio_version_is_old:
710
+ cleaner_out_image = gr.Gallery(label="Cleaned image", elem_id="ia_cleaner_out_image", show_label=False
711
+ ).style(**out_gallery_kwargs)
712
+ else:
713
+ cleaner_out_image = gr.Gallery(label="Cleaned image", elem_id="ia_cleaner_out_image", show_label=False,
714
+ **out_gallery_kwargs)
715
+
716
+ with gr.Tab("Mask only", elem_id="mask_only_tab"):
717
+ with gr.Row():
718
+ with gr.Column():
719
+ get_alpha_image_btn = gr.Button("Get mask as alpha of image", elem_id="get_alpha_image_btn")
720
+ with gr.Column():
721
+ get_mask_btn = gr.Button("Get mask", elem_id="get_mask_btn")
722
+
723
+ with gr.Row():
724
+ with gr.Column():
725
+ alpha_out_image = gr.Image(label="Alpha channel image", elem_id="alpha_out_image", type="pil", image_mode="RGBA", interactive=False)
726
+ with gr.Column():
727
+ mask_out_image = gr.Image(label="Mask image", elem_id="mask_out_image", type="numpy", interactive=False)
728
+
729
+ with gr.Row():
730
+ with gr.Column():
731
+ get_alpha_status_text = gr.Textbox(label="", elem_id="get_alpha_status_text", max_lines=1, show_label=False, interactive=False)
732
+ with gr.Column():
733
+ gr.Markdown("")
734
+
735
+ with gr.Column():
736
+ with gr.Row():
737
+ gr.Markdown("Mouse over image: Press `S` key for Fullscreen mode, `R` key to Reset zoom")
738
+ with gr.Row():
739
+ if ia_check_versions.gradio_version_is_old:
740
+ sam_image = gr.Image(label="Segment Anything image", elem_id="ia_sam_image", type="numpy", tool="sketch", brush_radius=8,
741
+ show_label=False, interactive=True).style(height=480)
742
+ else:
743
+ sam_image = gr.Image(label="Segment Anything image", elem_id="ia_sam_image", type="numpy", tool="sketch", brush_radius=8,
744
+ show_label=False, interactive=True, height=480)
745
+
746
+ with gr.Row():
747
+ with gr.Column():
748
+ select_btn = gr.Button("Create Mask", elem_id="select_btn", variant="primary")
749
+ with gr.Column():
750
+ with gr.Row():
751
+ invert_chk = gr.Checkbox(label="Invert mask", elem_id="invert_chk", show_label=True, interactive=True)
752
+ ignore_black_chk = gr.Checkbox(label="Ignore black area", elem_id="ignore_black_chk", value=True, show_label=True, interactive=True)
753
+
754
+ with gr.Row():
755
+ if ia_check_versions.gradio_version_is_old:
756
+ sel_mask = gr.Image(label="Selected mask image", elem_id="ia_sel_mask", type="numpy", tool="sketch", brush_radius=12,
757
+ show_label=False, interactive=True).style(height=480)
758
+ else:
759
+ sel_mask = gr.Image(label="Selected mask image", elem_id="ia_sel_mask", type="numpy", tool="sketch", brush_radius=12,
760
+ show_label=False, interactive=True, height=480)
761
+
762
+ with gr.Row():
763
+ with gr.Column():
764
+ expand_mask_btn = gr.Button("Expand mask region", elem_id="expand_mask_btn")
765
+ expand_mask_iteration_count = gr.Slider(label="Expand Mask Iterations",
766
+ elem_id="expand_mask_iteration_count", minimum=1, maximum=100, value=1, step=1)
767
+ with gr.Column():
768
+ apply_mask_btn = gr.Button("Trim mask by sketch", elem_id="apply_mask_btn")
769
+ add_mask_btn = gr.Button("Add mask by sketch", elem_id="add_mask_btn")
770
+
771
+ load_model_btn.click(download_model, inputs=[sam_model_id], outputs=[status_text])
772
+ input_image.upload(input_image_upload, inputs=[input_image, sam_image, sel_mask], outputs=[sam_image, sel_mask, sam_btn]).then(
773
+ fn=None, inputs=None, outputs=None, _js="inpaintAnything_initSamSelMask")
774
+ padding_btn.click(run_padding, inputs=[input_image, pad_scale_width, pad_scale_height, pad_lr_barance, pad_tb_barance, padding_mode],
775
+ outputs=[input_image, status_text])
776
+ sam_btn.click(run_sam, inputs=[input_image, sam_model_id, sam_image, anime_style_chk], outputs=[sam_image, status_text]).then(
777
+ fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSamMask")
778
+ select_btn.click(select_mask, inputs=[input_image, sam_image, invert_chk, ignore_black_chk, sel_mask], outputs=[sel_mask]).then(
779
+ fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
780
+ expand_mask_btn.click(expand_mask, inputs=[input_image, sel_mask, expand_mask_iteration_count], outputs=[sel_mask]).then(
781
+ fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
782
+ apply_mask_btn.click(apply_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then(
783
+ fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
784
+ add_mask_btn.click(add_mask, inputs=[input_image, sel_mask], outputs=[sel_mask]).then(
785
+ fn=None, inputs=None, outputs=None, _js="inpaintAnything_clearSelMask")
786
+
787
+ inpaint_btn.click(
788
+ run_inpaint,
789
+ inputs=[input_image, sel_mask, prompt, n_prompt, ddim_steps, cfg_scale, seed, inp_model_id, save_mask_chk, composite_chk,
790
+ sampler_name, iteration_count],
791
+ outputs=[out_image, iteration_count])
792
+ cleaner_btn.click(
793
+ run_cleaner,
794
+ inputs=[input_image, sel_mask, cleaner_model_id, cleaner_save_mask_chk],
795
+ outputs=[cleaner_out_image])
796
+ get_alpha_image_btn.click(
797
+ run_get_alpha_image,
798
+ inputs=[input_image, sel_mask],
799
+ outputs=[alpha_out_image, get_alpha_status_text])
800
+ get_mask_btn.click(
801
+ run_get_mask,
802
+ inputs=[sel_mask],
803
+ outputs=[mask_out_image])
804
+
805
+ return [(inpaint_anything_interface, "Inpaint Anything", "inpaint_anything")]
806
+
807
+
808
+ block, _, _ = on_ui_tabs()[0]
809
+ block.launch(share=True)
images/inpaint_anything_explanation_image_1.png ADDED
images/inpaint_anything_ui_image_1.png ADDED
images/sample_input_image.png ADDED
images/sample_mask_image.png ADDED
images/sample_seg_color_image.png ADDED
inpalib/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .masklib import create_mask_image, invert_mask
2
+ from .samlib import (create_seg_color_image, generate_sam_masks, get_all_sam_ids,
3
+ get_available_sam_ids, get_seg_colormap, insert_mask_to_sam_masks,
4
+ sam_file_exists, sam_file_path, sort_masks_by_area)
5
+
6
+ __all__ = [
7
+ "create_mask_image",
8
+ "invert_mask",
9
+ "create_seg_color_image",
10
+ "generate_sam_masks",
11
+ "get_all_sam_ids",
12
+ "get_available_sam_ids",
13
+ "get_seg_colormap",
14
+ "insert_mask_to_sam_masks",
15
+ "sam_file_exists",
16
+ "sam_file_path",
17
+ "sort_masks_by_area",
18
+ ]
inpalib/masklib.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+
7
+ def invert_mask(mask: np.ndarray) -> np.ndarray:
8
+ """Invert mask.
9
+
10
+ Args:
11
+ mask (np.ndarray): mask
12
+
13
+ Returns:
14
+ np.ndarray: inverted mask
15
+ """
16
+ if mask is None or not isinstance(mask, np.ndarray):
17
+ raise ValueError("Invalid mask")
18
+
19
+ # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255
20
+ return np.invert(mask.astype(np.uint8))
21
+
22
+
23
+ def check_inputs_create_mask_image(
24
+ mask: Union[np.ndarray, Image.Image],
25
+ sam_masks: List[Dict[str, Any]],
26
+ ignore_black_chk: bool = True,
27
+ ) -> None:
28
+ """Check create mask image inputs.
29
+
30
+ Args:
31
+ mask (Union[np.ndarray, Image.Image]): mask
32
+ sam_masks (List[Dict[str, Any]]): SAM masks
33
+ ignore_black_chk (bool): ignore black check
34
+
35
+ Returns:
36
+ None
37
+ """
38
+ if mask is None or not isinstance(mask, (np.ndarray, Image.Image)):
39
+ raise ValueError("Invalid mask")
40
+
41
+ if sam_masks is None or not isinstance(sam_masks, list):
42
+ raise ValueError("Invalid SAM masks")
43
+
44
+ if ignore_black_chk is None or not isinstance(ignore_black_chk, bool):
45
+ raise ValueError("Invalid ignore black check")
46
+
47
+
48
+ def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray:
49
+ """Convert mask.
50
+
51
+ Args:
52
+ mask (Union[np.ndarray, Image.Image]): mask
53
+
54
+ Returns:
55
+ np.ndarray: converted mask
56
+ """
57
+ if isinstance(mask, Image.Image):
58
+ mask = np.array(mask)
59
+
60
+ if mask.ndim == 2:
61
+ mask = mask[:, :, np.newaxis]
62
+
63
+ if mask.shape[2] != 1:
64
+ mask = mask[:, :, 0:1]
65
+
66
+ return mask
67
+
68
+
69
+ def create_mask_image(
70
+ mask: Union[np.ndarray, Image.Image],
71
+ sam_masks: List[Dict[str, Any]],
72
+ ignore_black_chk: bool = True,
73
+ ) -> np.ndarray:
74
+ """Create mask image.
75
+
76
+ Args:
77
+ mask (Union[np.ndarray, Image.Image]): mask
78
+ sam_masks (List[Dict[str, Any]]): SAM masks
79
+ ignore_black_chk (bool): ignore black check
80
+
81
+ Returns:
82
+ np.ndarray: mask image
83
+ """
84
+ check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk)
85
+ mask = convert_mask(mask)
86
+
87
+ canvas_image = np.zeros(mask.shape, dtype=np.uint8)
88
+ mask_region = np.zeros(mask.shape, dtype=np.uint8)
89
+ for seg_dict in sam_masks:
90
+ seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
91
+ canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
92
+ if (seg_mask * canvas_mask * mask).astype(bool).any():
93
+ mask_region = mask_region + (seg_mask * canvas_mask)
94
+ seg_color = seg_mask * canvas_mask
95
+ canvas_image = canvas_image + seg_color
96
+
97
+ if not ignore_black_chk:
98
+ canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
99
+ if (canvas_mask * mask).astype(bool).any():
100
+ mask_region = mask_region + (canvas_mask)
101
+
102
+ mask_region = np.tile(mask_region * 255, (1, 1, 3))
103
+
104
+ seg_image = mask_region.astype(np.uint8)
105
+
106
+ return seg_image
inpalib/samlib.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import sys
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ inpa_basedir = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
13
+ if inpa_basedir not in sys.path:
14
+ sys.path.append(inpa_basedir)
15
+
16
+ from ia_file_manager import ia_file_manager # noqa: E402
17
+ from ia_get_dataset_colormap import create_pascal_label_colormap # noqa: E402
18
+ from ia_logging import ia_logging # noqa: E402
19
+ from ia_sam_manager import check_bfloat16_support, get_sam_mask_generator # noqa: E402
20
+ from ia_ui_items import get_sam_model_ids # noqa: E402
21
+
22
+
23
+ def get_all_sam_ids() -> List[str]:
24
+ """Get all SAM IDs.
25
+
26
+ Returns:
27
+ List[str]: SAM IDs
28
+ """
29
+ return get_sam_model_ids()
30
+
31
+
32
+ def sam_file_path(sam_id: str) -> str:
33
+ """Get SAM file path.
34
+
35
+ Args:
36
+ sam_id (str): SAM ID
37
+
38
+ Returns:
39
+ str: SAM file path
40
+ """
41
+ return os.path.join(ia_file_manager.models_dir, sam_id)
42
+
43
+
44
+ def sam_file_exists(sam_id: str) -> bool:
45
+ """Check if SAM file exists.
46
+
47
+ Args:
48
+ sam_id (str): SAM ID
49
+
50
+ Returns:
51
+ bool: True if SAM file exists else False
52
+ """
53
+ sam_checkpoint = sam_file_path(sam_id)
54
+
55
+ return os.path.isfile(sam_checkpoint)
56
+
57
+
58
+ def get_available_sam_ids() -> List[str]:
59
+ """Get available SAM IDs.
60
+
61
+ Returns:
62
+ List[str]: available SAM IDs
63
+ """
64
+ all_sam_ids = get_all_sam_ids()
65
+ for sam_id in all_sam_ids.copy():
66
+ if not sam_file_exists(sam_id):
67
+ all_sam_ids.remove(sam_id)
68
+
69
+ return all_sam_ids
70
+
71
+
72
+ def check_inputs_generate_sam_masks(
73
+ input_image: Union[np.ndarray, Image.Image],
74
+ sam_id: str,
75
+ anime_style_chk: bool = False,
76
+ ) -> None:
77
+ """Check generate SAM masks inputs.
78
+
79
+ Args:
80
+ input_image (Union[np.ndarray, Image.Image]): input image
81
+ sam_id (str): SAM ID
82
+ anime_style_chk (bool): anime style check
83
+
84
+ Returns:
85
+ None
86
+ """
87
+ if input_image is None or not isinstance(input_image, (np.ndarray, Image.Image)):
88
+ raise ValueError("Invalid input image")
89
+
90
+ if sam_id is None or not isinstance(sam_id, str):
91
+ raise ValueError("Invalid SAM ID")
92
+
93
+ if anime_style_chk is None or not isinstance(anime_style_chk, bool):
94
+ raise ValueError("Invalid anime style check")
95
+
96
+
97
+ def convert_input_image(input_image: Union[np.ndarray, Image.Image]) -> np.ndarray:
98
+ """Convert input image.
99
+
100
+ Args:
101
+ input_image (Union[np.ndarray, Image.Image]): input image
102
+
103
+ Returns:
104
+ np.ndarray: converted input image
105
+ """
106
+ if isinstance(input_image, Image.Image):
107
+ input_image = np.array(input_image)
108
+
109
+ if input_image.ndim == 2:
110
+ input_image = input_image[:, :, np.newaxis]
111
+
112
+ if input_image.shape[2] == 1:
113
+ input_image = np.concatenate([input_image] * 3, axis=-1)
114
+
115
+ return input_image
116
+
117
+
118
+ def generate_sam_masks(
119
+ input_image: Union[np.ndarray, Image.Image],
120
+ sam_id: str,
121
+ anime_style_chk: bool = False,
122
+ ) -> List[Dict[str, Any]]:
123
+ """Generate SAM masks.
124
+
125
+ Args:
126
+ input_image (Union[np.ndarray, Image.Image]): input image
127
+ sam_id (str): SAM ID
128
+ anime_style_chk (bool): anime style check
129
+
130
+ Returns:
131
+ List[Dict[str, Any]]: SAM masks
132
+ """
133
+ check_inputs_generate_sam_masks(input_image, sam_id, anime_style_chk)
134
+ input_image = convert_input_image(input_image)
135
+
136
+ sam_checkpoint = sam_file_path(sam_id)
137
+ sam_mask_generator = get_sam_mask_generator(sam_checkpoint, anime_style_chk)
138
+ ia_logging.info(f"{sam_mask_generator.__class__.__name__} {sam_id}")
139
+
140
+ if "sam2_" in sam_id:
141
+ device = "cuda" if torch.cuda.is_available() else "cpu"
142
+ torch_dtype = torch.bfloat16 if check_bfloat16_support() else torch.float16
143
+ with torch.inference_mode(), torch.autocast(device, dtype=torch_dtype):
144
+ sam_masks = sam_mask_generator.generate(input_image)
145
+ else:
146
+ sam_masks = sam_mask_generator.generate(input_image)
147
+
148
+ if anime_style_chk:
149
+ for sam_mask in sam_masks:
150
+ sam_mask_seg = sam_mask["segmentation"]
151
+ sam_mask_seg = cv2.morphologyEx(sam_mask_seg.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
152
+ sam_mask_seg = cv2.morphologyEx(sam_mask_seg.astype(np.uint8), cv2.MORPH_OPEN, np.ones((5, 5), np.uint8))
153
+ sam_mask["segmentation"] = sam_mask_seg.astype(bool)
154
+
155
+ ia_logging.info("sam_masks: {}".format(len(sam_masks)))
156
+
157
+ sam_masks = copy.deepcopy(sam_masks)
158
+ return sam_masks
159
+
160
+
161
+ def sort_masks_by_area(
162
+ sam_masks: List[Dict[str, Any]],
163
+ ) -> List[Dict[str, Any]]:
164
+ """Sort mask by area.
165
+
166
+ Args:
167
+ sam_masks (List[Dict[str, Any]]): SAM masks
168
+
169
+ Returns:
170
+ List[Dict[str, Any]]: sorted SAM masks
171
+ """
172
+ return sorted(sam_masks, key=lambda x: np.sum(x.get("segmentation").astype(np.uint32)))
173
+
174
+
175
+ def get_seg_colormap() -> np.ndarray:
176
+ """Get segmentation colormap.
177
+
178
+ Returns:
179
+ np.ndarray: segmentation colormap
180
+ """
181
+ cm_pascal = create_pascal_label_colormap()
182
+ seg_colormap = cm_pascal
183
+ seg_colormap = np.array([c for c in seg_colormap if max(c) >= 64], dtype=np.uint8)
184
+
185
+ return seg_colormap
186
+
187
+
188
+ def insert_mask_to_sam_masks(
189
+ sam_masks: List[Dict[str, Any]],
190
+ insert_mask: Dict[str, Any],
191
+ ) -> List[Dict[str, Any]]:
192
+ """Insert mask to SAM masks.
193
+
194
+ Args:
195
+ sam_masks (List[Dict[str, Any]]): SAM masks
196
+ insert_mask (Dict[str, Any]): insert mask
197
+
198
+ Returns:
199
+ List[Dict[str, Any]]: SAM masks
200
+ """
201
+ if insert_mask is not None and isinstance(insert_mask, dict) and "segmentation" in insert_mask:
202
+ if (len(sam_masks) > 0 and
203
+ sam_masks[0]["segmentation"].shape == insert_mask["segmentation"].shape and
204
+ np.any(insert_mask["segmentation"])):
205
+ sam_masks.insert(0, insert_mask)
206
+ ia_logging.info("insert mask to sam_masks")
207
+
208
+ return sam_masks
209
+
210
+
211
+ def create_seg_color_image(
212
+ input_image: Union[np.ndarray, Image.Image],
213
+ sam_masks: List[Dict[str, Any]],
214
+ ) -> np.ndarray:
215
+ """Create segmentation color image.
216
+
217
+ Args:
218
+ input_image (Union[np.ndarray, Image.Image]): input image
219
+ sam_masks (List[Dict[str, Any]]): SAM masks
220
+
221
+ Returns:
222
+ np.ndarray: segmentation color image
223
+ """
224
+ input_image = convert_input_image(input_image)
225
+
226
+ seg_colormap = get_seg_colormap()
227
+ sam_masks = sam_masks[:len(seg_colormap)]
228
+
229
+ with tqdm(total=len(sam_masks), desc="Processing segments") as progress_bar:
230
+ canvas_image = np.zeros((*input_image.shape[:2], 1), dtype=np.uint8)
231
+ for idx, seg_dict in enumerate(sam_masks[0:min(255, len(sam_masks))]):
232
+ seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
233
+ canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
234
+ seg_color = np.array([idx+1], dtype=np.uint8) * seg_mask * canvas_mask
235
+ canvas_image = canvas_image + seg_color
236
+ progress_bar.update(1)
237
+ seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0)
238
+ temp_canvas_image = np.apply_along_axis(lambda x: seg_colormap[x[0]], axis=-1, arr=canvas_image)
239
+ if len(sam_masks) > 255:
240
+ canvas_image = canvas_image.astype(bool).astype(np.uint8)
241
+ for idx, seg_dict in enumerate(sam_masks[255:min(509, len(sam_masks))]):
242
+ seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
243
+ canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
244
+ seg_color = np.array([idx+2], dtype=np.uint8) * seg_mask * canvas_mask
245
+ canvas_image = canvas_image + seg_color
246
+ progress_bar.update(1)
247
+ seg_colormap = seg_colormap[256:]
248
+ seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0)
249
+ seg_colormap = np.insert(seg_colormap, 0, [0, 0, 0], axis=0)
250
+ canvas_image = np.apply_along_axis(lambda x: seg_colormap[x[0]], axis=-1, arr=canvas_image)
251
+ canvas_image = temp_canvas_image + canvas_image
252
+ else:
253
+ canvas_image = temp_canvas_image
254
+ ret_seg_image = canvas_image.astype(np.uint8)
255
+
256
+ return ret_seg_image
javascript/inpaint-anything.js ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const inpaintAnything_waitForElement = async (parent, selector, exist) => {
2
+ return new Promise((resolve) => {
3
+ const observer = new MutationObserver(() => {
4
+ if (!!parent.querySelector(selector) != exist) {
5
+ return;
6
+ }
7
+ observer.disconnect();
8
+ resolve(undefined);
9
+ });
10
+
11
+ observer.observe(parent, {
12
+ childList: true,
13
+ subtree: true,
14
+ });
15
+
16
+ if (!!parent.querySelector(selector) == exist) {
17
+ resolve(undefined);
18
+ }
19
+ });
20
+ };
21
+
22
+ const inpaintAnything_waitForStyle = async (parent, selector, style) => {
23
+ return new Promise((resolve) => {
24
+ const observer = new MutationObserver(() => {
25
+ if (!parent.querySelector(selector) || !parent.querySelector(selector).style[style]) {
26
+ return;
27
+ }
28
+ observer.disconnect();
29
+ resolve(undefined);
30
+ });
31
+
32
+ observer.observe(parent, {
33
+ childList: true,
34
+ subtree: true,
35
+ attributes: true,
36
+ attributeFilter: ["style"],
37
+ });
38
+
39
+ if (!!parent.querySelector(selector) && !!parent.querySelector(selector).style[style]) {
40
+ resolve(undefined);
41
+ }
42
+ });
43
+ };
44
+
45
+ const inpaintAnything_timeout = (ms) => {
46
+ return new Promise(function (resolve, reject) {
47
+ setTimeout(() => reject("Timeout"), ms);
48
+ });
49
+ };
50
+
51
+ async function inpaintAnything_clearSamMask() {
52
+ const waitForElementToBeInDocument = (parent, selector) =>
53
+ Promise.race([inpaintAnything_waitForElement(parent, selector, true), inpaintAnything_timeout(1000)]);
54
+
55
+ const elemId = "#ia_sam_image";
56
+
57
+ const targetElement = document.querySelector(elemId);
58
+ if (!targetElement) {
59
+ return;
60
+ }
61
+ await waitForElementToBeInDocument(targetElement, "button[aria-label='Clear']");
62
+
63
+ targetElement.style.transform = null;
64
+ targetElement.style.zIndex = null;
65
+ targetElement.style.overflow = "auto";
66
+
67
+ const samMaskClear = targetElement.querySelector("button[aria-label='Clear']");
68
+ if (!samMaskClear) {
69
+ return;
70
+ }
71
+ const removeImageButton = targetElement.querySelector("button[aria-label='Remove Image']");
72
+ if (!removeImageButton) {
73
+ return;
74
+ }
75
+ samMaskClear?.click();
76
+
77
+ if (typeof inpaintAnything_clearSamMask.clickRemoveImage === "undefined") {
78
+ inpaintAnything_clearSamMask.clickRemoveImage = () => {
79
+ targetElement.style.transform = null;
80
+ targetElement.style.zIndex = null;
81
+ };
82
+ } else {
83
+ removeImageButton.removeEventListener("click", inpaintAnything_clearSamMask.clickRemoveImage);
84
+ }
85
+ removeImageButton.addEventListener("click", inpaintAnything_clearSamMask.clickRemoveImage);
86
+ }
87
+
88
+ async function inpaintAnything_clearSelMask() {
89
+ const waitForElementToBeInDocument = (parent, selector) =>
90
+ Promise.race([inpaintAnything_waitForElement(parent, selector, true), inpaintAnything_timeout(1000)]);
91
+
92
+ const elemId = "#ia_sel_mask";
93
+
94
+ const targetElement = document.querySelector(elemId);
95
+ if (!targetElement) {
96
+ return;
97
+ }
98
+ await waitForElementToBeInDocument(targetElement, "button[aria-label='Clear']");
99
+
100
+ targetElement.style.transform = null;
101
+ targetElement.style.zIndex = null;
102
+ targetElement.style.overflow = "auto";
103
+
104
+ const selMaskClear = targetElement.querySelector("button[aria-label='Clear']");
105
+ if (!selMaskClear) {
106
+ return;
107
+ }
108
+ const removeImageButton = targetElement.querySelector("button[aria-label='Remove Image']");
109
+ if (!removeImageButton) {
110
+ return;
111
+ }
112
+ selMaskClear?.click();
113
+
114
+ if (typeof inpaintAnything_clearSelMask.clickRemoveImage === "undefined") {
115
+ inpaintAnything_clearSelMask.clickRemoveImage = () => {
116
+ targetElement.style.transform = null;
117
+ targetElement.style.zIndex = null;
118
+ };
119
+ } else {
120
+ removeImageButton.removeEventListener("click", inpaintAnything_clearSelMask.clickRemoveImage);
121
+ }
122
+ removeImageButton.addEventListener("click", inpaintAnything_clearSelMask.clickRemoveImage);
123
+ }
124
+
125
+ async function inpaintAnything_initSamSelMask() {
126
+ inpaintAnything_clearSamMask();
127
+ inpaintAnything_clearSelMask();
128
+ }
129
+
130
+ var uiLoadedCallbacks = [];
131
+
132
+ function gradioApp() {
133
+ const elems = document.getElementsByTagName("gradio-app");
134
+ const elem = elems.length == 0 ? document : elems[0];
135
+
136
+ if (elem !== document) {
137
+ elem.getElementById = function (id) {
138
+ return document.getElementById(id);
139
+ };
140
+ }
141
+ return elem.shadowRoot ? elem.shadowRoot : elem;
142
+ }
143
+
144
+ function onUiLoaded(callback) {
145
+ uiLoadedCallbacks.push(callback);
146
+ }
147
+
148
+ function executeCallbacks(queue) {
149
+ for (const callback of queue) {
150
+ try {
151
+ callback();
152
+ } catch (e) {
153
+ console.error("error running callback", callback, ":", e);
154
+ }
155
+ }
156
+ }
157
+
158
+ onUiLoaded(async () => {
159
+ const elementIDs = {
160
+ ia_sam_image: "#ia_sam_image",
161
+ ia_sel_mask: "#ia_sel_mask",
162
+ ia_out_image: "#ia_out_image",
163
+ ia_cleaner_out_image: "#ia_cleaner_out_image",
164
+ };
165
+
166
+ function setStyleHeight(elemId, height) {
167
+ const elem = gradioApp().querySelector(elemId);
168
+ if (elem) {
169
+ if (!elem.style.height) {
170
+ elem.style.height = height;
171
+ const observer = new MutationObserver(() => {
172
+ const divPreview = elem.querySelector(".preview");
173
+ if (divPreview) {
174
+ divPreview.classList.remove("fixed-height");
175
+ }
176
+ });
177
+ observer.observe(elem, {
178
+ childList: true,
179
+ attributes: true,
180
+ attributeFilter: ["class"],
181
+ });
182
+ }
183
+ }
184
+ }
185
+
186
+ setStyleHeight(elementIDs.ia_out_image, "520px");
187
+ setStyleHeight(elementIDs.ia_cleaner_out_image, "520px");
188
+
189
+ // Default config
190
+ const defaultHotkeysConfig = {
191
+ canvas_hotkey_reset: "KeyR",
192
+ canvas_hotkey_fullscreen: "KeyS",
193
+ };
194
+
195
+ const elemData = {};
196
+ let activeElement;
197
+
198
+ function applyZoomAndPan(elemId) {
199
+ const targetElement = gradioApp().querySelector(elemId);
200
+
201
+ if (!targetElement) {
202
+ console.log("Element not found");
203
+ return;
204
+ }
205
+
206
+ targetElement.style.transformOrigin = "0 0";
207
+
208
+ elemData[elemId] = {
209
+ zoomLevel: 1,
210
+ panX: 0,
211
+ panY: 0,
212
+ };
213
+ let fullScreenMode = false;
214
+
215
+ // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
216
+ function toggleOverlap(forced = "") {
217
+ // const zIndex1 = "0";
218
+ const zIndex1 = null;
219
+ const zIndex2 = "998";
220
+
221
+ targetElement.style.zIndex = targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
222
+
223
+ if (forced === "off") {
224
+ targetElement.style.zIndex = zIndex1;
225
+ } else if (forced === "on") {
226
+ targetElement.style.zIndex = zIndex2;
227
+ }
228
+ }
229
+
230
+ /**
231
+ * This function fits the target element to the screen by calculating
232
+ * the required scale and offsets. It also updates the global variables
233
+ * zoomLevel, panX, and panY to reflect the new state.
234
+ */
235
+
236
+ function fitToElement() {
237
+ //Reset Zoom
238
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
239
+
240
+ // Get element and screen dimensions
241
+ const elementWidth = targetElement.offsetWidth;
242
+ const elementHeight = targetElement.offsetHeight;
243
+ const parentElement = targetElement.parentElement;
244
+ const screenWidth = parentElement.clientWidth;
245
+ const screenHeight = parentElement.clientHeight;
246
+
247
+ // Get element's coordinates relative to the parent element
248
+ const elementRect = targetElement.getBoundingClientRect();
249
+ const parentRect = parentElement.getBoundingClientRect();
250
+ const elementX = elementRect.x - parentRect.x;
251
+
252
+ // Calculate scale and offsets
253
+ const scaleX = screenWidth / elementWidth;
254
+ const scaleY = screenHeight / elementHeight;
255
+ const scale = Math.min(scaleX, scaleY);
256
+
257
+ const transformOrigin = window.getComputedStyle(targetElement).transformOrigin;
258
+ const [originX, originY] = transformOrigin.split(" ");
259
+ const originXValue = parseFloat(originX);
260
+ const originYValue = parseFloat(originY);
261
+
262
+ const offsetX = (screenWidth - elementWidth * scale) / 2 - originXValue * (1 - scale);
263
+ const offsetY = (screenHeight - elementHeight * scale) / 2.5 - originYValue * (1 - scale);
264
+
265
+ // Apply scale and offsets to the element
266
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
267
+
268
+ // Update global variables
269
+ elemData[elemId].zoomLevel = scale;
270
+ elemData[elemId].panX = offsetX;
271
+ elemData[elemId].panY = offsetY;
272
+
273
+ fullScreenMode = false;
274
+ toggleOverlap("off");
275
+ }
276
+
277
+ // Reset the zoom level and pan position of the target element to their initial values
278
+ function resetZoom() {
279
+ elemData[elemId] = {
280
+ zoomLevel: 1,
281
+ panX: 0,
282
+ panY: 0,
283
+ };
284
+
285
+ // fixCanvas();
286
+ targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
287
+
288
+ // const canvas = gradioApp().querySelector(`${elemId} canvas[key="interface"]`);
289
+
290
+ toggleOverlap("off");
291
+ fullScreenMode = false;
292
+
293
+ // if (
294
+ // canvas &&
295
+ // parseFloat(canvas.style.width) > 865 &&
296
+ // parseFloat(targetElement.style.width) > 865
297
+ // ) {
298
+ // fitToElement();
299
+ // return;
300
+ // }
301
+
302
+ // targetElement.style.width = "";
303
+ // if (canvas) {
304
+ // targetElement.style.height = canvas.style.height;
305
+ // }
306
+ targetElement.style.width = null;
307
+ targetElement.style.height = 480;
308
+ }
309
+
310
+ /**
311
+ * This function fits the target element to the screen by calculating
312
+ * the required scale and offsets. It also updates the global variables
313
+ * zoomLevel, panX, and panY to reflect the new state.
314
+ */
315
+
316
+ // Fullscreen mode
317
+ function fitToScreen() {
318
+ const canvas = gradioApp().querySelector(`${elemId} canvas[key="interface"]`);
319
+ const img = gradioApp().querySelector(`${elemId} img`);
320
+
321
+ if (!canvas && !img) return;
322
+
323
+ // if (canvas.offsetWidth > 862) {
324
+ // targetElement.style.width = canvas.offsetWidth + "px";
325
+ // }
326
+
327
+ if (fullScreenMode) {
328
+ resetZoom();
329
+ fullScreenMode = false;
330
+ return;
331
+ }
332
+
333
+ //Reset Zoom
334
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
335
+
336
+ // Get scrollbar width to right-align the image
337
+ const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth;
338
+
339
+ // Get element and screen dimensions
340
+ const elementWidth = targetElement.offsetWidth;
341
+ const elementHeight = targetElement.offsetHeight;
342
+ const screenWidth = window.innerWidth - scrollbarWidth;
343
+ const screenHeight = window.innerHeight;
344
+
345
+ // Get element's coordinates relative to the page
346
+ const elementRect = targetElement.getBoundingClientRect();
347
+ const elementY = elementRect.y;
348
+ const elementX = elementRect.x;
349
+
350
+ // Calculate scale and offsets
351
+ const scaleX = screenWidth / elementWidth;
352
+ const scaleY = screenHeight / elementHeight;
353
+ const scale = Math.min(scaleX, scaleY);
354
+
355
+ // Get the current transformOrigin
356
+ const computedStyle = window.getComputedStyle(targetElement);
357
+ const transformOrigin = computedStyle.transformOrigin;
358
+ const [originX, originY] = transformOrigin.split(" ");
359
+ const originXValue = parseFloat(originX);
360
+ const originYValue = parseFloat(originY);
361
+
362
+ // Calculate offsets with respect to the transformOrigin
363
+ const offsetX = (screenWidth - elementWidth * scale) / 2 - elementX - originXValue * (1 - scale);
364
+ const offsetY = (screenHeight - elementHeight * scale) / 2 - elementY - originYValue * (1 - scale);
365
+
366
+ // Apply scale and offsets to the element
367
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
368
+
369
+ // Update global variables
370
+ elemData[elemId].zoomLevel = scale;
371
+ elemData[elemId].panX = offsetX;
372
+ elemData[elemId].panY = offsetY;
373
+
374
+ fullScreenMode = true;
375
+ toggleOverlap("on");
376
+ }
377
+
378
+ // Reset zoom when uploading a new image
379
+ const fileInput = gradioApp().querySelector(`${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`);
380
+ if (fileInput) {
381
+ fileInput.addEventListener("click", resetZoom);
382
+ }
383
+
384
+ // Handle keydown events
385
+ function handleKeyDown(event) {
386
+ // Disable key locks to make pasting from the buffer work correctly
387
+ if (
388
+ (event.ctrlKey && event.code === "KeyV") ||
389
+ (event.ctrlKey && event.code === "KeyC") ||
390
+ event.code === "F5"
391
+ ) {
392
+ return;
393
+ }
394
+
395
+ // before activating shortcut, ensure user is not actively typing in an input field
396
+ if (event.target.nodeName === "TEXTAREA" || event.target.nodeName === "INPUT") {
397
+ return;
398
+ }
399
+
400
+ const hotkeyActions = {
401
+ [defaultHotkeysConfig.canvas_hotkey_reset]: resetZoom,
402
+ [defaultHotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
403
+ };
404
+
405
+ const action = hotkeyActions[event.code];
406
+ if (action) {
407
+ event.preventDefault();
408
+ action(event);
409
+ }
410
+ }
411
+
412
+ // Handle events only inside the targetElement
413
+ let isKeyDownHandlerAttached = false;
414
+
415
+ function handleMouseMove() {
416
+ if (!isKeyDownHandlerAttached) {
417
+ document.addEventListener("keydown", handleKeyDown);
418
+ isKeyDownHandlerAttached = true;
419
+
420
+ activeElement = elemId;
421
+ }
422
+ }
423
+
424
+ function handleMouseLeave() {
425
+ if (isKeyDownHandlerAttached) {
426
+ document.removeEventListener("keydown", handleKeyDown);
427
+ isKeyDownHandlerAttached = false;
428
+
429
+ activeElement = null;
430
+ }
431
+ }
432
+
433
+ // Add mouse event handlers
434
+ targetElement.addEventListener("mousemove", handleMouseMove);
435
+ targetElement.addEventListener("mouseleave", handleMouseLeave);
436
+ }
437
+
438
+ applyZoomAndPan(elementIDs.ia_sam_image);
439
+ applyZoomAndPan(elementIDs.ia_sel_mask);
440
+ // applyZoomAndPan(elementIDs.ia_out_image);
441
+ // applyZoomAndPan(elementIDs.ia_cleaner_out_image);
442
+ });
443
+
444
+ var executedOnLoaded = false;
445
+
446
+ document.addEventListener("DOMContentLoaded", function () {
447
+ var mutationObserver = new MutationObserver(function () {
448
+ if (
449
+ !executedOnLoaded &&
450
+ gradioApp().querySelector("#ia_sam_image") &&
451
+ gradioApp().querySelector("#ia_sel_mask")
452
+ ) {
453
+ executedOnLoaded = true;
454
+ executeCallbacks(uiLoadedCallbacks);
455
+ }
456
+ });
457
+ mutationObserver.observe(gradioApp(), { childList: true, subtree: true });
458
+ });
lama_cleaner/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+
5
+ import warnings # noqa: E402
6
+
7
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
8
+ warnings.filterwarnings("ignore", category=UserWarning, module="lama_cleaner")
9
+
10
+ from lama_cleaner.parse_args import parse_args # noqa: E402
11
+
12
+
13
+ def entry_point():
14
+ args = parse_args()
15
+ # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
16
+ # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
17
+ from lama_cleaner.server import main
18
+
19
+ main(args)
lama_cleaner/benchmark.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ import numpy as np
8
+ import nvidia_smi
9
+ import psutil
10
+ import torch
11
+
12
+ from lama_cleaner.model_manager import ModelManager
13
+ from lama_cleaner.schema import Config, HDStrategy, SDSampler
14
+
15
+ try:
16
+ torch._C._jit_override_can_fuse_on_cpu(False)
17
+ torch._C._jit_override_can_fuse_on_gpu(False)
18
+ torch._C._jit_set_texpr_fuser_enabled(False)
19
+ torch._C._jit_set_nvfuser_enabled(False)
20
+ except:
21
+ pass
22
+
23
+ NUM_THREADS = str(4)
24
+
25
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
26
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
27
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
28
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
29
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
30
+ if os.environ.get("CACHE_DIR"):
31
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
32
+
33
+
34
+ def run_model(model, size):
35
+ # RGB
36
+ image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
37
+ mask = np.random.randint(0, 255, size).astype(np.uint8)
38
+
39
+ config = Config(
40
+ ldm_steps=2,
41
+ hd_strategy=HDStrategy.ORIGINAL,
42
+ hd_strategy_crop_margin=128,
43
+ hd_strategy_crop_trigger_size=128,
44
+ hd_strategy_resize_limit=128,
45
+ prompt="a fox is sitting on a bench",
46
+ sd_steps=5,
47
+ sd_sampler=SDSampler.ddim
48
+ )
49
+ model(image, mask, config)
50
+
51
+
52
+ def benchmark(model, times: int, empty_cache: bool):
53
+ sizes = [(512, 512)]
54
+
55
+ nvidia_smi.nvmlInit()
56
+ device_id = 0
57
+ handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
58
+
59
+ def format(metrics):
60
+ return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
61
+
62
+ process = psutil.Process(os.getpid())
63
+ # 每个 size 给出显存和内存占用的指标
64
+ for size in sizes:
65
+ torch.cuda.empty_cache()
66
+ time_metrics = []
67
+ cpu_metrics = []
68
+ memory_metrics = []
69
+ gpu_memory_metrics = []
70
+ for _ in range(times):
71
+ start = time.time()
72
+ run_model(model, size)
73
+ torch.cuda.synchronize()
74
+
75
+ # cpu_metrics.append(process.cpu_percent())
76
+ time_metrics.append((time.time() - start) * 1000)
77
+ memory_metrics.append(process.memory_info().rss / 1024 / 1024)
78
+ gpu_memory_metrics.append(nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024)
79
+
80
+ print(f"size: {size}".center(80, "-"))
81
+ # print(f"cpu: {format(cpu_metrics)}")
82
+ print(f"latency: {format(time_metrics)}ms")
83
+ print(f"memory: {format(memory_metrics)} MB")
84
+ print(f"gpu memory: {format(gpu_memory_metrics)} MB")
85
+
86
+ nvidia_smi.nvmlShutdown()
87
+
88
+
89
+ def get_args_parser():
90
+ parser = argparse.ArgumentParser()
91
+ parser.add_argument("--name")
92
+ parser.add_argument("--device", default="cuda", type=str)
93
+ parser.add_argument("--times", default=10, type=int)
94
+ parser.add_argument("--empty-cache", action="store_true")
95
+ return parser.parse_args()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ args = get_args_parser()
100
+ device = torch.device(args.device)
101
+ model = ModelManager(
102
+ name=args.name,
103
+ device=device,
104
+ sd_run_local=True,
105
+ disable_nsfw=True,
106
+ sd_cpu_textencoder=True,
107
+ hf_access_token="123"
108
+ )
109
+ benchmark(model, args.times, args.empty_cache)
lama_cleaner/const.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from enum import Enum
4
+ from pydantic import BaseModel
5
+
6
+
7
+ MPS_SUPPORT_MODELS = [
8
+ "instruct_pix2pix",
9
+ "sd1.5",
10
+ "anything4",
11
+ "realisticVision1.4",
12
+ "sd2",
13
+ "paint_by_example",
14
+ "controlnet",
15
+ ]
16
+
17
+ DEFAULT_MODEL = "lama"
18
+ AVAILABLE_MODELS = [
19
+ "lama",
20
+ "ldm",
21
+ "zits",
22
+ "mat",
23
+ "fcf",
24
+ "sd1.5",
25
+ "anything4",
26
+ "realisticVision1.4",
27
+ "cv2",
28
+ "manga",
29
+ "sd2",
30
+ "paint_by_example",
31
+ "instruct_pix2pix",
32
+ ]
33
+ SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
34
+
35
+ AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
36
+ DEFAULT_DEVICE = "cuda"
37
+
38
+ NO_HALF_HELP = """
39
+ Using full precision model.
40
+ If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
41
+ """
42
+
43
+ CPU_OFFLOAD_HELP = """
44
+ Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
45
+ """
46
+
47
+ DISABLE_NSFW_HELP = """
48
+ Disable NSFW checker. (sd/paint_by_example)
49
+ """
50
+
51
+ SD_CPU_TEXTENCODER_HELP = """
52
+ Run Stable Diffusion text encoder model on CPU to save GPU memory.
53
+ """
54
+
55
+ SD_CONTROLNET_HELP = """
56
+ Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
57
+ """
58
+ DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
59
+ SD_CONTROLNET_CHOICES = [
60
+ "control_v11p_sd15_canny",
61
+ "control_v11p_sd15_openpose",
62
+ "control_v11p_sd15_inpaint",
63
+ "control_v11f1p_sd15_depth"
64
+ ]
65
+
66
+ SD_LOCAL_MODEL_HELP = """
67
+ Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
68
+ """
69
+
70
+ LOCAL_FILES_ONLY_HELP = """
71
+ Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
72
+ """
73
+
74
+ ENABLE_XFORMERS_HELP = """
75
+ Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
76
+ """
77
+
78
+ DEFAULT_MODEL_DIR = os.getenv(
79
+ "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
80
+ )
81
+ MODEL_DIR_HELP = """
82
+ Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
83
+ """
84
+
85
+ OUTPUT_DIR_HELP = """
86
+ Result images will be saved to output directory automatically without confirmation.
87
+ """
88
+
89
+ INPUT_HELP = """
90
+ If input is image, it will be loaded by default.
91
+ If input is directory, you can browse and select image in file manager.
92
+ """
93
+
94
+ GUI_HELP = """
95
+ Launch Lama Cleaner as desktop app
96
+ """
97
+
98
+ NO_GUI_AUTO_CLOSE_HELP = """
99
+ Prevent backend auto close after the GUI window closed.
100
+ """
101
+
102
+ QUALITY_HELP = """
103
+ Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
104
+ """
105
+
106
+
107
+ class RealESRGANModelName(str, Enum):
108
+ realesr_general_x4v3 = "realesr-general-x4v3"
109
+ RealESRGAN_x4plus = "RealESRGAN_x4plus"
110
+ RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
111
+
112
+
113
+ RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
114
+
115
+ INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
116
+ INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
117
+ AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h"]
118
+ AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"]
119
+ REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
120
+ ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
121
+ REALESRGAN_HELP = "Enable realesrgan super resolution"
122
+ REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
123
+ GFPGAN_HELP = (
124
+ "Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
125
+ )
126
+ GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
127
+ RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
128
+ RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
129
+ GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
130
+
131
+
132
+ class Config(BaseModel):
133
+ host: str = "127.0.0.1"
134
+ port: int = 8080
135
+ model: str = DEFAULT_MODEL
136
+ sd_local_model_path: str = None
137
+ sd_controlnet: bool = False
138
+ sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
139
+ device: str = DEFAULT_DEVICE
140
+ gui: bool = False
141
+ no_gui_auto_close: bool = False
142
+ no_half: bool = False
143
+ cpu_offload: bool = False
144
+ disable_nsfw: bool = False
145
+ sd_cpu_textencoder: bool = False
146
+ enable_xformers: bool = False
147
+ local_files_only: bool = False
148
+ model_dir: str = DEFAULT_MODEL_DIR
149
+ input: str = None
150
+ output_dir: str = None
151
+ # plugins
152
+ enable_interactive_seg: bool = False
153
+ interactive_seg_model: str = "vit_l"
154
+ interactive_seg_device: str = "cpu"
155
+ enable_remove_bg: bool = False
156
+ enable_anime_seg: bool = False
157
+ enable_realesrgan: bool = False
158
+ realesrgan_device: str = "cpu"
159
+ realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value
160
+ realesrgan_no_half: bool = False
161
+ enable_gfpgan: bool = False
162
+ gfpgan_device: str = "cpu"
163
+ enable_restoreformer: bool = False
164
+ restoreformer_device: str = "cpu"
165
+ enable_gif: bool = False
166
+
167
+
168
+ def load_config(installer_config: str):
169
+ if os.path.exists(installer_config):
170
+ with open(installer_config, "r", encoding="utf-8") as f:
171
+ return Config(**json.load(f))
172
+ else:
173
+ return Config()
lama_cleaner/file_manager/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .file_manager import FileManager
lama_cleaner/file_manager/file_manager.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/thumbnail.py
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import cv2
6
+ import time
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+ import numpy as np
10
+ # from watchdog.events import FileSystemEventHandler
11
+ # from watchdog.observers import Observer
12
+
13
+ from PIL import Image, ImageOps, PngImagePlugin
14
+ from loguru import logger
15
+
16
+ LARGE_ENOUGH_NUMBER = 100
17
+ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
18
+ from .storage_backends import FilesystemStorageBackend
19
+ from .utils import aspect_to_string, generate_filename, glob_img
20
+
21
+
22
+ class FileManager:
23
+ def __init__(self, app=None):
24
+ self.app = app
25
+ self._default_root_directory = "media"
26
+ self._default_thumbnail_directory = "media"
27
+ self._default_root_url = "/"
28
+ self._default_thumbnail_root_url = "/"
29
+ self._default_format = "JPEG"
30
+ self.output_dir: Path = None
31
+
32
+ if app is not None:
33
+ self.init_app(app)
34
+
35
+ self.image_dir_filenames = []
36
+ self.output_dir_filenames = []
37
+
38
+ self.image_dir_observer = None
39
+ self.output_dir_observer = None
40
+
41
+ self.modified_time = {
42
+ "image": datetime.utcnow(),
43
+ "output": datetime.utcnow(),
44
+ }
45
+
46
+ # def start(self):
47
+ # self.image_dir_filenames = self._media_names(self.root_directory)
48
+ # self.output_dir_filenames = self._media_names(self.output_dir)
49
+ #
50
+ # logger.info(f"Start watching image directory: {self.root_directory}")
51
+ # self.image_dir_observer = Observer()
52
+ # self.image_dir_observer.schedule(self, self.root_directory, recursive=False)
53
+ # self.image_dir_observer.start()
54
+ #
55
+ # logger.info(f"Start watching output directory: {self.output_dir}")
56
+ # self.output_dir_observer = Observer()
57
+ # self.output_dir_observer.schedule(self, self.output_dir, recursive=False)
58
+ # self.output_dir_observer.start()
59
+
60
+ def on_modified(self, event):
61
+ if not os.path.isdir(event.src_path):
62
+ return
63
+ if event.src_path == str(self.root_directory):
64
+ logger.info(f"Image directory {event.src_path} modified")
65
+ self.image_dir_filenames = self._media_names(self.root_directory)
66
+ self.modified_time["image"] = datetime.utcnow()
67
+ elif event.src_path == str(self.output_dir):
68
+ logger.info(f"Output directory {event.src_path} modified")
69
+ self.output_dir_filenames = self._media_names(self.output_dir)
70
+ self.modified_time["output"] = datetime.utcnow()
71
+
72
+ def init_app(self, app):
73
+ if self.app is None:
74
+ self.app = app
75
+ app.thumbnail_instance = self
76
+
77
+ if not hasattr(app, "extensions"):
78
+ app.extensions = {}
79
+
80
+ if "thumbnail" in app.extensions:
81
+ raise RuntimeError("Flask-thumbnail extension already initialized")
82
+
83
+ app.extensions["thumbnail"] = self
84
+
85
+ app.config.setdefault("THUMBNAIL_MEDIA_ROOT", self._default_root_directory)
86
+ app.config.setdefault(
87
+ "THUMBNAIL_MEDIA_THUMBNAIL_ROOT", self._default_thumbnail_directory
88
+ )
89
+ app.config.setdefault("THUMBNAIL_MEDIA_URL", self._default_root_url)
90
+ app.config.setdefault(
91
+ "THUMBNAIL_MEDIA_THUMBNAIL_URL", self._default_thumbnail_root_url
92
+ )
93
+ app.config.setdefault("THUMBNAIL_DEFAULT_FORMAT", self._default_format)
94
+
95
+ @property
96
+ def root_directory(self):
97
+ path = self.app.config["THUMBNAIL_MEDIA_ROOT"]
98
+
99
+ if os.path.isabs(path):
100
+ return path
101
+ else:
102
+ return os.path.join(self.app.root_path, path)
103
+
104
+ @property
105
+ def thumbnail_directory(self):
106
+ path = self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"]
107
+
108
+ if os.path.isabs(path):
109
+ return path
110
+ else:
111
+ return os.path.join(self.app.root_path, path)
112
+
113
+ @property
114
+ def root_url(self):
115
+ return self.app.config["THUMBNAIL_MEDIA_URL"]
116
+
117
+ @property
118
+ def media_names(self):
119
+ # return self.image_dir_filenames
120
+ return self._media_names(self.root_directory)
121
+
122
+ @property
123
+ def output_media_names(self):
124
+ return self._media_names(self.output_dir)
125
+ # return self.output_dir_filenames
126
+
127
+ @staticmethod
128
+ def _media_names(directory: Path):
129
+ names = sorted([it.name for it in glob_img(directory)])
130
+ res = []
131
+ for name in names:
132
+ path = os.path.join(directory, name)
133
+ img = Image.open(path)
134
+ res.append(
135
+ {
136
+ "name": name,
137
+ "height": img.height,
138
+ "width": img.width,
139
+ "ctime": os.path.getctime(path),
140
+ "mtime": os.path.getmtime(path),
141
+ }
142
+ )
143
+ return res
144
+
145
+ @property
146
+ def thumbnail_url(self):
147
+ return self.app.config["THUMBNAIL_MEDIA_THUMBNAIL_URL"]
148
+
149
+ def get_thumbnail(
150
+ self, directory: Path, original_filename: str, width, height, **options
151
+ ):
152
+ storage = FilesystemStorageBackend(self.app)
153
+ crop = options.get("crop", "fit")
154
+ background = options.get("background")
155
+ quality = options.get("quality", 90)
156
+
157
+ original_path, original_filename = os.path.split(original_filename)
158
+ original_filepath = os.path.join(directory, original_path, original_filename)
159
+ image = Image.open(BytesIO(storage.read(original_filepath)))
160
+
161
+ # keep ratio resize
162
+ if width is not None:
163
+ height = int(image.height * width / image.width)
164
+ else:
165
+ width = int(image.width * height / image.height)
166
+
167
+ thumbnail_size = (width, height)
168
+
169
+ thumbnail_filename = generate_filename(
170
+ original_filename,
171
+ aspect_to_string(thumbnail_size),
172
+ crop,
173
+ background,
174
+ quality,
175
+ )
176
+
177
+ thumbnail_filepath = os.path.join(
178
+ self.thumbnail_directory, original_path, thumbnail_filename
179
+ )
180
+ thumbnail_url = os.path.join(
181
+ self.thumbnail_url, original_path, thumbnail_filename
182
+ )
183
+
184
+ if storage.exists(thumbnail_filepath):
185
+ return thumbnail_url, (width, height)
186
+
187
+ try:
188
+ image.load()
189
+ except (IOError, OSError):
190
+ self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
191
+ return thumbnail_url, (width, height)
192
+
193
+ # get original image format
194
+ options["format"] = options.get("format", image.format)
195
+
196
+ image = self._create_thumbnail(
197
+ image, thumbnail_size, crop, background=background
198
+ )
199
+
200
+ raw_data = self.get_raw_data(image, **options)
201
+ storage.save(thumbnail_filepath, raw_data)
202
+
203
+ return thumbnail_url, (width, height)
204
+
205
+ def get_raw_data(self, image, **options):
206
+ data = {
207
+ "format": self._get_format(image, **options),
208
+ "quality": options.get("quality", 90),
209
+ }
210
+
211
+ _file = BytesIO()
212
+ image.save(_file, **data)
213
+ return _file.getvalue()
214
+
215
+ @staticmethod
216
+ def colormode(image, colormode="RGB"):
217
+ if colormode == "RGB" or colormode == "RGBA":
218
+ if image.mode == "RGBA":
219
+ return image
220
+ if image.mode == "LA":
221
+ return image.convert("RGBA")
222
+ return image.convert(colormode)
223
+
224
+ if colormode == "GRAY":
225
+ return image.convert("L")
226
+
227
+ return image.convert(colormode)
228
+
229
+ @staticmethod
230
+ def background(original_image, color=0xFF):
231
+ size = (max(original_image.size),) * 2
232
+ image = Image.new("L", size, color)
233
+ image.paste(
234
+ original_image,
235
+ tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
236
+ )
237
+
238
+ return image
239
+
240
+ def _get_format(self, image, **options):
241
+ if options.get("format"):
242
+ return options.get("format")
243
+ if image.format:
244
+ return image.format
245
+
246
+ return self.app.config["THUMBNAIL_DEFAULT_FORMAT"]
247
+
248
+ def _create_thumbnail(self, image, size, crop="fit", background=None):
249
+ try:
250
+ resample = Image.Resampling.LANCZOS
251
+ except AttributeError: # pylint: disable=raise-missing-from
252
+ resample = Image.ANTIALIAS
253
+
254
+ if crop == "fit":
255
+ image = ImageOps.fit(image, size, resample)
256
+ else:
257
+ image = image.copy()
258
+ image.thumbnail(size, resample=resample)
259
+
260
+ if background is not None:
261
+ image = self.background(image)
262
+
263
+ image = self.colormode(image)
264
+
265
+ return image
lama_cleaner/file_manager/storage_backends.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
2
+ import errno
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ class BaseStorageBackend(ABC):
8
+ def __init__(self, app=None):
9
+ self.app = app
10
+
11
+ @abstractmethod
12
+ def read(self, filepath, mode="rb", **kwargs):
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def exists(self, filepath):
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def save(self, filepath, data):
21
+ raise NotImplementedError
22
+
23
+
24
+ class FilesystemStorageBackend(BaseStorageBackend):
25
+ def read(self, filepath, mode="rb", **kwargs):
26
+ with open(filepath, mode) as f: # pylint: disable=unspecified-encoding
27
+ return f.read()
28
+
29
+ def exists(self, filepath):
30
+ return os.path.exists(filepath)
31
+
32
+ def save(self, filepath, data):
33
+ directory = os.path.dirname(filepath)
34
+
35
+ if not os.path.exists(directory):
36
+ try:
37
+ os.makedirs(directory)
38
+ except OSError as e:
39
+ if e.errno != errno.EEXIST:
40
+ raise
41
+
42
+ if not os.path.isdir(directory):
43
+ raise IOError("{} is not a directory".format(directory))
44
+
45
+ with open(filepath, "wb") as f:
46
+ f.write(data)
lama_cleaner/file_manager/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
2
+ import importlib
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from typing import Union
7
+
8
+
9
+ def generate_filename(original_filename, *options):
10
+ name, ext = os.path.splitext(original_filename)
11
+ for v in options:
12
+ if v:
13
+ name += "_%s" % v
14
+ name += ext
15
+
16
+ return name
17
+
18
+
19
+ def parse_size(size):
20
+ if isinstance(size, int):
21
+ # If the size parameter is a single number, assume square aspect.
22
+ return [size, size]
23
+
24
+ if isinstance(size, (tuple, list)):
25
+ if len(size) == 1:
26
+ # If single value tuple/list is provided, exand it to two elements
27
+ return size + type(size)(size)
28
+ return size
29
+
30
+ try:
31
+ thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
32
+ except ValueError:
33
+ raise ValueError( # pylint: disable=raise-missing-from
34
+ "Bad thumbnail size format. Valid format is INTxINT."
35
+ )
36
+
37
+ if len(thumbnail_size) == 1:
38
+ # If the size parameter only contains a single integer, assume square aspect.
39
+ thumbnail_size.append(thumbnail_size[0])
40
+
41
+ return thumbnail_size
42
+
43
+
44
+ def aspect_to_string(size):
45
+ if isinstance(size, str):
46
+ return size
47
+
48
+ return "x".join(map(str, size))
49
+
50
+
51
+ IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
52
+
53
+
54
+ def glob_img(p: Union[Path, str], recursive: bool = False):
55
+ p = Path(p)
56
+ if p.is_file() and p.suffix in IMG_SUFFIX:
57
+ yield p
58
+ else:
59
+ if recursive:
60
+ files = Path(p).glob("**/*.*")
61
+ else:
62
+ files = Path(p).glob("*.*")
63
+
64
+ for it in files:
65
+ if it.suffix not in IMG_SUFFIX:
66
+ continue
67
+ yield it
lama_cleaner/helper.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import sys
4
+ from typing import List, Optional
5
+
6
+ from urllib.parse import urlparse
7
+ import cv2
8
+ from PIL import Image, ImageOps, PngImagePlugin
9
+ import numpy as np
10
+ import torch
11
+ from lama_cleaner.const import MPS_SUPPORT_MODELS
12
+ from loguru import logger
13
+ from torch.hub import download_url_to_file, get_dir
14
+ import hashlib
15
+
16
+
17
+ def md5sum(filename):
18
+ md5 = hashlib.md5()
19
+ with open(filename, "rb") as f:
20
+ for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
21
+ md5.update(chunk)
22
+ return md5.hexdigest()
23
+
24
+
25
+ def switch_mps_device(model_name, device):
26
+ if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
27
+ logger.info(f"{model_name} not support mps, switch to cpu")
28
+ return torch.device("cpu")
29
+ return device
30
+
31
+
32
+ def get_cache_path_by_url(url):
33
+ parts = urlparse(url)
34
+ hub_dir = get_dir()
35
+ model_dir = os.path.join(hub_dir, "checkpoints")
36
+ if not os.path.isdir(model_dir):
37
+ os.makedirs(model_dir)
38
+ filename = os.path.basename(parts.path)
39
+ cached_file = os.path.join(model_dir, filename)
40
+ return cached_file
41
+
42
+
43
+ def download_model(url, model_md5: str = None):
44
+ cached_file = get_cache_path_by_url(url)
45
+ if not os.path.exists(cached_file):
46
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
47
+ hash_prefix = None
48
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
49
+ if model_md5:
50
+ _md5 = md5sum(cached_file)
51
+ if model_md5 == _md5:
52
+ logger.info(f"Download model success, md5: {_md5}")
53
+ else:
54
+ try:
55
+ os.remove(cached_file)
56
+ logger.error(
57
+ f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart lama-cleaner."
58
+ f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
59
+ )
60
+ except:
61
+ logger.error(
62
+ f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart lama-cleaner."
63
+ )
64
+ exit(-1)
65
+
66
+ return cached_file
67
+
68
+
69
+ def ceil_modulo(x, mod):
70
+ if x % mod == 0:
71
+ return x
72
+ return (x // mod + 1) * mod
73
+
74
+
75
+ def handle_error(model_path, model_md5, e):
76
+ _md5 = md5sum(model_path)
77
+ if _md5 != model_md5:
78
+ try:
79
+ os.remove(model_path)
80
+ logger.error(
81
+ f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart lama-cleaner."
82
+ f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
83
+ )
84
+ except:
85
+ logger.error(
86
+ f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart lama-cleaner."
87
+ )
88
+ else:
89
+ logger.error(
90
+ f"Failed to load model {model_path},"
91
+ f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
92
+ )
93
+ exit(-1)
94
+
95
+
96
+ def load_jit_model(url_or_path, device, model_md5: str):
97
+ if os.path.exists(url_or_path):
98
+ model_path = url_or_path
99
+ else:
100
+ model_path = download_model(url_or_path, model_md5)
101
+
102
+ logger.info(f"Loading model from: {model_path}")
103
+ try:
104
+ model = torch.jit.load(model_path, map_location="cpu").to(device)
105
+ except Exception as e:
106
+ handle_error(model_path, model_md5, e)
107
+ model.eval()
108
+ return model
109
+
110
+
111
+ def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
112
+ if os.path.exists(url_or_path):
113
+ model_path = url_or_path
114
+ else:
115
+ model_path = download_model(url_or_path, model_md5)
116
+
117
+ try:
118
+ logger.info(f"Loading model from: {model_path}")
119
+ state_dict = torch.load(model_path, map_location="cpu")
120
+ model.load_state_dict(state_dict, strict=True)
121
+ model.to(device)
122
+ except Exception as e:
123
+ handle_error(model_path, model_md5, e)
124
+ model.eval()
125
+ return model
126
+
127
+
128
+ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
129
+ data = cv2.imencode(
130
+ f".{ext}",
131
+ image_numpy,
132
+ [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
133
+ )[1]
134
+ image_bytes = data.tobytes()
135
+ return image_bytes
136
+
137
+
138
+ def pil_to_bytes(pil_img, ext: str, quality: int = 95, exif_infos={}) -> bytes:
139
+ with io.BytesIO() as output:
140
+ kwargs = {k: v for k, v in exif_infos.items() if v is not None}
141
+ if ext == "png" and "parameters" in kwargs:
142
+ pnginfo_data = PngImagePlugin.PngInfo()
143
+ pnginfo_data.add_text("parameters", kwargs["parameters"])
144
+ kwargs["pnginfo"] = pnginfo_data
145
+
146
+ pil_img.save(
147
+ output,
148
+ format=ext,
149
+ quality=quality,
150
+ **kwargs,
151
+ )
152
+ image_bytes = output.getvalue()
153
+ return image_bytes
154
+
155
+
156
+ def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
157
+ alpha_channel = None
158
+ image = Image.open(io.BytesIO(img_bytes))
159
+
160
+ if return_exif:
161
+ info = image.info or {}
162
+ exif_infos = {"exif": image.getexif(), "parameters": info.get("parameters")}
163
+
164
+ try:
165
+ image = ImageOps.exif_transpose(image)
166
+ except:
167
+ pass
168
+
169
+ if gray:
170
+ image = image.convert("L")
171
+ np_img = np.array(image)
172
+ else:
173
+ if image.mode == "RGBA":
174
+ np_img = np.array(image)
175
+ alpha_channel = np_img[:, :, -1]
176
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
177
+ else:
178
+ image = image.convert("RGB")
179
+ np_img = np.array(image)
180
+
181
+ if return_exif:
182
+ return np_img, alpha_channel, exif_infos
183
+ return np_img, alpha_channel
184
+
185
+
186
+ def norm_img(np_img):
187
+ if len(np_img.shape) == 2:
188
+ np_img = np_img[:, :, np.newaxis]
189
+ np_img = np.transpose(np_img, (2, 0, 1))
190
+ np_img = np_img.astype("float32") / 255
191
+ return np_img
192
+
193
+
194
+ def resize_max_size(
195
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
196
+ ) -> np.ndarray:
197
+ # Resize image's longer size to size_limit if longer size larger than size_limit
198
+ h, w = np_img.shape[:2]
199
+ if max(h, w) > size_limit:
200
+ ratio = size_limit / max(h, w)
201
+ new_w = int(w * ratio + 0.5)
202
+ new_h = int(h * ratio + 0.5)
203
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
204
+ else:
205
+ return np_img
206
+
207
+
208
+ def pad_img_to_modulo(
209
+ img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
210
+ ):
211
+ """
212
+
213
+ Args:
214
+ img: [H, W, C]
215
+ mod:
216
+ square: 是否为正方形
217
+ min_size:
218
+
219
+ Returns:
220
+
221
+ """
222
+ if len(img.shape) == 2:
223
+ img = img[:, :, np.newaxis]
224
+ height, width = img.shape[:2]
225
+ out_height = ceil_modulo(height, mod)
226
+ out_width = ceil_modulo(width, mod)
227
+
228
+ if min_size is not None:
229
+ assert min_size % mod == 0
230
+ out_width = max(min_size, out_width)
231
+ out_height = max(min_size, out_height)
232
+
233
+ if square:
234
+ max_size = max(out_height, out_width)
235
+ out_height = max_size
236
+ out_width = max_size
237
+
238
+ return np.pad(
239
+ img,
240
+ ((0, out_height - height), (0, out_width - width), (0, 0)),
241
+ mode="symmetric",
242
+ )
243
+
244
+
245
+ def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
246
+ """
247
+ Args:
248
+ mask: (h, w, 1) 0~255
249
+
250
+ Returns:
251
+
252
+ """
253
+ height, width = mask.shape[:2]
254
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
255
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
256
+
257
+ boxes = []
258
+ for cnt in contours:
259
+ x, y, w, h = cv2.boundingRect(cnt)
260
+ box = np.array([x, y, x + w, y + h]).astype(int)
261
+
262
+ box[::2] = np.clip(box[::2], 0, width)
263
+ box[1::2] = np.clip(box[1::2], 0, height)
264
+ boxes.append(box)
265
+
266
+ return boxes
267
+
268
+
269
+ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
270
+ """
271
+ Args:
272
+ mask: (h, w) 0~255
273
+
274
+ Returns:
275
+
276
+ """
277
+ _, thresh = cv2.threshold(mask, 127, 255, 0)
278
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
279
+
280
+ max_area = 0
281
+ max_index = -1
282
+ for i, cnt in enumerate(contours):
283
+ area = cv2.contourArea(cnt)
284
+ if area > max_area:
285
+ max_area = area
286
+ max_index = i
287
+
288
+ if max_index != -1:
289
+ new_mask = np.zeros_like(mask)
290
+ return cv2.drawContours(new_mask, contours, max_index, 255, -1)
291
+ else:
292
+ return mask
lama_cleaner/installer.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+
5
+ def install(package):
6
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
+
8
+
9
+ def install_plugins_package():
10
+ install("rembg")
11
+ install("realesrgan")
12
+ install("gfpgan")
lama_cleaner/model/__init__.py ADDED
File without changes
lama_cleaner/model/base.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ from loguru import logger
8
+
9
+ from lama_cleaner.helper import (
10
+ boxes_from_mask,
11
+ resize_max_size,
12
+ pad_img_to_modulo,
13
+ switch_mps_device,
14
+ )
15
+ from lama_cleaner.schema import Config, HDStrategy
16
+
17
+
18
+ class InpaintModel:
19
+ name = "base"
20
+ min_size: Optional[int] = None
21
+ pad_mod = 8
22
+ pad_to_square = False
23
+
24
+ def __init__(self, device, **kwargs):
25
+ """
26
+
27
+ Args:
28
+ device:
29
+ """
30
+ device = switch_mps_device(self.name, device)
31
+ self.device = device
32
+ self.init_model(device, **kwargs)
33
+
34
+ @abc.abstractmethod
35
+ def init_model(self, device, **kwargs):
36
+ ...
37
+
38
+ @staticmethod
39
+ @abc.abstractmethod
40
+ def is_downloaded() -> bool:
41
+ ...
42
+
43
+ @abc.abstractmethod
44
+ def forward(self, image, mask, config: Config):
45
+ """Input images and output images have same size
46
+ images: [H, W, C] RGB
47
+ masks: [H, W, 1] 255 为 masks 区域
48
+ return: BGR IMAGE
49
+ """
50
+ ...
51
+
52
+ def _pad_forward(self, image, mask, config: Config):
53
+ origin_height, origin_width = image.shape[:2]
54
+ pad_image = pad_img_to_modulo(
55
+ image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
56
+ )
57
+ pad_mask = pad_img_to_modulo(
58
+ mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
59
+ )
60
+
61
+ logger.info(f"final forward pad size: {pad_image.shape}")
62
+
63
+ result = self.forward(pad_image, pad_mask, config)
64
+ result = result[0:origin_height, 0:origin_width, :]
65
+
66
+ result, image, mask = self.forward_post_process(result, image, mask, config)
67
+
68
+ mask = mask[:, :, np.newaxis]
69
+ result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
70
+ return result
71
+
72
+ def forward_post_process(self, result, image, mask, config):
73
+ return result, image, mask
74
+
75
+ @torch.no_grad()
76
+ def __call__(self, image, mask, config: Config):
77
+ """
78
+ images: [H, W, C] RGB, not normalized
79
+ masks: [H, W]
80
+ return: BGR IMAGE
81
+ """
82
+ inpaint_result = None
83
+ logger.info(f"hd_strategy: {config.hd_strategy}")
84
+ if config.hd_strategy == HDStrategy.CROP:
85
+ if max(image.shape) > config.hd_strategy_crop_trigger_size:
86
+ logger.info(f"Run crop strategy")
87
+ boxes = boxes_from_mask(mask)
88
+ crop_result = []
89
+ for box in boxes:
90
+ crop_image, crop_box = self._run_box(image, mask, box, config)
91
+ crop_result.append((crop_image, crop_box))
92
+
93
+ inpaint_result = image[:, :, ::-1]
94
+ for crop_image, crop_box in crop_result:
95
+ x1, y1, x2, y2 = crop_box
96
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
97
+
98
+ elif config.hd_strategy == HDStrategy.RESIZE:
99
+ if max(image.shape) > config.hd_strategy_resize_limit:
100
+ origin_size = image.shape[:2]
101
+ downsize_image = resize_max_size(
102
+ image, size_limit=config.hd_strategy_resize_limit
103
+ )
104
+ downsize_mask = resize_max_size(
105
+ mask, size_limit=config.hd_strategy_resize_limit
106
+ )
107
+
108
+ logger.info(
109
+ f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
110
+ )
111
+ inpaint_result = self._pad_forward(
112
+ downsize_image, downsize_mask, config
113
+ )
114
+
115
+ # only paste masked area result
116
+ inpaint_result = cv2.resize(
117
+ inpaint_result,
118
+ (origin_size[1], origin_size[0]),
119
+ interpolation=cv2.INTER_CUBIC,
120
+ )
121
+ original_pixel_indices = mask < 127
122
+ inpaint_result[original_pixel_indices] = image[:, :, ::-1][
123
+ original_pixel_indices
124
+ ]
125
+
126
+ if inpaint_result is None:
127
+ inpaint_result = self._pad_forward(image, mask, config)
128
+
129
+ return inpaint_result
130
+
131
+ def _crop_box(self, image, mask, box, config: Config):
132
+ """
133
+
134
+ Args:
135
+ image: [H, W, C] RGB
136
+ mask: [H, W, 1]
137
+ box: [left,top,right,bottom]
138
+
139
+ Returns:
140
+ BGR IMAGE, (l, r, r, b)
141
+ """
142
+ box_h = box[3] - box[1]
143
+ box_w = box[2] - box[0]
144
+ cx = (box[0] + box[2]) // 2
145
+ cy = (box[1] + box[3]) // 2
146
+ img_h, img_w = image.shape[:2]
147
+
148
+ w = box_w + config.hd_strategy_crop_margin * 2
149
+ h = box_h + config.hd_strategy_crop_margin * 2
150
+
151
+ _l = cx - w // 2
152
+ _r = cx + w // 2
153
+ _t = cy - h // 2
154
+ _b = cy + h // 2
155
+
156
+ l = max(_l, 0)
157
+ r = min(_r, img_w)
158
+ t = max(_t, 0)
159
+ b = min(_b, img_h)
160
+
161
+ # try to get more context when crop around image edge
162
+ if _l < 0:
163
+ r += abs(_l)
164
+ if _r > img_w:
165
+ l -= _r - img_w
166
+ if _t < 0:
167
+ b += abs(_t)
168
+ if _b > img_h:
169
+ t -= _b - img_h
170
+
171
+ l = max(l, 0)
172
+ r = min(r, img_w)
173
+ t = max(t, 0)
174
+ b = min(b, img_h)
175
+
176
+ crop_img = image[t:b, l:r, :]
177
+ crop_mask = mask[t:b, l:r]
178
+
179
+ logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
180
+
181
+ return crop_img, crop_mask, [l, t, r, b]
182
+
183
+ def _calculate_cdf(self, histogram):
184
+ cdf = histogram.cumsum()
185
+ normalized_cdf = cdf / float(cdf.max())
186
+ return normalized_cdf
187
+
188
+ def _calculate_lookup(self, source_cdf, reference_cdf):
189
+ lookup_table = np.zeros(256)
190
+ lookup_val = 0
191
+ for source_index, source_val in enumerate(source_cdf):
192
+ for reference_index, reference_val in enumerate(reference_cdf):
193
+ if reference_val >= source_val:
194
+ lookup_val = reference_index
195
+ break
196
+ lookup_table[source_index] = lookup_val
197
+ return lookup_table
198
+
199
+ def _match_histograms(self, source, reference, mask):
200
+ transformed_channels = []
201
+ for channel in range(source.shape[-1]):
202
+ source_channel = source[:, :, channel]
203
+ reference_channel = reference[:, :, channel]
204
+
205
+ # only calculate histograms for non-masked parts
206
+ source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
207
+ reference_histogram, _ = np.histogram(
208
+ reference_channel[mask == 0], 256, [0, 256]
209
+ )
210
+
211
+ source_cdf = self._calculate_cdf(source_histogram)
212
+ reference_cdf = self._calculate_cdf(reference_histogram)
213
+
214
+ lookup = self._calculate_lookup(source_cdf, reference_cdf)
215
+
216
+ transformed_channels.append(cv2.LUT(source_channel, lookup))
217
+
218
+ result = cv2.merge(transformed_channels)
219
+ result = cv2.convertScaleAbs(result)
220
+
221
+ return result
222
+
223
+ def _apply_cropper(self, image, mask, config: Config):
224
+ img_h, img_w = image.shape[:2]
225
+ l, t, w, h = (
226
+ config.croper_x,
227
+ config.croper_y,
228
+ config.croper_width,
229
+ config.croper_height,
230
+ )
231
+ r = l + w
232
+ b = t + h
233
+
234
+ l = max(l, 0)
235
+ r = min(r, img_w)
236
+ t = max(t, 0)
237
+ b = min(b, img_h)
238
+
239
+ crop_img = image[t:b, l:r, :]
240
+ crop_mask = mask[t:b, l:r]
241
+ return crop_img, crop_mask, (l, t, r, b)
242
+
243
+ def _run_box(self, image, mask, box, config: Config):
244
+ """
245
+
246
+ Args:
247
+ image: [H, W, C] RGB
248
+ mask: [H, W, 1]
249
+ box: [left,top,right,bottom]
250
+
251
+ Returns:
252
+ BGR IMAGE
253
+ """
254
+ crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
255
+
256
+ return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
257
+
258
+
259
+ class DiffusionInpaintModel(InpaintModel):
260
+ @torch.no_grad()
261
+ def __call__(self, image, mask, config: Config):
262
+ """
263
+ images: [H, W, C] RGB, not normalized
264
+ masks: [H, W]
265
+ return: BGR IMAGE
266
+ """
267
+ # boxes = boxes_from_mask(mask)
268
+ if config.use_croper:
269
+ crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
270
+ crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
271
+ inpaint_result = image[:, :, ::-1]
272
+ inpaint_result[t:b, l:r, :] = crop_image
273
+ else:
274
+ inpaint_result = self._scaled_pad_forward(image, mask, config)
275
+
276
+ return inpaint_result
277
+
278
+ def _scaled_pad_forward(self, image, mask, config: Config):
279
+ longer_side_length = int(config.sd_scale * max(image.shape[:2]))
280
+ origin_size = image.shape[:2]
281
+ downsize_image = resize_max_size(image, size_limit=longer_side_length)
282
+ downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
283
+ if config.sd_scale != 1:
284
+ logger.info(
285
+ f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
286
+ )
287
+ inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
288
+ # only paste masked area result
289
+ inpaint_result = cv2.resize(
290
+ inpaint_result,
291
+ (origin_size[1], origin_size[0]),
292
+ interpolation=cv2.INTER_CUBIC,
293
+ )
294
+ original_pixel_indices = mask < 127
295
+ inpaint_result[original_pixel_indices] = image[:, :, ::-1][
296
+ original_pixel_indices
297
+ ]
298
+ return inpaint_result
lama_cleaner/model/controlnet.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import PIL.Image
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import ControlNetModel
8
+ from loguru import logger
9
+
10
+ from lama_cleaner.model.base import DiffusionInpaintModel
11
+ from lama_cleaner.model.utils import torch_gc, get_scheduler
12
+ from lama_cleaner.schema import Config
13
+
14
+
15
+ class CPUTextEncoderWrapper:
16
+ def __init__(self, text_encoder, torch_dtype):
17
+ self.config = text_encoder.config
18
+ self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
19
+ self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
20
+ self.torch_dtype = torch_dtype
21
+ del text_encoder
22
+ torch_gc()
23
+
24
+ def __call__(self, x, **kwargs):
25
+ input_device = x.device
26
+ return [
27
+ self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
28
+ .to(input_device)
29
+ .to(self.torch_dtype)
30
+ ]
31
+
32
+ @property
33
+ def dtype(self):
34
+ return self.torch_dtype
35
+
36
+
37
+ NAMES_MAP = {
38
+ "sd1.5": "runwayml/stable-diffusion-inpainting",
39
+ "anything4": "Sanster/anything-4.0-inpainting",
40
+ "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
41
+ }
42
+
43
+ NATIVE_NAMES_MAP = {
44
+ "sd1.5": "runwayml/stable-diffusion-v1-5",
45
+ "anything4": "andite/anything-v4.0",
46
+ "realisticVision1.4": "SG161222/Realistic_Vision_V1.4",
47
+ }
48
+
49
+
50
+ def make_inpaint_condition(image, image_mask):
51
+ """
52
+ image: [H, W, C] RGB
53
+ mask: [H, W, 1] 255 means area to repaint
54
+ """
55
+ image = image.astype(np.float32) / 255.0
56
+ image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel
57
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
58
+ image = torch.from_numpy(image)
59
+ return image
60
+
61
+
62
+ def load_from_local_model(
63
+ local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint
64
+ ):
65
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
66
+ download_from_original_stable_diffusion_ckpt,
67
+ )
68
+
69
+ logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline")
70
+
71
+ try:
72
+ pipe = download_from_original_stable_diffusion_ckpt(
73
+ local_model_path,
74
+ num_in_channels=4 if is_native_control_inpaint else 9,
75
+ from_safetensors=local_model_path.endswith("safetensors"),
76
+ device="cpu",
77
+ load_safety_checker=False,
78
+ )
79
+ except Exception as e:
80
+ err_msg = str(e)
81
+ logger.exception(e)
82
+ if is_native_control_inpaint and "[320, 9, 3, 3]" in err_msg:
83
+ logger.error(
84
+ "control_v11p_sd15_inpaint method requires normal SD model, not inpainting SD model"
85
+ )
86
+ if not is_native_control_inpaint and "[320, 4, 3, 3]" in err_msg:
87
+ logger.error(
88
+ f"{controlnet.config['_name_or_path']} method requires inpainting SD model, "
89
+ f"you can convert any SD model to inpainting model in AUTO1111: \n"
90
+ f"https://www.reddit.com/r/StableDiffusion/comments/zyi24j/how_to_turn_any_model_into_an_inpainting_model/"
91
+ )
92
+ exit(-1)
93
+
94
+ inpaint_pipe = pipe_class(
95
+ vae=pipe.vae,
96
+ text_encoder=pipe.text_encoder,
97
+ tokenizer=pipe.tokenizer,
98
+ unet=pipe.unet,
99
+ controlnet=controlnet,
100
+ scheduler=pipe.scheduler,
101
+ safety_checker=None,
102
+ feature_extractor=None,
103
+ requires_safety_checker=False,
104
+ )
105
+
106
+ del pipe
107
+ gc.collect()
108
+ return inpaint_pipe.to(torch_dtype=torch_dtype)
109
+
110
+
111
+ class ControlNet(DiffusionInpaintModel):
112
+ name = "controlnet"
113
+ pad_mod = 8
114
+ min_size = 512
115
+
116
+ def init_model(self, device: torch.device, **kwargs):
117
+ fp16 = not kwargs.get("no_half", False)
118
+
119
+ model_kwargs = {
120
+ "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
121
+ }
122
+ if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
123
+ logger.info("Disable Stable Diffusion Model NSFW checker")
124
+ model_kwargs.update(
125
+ dict(
126
+ safety_checker=None,
127
+ feature_extractor=None,
128
+ requires_safety_checker=False,
129
+ )
130
+ )
131
+
132
+ use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
133
+ torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
134
+
135
+ sd_controlnet_method = kwargs["sd_controlnet_method"]
136
+ self.sd_controlnet_method = sd_controlnet_method
137
+
138
+ if sd_controlnet_method == "control_v11p_sd15_inpaint":
139
+ from diffusers import StableDiffusionControlNetPipeline as PipeClass
140
+
141
+ self.is_native_control_inpaint = True
142
+ else:
143
+ from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass
144
+
145
+ self.is_native_control_inpaint = False
146
+
147
+ if self.is_native_control_inpaint:
148
+ model_id = NATIVE_NAMES_MAP[kwargs["name"]]
149
+ else:
150
+ model_id = NAMES_MAP[kwargs["name"]]
151
+
152
+ controlnet = ControlNetModel.from_pretrained(
153
+ f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype
154
+ )
155
+ self.is_local_sd_model = False
156
+ if kwargs.get("sd_local_model_path", None):
157
+ self.is_local_sd_model = True
158
+ self.model = load_from_local_model(
159
+ kwargs["sd_local_model_path"],
160
+ torch_dtype=torch_dtype,
161
+ controlnet=controlnet,
162
+ pipe_class=PipeClass,
163
+ is_native_control_inpaint=self.is_native_control_inpaint,
164
+ )
165
+ else:
166
+ self.model = PipeClass.from_pretrained(
167
+ model_id,
168
+ controlnet=controlnet,
169
+ revision="fp16" if use_gpu and fp16 else "main",
170
+ torch_dtype=torch_dtype,
171
+ **model_kwargs,
172
+ )
173
+
174
+ # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
175
+ self.model.enable_attention_slicing()
176
+ # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
177
+ if kwargs.get("enable_xformers", False):
178
+ self.model.enable_xformers_memory_efficient_attention()
179
+
180
+ if kwargs.get("cpu_offload", False) and use_gpu:
181
+ logger.info("Enable sequential cpu offload")
182
+ self.model.enable_sequential_cpu_offload(gpu_id=0)
183
+ else:
184
+ self.model = self.model.to(device)
185
+ if kwargs["sd_cpu_textencoder"]:
186
+ logger.info("Run Stable Diffusion TextEncoder on CPU")
187
+ self.model.text_encoder = CPUTextEncoderWrapper(
188
+ self.model.text_encoder, torch_dtype
189
+ )
190
+
191
+ self.callback = kwargs.pop("callback", None)
192
+
193
+ def forward(self, image, mask, config: Config):
194
+ """Input image and output image have same size
195
+ image: [H, W, C] RGB
196
+ mask: [H, W, 1] 255 means area to repaint
197
+ return: BGR IMAGE
198
+ """
199
+ scheduler_config = self.model.scheduler.config
200
+ scheduler = get_scheduler(config.sd_sampler, scheduler_config)
201
+ self.model.scheduler = scheduler
202
+
203
+ if config.sd_mask_blur != 0:
204
+ k = 2 * config.sd_mask_blur + 1
205
+ mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
206
+
207
+ img_h, img_w = image.shape[:2]
208
+
209
+ if self.is_native_control_inpaint:
210
+ control_image = make_inpaint_condition(image, mask)
211
+ output = self.model(
212
+ prompt=config.prompt,
213
+ image=control_image,
214
+ height=img_h,
215
+ width=img_w,
216
+ num_inference_steps=config.sd_steps,
217
+ guidance_scale=config.sd_guidance_scale,
218
+ controlnet_conditioning_scale=config.controlnet_conditioning_scale,
219
+ negative_prompt=config.negative_prompt,
220
+ generator=torch.manual_seed(config.sd_seed),
221
+ output_type="np.array",
222
+ callback=self.callback,
223
+ ).images[0]
224
+ else:
225
+ if "canny" in self.sd_controlnet_method:
226
+ canny_image = cv2.Canny(image, 100, 200)
227
+ canny_image = canny_image[:, :, None]
228
+ canny_image = np.concatenate(
229
+ [canny_image, canny_image, canny_image], axis=2
230
+ )
231
+ canny_image = PIL.Image.fromarray(canny_image)
232
+ control_image = canny_image
233
+ elif "openpose" in self.sd_controlnet_method:
234
+ from controlnet_aux import OpenposeDetector
235
+
236
+ processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
237
+ control_image = processor(image, hand_and_face=True)
238
+ elif "depth" in self.sd_controlnet_method:
239
+ from transformers import pipeline
240
+
241
+ depth_estimator = pipeline("depth-estimation")
242
+ depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"]
243
+ depth_image = np.array(depth_image)
244
+ depth_image = depth_image[:, :, None]
245
+ depth_image = np.concatenate(
246
+ [depth_image, depth_image, depth_image], axis=2
247
+ )
248
+ control_image = PIL.Image.fromarray(depth_image)
249
+ else:
250
+ raise NotImplementedError(
251
+ f"{self.sd_controlnet_method} not implemented"
252
+ )
253
+
254
+ mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
255
+ image = PIL.Image.fromarray(image)
256
+
257
+ output = self.model(
258
+ image=image,
259
+ control_image=control_image,
260
+ prompt=config.prompt,
261
+ negative_prompt=config.negative_prompt,
262
+ mask_image=mask_image,
263
+ num_inference_steps=config.sd_steps,
264
+ guidance_scale=config.sd_guidance_scale,
265
+ output_type="np.array",
266
+ callback=self.callback,
267
+ height=img_h,
268
+ width=img_w,
269
+ generator=torch.manual_seed(config.sd_seed),
270
+ controlnet_conditioning_scale=config.controlnet_conditioning_scale,
271
+ ).images[0]
272
+
273
+ output = (output * 255).round().astype("uint8")
274
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
275
+ return output
276
+
277
+ def forward_post_process(self, result, image, mask, config):
278
+ if config.sd_match_histograms:
279
+ result = self._match_histograms(result, image[:, :, ::-1], mask)
280
+
281
+ if config.sd_mask_blur != 0:
282
+ k = 2 * config.sd_mask_blur + 1
283
+ mask = cv2.GaussianBlur(mask, (k, k), 0)
284
+ return result, image, mask
285
+
286
+ @staticmethod
287
+ def is_downloaded() -> bool:
288
+ # model will be downloaded when app start, and can't switch in frontend settings
289
+ return True
lama_cleaner/model/ddim_sampler.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
6
+
7
+ from loguru import logger
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear"):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ setattr(self, name, attr)
19
+
20
+ def make_schedule(
21
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
22
+ ):
23
+ self.ddim_timesteps = make_ddim_timesteps(
24
+ ddim_discr_method=ddim_discretize,
25
+ num_ddim_timesteps=ddim_num_steps,
26
+ # array([1])
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
28
+ verbose=verbose,
29
+ )
30
+ alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
31
+ assert (
32
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
33
+ ), "alphas have to be defined for each timestep"
34
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
35
+
36
+ self.register_buffer("betas", to_torch(self.model.betas))
37
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
38
+ self.register_buffer(
39
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
40
+ )
41
+
42
+ # calculations for diffusion q(x_t | x_{t-1}) and others
43
+ self.register_buffer(
44
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
45
+ )
46
+ self.register_buffer(
47
+ "sqrt_one_minus_alphas_cumprod",
48
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
49
+ )
50
+ self.register_buffer(
51
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
52
+ )
53
+ self.register_buffer(
54
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
55
+ )
56
+ self.register_buffer(
57
+ "sqrt_recipm1_alphas_cumprod",
58
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
59
+ )
60
+
61
+ # ddim sampling parameters
62
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
63
+ alphacums=alphas_cumprod.cpu(),
64
+ ddim_timesteps=self.ddim_timesteps,
65
+ eta=ddim_eta,
66
+ verbose=verbose,
67
+ )
68
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
69
+ self.register_buffer("ddim_alphas", ddim_alphas)
70
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
71
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
72
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
73
+ (1 - self.alphas_cumprod_prev)
74
+ / (1 - self.alphas_cumprod)
75
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
76
+ )
77
+ self.register_buffer(
78
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
79
+ )
80
+
81
+ @torch.no_grad()
82
+ def sample(self, steps, conditioning, batch_size, shape):
83
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
84
+ # sampling
85
+ C, H, W = shape
86
+ size = (batch_size, C, H, W)
87
+
88
+ # samples: 1,3,128,128
89
+ return self.ddim_sampling(
90
+ conditioning,
91
+ size,
92
+ quantize_denoised=False,
93
+ ddim_use_original_steps=False,
94
+ noise_dropout=0,
95
+ temperature=1.0,
96
+ )
97
+
98
+ @torch.no_grad()
99
+ def ddim_sampling(
100
+ self,
101
+ cond,
102
+ shape,
103
+ ddim_use_original_steps=False,
104
+ quantize_denoised=False,
105
+ temperature=1.0,
106
+ noise_dropout=0.0,
107
+ ):
108
+ device = self.model.betas.device
109
+ b = shape[0]
110
+ img = torch.randn(shape, device=device, dtype=cond.dtype)
111
+ timesteps = (
112
+ self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
113
+ )
114
+
115
+ time_range = (
116
+ reversed(range(0, timesteps))
117
+ if ddim_use_original_steps
118
+ else np.flip(timesteps)
119
+ )
120
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
121
+ logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
122
+
123
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
124
+
125
+ for i, step in enumerate(iterator):
126
+ index = total_steps - i - 1
127
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
128
+
129
+ outs = self.p_sample_ddim(
130
+ img,
131
+ cond,
132
+ ts,
133
+ index=index,
134
+ use_original_steps=ddim_use_original_steps,
135
+ quantize_denoised=quantize_denoised,
136
+ temperature=temperature,
137
+ noise_dropout=noise_dropout,
138
+ )
139
+ img, _ = outs
140
+
141
+ return img
142
+
143
+ @torch.no_grad()
144
+ def p_sample_ddim(
145
+ self,
146
+ x,
147
+ c,
148
+ t,
149
+ index,
150
+ repeat_noise=False,
151
+ use_original_steps=False,
152
+ quantize_denoised=False,
153
+ temperature=1.0,
154
+ noise_dropout=0.0,
155
+ ):
156
+ b, *_, device = *x.shape, x.device
157
+ e_t = self.model.apply_model(x, t, c)
158
+
159
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
160
+ alphas_prev = (
161
+ self.model.alphas_cumprod_prev
162
+ if use_original_steps
163
+ else self.ddim_alphas_prev
164
+ )
165
+ sqrt_one_minus_alphas = (
166
+ self.model.sqrt_one_minus_alphas_cumprod
167
+ if use_original_steps
168
+ else self.ddim_sqrt_one_minus_alphas
169
+ )
170
+ sigmas = (
171
+ self.model.ddim_sigmas_for_original_num_steps
172
+ if use_original_steps
173
+ else self.ddim_sigmas
174
+ )
175
+ # select parameters corresponding to the currently considered timestep
176
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
177
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
178
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
179
+ sqrt_one_minus_at = torch.full(
180
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
181
+ )
182
+
183
+ # current prediction for x_0
184
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
185
+ if quantize_denoised: # 没用
186
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
187
+ # direction pointing to x_t
188
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
189
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
190
+ if noise_dropout > 0.0: # 没用
191
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
192
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
193
+ return x_prev, pred_x0
lama_cleaner/model/fcf.py ADDED
@@ -0,0 +1,1733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import torch.fft as fft
8
+
9
+ from lama_cleaner.schema import Config
10
+
11
+ from lama_cleaner.helper import (
12
+ load_model,
13
+ get_cache_path_by_url,
14
+ norm_img,
15
+ boxes_from_mask,
16
+ resize_max_size,
17
+ )
18
+ from lama_cleaner.model.base import InpaintModel
19
+ from torch import conv2d, nn
20
+ import torch.nn.functional as F
21
+
22
+ from lama_cleaner.model.utils import (
23
+ setup_filter,
24
+ _parse_scaling,
25
+ _parse_padding,
26
+ Conv2dLayer,
27
+ FullyConnectedLayer,
28
+ MinibatchStdLayer,
29
+ activation_funcs,
30
+ conv2d_resample,
31
+ bias_act,
32
+ upsample2d,
33
+ normalize_2nd_moment,
34
+ downsample2d,
35
+ )
36
+
37
+
38
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
39
+ assert isinstance(x, torch.Tensor)
40
+ return _upfirdn2d_ref(
41
+ x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
42
+ )
43
+
44
+
45
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
46
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
47
+ # Validate arguments.
48
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
49
+ if f is None:
50
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
51
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
52
+ assert f.dtype == torch.float32 and not f.requires_grad
53
+ batch_size, num_channels, in_height, in_width = x.shape
54
+ upx, upy = _parse_scaling(up)
55
+ downx, downy = _parse_scaling(down)
56
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
57
+
58
+ # Upsample by inserting zeros.
59
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
60
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
61
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
62
+
63
+ # Pad or crop.
64
+ x = torch.nn.functional.pad(
65
+ x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
66
+ )
67
+ x = x[
68
+ :,
69
+ :,
70
+ max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
71
+ max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
72
+ ]
73
+
74
+ # Setup filter.
75
+ f = f * (gain ** (f.ndim / 2))
76
+ f = f.to(x.dtype)
77
+ if not flip_filter:
78
+ f = f.flip(list(range(f.ndim)))
79
+
80
+ # Convolve with the filter.
81
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
82
+ if f.ndim == 4:
83
+ x = conv2d(input=x, weight=f, groups=num_channels)
84
+ else:
85
+ x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
86
+ x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
87
+
88
+ # Downsample by throwing away pixels.
89
+ x = x[:, :, ::downy, ::downx]
90
+ return x
91
+
92
+
93
+ class EncoderEpilogue(torch.nn.Module):
94
+ def __init__(
95
+ self,
96
+ in_channels, # Number of input channels.
97
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
98
+ z_dim, # Output Latent (Z) dimensionality.
99
+ resolution, # Resolution of this block.
100
+ img_channels, # Number of input color channels.
101
+ architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'.
102
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
103
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
104
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
105
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
106
+ ):
107
+ assert architecture in ["orig", "skip", "resnet"]
108
+ super().__init__()
109
+ self.in_channels = in_channels
110
+ self.cmap_dim = cmap_dim
111
+ self.resolution = resolution
112
+ self.img_channels = img_channels
113
+ self.architecture = architecture
114
+
115
+ if architecture == "skip":
116
+ self.fromrgb = Conv2dLayer(
117
+ self.img_channels, in_channels, kernel_size=1, activation=activation
118
+ )
119
+ self.mbstd = (
120
+ MinibatchStdLayer(
121
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
122
+ )
123
+ if mbstd_num_channels > 0
124
+ else None
125
+ )
126
+ self.conv = Conv2dLayer(
127
+ in_channels + mbstd_num_channels,
128
+ in_channels,
129
+ kernel_size=3,
130
+ activation=activation,
131
+ conv_clamp=conv_clamp,
132
+ )
133
+ self.fc = FullyConnectedLayer(
134
+ in_channels * (resolution**2), z_dim, activation=activation
135
+ )
136
+ self.dropout = torch.nn.Dropout(p=0.5)
137
+
138
+ def forward(self, x, cmap, force_fp32=False):
139
+ _ = force_fp32 # unused
140
+ dtype = torch.float32
141
+ memory_format = torch.contiguous_format
142
+
143
+ # FromRGB.
144
+ x = x.to(dtype=dtype, memory_format=memory_format)
145
+
146
+ # Main layers.
147
+ if self.mbstd is not None:
148
+ x = self.mbstd(x)
149
+ const_e = self.conv(x)
150
+ x = self.fc(const_e.flatten(1))
151
+ x = self.dropout(x)
152
+
153
+ # Conditioning.
154
+ if self.cmap_dim > 0:
155
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
156
+
157
+ assert x.dtype == dtype
158
+ return x, const_e
159
+
160
+
161
+ class EncoderBlock(torch.nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_channels, # Number of input channels, 0 = first block.
165
+ tmp_channels, # Number of intermediate channels.
166
+ out_channels, # Number of output channels.
167
+ resolution, # Resolution of this block.
168
+ img_channels, # Number of input color channels.
169
+ first_layer_idx, # Index of the first layer.
170
+ architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
171
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
172
+ resample_filter=[
173
+ 1,
174
+ 3,
175
+ 3,
176
+ 1,
177
+ ], # Low-pass filter to apply when resampling activations.
178
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
179
+ use_fp16=False, # Use FP16 for this block?
180
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
181
+ freeze_layers=0, # Freeze-D: Number of layers to freeze.
182
+ ):
183
+ assert in_channels in [0, tmp_channels]
184
+ assert architecture in ["orig", "skip", "resnet"]
185
+ super().__init__()
186
+ self.in_channels = in_channels
187
+ self.resolution = resolution
188
+ self.img_channels = img_channels + 1
189
+ self.first_layer_idx = first_layer_idx
190
+ self.architecture = architecture
191
+ self.use_fp16 = use_fp16
192
+ self.channels_last = use_fp16 and fp16_channels_last
193
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
194
+
195
+ self.num_layers = 0
196
+
197
+ def trainable_gen():
198
+ while True:
199
+ layer_idx = self.first_layer_idx + self.num_layers
200
+ trainable = layer_idx >= freeze_layers
201
+ self.num_layers += 1
202
+ yield trainable
203
+
204
+ trainable_iter = trainable_gen()
205
+
206
+ if in_channels == 0:
207
+ self.fromrgb = Conv2dLayer(
208
+ self.img_channels,
209
+ tmp_channels,
210
+ kernel_size=1,
211
+ activation=activation,
212
+ trainable=next(trainable_iter),
213
+ conv_clamp=conv_clamp,
214
+ channels_last=self.channels_last,
215
+ )
216
+
217
+ self.conv0 = Conv2dLayer(
218
+ tmp_channels,
219
+ tmp_channels,
220
+ kernel_size=3,
221
+ activation=activation,
222
+ trainable=next(trainable_iter),
223
+ conv_clamp=conv_clamp,
224
+ channels_last=self.channels_last,
225
+ )
226
+
227
+ self.conv1 = Conv2dLayer(
228
+ tmp_channels,
229
+ out_channels,
230
+ kernel_size=3,
231
+ activation=activation,
232
+ down=2,
233
+ trainable=next(trainable_iter),
234
+ resample_filter=resample_filter,
235
+ conv_clamp=conv_clamp,
236
+ channels_last=self.channels_last,
237
+ )
238
+
239
+ if architecture == "resnet":
240
+ self.skip = Conv2dLayer(
241
+ tmp_channels,
242
+ out_channels,
243
+ kernel_size=1,
244
+ bias=False,
245
+ down=2,
246
+ trainable=next(trainable_iter),
247
+ resample_filter=resample_filter,
248
+ channels_last=self.channels_last,
249
+ )
250
+
251
+ def forward(self, x, img, force_fp32=False):
252
+ # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
253
+ dtype = torch.float32
254
+ memory_format = (
255
+ torch.channels_last
256
+ if self.channels_last and not force_fp32
257
+ else torch.contiguous_format
258
+ )
259
+
260
+ # Input.
261
+ if x is not None:
262
+ x = x.to(dtype=dtype, memory_format=memory_format)
263
+
264
+ # FromRGB.
265
+ if self.in_channels == 0:
266
+ img = img.to(dtype=dtype, memory_format=memory_format)
267
+ y = self.fromrgb(img)
268
+ x = x + y if x is not None else y
269
+ img = (
270
+ downsample2d(img, self.resample_filter)
271
+ if self.architecture == "skip"
272
+ else None
273
+ )
274
+
275
+ # Main layers.
276
+ if self.architecture == "resnet":
277
+ y = self.skip(x, gain=np.sqrt(0.5))
278
+ x = self.conv0(x)
279
+ feat = x.clone()
280
+ x = self.conv1(x, gain=np.sqrt(0.5))
281
+ x = y.add_(x)
282
+ else:
283
+ x = self.conv0(x)
284
+ feat = x.clone()
285
+ x = self.conv1(x)
286
+
287
+ assert x.dtype == dtype
288
+ return x, img, feat
289
+
290
+
291
+ class EncoderNetwork(torch.nn.Module):
292
+ def __init__(
293
+ self,
294
+ c_dim, # Conditioning label (C) dimensionality.
295
+ z_dim, # Input latent (Z) dimensionality.
296
+ img_resolution, # Input resolution.
297
+ img_channels, # Number of input color channels.
298
+ architecture="orig", # Architecture: 'orig', 'skip', 'resnet'.
299
+ channel_base=16384, # Overall multiplier for the number of channels.
300
+ channel_max=512, # Maximum number of channels in any layer.
301
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
302
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
303
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
304
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
305
+ mapping_kwargs={}, # Arguments for MappingNetwork.
306
+ epilogue_kwargs={}, # Arguments for EncoderEpilogue.
307
+ ):
308
+ super().__init__()
309
+ self.c_dim = c_dim
310
+ self.z_dim = z_dim
311
+ self.img_resolution = img_resolution
312
+ self.img_resolution_log2 = int(np.log2(img_resolution))
313
+ self.img_channels = img_channels
314
+ self.block_resolutions = [
315
+ 2**i for i in range(self.img_resolution_log2, 2, -1)
316
+ ]
317
+ channels_dict = {
318
+ res: min(channel_base // res, channel_max)
319
+ for res in self.block_resolutions + [4]
320
+ }
321
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
322
+
323
+ if cmap_dim is None:
324
+ cmap_dim = channels_dict[4]
325
+ if c_dim == 0:
326
+ cmap_dim = 0
327
+
328
+ common_kwargs = dict(
329
+ img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp
330
+ )
331
+ cur_layer_idx = 0
332
+ for res in self.block_resolutions:
333
+ in_channels = channels_dict[res] if res < img_resolution else 0
334
+ tmp_channels = channels_dict[res]
335
+ out_channels = channels_dict[res // 2]
336
+ use_fp16 = res >= fp16_resolution
337
+ use_fp16 = False
338
+ block = EncoderBlock(
339
+ in_channels,
340
+ tmp_channels,
341
+ out_channels,
342
+ resolution=res,
343
+ first_layer_idx=cur_layer_idx,
344
+ use_fp16=use_fp16,
345
+ **block_kwargs,
346
+ **common_kwargs,
347
+ )
348
+ setattr(self, f"b{res}", block)
349
+ cur_layer_idx += block.num_layers
350
+ if c_dim > 0:
351
+ self.mapping = MappingNetwork(
352
+ z_dim=0,
353
+ c_dim=c_dim,
354
+ w_dim=cmap_dim,
355
+ num_ws=None,
356
+ w_avg_beta=None,
357
+ **mapping_kwargs,
358
+ )
359
+ self.b4 = EncoderEpilogue(
360
+ channels_dict[4],
361
+ cmap_dim=cmap_dim,
362
+ z_dim=z_dim * 2,
363
+ resolution=4,
364
+ **epilogue_kwargs,
365
+ **common_kwargs,
366
+ )
367
+
368
+ def forward(self, img, c, **block_kwargs):
369
+ x = None
370
+ feats = {}
371
+ for res in self.block_resolutions:
372
+ block = getattr(self, f"b{res}")
373
+ x, img, feat = block(x, img, **block_kwargs)
374
+ feats[res] = feat
375
+
376
+ cmap = None
377
+ if self.c_dim > 0:
378
+ cmap = self.mapping(None, c)
379
+ x, const_e = self.b4(x, cmap)
380
+ feats[4] = const_e
381
+
382
+ B, _ = x.shape
383
+ z = torch.zeros(
384
+ (B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device
385
+ ) ## Noise for Co-Modulation
386
+ return x, z, feats
387
+
388
+
389
+ def fma(a, b, c): # => a * b + c
390
+ return _FusedMultiplyAdd.apply(a, b, c)
391
+
392
+
393
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
394
+ @staticmethod
395
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
396
+ out = torch.addcmul(c, a, b)
397
+ ctx.save_for_backward(a, b)
398
+ ctx.c_shape = c.shape
399
+ return out
400
+
401
+ @staticmethod
402
+ def backward(ctx, dout): # pylint: disable=arguments-differ
403
+ a, b = ctx.saved_tensors
404
+ c_shape = ctx.c_shape
405
+ da = None
406
+ db = None
407
+ dc = None
408
+
409
+ if ctx.needs_input_grad[0]:
410
+ da = _unbroadcast(dout * b, a.shape)
411
+
412
+ if ctx.needs_input_grad[1]:
413
+ db = _unbroadcast(dout * a, b.shape)
414
+
415
+ if ctx.needs_input_grad[2]:
416
+ dc = _unbroadcast(dout, c_shape)
417
+
418
+ return da, db, dc
419
+
420
+
421
+ def _unbroadcast(x, shape):
422
+ extra_dims = x.ndim - len(shape)
423
+ assert extra_dims >= 0
424
+ dim = [
425
+ i
426
+ for i in range(x.ndim)
427
+ if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)
428
+ ]
429
+ if len(dim):
430
+ x = x.sum(dim=dim, keepdim=True)
431
+ if extra_dims:
432
+ x = x.reshape(-1, *x.shape[extra_dims + 1 :])
433
+ assert x.shape == shape
434
+ return x
435
+
436
+
437
+ def modulated_conv2d(
438
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
439
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
440
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
441
+ noise=None, # Optional noise tensor to add to the output activations.
442
+ up=1, # Integer upsampling factor.
443
+ down=1, # Integer downsampling factor.
444
+ padding=0, # Padding with respect to the upsampled image.
445
+ resample_filter=None,
446
+ # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
447
+ demodulate=True, # Apply weight demodulation?
448
+ flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
449
+ fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
450
+ ):
451
+ batch_size = x.shape[0]
452
+ out_channels, in_channels, kh, kw = weight.shape
453
+
454
+ # Pre-normalize inputs to avoid FP16 overflow.
455
+ if x.dtype == torch.float16 and demodulate:
456
+ weight = weight * (
457
+ 1
458
+ / np.sqrt(in_channels * kh * kw)
459
+ / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True)
460
+ ) # max_Ikk
461
+ styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I
462
+
463
+ # Calculate per-sample weights and demodulation coefficients.
464
+ w = None
465
+ dcoefs = None
466
+ if demodulate or fused_modconv:
467
+ w = weight.unsqueeze(0) # [NOIkk]
468
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
469
+ if demodulate:
470
+ dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
471
+ if demodulate and fused_modconv:
472
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
473
+ # Execute by scaling the activations before and after the convolution.
474
+ if not fused_modconv:
475
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
476
+ x = conv2d_resample.conv2d_resample(
477
+ x=x,
478
+ w=weight.to(x.dtype),
479
+ f=resample_filter,
480
+ up=up,
481
+ down=down,
482
+ padding=padding,
483
+ flip_weight=flip_weight,
484
+ )
485
+ if demodulate and noise is not None:
486
+ x = fma(
487
+ x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)
488
+ )
489
+ elif demodulate:
490
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
491
+ elif noise is not None:
492
+ x = x.add_(noise.to(x.dtype))
493
+ return x
494
+
495
+ # Execute as one fused op using grouped convolution.
496
+ batch_size = int(batch_size)
497
+ x = x.reshape(1, -1, *x.shape[2:])
498
+ w = w.reshape(-1, in_channels, kh, kw)
499
+ x = conv2d_resample(
500
+ x=x,
501
+ w=w.to(x.dtype),
502
+ f=resample_filter,
503
+ up=up,
504
+ down=down,
505
+ padding=padding,
506
+ groups=batch_size,
507
+ flip_weight=flip_weight,
508
+ )
509
+ x = x.reshape(batch_size, -1, *x.shape[2:])
510
+ if noise is not None:
511
+ x = x.add_(noise)
512
+ return x
513
+
514
+
515
+ class SynthesisLayer(torch.nn.Module):
516
+ def __init__(
517
+ self,
518
+ in_channels, # Number of input channels.
519
+ out_channels, # Number of output channels.
520
+ w_dim, # Intermediate latent (W) dimensionality.
521
+ resolution, # Resolution of this layer.
522
+ kernel_size=3, # Convolution kernel size.
523
+ up=1, # Integer upsampling factor.
524
+ use_noise=True, # Enable noise input?
525
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
526
+ resample_filter=[
527
+ 1,
528
+ 3,
529
+ 3,
530
+ 1,
531
+ ], # Low-pass filter to apply when resampling activations.
532
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
533
+ channels_last=False, # Use channels_last format for the weights?
534
+ ):
535
+ super().__init__()
536
+ self.resolution = resolution
537
+ self.up = up
538
+ self.use_noise = use_noise
539
+ self.activation = activation
540
+ self.conv_clamp = conv_clamp
541
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
542
+ self.padding = kernel_size // 2
543
+ self.act_gain = activation_funcs[activation].def_gain
544
+
545
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
546
+ memory_format = (
547
+ torch.channels_last if channels_last else torch.contiguous_format
548
+ )
549
+ self.weight = torch.nn.Parameter(
550
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
551
+ memory_format=memory_format
552
+ )
553
+ )
554
+ if use_noise:
555
+ self.register_buffer("noise_const", torch.randn([resolution, resolution]))
556
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
557
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
558
+
559
+ def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1):
560
+ assert noise_mode in ["random", "const", "none"]
561
+ in_resolution = self.resolution // self.up
562
+ styles = self.affine(w)
563
+
564
+ noise = None
565
+ if self.use_noise and noise_mode == "random":
566
+ noise = (
567
+ torch.randn(
568
+ [x.shape[0], 1, self.resolution, self.resolution], device=x.device
569
+ )
570
+ * self.noise_strength
571
+ )
572
+ if self.use_noise and noise_mode == "const":
573
+ noise = self.noise_const * self.noise_strength
574
+
575
+ flip_weight = self.up == 1 # slightly faster
576
+ x = modulated_conv2d(
577
+ x=x,
578
+ weight=self.weight,
579
+ styles=styles,
580
+ noise=noise,
581
+ up=self.up,
582
+ padding=self.padding,
583
+ resample_filter=self.resample_filter,
584
+ flip_weight=flip_weight,
585
+ fused_modconv=fused_modconv,
586
+ )
587
+
588
+ act_gain = self.act_gain * gain
589
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
590
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
591
+ if act_gain != 1:
592
+ x = x * act_gain
593
+ if act_clamp is not None:
594
+ x = x.clamp(-act_clamp, act_clamp)
595
+ return x
596
+
597
+
598
+ class ToRGBLayer(torch.nn.Module):
599
+ def __init__(
600
+ self,
601
+ in_channels,
602
+ out_channels,
603
+ w_dim,
604
+ kernel_size=1,
605
+ conv_clamp=None,
606
+ channels_last=False,
607
+ ):
608
+ super().__init__()
609
+ self.conv_clamp = conv_clamp
610
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
611
+ memory_format = (
612
+ torch.channels_last if channels_last else torch.contiguous_format
613
+ )
614
+ self.weight = torch.nn.Parameter(
615
+ torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
616
+ memory_format=memory_format
617
+ )
618
+ )
619
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
620
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
621
+
622
+ def forward(self, x, w, fused_modconv=True):
623
+ styles = self.affine(w) * self.weight_gain
624
+ x = modulated_conv2d(
625
+ x=x,
626
+ weight=self.weight,
627
+ styles=styles,
628
+ demodulate=False,
629
+ fused_modconv=fused_modconv,
630
+ )
631
+ x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
632
+ return x
633
+
634
+
635
+ class SynthesisForeword(torch.nn.Module):
636
+ def __init__(
637
+ self,
638
+ z_dim, # Output Latent (Z) dimensionality.
639
+ resolution, # Resolution of this block.
640
+ in_channels,
641
+ img_channels, # Number of input color channels.
642
+ architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
643
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
644
+ ):
645
+ super().__init__()
646
+ self.in_channels = in_channels
647
+ self.z_dim = z_dim
648
+ self.resolution = resolution
649
+ self.img_channels = img_channels
650
+ self.architecture = architecture
651
+
652
+ self.fc = FullyConnectedLayer(
653
+ self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation
654
+ )
655
+ self.conv = SynthesisLayer(
656
+ self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4
657
+ )
658
+
659
+ if architecture == "skip":
660
+ self.torgb = ToRGBLayer(
661
+ self.in_channels,
662
+ self.img_channels,
663
+ kernel_size=1,
664
+ w_dim=(z_dim // 2) * 3,
665
+ )
666
+
667
+ def forward(self, x, ws, feats, img, force_fp32=False):
668
+ _ = force_fp32 # unused
669
+ dtype = torch.float32
670
+ memory_format = torch.contiguous_format
671
+
672
+ x_global = x.clone()
673
+ # ToRGB.
674
+ x = self.fc(x)
675
+ x = x.view(-1, self.z_dim // 2, 4, 4)
676
+ x = x.to(dtype=dtype, memory_format=memory_format)
677
+
678
+ # Main layers.
679
+ x_skip = feats[4].clone()
680
+ x = x + x_skip
681
+
682
+ mod_vector = []
683
+ mod_vector.append(ws[:, 0])
684
+ mod_vector.append(x_global.clone())
685
+ mod_vector = torch.cat(mod_vector, dim=1)
686
+
687
+ x = self.conv(x, mod_vector)
688
+
689
+ mod_vector = []
690
+ mod_vector.append(ws[:, 2 * 2 - 3])
691
+ mod_vector.append(x_global.clone())
692
+ mod_vector = torch.cat(mod_vector, dim=1)
693
+
694
+ if self.architecture == "skip":
695
+ img = self.torgb(x, mod_vector)
696
+ img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
697
+
698
+ assert x.dtype == dtype
699
+ return x, img
700
+
701
+
702
+ class SELayer(nn.Module):
703
+ def __init__(self, channel, reduction=16):
704
+ super(SELayer, self).__init__()
705
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
706
+ self.fc = nn.Sequential(
707
+ nn.Linear(channel, channel // reduction, bias=False),
708
+ nn.ReLU(inplace=False),
709
+ nn.Linear(channel // reduction, channel, bias=False),
710
+ nn.Sigmoid(),
711
+ )
712
+
713
+ def forward(self, x):
714
+ b, c, _, _ = x.size()
715
+ y = self.avg_pool(x).view(b, c)
716
+ y = self.fc(y).view(b, c, 1, 1)
717
+ res = x * y.expand_as(x)
718
+ return res
719
+
720
+
721
+ class FourierUnit(nn.Module):
722
+ def __init__(
723
+ self,
724
+ in_channels,
725
+ out_channels,
726
+ groups=1,
727
+ spatial_scale_factor=None,
728
+ spatial_scale_mode="bilinear",
729
+ spectral_pos_encoding=False,
730
+ use_se=False,
731
+ se_kwargs=None,
732
+ ffc3d=False,
733
+ fft_norm="ortho",
734
+ ):
735
+ # bn_layer not used
736
+ super(FourierUnit, self).__init__()
737
+ self.groups = groups
738
+
739
+ self.conv_layer = torch.nn.Conv2d(
740
+ in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
741
+ out_channels=out_channels * 2,
742
+ kernel_size=1,
743
+ stride=1,
744
+ padding=0,
745
+ groups=self.groups,
746
+ bias=False,
747
+ )
748
+ self.relu = torch.nn.ReLU(inplace=False)
749
+
750
+ # squeeze and excitation block
751
+ self.use_se = use_se
752
+ if use_se:
753
+ if se_kwargs is None:
754
+ se_kwargs = {}
755
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
756
+
757
+ self.spatial_scale_factor = spatial_scale_factor
758
+ self.spatial_scale_mode = spatial_scale_mode
759
+ self.spectral_pos_encoding = spectral_pos_encoding
760
+ self.ffc3d = ffc3d
761
+ self.fft_norm = fft_norm
762
+
763
+ def forward(self, x):
764
+ batch = x.shape[0]
765
+
766
+ if self.spatial_scale_factor is not None:
767
+ orig_size = x.shape[-2:]
768
+ x = F.interpolate(
769
+ x,
770
+ scale_factor=self.spatial_scale_factor,
771
+ mode=self.spatial_scale_mode,
772
+ align_corners=False,
773
+ )
774
+
775
+ r_size = x.size()
776
+ # (batch, c, h, w/2+1, 2)
777
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
778
+ ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
779
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
780
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
781
+ ffted = ffted.view(
782
+ (
783
+ batch,
784
+ -1,
785
+ )
786
+ + ffted.size()[3:]
787
+ )
788
+
789
+ if self.spectral_pos_encoding:
790
+ height, width = ffted.shape[-2:]
791
+ coords_vert = (
792
+ torch.linspace(0, 1, height)[None, None, :, None]
793
+ .expand(batch, 1, height, width)
794
+ .to(ffted)
795
+ )
796
+ coords_hor = (
797
+ torch.linspace(0, 1, width)[None, None, None, :]
798
+ .expand(batch, 1, height, width)
799
+ .to(ffted)
800
+ )
801
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
802
+
803
+ if self.use_se:
804
+ ffted = self.se(ffted)
805
+
806
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
807
+ ffted = self.relu(ffted)
808
+
809
+ ffted = (
810
+ ffted.view(
811
+ (
812
+ batch,
813
+ -1,
814
+ 2,
815
+ )
816
+ + ffted.size()[2:]
817
+ )
818
+ .permute(0, 1, 3, 4, 2)
819
+ .contiguous()
820
+ ) # (batch,c, t, h, w/2+1, 2)
821
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
822
+
823
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
824
+ output = torch.fft.irfftn(
825
+ ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
826
+ )
827
+
828
+ if self.spatial_scale_factor is not None:
829
+ output = F.interpolate(
830
+ output,
831
+ size=orig_size,
832
+ mode=self.spatial_scale_mode,
833
+ align_corners=False,
834
+ )
835
+
836
+ return output
837
+
838
+
839
+ class SpectralTransform(nn.Module):
840
+ def __init__(
841
+ self,
842
+ in_channels,
843
+ out_channels,
844
+ stride=1,
845
+ groups=1,
846
+ enable_lfu=True,
847
+ **fu_kwargs,
848
+ ):
849
+ # bn_layer not used
850
+ super(SpectralTransform, self).__init__()
851
+ self.enable_lfu = enable_lfu
852
+ if stride == 2:
853
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
854
+ else:
855
+ self.downsample = nn.Identity()
856
+
857
+ self.stride = stride
858
+ self.conv1 = nn.Sequential(
859
+ nn.Conv2d(
860
+ in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
861
+ ),
862
+ # nn.BatchNorm2d(out_channels // 2),
863
+ nn.ReLU(inplace=True),
864
+ )
865
+ self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
866
+ if self.enable_lfu:
867
+ self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
868
+ self.conv2 = torch.nn.Conv2d(
869
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
870
+ )
871
+
872
+ def forward(self, x):
873
+
874
+ x = self.downsample(x)
875
+ x = self.conv1(x)
876
+ output = self.fu(x)
877
+
878
+ if self.enable_lfu:
879
+ n, c, h, w = x.shape
880
+ split_no = 2
881
+ split_s = h // split_no
882
+ xs = torch.cat(
883
+ torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
884
+ ).contiguous()
885
+ xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
886
+ xs = self.lfu(xs)
887
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
888
+ else:
889
+ xs = 0
890
+
891
+ output = self.conv2(x + output + xs)
892
+
893
+ return output
894
+
895
+
896
+ class FFC(nn.Module):
897
+ def __init__(
898
+ self,
899
+ in_channels,
900
+ out_channels,
901
+ kernel_size,
902
+ ratio_gin,
903
+ ratio_gout,
904
+ stride=1,
905
+ padding=0,
906
+ dilation=1,
907
+ groups=1,
908
+ bias=False,
909
+ enable_lfu=True,
910
+ padding_type="reflect",
911
+ gated=False,
912
+ **spectral_kwargs,
913
+ ):
914
+ super(FFC, self).__init__()
915
+
916
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
917
+ self.stride = stride
918
+
919
+ in_cg = int(in_channels * ratio_gin)
920
+ in_cl = in_channels - in_cg
921
+ out_cg = int(out_channels * ratio_gout)
922
+ out_cl = out_channels - out_cg
923
+ # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
924
+ # groups_l = 1 if groups == 1 else groups - groups_g
925
+
926
+ self.ratio_gin = ratio_gin
927
+ self.ratio_gout = ratio_gout
928
+ self.global_in_num = in_cg
929
+
930
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
931
+ self.convl2l = module(
932
+ in_cl,
933
+ out_cl,
934
+ kernel_size,
935
+ stride,
936
+ padding,
937
+ dilation,
938
+ groups,
939
+ bias,
940
+ padding_mode=padding_type,
941
+ )
942
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
943
+ self.convl2g = module(
944
+ in_cl,
945
+ out_cg,
946
+ kernel_size,
947
+ stride,
948
+ padding,
949
+ dilation,
950
+ groups,
951
+ bias,
952
+ padding_mode=padding_type,
953
+ )
954
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
955
+ self.convg2l = module(
956
+ in_cg,
957
+ out_cl,
958
+ kernel_size,
959
+ stride,
960
+ padding,
961
+ dilation,
962
+ groups,
963
+ bias,
964
+ padding_mode=padding_type,
965
+ )
966
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
967
+ self.convg2g = module(
968
+ in_cg,
969
+ out_cg,
970
+ stride,
971
+ 1 if groups == 1 else groups // 2,
972
+ enable_lfu,
973
+ **spectral_kwargs,
974
+ )
975
+
976
+ self.gated = gated
977
+ module = (
978
+ nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
979
+ )
980
+ self.gate = module(in_channels, 2, 1)
981
+
982
+ def forward(self, x, fname=None):
983
+ x_l, x_g = x if type(x) is tuple else (x, 0)
984
+ out_xl, out_xg = 0, 0
985
+
986
+ if self.gated:
987
+ total_input_parts = [x_l]
988
+ if torch.is_tensor(x_g):
989
+ total_input_parts.append(x_g)
990
+ total_input = torch.cat(total_input_parts, dim=1)
991
+
992
+ gates = torch.sigmoid(self.gate(total_input))
993
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
994
+ else:
995
+ g2l_gate, l2g_gate = 1, 1
996
+
997
+ spec_x = self.convg2g(x_g)
998
+
999
+ if self.ratio_gout != 1:
1000
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
1001
+ if self.ratio_gout != 0:
1002
+ out_xg = self.convl2g(x_l) * l2g_gate + spec_x
1003
+
1004
+ return out_xl, out_xg
1005
+
1006
+
1007
+ class FFC_BN_ACT(nn.Module):
1008
+ def __init__(
1009
+ self,
1010
+ in_channels,
1011
+ out_channels,
1012
+ kernel_size,
1013
+ ratio_gin,
1014
+ ratio_gout,
1015
+ stride=1,
1016
+ padding=0,
1017
+ dilation=1,
1018
+ groups=1,
1019
+ bias=False,
1020
+ norm_layer=nn.SyncBatchNorm,
1021
+ activation_layer=nn.Identity,
1022
+ padding_type="reflect",
1023
+ enable_lfu=True,
1024
+ **kwargs,
1025
+ ):
1026
+ super(FFC_BN_ACT, self).__init__()
1027
+ self.ffc = FFC(
1028
+ in_channels,
1029
+ out_channels,
1030
+ kernel_size,
1031
+ ratio_gin,
1032
+ ratio_gout,
1033
+ stride,
1034
+ padding,
1035
+ dilation,
1036
+ groups,
1037
+ bias,
1038
+ enable_lfu,
1039
+ padding_type=padding_type,
1040
+ **kwargs,
1041
+ )
1042
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
1043
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
1044
+ global_channels = int(out_channels * ratio_gout)
1045
+ # self.bn_l = lnorm(out_channels - global_channels)
1046
+ # self.bn_g = gnorm(global_channels)
1047
+
1048
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
1049
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
1050
+ self.act_l = lact(inplace=True)
1051
+ self.act_g = gact(inplace=True)
1052
+
1053
+ def forward(self, x, fname=None):
1054
+ x_l, x_g = self.ffc(
1055
+ x,
1056
+ fname=fname,
1057
+ )
1058
+ x_l = self.act_l(x_l)
1059
+ x_g = self.act_g(x_g)
1060
+ return x_l, x_g
1061
+
1062
+
1063
+ class FFCResnetBlock(nn.Module):
1064
+ def __init__(
1065
+ self,
1066
+ dim,
1067
+ padding_type,
1068
+ norm_layer,
1069
+ activation_layer=nn.ReLU,
1070
+ dilation=1,
1071
+ spatial_transform_kwargs=None,
1072
+ inline=False,
1073
+ ratio_gin=0.75,
1074
+ ratio_gout=0.75,
1075
+ ):
1076
+ super().__init__()
1077
+ self.conv1 = FFC_BN_ACT(
1078
+ dim,
1079
+ dim,
1080
+ kernel_size=3,
1081
+ padding=dilation,
1082
+ dilation=dilation,
1083
+ norm_layer=norm_layer,
1084
+ activation_layer=activation_layer,
1085
+ padding_type=padding_type,
1086
+ ratio_gin=ratio_gin,
1087
+ ratio_gout=ratio_gout,
1088
+ )
1089
+ self.conv2 = FFC_BN_ACT(
1090
+ dim,
1091
+ dim,
1092
+ kernel_size=3,
1093
+ padding=dilation,
1094
+ dilation=dilation,
1095
+ norm_layer=norm_layer,
1096
+ activation_layer=activation_layer,
1097
+ padding_type=padding_type,
1098
+ ratio_gin=ratio_gin,
1099
+ ratio_gout=ratio_gout,
1100
+ )
1101
+ self.inline = inline
1102
+
1103
+ def forward(self, x, fname=None):
1104
+ if self.inline:
1105
+ x_l, x_g = (
1106
+ x[:, : -self.conv1.ffc.global_in_num],
1107
+ x[:, -self.conv1.ffc.global_in_num :],
1108
+ )
1109
+ else:
1110
+ x_l, x_g = x if type(x) is tuple else (x, 0)
1111
+
1112
+ id_l, id_g = x_l, x_g
1113
+
1114
+ x_l, x_g = self.conv1((x_l, x_g), fname=fname)
1115
+ x_l, x_g = self.conv2((x_l, x_g), fname=fname)
1116
+
1117
+ x_l, x_g = id_l + x_l, id_g + x_g
1118
+ out = x_l, x_g
1119
+ if self.inline:
1120
+ out = torch.cat(out, dim=1)
1121
+ return out
1122
+
1123
+
1124
+ class ConcatTupleLayer(nn.Module):
1125
+ def forward(self, x):
1126
+ assert isinstance(x, tuple)
1127
+ x_l, x_g = x
1128
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
1129
+ if not torch.is_tensor(x_g):
1130
+ return x_l
1131
+ return torch.cat(x, dim=1)
1132
+
1133
+
1134
+ class FFCBlock(torch.nn.Module):
1135
+ def __init__(
1136
+ self,
1137
+ dim, # Number of output/input channels.
1138
+ kernel_size, # Width and height of the convolution kernel.
1139
+ padding,
1140
+ ratio_gin=0.75,
1141
+ ratio_gout=0.75,
1142
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
1143
+ ):
1144
+ super().__init__()
1145
+ if activation == "linear":
1146
+ self.activation = nn.Identity
1147
+ else:
1148
+ self.activation = nn.ReLU
1149
+ self.padding = padding
1150
+ self.kernel_size = kernel_size
1151
+ self.ffc_block = FFCResnetBlock(
1152
+ dim=dim,
1153
+ padding_type="reflect",
1154
+ norm_layer=nn.SyncBatchNorm,
1155
+ activation_layer=self.activation,
1156
+ dilation=1,
1157
+ ratio_gin=ratio_gin,
1158
+ ratio_gout=ratio_gout,
1159
+ )
1160
+
1161
+ self.concat_layer = ConcatTupleLayer()
1162
+
1163
+ def forward(self, gen_ft, mask, fname=None):
1164
+ x = gen_ft.float()
1165
+
1166
+ x_l, x_g = (
1167
+ x[:, : -self.ffc_block.conv1.ffc.global_in_num],
1168
+ x[:, -self.ffc_block.conv1.ffc.global_in_num :],
1169
+ )
1170
+ id_l, id_g = x_l, x_g
1171
+
1172
+ x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
1173
+ x_l, x_g = id_l + x_l, id_g + x_g
1174
+ x = self.concat_layer((x_l, x_g))
1175
+
1176
+ return x + gen_ft.float()
1177
+
1178
+
1179
+ class FFCSkipLayer(torch.nn.Module):
1180
+ def __init__(
1181
+ self,
1182
+ dim, # Number of input/output channels.
1183
+ kernel_size=3, # Convolution kernel size.
1184
+ ratio_gin=0.75,
1185
+ ratio_gout=0.75,
1186
+ ):
1187
+ super().__init__()
1188
+ self.padding = kernel_size // 2
1189
+
1190
+ self.ffc_act = FFCBlock(
1191
+ dim=dim,
1192
+ kernel_size=kernel_size,
1193
+ activation=nn.ReLU,
1194
+ padding=self.padding,
1195
+ ratio_gin=ratio_gin,
1196
+ ratio_gout=ratio_gout,
1197
+ )
1198
+
1199
+ def forward(self, gen_ft, mask, fname=None):
1200
+ x = self.ffc_act(gen_ft, mask, fname=fname)
1201
+ return x
1202
+
1203
+
1204
+ class SynthesisBlock(torch.nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ in_channels, # Number of input channels, 0 = first block.
1208
+ out_channels, # Number of output channels.
1209
+ w_dim, # Intermediate latent (W) dimensionality.
1210
+ resolution, # Resolution of this block.
1211
+ img_channels, # Number of output color channels.
1212
+ is_last, # Is this the last block?
1213
+ architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
1214
+ resample_filter=[
1215
+ 1,
1216
+ 3,
1217
+ 3,
1218
+ 1,
1219
+ ], # Low-pass filter to apply when resampling activations.
1220
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
1221
+ use_fp16=False, # Use FP16 for this block?
1222
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
1223
+ **layer_kwargs, # Arguments for SynthesisLayer.
1224
+ ):
1225
+ assert architecture in ["orig", "skip", "resnet"]
1226
+ super().__init__()
1227
+ self.in_channels = in_channels
1228
+ self.w_dim = w_dim
1229
+ self.resolution = resolution
1230
+ self.img_channels = img_channels
1231
+ self.is_last = is_last
1232
+ self.architecture = architecture
1233
+ self.use_fp16 = use_fp16
1234
+ self.channels_last = use_fp16 and fp16_channels_last
1235
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
1236
+ self.num_conv = 0
1237
+ self.num_torgb = 0
1238
+ self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
1239
+
1240
+ if in_channels != 0 and resolution >= 8:
1241
+ self.ffc_skip = nn.ModuleList()
1242
+ for _ in range(self.res_ffc[resolution]):
1243
+ self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
1244
+
1245
+ if in_channels == 0:
1246
+ self.const = torch.nn.Parameter(
1247
+ torch.randn([out_channels, resolution, resolution])
1248
+ )
1249
+
1250
+ if in_channels != 0:
1251
+ self.conv0 = SynthesisLayer(
1252
+ in_channels,
1253
+ out_channels,
1254
+ w_dim=w_dim * 3,
1255
+ resolution=resolution,
1256
+ up=2,
1257
+ resample_filter=resample_filter,
1258
+ conv_clamp=conv_clamp,
1259
+ channels_last=self.channels_last,
1260
+ **layer_kwargs,
1261
+ )
1262
+ self.num_conv += 1
1263
+
1264
+ self.conv1 = SynthesisLayer(
1265
+ out_channels,
1266
+ out_channels,
1267
+ w_dim=w_dim * 3,
1268
+ resolution=resolution,
1269
+ conv_clamp=conv_clamp,
1270
+ channels_last=self.channels_last,
1271
+ **layer_kwargs,
1272
+ )
1273
+ self.num_conv += 1
1274
+
1275
+ if is_last or architecture == "skip":
1276
+ self.torgb = ToRGBLayer(
1277
+ out_channels,
1278
+ img_channels,
1279
+ w_dim=w_dim * 3,
1280
+ conv_clamp=conv_clamp,
1281
+ channels_last=self.channels_last,
1282
+ )
1283
+ self.num_torgb += 1
1284
+
1285
+ if in_channels != 0 and architecture == "resnet":
1286
+ self.skip = Conv2dLayer(
1287
+ in_channels,
1288
+ out_channels,
1289
+ kernel_size=1,
1290
+ bias=False,
1291
+ up=2,
1292
+ resample_filter=resample_filter,
1293
+ channels_last=self.channels_last,
1294
+ )
1295
+
1296
+ def forward(
1297
+ self,
1298
+ x,
1299
+ mask,
1300
+ feats,
1301
+ img,
1302
+ ws,
1303
+ fname=None,
1304
+ force_fp32=False,
1305
+ fused_modconv=None,
1306
+ **layer_kwargs,
1307
+ ):
1308
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
1309
+ dtype = torch.float32
1310
+ memory_format = (
1311
+ torch.channels_last
1312
+ if self.channels_last and not force_fp32
1313
+ else torch.contiguous_format
1314
+ )
1315
+ if fused_modconv is None:
1316
+ fused_modconv = (not self.training) and (
1317
+ dtype == torch.float32 or int(x.shape[0]) == 1
1318
+ )
1319
+
1320
+ x = x.to(dtype=dtype, memory_format=memory_format)
1321
+ x_skip = (
1322
+ feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
1323
+ )
1324
+
1325
+ # Main layers.
1326
+ if self.in_channels == 0:
1327
+ x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
1328
+ elif self.architecture == "resnet":
1329
+ y = self.skip(x, gain=np.sqrt(0.5))
1330
+ x = self.conv0(
1331
+ x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
1332
+ )
1333
+ if len(self.ffc_skip) > 0:
1334
+ mask = F.interpolate(
1335
+ mask,
1336
+ size=x_skip.shape[2:],
1337
+ )
1338
+ z = x + x_skip
1339
+ for fres in self.ffc_skip:
1340
+ z = fres(z, mask)
1341
+ x = x + z
1342
+ else:
1343
+ x = x + x_skip
1344
+ x = self.conv1(
1345
+ x,
1346
+ ws[1].clone(),
1347
+ fused_modconv=fused_modconv,
1348
+ gain=np.sqrt(0.5),
1349
+ **layer_kwargs,
1350
+ )
1351
+ x = y.add_(x)
1352
+ else:
1353
+ x = self.conv0(
1354
+ x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
1355
+ )
1356
+ if len(self.ffc_skip) > 0:
1357
+ mask = F.interpolate(
1358
+ mask,
1359
+ size=x_skip.shape[2:],
1360
+ )
1361
+ z = x + x_skip
1362
+ for fres in self.ffc_skip:
1363
+ z = fres(z, mask)
1364
+ x = x + z
1365
+ else:
1366
+ x = x + x_skip
1367
+ x = self.conv1(
1368
+ x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs
1369
+ )
1370
+ # ToRGB.
1371
+ if img is not None:
1372
+ img = upsample2d(img, self.resample_filter)
1373
+ if self.is_last or self.architecture == "skip":
1374
+ y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
1375
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
1376
+ img = img.add_(y) if img is not None else y
1377
+
1378
+ x = x.to(dtype=dtype)
1379
+ assert x.dtype == dtype
1380
+ assert img is None or img.dtype == torch.float32
1381
+ return x, img
1382
+
1383
+
1384
+ class SynthesisNetwork(torch.nn.Module):
1385
+ def __init__(
1386
+ self,
1387
+ w_dim, # Intermediate latent (W) dimensionality.
1388
+ z_dim, # Output Latent (Z) dimensionality.
1389
+ img_resolution, # Output image resolution.
1390
+ img_channels, # Number of color channels.
1391
+ channel_base=16384, # Overall multiplier for the number of channels.
1392
+ channel_max=512, # Maximum number of channels in any layer.
1393
+ num_fp16_res=0, # Use FP16 for the N highest resolutions.
1394
+ **block_kwargs, # Arguments for SynthesisBlock.
1395
+ ):
1396
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
1397
+ super().__init__()
1398
+ self.w_dim = w_dim
1399
+ self.img_resolution = img_resolution
1400
+ self.img_resolution_log2 = int(np.log2(img_resolution))
1401
+ self.img_channels = img_channels
1402
+ self.block_resolutions = [
1403
+ 2**i for i in range(3, self.img_resolution_log2 + 1)
1404
+ ]
1405
+ channels_dict = {
1406
+ res: min(channel_base // res, channel_max) for res in self.block_resolutions
1407
+ }
1408
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
1409
+
1410
+ self.foreword = SynthesisForeword(
1411
+ img_channels=img_channels,
1412
+ in_channels=min(channel_base // 4, channel_max),
1413
+ z_dim=z_dim * 2,
1414
+ resolution=4,
1415
+ )
1416
+
1417
+ self.num_ws = self.img_resolution_log2 * 2 - 2
1418
+ for res in self.block_resolutions:
1419
+ if res // 2 in channels_dict.keys():
1420
+ in_channels = channels_dict[res // 2] if res > 4 else 0
1421
+ else:
1422
+ in_channels = min(channel_base // (res // 2), channel_max)
1423
+ out_channels = channels_dict[res]
1424
+ use_fp16 = res >= fp16_resolution
1425
+ use_fp16 = False
1426
+ is_last = res == self.img_resolution
1427
+ block = SynthesisBlock(
1428
+ in_channels,
1429
+ out_channels,
1430
+ w_dim=w_dim,
1431
+ resolution=res,
1432
+ img_channels=img_channels,
1433
+ is_last=is_last,
1434
+ use_fp16=use_fp16,
1435
+ **block_kwargs,
1436
+ )
1437
+ setattr(self, f"b{res}", block)
1438
+
1439
+ def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
1440
+
1441
+ img = None
1442
+
1443
+ x, img = self.foreword(x_global, ws, feats, img)
1444
+
1445
+ for res in self.block_resolutions:
1446
+ block = getattr(self, f"b{res}")
1447
+ mod_vector0 = []
1448
+ mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
1449
+ mod_vector0.append(x_global.clone())
1450
+ mod_vector0 = torch.cat(mod_vector0, dim=1)
1451
+
1452
+ mod_vector1 = []
1453
+ mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
1454
+ mod_vector1.append(x_global.clone())
1455
+ mod_vector1 = torch.cat(mod_vector1, dim=1)
1456
+
1457
+ mod_vector_rgb = []
1458
+ mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
1459
+ mod_vector_rgb.append(x_global.clone())
1460
+ mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
1461
+ x, img = block(
1462
+ x,
1463
+ mask,
1464
+ feats,
1465
+ img,
1466
+ (mod_vector0, mod_vector1, mod_vector_rgb),
1467
+ fname=fname,
1468
+ **block_kwargs,
1469
+ )
1470
+ return img
1471
+
1472
+
1473
+ class MappingNetwork(torch.nn.Module):
1474
+ def __init__(
1475
+ self,
1476
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1477
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1478
+ w_dim, # Intermediate latent (W) dimensionality.
1479
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
1480
+ num_layers=8, # Number of mapping layers.
1481
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
1482
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
1483
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
1484
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
1485
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
1486
+ ):
1487
+ super().__init__()
1488
+ self.z_dim = z_dim
1489
+ self.c_dim = c_dim
1490
+ self.w_dim = w_dim
1491
+ self.num_ws = num_ws
1492
+ self.num_layers = num_layers
1493
+ self.w_avg_beta = w_avg_beta
1494
+
1495
+ if embed_features is None:
1496
+ embed_features = w_dim
1497
+ if c_dim == 0:
1498
+ embed_features = 0
1499
+ if layer_features is None:
1500
+ layer_features = w_dim
1501
+ features_list = (
1502
+ [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
1503
+ )
1504
+
1505
+ if c_dim > 0:
1506
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
1507
+ for idx in range(num_layers):
1508
+ in_features = features_list[idx]
1509
+ out_features = features_list[idx + 1]
1510
+ layer = FullyConnectedLayer(
1511
+ in_features,
1512
+ out_features,
1513
+ activation=activation,
1514
+ lr_multiplier=lr_multiplier,
1515
+ )
1516
+ setattr(self, f"fc{idx}", layer)
1517
+
1518
+ if num_ws is not None and w_avg_beta is not None:
1519
+ self.register_buffer("w_avg", torch.zeros([w_dim]))
1520
+
1521
+ def forward(
1522
+ self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
1523
+ ):
1524
+ # Embed, normalize, and concat inputs.
1525
+ x = None
1526
+ with torch.autograd.profiler.record_function("input"):
1527
+ if self.z_dim > 0:
1528
+ x = normalize_2nd_moment(z.to(torch.float32))
1529
+ if self.c_dim > 0:
1530
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
1531
+ x = torch.cat([x, y], dim=1) if x is not None else y
1532
+
1533
+ # Main layers.
1534
+ for idx in range(self.num_layers):
1535
+ layer = getattr(self, f"fc{idx}")
1536
+ x = layer(x)
1537
+
1538
+ # Update moving average of W.
1539
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
1540
+ with torch.autograd.profiler.record_function("update_w_avg"):
1541
+ self.w_avg.copy_(
1542
+ x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
1543
+ )
1544
+
1545
+ # Broadcast.
1546
+ if self.num_ws is not None:
1547
+ with torch.autograd.profiler.record_function("broadcast"):
1548
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
1549
+
1550
+ # Apply truncation.
1551
+ if truncation_psi != 1:
1552
+ with torch.autograd.profiler.record_function("truncate"):
1553
+ assert self.w_avg_beta is not None
1554
+ if self.num_ws is None or truncation_cutoff is None:
1555
+ x = self.w_avg.lerp(x, truncation_psi)
1556
+ else:
1557
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
1558
+ x[:, :truncation_cutoff], truncation_psi
1559
+ )
1560
+ return x
1561
+
1562
+
1563
+ class Generator(torch.nn.Module):
1564
+ def __init__(
1565
+ self,
1566
+ z_dim, # Input latent (Z) dimensionality.
1567
+ c_dim, # Conditioning label (C) dimensionality.
1568
+ w_dim, # Intermediate latent (W) dimensionality.
1569
+ img_resolution, # Output resolution.
1570
+ img_channels, # Number of output color channels.
1571
+ encoder_kwargs={}, # Arguments for EncoderNetwork.
1572
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1573
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1574
+ ):
1575
+ super().__init__()
1576
+ self.z_dim = z_dim
1577
+ self.c_dim = c_dim
1578
+ self.w_dim = w_dim
1579
+ self.img_resolution = img_resolution
1580
+ self.img_channels = img_channels
1581
+ self.encoder = EncoderNetwork(
1582
+ c_dim=c_dim,
1583
+ z_dim=z_dim,
1584
+ img_resolution=img_resolution,
1585
+ img_channels=img_channels,
1586
+ **encoder_kwargs,
1587
+ )
1588
+ self.synthesis = SynthesisNetwork(
1589
+ z_dim=z_dim,
1590
+ w_dim=w_dim,
1591
+ img_resolution=img_resolution,
1592
+ img_channels=img_channels,
1593
+ **synthesis_kwargs,
1594
+ )
1595
+ self.num_ws = self.synthesis.num_ws
1596
+ self.mapping = MappingNetwork(
1597
+ z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs
1598
+ )
1599
+
1600
+ def forward(
1601
+ self,
1602
+ img,
1603
+ c,
1604
+ fname=None,
1605
+ truncation_psi=1,
1606
+ truncation_cutoff=None,
1607
+ **synthesis_kwargs,
1608
+ ):
1609
+ mask = img[:, -1].unsqueeze(1)
1610
+ x_global, z, feats = self.encoder(img, c)
1611
+ ws = self.mapping(
1612
+ z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff
1613
+ )
1614
+ img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
1615
+ return img
1616
+
1617
+
1618
+ FCF_MODEL_URL = os.environ.get(
1619
+ "FCF_MODEL_URL",
1620
+ "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
1621
+ )
1622
+ FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211")
1623
+
1624
+
1625
+ class FcF(InpaintModel):
1626
+ name = "fcf"
1627
+ min_size = 512
1628
+ pad_mod = 512
1629
+ pad_to_square = True
1630
+
1631
+ def init_model(self, device, **kwargs):
1632
+ seed = 0
1633
+ random.seed(seed)
1634
+ np.random.seed(seed)
1635
+ torch.manual_seed(seed)
1636
+ torch.cuda.manual_seed_all(seed)
1637
+ torch.backends.cudnn.deterministic = True
1638
+ torch.backends.cudnn.benchmark = False
1639
+
1640
+ kwargs = {
1641
+ "channel_base": 1 * 32768,
1642
+ "channel_max": 512,
1643
+ "num_fp16_res": 4,
1644
+ "conv_clamp": 256,
1645
+ }
1646
+ G = Generator(
1647
+ z_dim=512,
1648
+ c_dim=0,
1649
+ w_dim=512,
1650
+ img_resolution=512,
1651
+ img_channels=3,
1652
+ synthesis_kwargs=kwargs,
1653
+ encoder_kwargs=kwargs,
1654
+ mapping_kwargs={"num_layers": 2},
1655
+ )
1656
+ self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5)
1657
+ self.label = torch.zeros([1, self.model.c_dim], device=device)
1658
+
1659
+ @staticmethod
1660
+ def is_downloaded() -> bool:
1661
+ return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
1662
+
1663
+ @torch.no_grad()
1664
+ def __call__(self, image, mask, config: Config):
1665
+ """
1666
+ images: [H, W, C] RGB, not normalized
1667
+ masks: [H, W]
1668
+ return: BGR IMAGE
1669
+ """
1670
+ if image.shape[0] == 512 and image.shape[1] == 512:
1671
+ return self._pad_forward(image, mask, config)
1672
+
1673
+ boxes = boxes_from_mask(mask)
1674
+ crop_result = []
1675
+ config.hd_strategy_crop_margin = 128
1676
+ for box in boxes:
1677
+ crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
1678
+ origin_size = crop_image.shape[:2]
1679
+ resize_image = resize_max_size(crop_image, size_limit=512)
1680
+ resize_mask = resize_max_size(crop_mask, size_limit=512)
1681
+ inpaint_result = self._pad_forward(resize_image, resize_mask, config)
1682
+
1683
+ # only paste masked area result
1684
+ inpaint_result = cv2.resize(
1685
+ inpaint_result,
1686
+ (origin_size[1], origin_size[0]),
1687
+ interpolation=cv2.INTER_CUBIC,
1688
+ )
1689
+
1690
+ original_pixel_indices = crop_mask < 127
1691
+ inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
1692
+ original_pixel_indices
1693
+ ]
1694
+
1695
+ crop_result.append((inpaint_result, crop_box))
1696
+
1697
+ inpaint_result = image[:, :, ::-1]
1698
+ for crop_image, crop_box in crop_result:
1699
+ x1, y1, x2, y2 = crop_box
1700
+ inpaint_result[y1:y2, x1:x2, :] = crop_image
1701
+
1702
+ return inpaint_result
1703
+
1704
+ def forward(self, image, mask, config: Config):
1705
+ """Input images and output images have same size
1706
+ images: [H, W, C] RGB
1707
+ masks: [H, W] mask area == 255
1708
+ return: BGR IMAGE
1709
+ """
1710
+
1711
+ image = norm_img(image) # [0, 1]
1712
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1713
+ mask = (mask > 120) * 255
1714
+ mask = norm_img(mask)
1715
+
1716
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
1717
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
1718
+
1719
+ erased_img = image * (1 - mask)
1720
+ input_image = torch.cat([0.5 - mask, erased_img], dim=1)
1721
+
1722
+ output = self.model(
1723
+ input_image, self.label, truncation_psi=0.1, noise_mode="none"
1724
+ )
1725
+ output = (
1726
+ (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
1727
+ .round()
1728
+ .clamp(0, 255)
1729
+ .to(torch.uint8)
1730
+ )
1731
+ output = output[0].cpu().numpy()
1732
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1733
+ return cur_res
lama_cleaner/model/instruct_pix2pix.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+ import cv2
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from lama_cleaner.model.base import DiffusionInpaintModel
7
+ from lama_cleaner.model.utils import set_seed
8
+ from lama_cleaner.schema import Config
9
+
10
+
11
+ class InstructPix2Pix(DiffusionInpaintModel):
12
+ name = "instruct_pix2pix"
13
+ pad_mod = 8
14
+ min_size = 512
15
+
16
+ def init_model(self, device: torch.device, **kwargs):
17
+ from diffusers import StableDiffusionInstructPix2PixPipeline
18
+ fp16 = not kwargs.get('no_half', False)
19
+
20
+ model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
21
+ if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
22
+ logger.info("Disable Stable Diffusion Model NSFW checker")
23
+ model_kwargs.update(dict(
24
+ safety_checker=None,
25
+ feature_extractor=None,
26
+ requires_safety_checker=False
27
+ ))
28
+
29
+ use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
30
+ torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
31
+ self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
32
+ "timbrooks/instruct-pix2pix",
33
+ revision="fp16" if use_gpu and fp16 else "main",
34
+ torch_dtype=torch_dtype,
35
+ **model_kwargs
36
+ )
37
+
38
+ self.model.enable_attention_slicing()
39
+ if kwargs.get('enable_xformers', False):
40
+ self.model.enable_xformers_memory_efficient_attention()
41
+
42
+ if kwargs.get('cpu_offload', False) and use_gpu:
43
+ logger.info("Enable sequential cpu offload")
44
+ self.model.enable_sequential_cpu_offload(gpu_id=0)
45
+ else:
46
+ self.model = self.model.to(device)
47
+
48
+ def forward(self, image, mask, config: Config):
49
+ """Input image and output image have same size
50
+ image: [H, W, C] RGB
51
+ mask: [H, W, 1] 255 means area to repaint
52
+ return: BGR IMAGE
53
+ edit = pipe(prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7).images[0]
54
+ """
55
+ output = self.model(
56
+ image=PIL.Image.fromarray(image),
57
+ prompt=config.prompt,
58
+ negative_prompt=config.negative_prompt,
59
+ num_inference_steps=config.p2p_steps,
60
+ image_guidance_scale=config.p2p_image_guidance_scale,
61
+ guidance_scale=config.p2p_guidance_scale,
62
+ output_type="np.array",
63
+ generator=torch.manual_seed(config.sd_seed)
64
+ ).images[0]
65
+
66
+ output = (output * 255).round().astype("uint8")
67
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
68
+ return output
69
+
70
+ #
71
+ # def forward_post_process(self, result, image, mask, config):
72
+ # if config.sd_match_histograms:
73
+ # result = self._match_histograms(result, image[:, :, ::-1], mask)
74
+ #
75
+ # if config.sd_mask_blur != 0:
76
+ # k = 2 * config.sd_mask_blur + 1
77
+ # mask = cv2.GaussianBlur(mask, (k, k), 0)
78
+ # return result, image, mask
79
+
80
+ @staticmethod
81
+ def is_downloaded() -> bool:
82
+ # model will be downloaded when app start, and can't switch in frontend settings
83
+ return True
lama_cleaner/model/lama.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+ from lama_cleaner.helper import (
8
+ norm_img,
9
+ get_cache_path_by_url,
10
+ load_jit_model,
11
+ )
12
+ from lama_cleaner.model.base import InpaintModel
13
+ from lama_cleaner.schema import Config
14
+
15
+ LAMA_MODEL_URL = os.environ.get(
16
+ "LAMA_MODEL_URL",
17
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
18
+ )
19
+ LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
20
+
21
+
22
+ class LaMa(InpaintModel):
23
+ name = "lama"
24
+ pad_mod = 8
25
+
26
+ def init_model(self, device, **kwargs):
27
+ self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
28
+
29
+ @staticmethod
30
+ def is_downloaded() -> bool:
31
+ return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
32
+
33
+ def forward(self, image, mask, config: Config):
34
+ """Input image and output image have same size
35
+ image: [H, W, C] RGB
36
+ mask: [H, W]
37
+ return: BGR IMAGE
38
+ """
39
+ image = norm_img(image)
40
+ mask = norm_img(mask)
41
+
42
+ mask = (mask > 0) * 1
43
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
44
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
45
+
46
+ inpainted_image = self.model(image, mask)
47
+
48
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
49
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
50
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
51
+ return cur_res
lama_cleaner/model/ldm.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import wraps
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from lama_cleaner.helper import get_cache_path_by_url, load_jit_model, norm_img
9
+ from lama_cleaner.model.base import InpaintModel
10
+ from lama_cleaner.model.ddim_sampler import DDIMSampler
11
+ from lama_cleaner.model.plms_sampler import PLMSSampler
12
+ from lama_cleaner.model.utils import make_beta_schedule, timestep_embedding
13
+ from lama_cleaner.schema import Config, LDMSampler
14
+
15
+ # torch.manual_seed(42)
16
+
17
+
18
+ def conditional_autocast(func):
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ if torch.cuda.is_available():
22
+ with torch.cuda.amp.autocast():
23
+ return func(*args, **kwargs)
24
+ else:
25
+ return func(*args, **kwargs)
26
+ return wrapper
27
+
28
+
29
+ LDM_ENCODE_MODEL_URL = os.environ.get(
30
+ "LDM_ENCODE_MODEL_URL",
31
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
32
+ )
33
+ LDM_ENCODE_MODEL_MD5 = os.environ.get(
34
+ "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
35
+ )
36
+
37
+ LDM_DECODE_MODEL_URL = os.environ.get(
38
+ "LDM_DECODE_MODEL_URL",
39
+ "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
40
+ )
41
+ LDM_DECODE_MODEL_MD5 = os.environ.get(
42
+ "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
43
+ )
44
+
45
+ LDM_DIFFUSION_MODEL_URL = os.environ.get(
46
+ "LDM_DIFFUSION_MODEL_URL",
47
+ "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
48
+ )
49
+
50
+ LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
51
+ "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
52
+ )
53
+
54
+
55
+ class DDPM(nn.Module):
56
+ # classic DDPM with Gaussian diffusion, in image space
57
+ def __init__(
58
+ self,
59
+ device,
60
+ timesteps=1000,
61
+ beta_schedule="linear",
62
+ linear_start=0.0015,
63
+ linear_end=0.0205,
64
+ cosine_s=0.008,
65
+ original_elbo_weight=0.0,
66
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
67
+ l_simple_weight=1.0,
68
+ parameterization="eps", # all assuming fixed variance schedules
69
+ use_positional_encodings=False,
70
+ ):
71
+ super().__init__()
72
+ self.device = device
73
+ self.parameterization = parameterization
74
+ self.use_positional_encodings = use_positional_encodings
75
+
76
+ self.v_posterior = v_posterior
77
+ self.original_elbo_weight = original_elbo_weight
78
+ self.l_simple_weight = l_simple_weight
79
+
80
+ self.register_schedule(
81
+ beta_schedule=beta_schedule,
82
+ timesteps=timesteps,
83
+ linear_start=linear_start,
84
+ linear_end=linear_end,
85
+ cosine_s=cosine_s,
86
+ )
87
+
88
+ def register_schedule(
89
+ self,
90
+ given_betas=None,
91
+ beta_schedule="linear",
92
+ timesteps=1000,
93
+ linear_start=1e-4,
94
+ linear_end=2e-2,
95
+ cosine_s=8e-3,
96
+ ):
97
+ betas = make_beta_schedule(
98
+ self.device,
99
+ beta_schedule,
100
+ timesteps,
101
+ linear_start=linear_start,
102
+ linear_end=linear_end,
103
+ cosine_s=cosine_s,
104
+ )
105
+ alphas = 1.0 - betas
106
+ alphas_cumprod = np.cumprod(alphas, axis=0)
107
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
108
+
109
+ (timesteps,) = betas.shape
110
+ self.num_timesteps = int(timesteps)
111
+ self.linear_start = linear_start
112
+ self.linear_end = linear_end
113
+ assert (
114
+ alphas_cumprod.shape[0] == self.num_timesteps
115
+ ), "alphas have to be defined for each timestep"
116
+
117
+ def to_torch(x): return torch.tensor(x, dtype=torch.float32).to(self.device)
118
+
119
+ self.register_buffer("betas", to_torch(betas))
120
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
121
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
122
+
123
+ # calculations for diffusion q(x_t | x_{t-1}) and others
124
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
125
+ self.register_buffer(
126
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
127
+ )
128
+ self.register_buffer(
129
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
130
+ )
131
+ self.register_buffer(
132
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
133
+ )
134
+ self.register_buffer(
135
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
136
+ )
137
+
138
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
139
+ posterior_variance = (1 - self.v_posterior) * betas * (
140
+ 1.0 - alphas_cumprod_prev
141
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
142
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
143
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
144
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
145
+ self.register_buffer(
146
+ "posterior_log_variance_clipped",
147
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
148
+ )
149
+ self.register_buffer(
150
+ "posterior_mean_coef1",
151
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
152
+ )
153
+ self.register_buffer(
154
+ "posterior_mean_coef2",
155
+ to_torch(
156
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
157
+ ),
158
+ )
159
+
160
+ if self.parameterization == "eps":
161
+ lvlb_weights = self.betas**2 / (
162
+ 2
163
+ * self.posterior_variance
164
+ * to_torch(alphas)
165
+ * (1 - self.alphas_cumprod)
166
+ )
167
+ elif self.parameterization == "x0":
168
+ lvlb_weights = (
169
+ 0.5
170
+ * np.sqrt(torch.Tensor(alphas_cumprod))
171
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
172
+ )
173
+ else:
174
+ raise NotImplementedError("mu not supported")
175
+ # TODO how to choose this term
176
+ lvlb_weights[0] = lvlb_weights[1]
177
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
178
+ assert not torch.isnan(self.lvlb_weights).all()
179
+
180
+
181
+ class LatentDiffusion(DDPM):
182
+ def __init__(
183
+ self,
184
+ diffusion_model,
185
+ device,
186
+ cond_stage_key="image",
187
+ cond_stage_trainable=False,
188
+ concat_mode=True,
189
+ scale_factor=1.0,
190
+ scale_by_std=False,
191
+ *args,
192
+ **kwargs,
193
+ ):
194
+ self.num_timesteps_cond = 1
195
+ self.scale_by_std = scale_by_std
196
+ super().__init__(device, *args, **kwargs)
197
+ self.diffusion_model = diffusion_model
198
+ self.concat_mode = concat_mode
199
+ self.cond_stage_trainable = cond_stage_trainable
200
+ self.cond_stage_key = cond_stage_key
201
+ self.num_downs = 2
202
+ self.scale_factor = scale_factor
203
+
204
+ def make_cond_schedule(
205
+ self,
206
+ ):
207
+ self.cond_ids = torch.full(
208
+ size=(self.num_timesteps,),
209
+ fill_value=self.num_timesteps - 1,
210
+ dtype=torch.long,
211
+ )
212
+ ids = torch.round(
213
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
214
+ ).long()
215
+ self.cond_ids[: self.num_timesteps_cond] = ids
216
+
217
+ def register_schedule(
218
+ self,
219
+ given_betas=None,
220
+ beta_schedule="linear",
221
+ timesteps=1000,
222
+ linear_start=1e-4,
223
+ linear_end=2e-2,
224
+ cosine_s=8e-3,
225
+ ):
226
+ super().register_schedule(
227
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
228
+ )
229
+
230
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
231
+ if self.shorten_cond_schedule:
232
+ self.make_cond_schedule()
233
+
234
+ def apply_model(self, x_noisy, t, cond):
235
+ # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
236
+ t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
237
+ x_recon = self.diffusion_model(x_noisy, t_emb, cond)
238
+ return x_recon
239
+
240
+
241
+ class LDM(InpaintModel):
242
+ name = "ldm"
243
+ pad_mod = 32
244
+
245
+ def __init__(self, device, fp16: bool = True, **kwargs):
246
+ self.fp16 = fp16
247
+ super().__init__(device)
248
+ self.device = device
249
+
250
+ def init_model(self, device, **kwargs):
251
+ self.diffusion_model = load_jit_model(
252
+ LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
253
+ )
254
+ self.cond_stage_model_decode = load_jit_model(
255
+ LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
256
+ )
257
+ self.cond_stage_model_encode = load_jit_model(
258
+ LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
259
+ )
260
+ if self.fp16 and "cuda" in str(device):
261
+ self.diffusion_model = self.diffusion_model.half()
262
+ self.cond_stage_model_decode = self.cond_stage_model_decode.half()
263
+ self.cond_stage_model_encode = self.cond_stage_model_encode.half()
264
+
265
+ self.model = LatentDiffusion(self.diffusion_model, device)
266
+
267
+ @staticmethod
268
+ def is_downloaded() -> bool:
269
+ model_paths = [
270
+ get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
271
+ get_cache_path_by_url(LDM_DECODE_MODEL_URL),
272
+ get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
273
+ ]
274
+ return all([os.path.exists(it) for it in model_paths])
275
+
276
+ @conditional_autocast
277
+ def forward(self, image, mask, config: Config):
278
+ """
279
+ image: [H, W, C] RGB
280
+ mask: [H, W, 1]
281
+ return: BGR IMAGE
282
+ """
283
+ # image [1,3,512,512] float32
284
+ # mask: [1,1,512,512] float32
285
+ # masked_image: [1,3,512,512] float32
286
+ if config.ldm_sampler == LDMSampler.ddim:
287
+ sampler = DDIMSampler(self.model)
288
+ elif config.ldm_sampler == LDMSampler.plms:
289
+ sampler = PLMSSampler(self.model)
290
+ else:
291
+ raise ValueError()
292
+
293
+ steps = config.ldm_steps
294
+ image = norm_img(image)
295
+ mask = norm_img(mask)
296
+
297
+ mask[mask < 0.5] = 0
298
+ mask[mask >= 0.5] = 1
299
+
300
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
301
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
302
+ masked_image = (1 - mask) * image
303
+
304
+ mask = self._norm(mask)
305
+ masked_image = self._norm(masked_image)
306
+
307
+ c = self.cond_stage_model_encode(masked_image)
308
+ torch.cuda.empty_cache()
309
+
310
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
311
+ c = torch.cat((c, cc), dim=1) # 1,4,128,128
312
+
313
+ shape = (c.shape[1] - 1,) + c.shape[2:]
314
+ samples_ddim = sampler.sample(
315
+ steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
316
+ )
317
+ torch.cuda.empty_cache()
318
+ x_samples_ddim = self.cond_stage_model_decode(
319
+ samples_ddim
320
+ ) # samples_ddim: 1, 3, 128, 128 float32
321
+ torch.cuda.empty_cache()
322
+
323
+ # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
324
+ # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
325
+ inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
326
+
327
+ # inpainted = (1 - mask) * image + mask * predicted_image
328
+ inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
329
+ inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
330
+ return inpainted_image
331
+
332
+ def _norm(self, tensor):
333
+ return tensor * 2.0 - 1.0
lama_cleaner/model/manga.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import time
8
+ from loguru import logger
9
+
10
+ from lama_cleaner.helper import get_cache_path_by_url, load_jit_model
11
+ from lama_cleaner.model.base import InpaintModel
12
+ from lama_cleaner.schema import Config
13
+
14
+
15
+ MANGA_INPAINTOR_MODEL_URL = os.environ.get(
16
+ "MANGA_INPAINTOR_MODEL_URL",
17
+ "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
18
+ )
19
+ MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
20
+ "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
21
+ )
22
+
23
+ MANGA_LINE_MODEL_URL = os.environ.get(
24
+ "MANGA_LINE_MODEL_URL",
25
+ "https://github.com/Sanster/models/releases/download/manga/erika.jit",
26
+ )
27
+ MANGA_LINE_MODEL_MD5 = os.environ.get(
28
+ "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
29
+ )
30
+
31
+
32
+ class Manga(InpaintModel):
33
+ name = "manga"
34
+ pad_mod = 16
35
+
36
+ def init_model(self, device, **kwargs):
37
+ self.inpaintor_model = load_jit_model(
38
+ MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
39
+ )
40
+ self.line_model = load_jit_model(
41
+ MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
42
+ )
43
+ self.seed = 42
44
+
45
+ @staticmethod
46
+ def is_downloaded() -> bool:
47
+ model_paths = [
48
+ get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
49
+ get_cache_path_by_url(MANGA_LINE_MODEL_URL),
50
+ ]
51
+ return all([os.path.exists(it) for it in model_paths])
52
+
53
+ def forward(self, image, mask, config: Config):
54
+ """
55
+ image: [H, W, C] RGB
56
+ mask: [H, W, 1]
57
+ return: BGR IMAGE
58
+ """
59
+ seed = self.seed
60
+ random.seed(seed)
61
+ np.random.seed(seed)
62
+ torch.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
+
65
+ gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
66
+ gray_img = torch.from_numpy(
67
+ gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
68
+ ).to(self.device)
69
+ start = time.time()
70
+ lines = self.line_model(gray_img)
71
+ torch.cuda.empty_cache()
72
+ lines = torch.clamp(lines, 0, 255)
73
+ logger.info(f"erika_model time: {time.time() - start}")
74
+
75
+ mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
76
+ mask = mask.permute(0, 3, 1, 2)
77
+ mask = torch.where(mask > 0.5, 1.0, 0.0)
78
+ noise = torch.randn_like(mask)
79
+ ones = torch.ones_like(mask)
80
+
81
+ gray_img = gray_img / 255 * 2 - 1.0
82
+ lines = lines / 255 * 2 - 1.0
83
+
84
+ start = time.time()
85
+ inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
86
+ logger.info(f"image_inpaintor_model time: {time.time() - start}")
87
+
88
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
89
+ cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
90
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
91
+ return cur_res
lama_cleaner/model/mat.py ADDED
@@ -0,0 +1,1935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+
11
+ from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
12
+ from lama_cleaner.model.base import InpaintModel
13
+ from lama_cleaner.model.utils import (
14
+ setup_filter,
15
+ Conv2dLayer,
16
+ FullyConnectedLayer,
17
+ conv2d_resample,
18
+ bias_act,
19
+ upsample2d,
20
+ activation_funcs,
21
+ MinibatchStdLayer,
22
+ to_2tuple,
23
+ normalize_2nd_moment,
24
+ set_seed,
25
+ )
26
+ from lama_cleaner.schema import Config
27
+
28
+
29
+ class ModulatedConv2d(nn.Module):
30
+ def __init__(
31
+ self,
32
+ in_channels, # Number of input channels.
33
+ out_channels, # Number of output channels.
34
+ kernel_size, # Width and height of the convolution kernel.
35
+ style_dim, # dimension of the style code
36
+ demodulate=True, # perfrom demodulation
37
+ up=1, # Integer upsampling factor.
38
+ down=1, # Integer downsampling factor.
39
+ resample_filter=[
40
+ 1,
41
+ 3,
42
+ 3,
43
+ 1,
44
+ ], # Low-pass filter to apply when resampling activations.
45
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
46
+ ):
47
+ super().__init__()
48
+ self.demodulate = demodulate
49
+
50
+ self.weight = torch.nn.Parameter(
51
+ torch.randn([1, out_channels, in_channels, kernel_size, kernel_size])
52
+ )
53
+ self.out_channels = out_channels
54
+ self.kernel_size = kernel_size
55
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
56
+ self.padding = self.kernel_size // 2
57
+ self.up = up
58
+ self.down = down
59
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
60
+ self.conv_clamp = conv_clamp
61
+
62
+ self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1)
63
+
64
+ def forward(self, x, style):
65
+ batch, in_channels, height, width = x.shape
66
+ style = self.affine(style).view(batch, 1, in_channels, 1, 1)
67
+ weight = self.weight * self.weight_gain * style
68
+
69
+ if self.demodulate:
70
+ decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt()
71
+ weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1)
72
+
73
+ weight = weight.view(
74
+ batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size
75
+ )
76
+ x = x.view(1, batch * in_channels, height, width)
77
+ x = conv2d_resample(
78
+ x=x,
79
+ w=weight,
80
+ f=self.resample_filter,
81
+ up=self.up,
82
+ down=self.down,
83
+ padding=self.padding,
84
+ groups=batch,
85
+ )
86
+ out = x.view(batch, self.out_channels, *x.shape[2:])
87
+
88
+ return out
89
+
90
+
91
+ class StyleConv(torch.nn.Module):
92
+ def __init__(
93
+ self,
94
+ in_channels, # Number of input channels.
95
+ out_channels, # Number of output channels.
96
+ style_dim, # Intermediate latent (W) dimensionality.
97
+ resolution, # Resolution of this layer.
98
+ kernel_size=3, # Convolution kernel size.
99
+ up=1, # Integer upsampling factor.
100
+ use_noise=False, # Enable noise input?
101
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
102
+ resample_filter=[
103
+ 1,
104
+ 3,
105
+ 3,
106
+ 1,
107
+ ], # Low-pass filter to apply when resampling activations.
108
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
109
+ demodulate=True, # perform demodulation
110
+ ):
111
+ super().__init__()
112
+
113
+ self.conv = ModulatedConv2d(
114
+ in_channels=in_channels,
115
+ out_channels=out_channels,
116
+ kernel_size=kernel_size,
117
+ style_dim=style_dim,
118
+ demodulate=demodulate,
119
+ up=up,
120
+ resample_filter=resample_filter,
121
+ conv_clamp=conv_clamp,
122
+ )
123
+
124
+ self.use_noise = use_noise
125
+ self.resolution = resolution
126
+ if use_noise:
127
+ self.register_buffer("noise_const", torch.randn([resolution, resolution]))
128
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
129
+
130
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
131
+ self.activation = activation
132
+ self.act_gain = activation_funcs[activation].def_gain
133
+ self.conv_clamp = conv_clamp
134
+
135
+ def forward(self, x, style, noise_mode="random", gain=1):
136
+ x = self.conv(x, style)
137
+
138
+ assert noise_mode in ["random", "const", "none"]
139
+
140
+ if self.use_noise:
141
+ if noise_mode == "random":
142
+ xh, xw = x.size()[-2:]
143
+ noise = (
144
+ torch.randn([x.shape[0], 1, xh, xw], device=x.device)
145
+ * self.noise_strength
146
+ )
147
+ if noise_mode == "const":
148
+ noise = self.noise_const * self.noise_strength
149
+ x = x + noise
150
+
151
+ act_gain = self.act_gain * gain
152
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
153
+ out = bias_act(
154
+ x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
155
+ )
156
+
157
+ return out
158
+
159
+
160
+ class ToRGB(torch.nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ out_channels,
165
+ style_dim,
166
+ kernel_size=1,
167
+ resample_filter=[1, 3, 3, 1],
168
+ conv_clamp=None,
169
+ demodulate=False,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.conv = ModulatedConv2d(
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ kernel_size=kernel_size,
177
+ style_dim=style_dim,
178
+ demodulate=demodulate,
179
+ resample_filter=resample_filter,
180
+ conv_clamp=conv_clamp,
181
+ )
182
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
183
+ self.register_buffer("resample_filter", setup_filter(resample_filter))
184
+ self.conv_clamp = conv_clamp
185
+
186
+ def forward(self, x, style, skip=None):
187
+ x = self.conv(x, style)
188
+ out = bias_act(x, self.bias, clamp=self.conv_clamp)
189
+
190
+ if skip is not None:
191
+ if skip.shape != out.shape:
192
+ skip = upsample2d(skip, self.resample_filter)
193
+ out = out + skip
194
+
195
+ return out
196
+
197
+
198
+ def get_style_code(a, b):
199
+ return torch.cat([a, b], dim=1)
200
+
201
+
202
+ class DecBlockFirst(nn.Module):
203
+ def __init__(
204
+ self,
205
+ in_channels,
206
+ out_channels,
207
+ activation,
208
+ style_dim,
209
+ use_noise,
210
+ demodulate,
211
+ img_channels,
212
+ ):
213
+ super().__init__()
214
+ self.fc = FullyConnectedLayer(
215
+ in_features=in_channels * 2,
216
+ out_features=in_channels * 4 ** 2,
217
+ activation=activation,
218
+ )
219
+ self.conv = StyleConv(
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ style_dim=style_dim,
223
+ resolution=4,
224
+ kernel_size=3,
225
+ use_noise=use_noise,
226
+ activation=activation,
227
+ demodulate=demodulate,
228
+ )
229
+ self.toRGB = ToRGB(
230
+ in_channels=out_channels,
231
+ out_channels=img_channels,
232
+ style_dim=style_dim,
233
+ kernel_size=1,
234
+ demodulate=False,
235
+ )
236
+
237
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
238
+ x = self.fc(x).view(x.shape[0], -1, 4, 4)
239
+ x = x + E_features[2]
240
+ style = get_style_code(ws[:, 0], gs)
241
+ x = self.conv(x, style, noise_mode=noise_mode)
242
+ style = get_style_code(ws[:, 1], gs)
243
+ img = self.toRGB(x, style, skip=None)
244
+
245
+ return x, img
246
+
247
+
248
+ class DecBlockFirstV2(nn.Module):
249
+ def __init__(
250
+ self,
251
+ in_channels,
252
+ out_channels,
253
+ activation,
254
+ style_dim,
255
+ use_noise,
256
+ demodulate,
257
+ img_channels,
258
+ ):
259
+ super().__init__()
260
+ self.conv0 = Conv2dLayer(
261
+ in_channels=in_channels,
262
+ out_channels=in_channels,
263
+ kernel_size=3,
264
+ activation=activation,
265
+ )
266
+ self.conv1 = StyleConv(
267
+ in_channels=in_channels,
268
+ out_channels=out_channels,
269
+ style_dim=style_dim,
270
+ resolution=4,
271
+ kernel_size=3,
272
+ use_noise=use_noise,
273
+ activation=activation,
274
+ demodulate=demodulate,
275
+ )
276
+ self.toRGB = ToRGB(
277
+ in_channels=out_channels,
278
+ out_channels=img_channels,
279
+ style_dim=style_dim,
280
+ kernel_size=1,
281
+ demodulate=False,
282
+ )
283
+
284
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
285
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
286
+ x = self.conv0(x)
287
+ x = x + E_features[2]
288
+ style = get_style_code(ws[:, 0], gs)
289
+ x = self.conv1(x, style, noise_mode=noise_mode)
290
+ style = get_style_code(ws[:, 1], gs)
291
+ img = self.toRGB(x, style, skip=None)
292
+
293
+ return x, img
294
+
295
+
296
+ class DecBlock(nn.Module):
297
+ def __init__(
298
+ self,
299
+ res,
300
+ in_channels,
301
+ out_channels,
302
+ activation,
303
+ style_dim,
304
+ use_noise,
305
+ demodulate,
306
+ img_channels,
307
+ ): # res = 2, ..., resolution_log2
308
+ super().__init__()
309
+ self.res = res
310
+
311
+ self.conv0 = StyleConv(
312
+ in_channels=in_channels,
313
+ out_channels=out_channels,
314
+ style_dim=style_dim,
315
+ resolution=2 ** res,
316
+ kernel_size=3,
317
+ up=2,
318
+ use_noise=use_noise,
319
+ activation=activation,
320
+ demodulate=demodulate,
321
+ )
322
+ self.conv1 = StyleConv(
323
+ in_channels=out_channels,
324
+ out_channels=out_channels,
325
+ style_dim=style_dim,
326
+ resolution=2 ** res,
327
+ kernel_size=3,
328
+ use_noise=use_noise,
329
+ activation=activation,
330
+ demodulate=demodulate,
331
+ )
332
+ self.toRGB = ToRGB(
333
+ in_channels=out_channels,
334
+ out_channels=img_channels,
335
+ style_dim=style_dim,
336
+ kernel_size=1,
337
+ demodulate=False,
338
+ )
339
+
340
+ def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
341
+ style = get_style_code(ws[:, self.res * 2 - 5], gs)
342
+ x = self.conv0(x, style, noise_mode=noise_mode)
343
+ x = x + E_features[self.res]
344
+ style = get_style_code(ws[:, self.res * 2 - 4], gs)
345
+ x = self.conv1(x, style, noise_mode=noise_mode)
346
+ style = get_style_code(ws[:, self.res * 2 - 3], gs)
347
+ img = self.toRGB(x, style, skip=img)
348
+
349
+ return x, img
350
+
351
+
352
+ class MappingNet(torch.nn.Module):
353
+ def __init__(
354
+ self,
355
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
356
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
357
+ w_dim, # Intermediate latent (W) dimensionality.
358
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
359
+ num_layers=8, # Number of mapping layers.
360
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
361
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
362
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
363
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
364
+ w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
365
+ torch_dtype=torch.float32,
366
+ ):
367
+ super().__init__()
368
+ self.z_dim = z_dim
369
+ self.c_dim = c_dim
370
+ self.w_dim = w_dim
371
+ self.num_ws = num_ws
372
+ self.num_layers = num_layers
373
+ self.w_avg_beta = w_avg_beta
374
+ self.torch_dtype = torch_dtype
375
+
376
+ if embed_features is None:
377
+ embed_features = w_dim
378
+ if c_dim == 0:
379
+ embed_features = 0
380
+ if layer_features is None:
381
+ layer_features = w_dim
382
+ features_list = (
383
+ [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
384
+ )
385
+
386
+ if c_dim > 0:
387
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
388
+ for idx in range(num_layers):
389
+ in_features = features_list[idx]
390
+ out_features = features_list[idx + 1]
391
+ layer = FullyConnectedLayer(
392
+ in_features,
393
+ out_features,
394
+ activation=activation,
395
+ lr_multiplier=lr_multiplier,
396
+ )
397
+ setattr(self, f"fc{idx}", layer)
398
+
399
+ if num_ws is not None and w_avg_beta is not None:
400
+ self.register_buffer("w_avg", torch.zeros([w_dim]))
401
+
402
+ def forward(
403
+ self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
404
+ ):
405
+ # Embed, normalize, and concat inputs.
406
+ x = None
407
+ if self.z_dim > 0:
408
+ x = normalize_2nd_moment(z)
409
+ if self.c_dim > 0:
410
+ y = normalize_2nd_moment(self.embed(c))
411
+ x = torch.cat([x, y], dim=1) if x is not None else y
412
+
413
+ # Main layers.
414
+ for idx in range(self.num_layers):
415
+ layer = getattr(self, f"fc{idx}")
416
+ x = layer(x)
417
+
418
+ # Update moving average of W.
419
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
420
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
421
+
422
+ # Broadcast.
423
+ if self.num_ws is not None:
424
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
425
+
426
+ # Apply truncation.
427
+ if truncation_psi != 1:
428
+ assert self.w_avg_beta is not None
429
+ if self.num_ws is None or truncation_cutoff is None:
430
+ x = self.w_avg.lerp(x, truncation_psi)
431
+ else:
432
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
433
+ x[:, :truncation_cutoff], truncation_psi
434
+ )
435
+
436
+ return x
437
+
438
+
439
+ class DisFromRGB(nn.Module):
440
+ def __init__(
441
+ self, in_channels, out_channels, activation
442
+ ): # res = 2, ..., resolution_log2
443
+ super().__init__()
444
+ self.conv = Conv2dLayer(
445
+ in_channels=in_channels,
446
+ out_channels=out_channels,
447
+ kernel_size=1,
448
+ activation=activation,
449
+ )
450
+
451
+ def forward(self, x):
452
+ return self.conv(x)
453
+
454
+
455
+ class DisBlock(nn.Module):
456
+ def __init__(
457
+ self, in_channels, out_channels, activation
458
+ ): # res = 2, ..., resolution_log2
459
+ super().__init__()
460
+ self.conv0 = Conv2dLayer(
461
+ in_channels=in_channels,
462
+ out_channels=in_channels,
463
+ kernel_size=3,
464
+ activation=activation,
465
+ )
466
+ self.conv1 = Conv2dLayer(
467
+ in_channels=in_channels,
468
+ out_channels=out_channels,
469
+ kernel_size=3,
470
+ down=2,
471
+ activation=activation,
472
+ )
473
+ self.skip = Conv2dLayer(
474
+ in_channels=in_channels,
475
+ out_channels=out_channels,
476
+ kernel_size=1,
477
+ down=2,
478
+ bias=False,
479
+ )
480
+
481
+ def forward(self, x):
482
+ skip = self.skip(x, gain=np.sqrt(0.5))
483
+ x = self.conv0(x)
484
+ x = self.conv1(x, gain=np.sqrt(0.5))
485
+ out = skip + x
486
+
487
+ return out
488
+
489
+
490
+ class Discriminator(torch.nn.Module):
491
+ def __init__(
492
+ self,
493
+ c_dim, # Conditioning label (C) dimensionality.
494
+ img_resolution, # Input resolution.
495
+ img_channels, # Number of input color channels.
496
+ channel_base=32768, # Overall multiplier for the number of channels.
497
+ channel_max=512, # Maximum number of channels in any layer.
498
+ channel_decay=1,
499
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
500
+ activation="lrelu",
501
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
502
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
503
+ ):
504
+ super().__init__()
505
+ self.c_dim = c_dim
506
+ self.img_resolution = img_resolution
507
+ self.img_channels = img_channels
508
+
509
+ resolution_log2 = int(np.log2(img_resolution))
510
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
511
+ self.resolution_log2 = resolution_log2
512
+
513
+ def nf(stage):
514
+ return np.clip(
515
+ int(channel_base / 2 ** (stage * channel_decay)), 1, channel_max
516
+ )
517
+
518
+ if cmap_dim == None:
519
+ cmap_dim = nf(2)
520
+ if c_dim == 0:
521
+ cmap_dim = 0
522
+ self.cmap_dim = cmap_dim
523
+
524
+ if c_dim > 0:
525
+ self.mapping = MappingNet(
526
+ z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None
527
+ )
528
+
529
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
530
+ for res in range(resolution_log2, 2, -1):
531
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
532
+
533
+ if mbstd_num_channels > 0:
534
+ Dis.append(
535
+ MinibatchStdLayer(
536
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
537
+ )
538
+ )
539
+ Dis.append(
540
+ Conv2dLayer(
541
+ nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation
542
+ )
543
+ )
544
+ self.Dis = nn.Sequential(*Dis)
545
+
546
+ self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
547
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
548
+
549
+ def forward(self, images_in, masks_in, c):
550
+ x = torch.cat([masks_in - 0.5, images_in], dim=1)
551
+ x = self.Dis(x)
552
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
553
+
554
+ if self.c_dim > 0:
555
+ cmap = self.mapping(None, c)
556
+
557
+ if self.cmap_dim > 0:
558
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
559
+
560
+ return x
561
+
562
+
563
+ def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
564
+ NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
565
+ return NF[2 ** stage]
566
+
567
+
568
+ class Mlp(nn.Module):
569
+ def __init__(
570
+ self,
571
+ in_features,
572
+ hidden_features=None,
573
+ out_features=None,
574
+ act_layer=nn.GELU,
575
+ drop=0.0,
576
+ ):
577
+ super().__init__()
578
+ out_features = out_features or in_features
579
+ hidden_features = hidden_features or in_features
580
+ self.fc1 = FullyConnectedLayer(
581
+ in_features=in_features, out_features=hidden_features, activation="lrelu"
582
+ )
583
+ self.fc2 = FullyConnectedLayer(
584
+ in_features=hidden_features, out_features=out_features
585
+ )
586
+
587
+ def forward(self, x):
588
+ x = self.fc1(x)
589
+ x = self.fc2(x)
590
+ return x
591
+
592
+
593
+ def window_partition(x, window_size):
594
+ """
595
+ Args:
596
+ x: (B, H, W, C)
597
+ window_size (int): window size
598
+ Returns:
599
+ windows: (num_windows*B, window_size, window_size, C)
600
+ """
601
+ B, H, W, C = x.shape
602
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
603
+ windows = (
604
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
605
+ )
606
+ return windows
607
+
608
+
609
+ def window_reverse(windows, window_size: int, H: int, W: int):
610
+ """
611
+ Args:
612
+ windows: (num_windows*B, window_size, window_size, C)
613
+ window_size (int): Window size
614
+ H (int): Height of image
615
+ W (int): Width of image
616
+ Returns:
617
+ x: (B, H, W, C)
618
+ """
619
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
620
+ # B = windows.shape[0] / (H * W / window_size / window_size)
621
+ x = windows.view(
622
+ B, H // window_size, W // window_size, window_size, window_size, -1
623
+ )
624
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
625
+ return x
626
+
627
+
628
+ class Conv2dLayerPartial(nn.Module):
629
+ def __init__(
630
+ self,
631
+ in_channels, # Number of input channels.
632
+ out_channels, # Number of output channels.
633
+ kernel_size, # Width and height of the convolution kernel.
634
+ bias=True, # Apply additive bias before the activation function?
635
+ activation="linear", # Activation function: 'relu', 'lrelu', etc.
636
+ up=1, # Integer upsampling factor.
637
+ down=1, # Integer downsampling factor.
638
+ resample_filter=[
639
+ 1,
640
+ 3,
641
+ 3,
642
+ 1,
643
+ ], # Low-pass filter to apply when resampling activations.
644
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
645
+ trainable=True, # Update the weights of this layer during training?
646
+ ):
647
+ super().__init__()
648
+ self.conv = Conv2dLayer(
649
+ in_channels,
650
+ out_channels,
651
+ kernel_size,
652
+ bias,
653
+ activation,
654
+ up,
655
+ down,
656
+ resample_filter,
657
+ conv_clamp,
658
+ trainable,
659
+ )
660
+
661
+ self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
662
+ self.slide_winsize = kernel_size ** 2
663
+ self.stride = down
664
+ self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
665
+
666
+ def forward(self, x, mask=None):
667
+ if mask is not None:
668
+ with torch.no_grad():
669
+ if self.weight_maskUpdater.type() != x.type():
670
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
671
+ update_mask = F.conv2d(
672
+ mask,
673
+ self.weight_maskUpdater,
674
+ bias=None,
675
+ stride=self.stride,
676
+ padding=self.padding,
677
+ )
678
+ mask_ratio = self.slide_winsize / (update_mask.to(torch.float32) + 1e-8)
679
+ update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1
680
+ mask_ratio = torch.mul(mask_ratio, update_mask).to(x.dtype)
681
+ x = self.conv(x)
682
+ x = torch.mul(x, mask_ratio)
683
+ return x, update_mask
684
+ else:
685
+ x = self.conv(x)
686
+ return x, None
687
+
688
+
689
+ class WindowAttention(nn.Module):
690
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
691
+ It supports both of shifted and non-shifted window.
692
+ Args:
693
+ dim (int): Number of input channels.
694
+ window_size (tuple[int]): The height and width of the window.
695
+ num_heads (int): Number of attention heads.
696
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
697
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
698
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
699
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
700
+ """
701
+
702
+ def __init__(
703
+ self,
704
+ dim,
705
+ window_size,
706
+ num_heads,
707
+ down_ratio=1,
708
+ qkv_bias=True,
709
+ qk_scale=None,
710
+ attn_drop=0.0,
711
+ proj_drop=0.0,
712
+ ):
713
+ super().__init__()
714
+ self.dim = dim
715
+ self.window_size = window_size # Wh, Ww
716
+ self.num_heads = num_heads
717
+ head_dim = dim // num_heads
718
+ self.scale = qk_scale or head_dim ** -0.5
719
+
720
+ self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
721
+ self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
722
+ self.v = FullyConnectedLayer(in_features=dim, out_features=dim)
723
+ self.proj = FullyConnectedLayer(in_features=dim, out_features=dim)
724
+
725
+ self.softmax = nn.Softmax(dim=-1)
726
+
727
+ def forward(self, x, mask_windows=None, mask=None):
728
+ """
729
+ Args:
730
+ x: input features with shape of (num_windows*B, N, C)
731
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
732
+ """
733
+ B_, N, C = x.shape
734
+ norm_x = F.normalize(x, p=2.0, dim=-1, eps=torch.finfo(x.dtype).eps)
735
+ q = (
736
+ self.q(norm_x)
737
+ .reshape(B_, N, self.num_heads, C // self.num_heads)
738
+ .permute(0, 2, 1, 3)
739
+ )
740
+ k = (
741
+ self.k(norm_x)
742
+ .view(B_, -1, self.num_heads, C // self.num_heads)
743
+ .permute(0, 2, 3, 1)
744
+ )
745
+ v = (
746
+ self.v(x)
747
+ .view(B_, -1, self.num_heads, C // self.num_heads)
748
+ .permute(0, 2, 1, 3)
749
+ )
750
+
751
+ attn = (q @ k) * self.scale
752
+
753
+ if mask is not None:
754
+ nW = mask.shape[0]
755
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
756
+ 1
757
+ ).unsqueeze(0)
758
+ attn = attn.view(-1, self.num_heads, N, N)
759
+
760
+ if mask_windows is not None:
761
+ attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1)
762
+ attn = attn + attn_mask_windows.masked_fill(
763
+ attn_mask_windows == 0, float(-100.0)
764
+ ).masked_fill(attn_mask_windows == 1, float(0.0))
765
+ with torch.no_grad():
766
+ mask_windows = torch.clamp(
767
+ torch.sum(mask_windows, dim=1, keepdim=True), 0, 1
768
+ ).repeat(1, N, 1)
769
+
770
+ attn = self.softmax(attn)
771
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
772
+ x = self.proj(x)
773
+ return x, mask_windows
774
+
775
+
776
+ class SwinTransformerBlock(nn.Module):
777
+ r"""Swin Transformer Block.
778
+ Args:
779
+ dim (int): Number of input channels.
780
+ input_resolution (tuple[int]): Input resulotion.
781
+ num_heads (int): Number of attention heads.
782
+ window_size (int): Window size.
783
+ shift_size (int): Shift size for SW-MSA.
784
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
785
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
786
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
787
+ drop (float, optional): Dropout rate. Default: 0.0
788
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
789
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
790
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
791
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
792
+ """
793
+
794
+ def __init__(
795
+ self,
796
+ dim,
797
+ input_resolution,
798
+ num_heads,
799
+ down_ratio=1,
800
+ window_size=7,
801
+ shift_size=0,
802
+ mlp_ratio=4.0,
803
+ qkv_bias=True,
804
+ qk_scale=None,
805
+ drop=0.0,
806
+ attn_drop=0.0,
807
+ drop_path=0.0,
808
+ act_layer=nn.GELU,
809
+ norm_layer=nn.LayerNorm,
810
+ ):
811
+ super().__init__()
812
+ self.dim = dim
813
+ self.input_resolution = input_resolution
814
+ self.num_heads = num_heads
815
+ self.window_size = window_size
816
+ self.shift_size = shift_size
817
+ self.mlp_ratio = mlp_ratio
818
+ if min(self.input_resolution) <= self.window_size:
819
+ # if window size is larger than input resolution, we don't partition windows
820
+ self.shift_size = 0
821
+ self.window_size = min(self.input_resolution)
822
+ assert (
823
+ 0 <= self.shift_size < self.window_size
824
+ ), "shift_size must in 0-window_size"
825
+
826
+ if self.shift_size > 0:
827
+ down_ratio = 1
828
+ self.attn = WindowAttention(
829
+ dim,
830
+ window_size=to_2tuple(self.window_size),
831
+ num_heads=num_heads,
832
+ down_ratio=down_ratio,
833
+ qkv_bias=qkv_bias,
834
+ qk_scale=qk_scale,
835
+ attn_drop=attn_drop,
836
+ proj_drop=drop,
837
+ )
838
+
839
+ self.fuse = FullyConnectedLayer(
840
+ in_features=dim * 2, out_features=dim, activation="lrelu"
841
+ )
842
+
843
+ mlp_hidden_dim = int(dim * mlp_ratio)
844
+ self.mlp = Mlp(
845
+ in_features=dim,
846
+ hidden_features=mlp_hidden_dim,
847
+ act_layer=act_layer,
848
+ drop=drop,
849
+ )
850
+
851
+ if self.shift_size > 0:
852
+ attn_mask = self.calculate_mask(self.input_resolution)
853
+ else:
854
+ attn_mask = None
855
+
856
+ self.register_buffer("attn_mask", attn_mask)
857
+
858
+ def calculate_mask(self, x_size):
859
+ # calculate attention mask for SW-MSA
860
+ H, W = x_size
861
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
862
+ h_slices = (
863
+ slice(0, -self.window_size),
864
+ slice(-self.window_size, -self.shift_size),
865
+ slice(-self.shift_size, None),
866
+ )
867
+ w_slices = (
868
+ slice(0, -self.window_size),
869
+ slice(-self.window_size, -self.shift_size),
870
+ slice(-self.shift_size, None),
871
+ )
872
+ cnt = 0
873
+ for h in h_slices:
874
+ for w in w_slices:
875
+ img_mask[:, h, w, :] = cnt
876
+ cnt += 1
877
+
878
+ mask_windows = window_partition(
879
+ img_mask, self.window_size
880
+ ) # nW, window_size, window_size, 1
881
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
882
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
883
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
884
+ attn_mask == 0, float(0.0)
885
+ )
886
+
887
+ return attn_mask
888
+
889
+ def forward(self, x, x_size, mask=None):
890
+ # H, W = self.input_resolution
891
+ H, W = x_size
892
+ B, L, C = x.shape
893
+ # assert L == H * W, "input feature has wrong size"
894
+
895
+ shortcut = x
896
+ x = x.view(B, H, W, C)
897
+ if mask is not None:
898
+ mask = mask.view(B, H, W, 1)
899
+
900
+ # cyclic shift
901
+ if self.shift_size > 0:
902
+ shifted_x = torch.roll(
903
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
904
+ )
905
+ if mask is not None:
906
+ shifted_mask = torch.roll(
907
+ mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
908
+ )
909
+ else:
910
+ shifted_x = x
911
+ if mask is not None:
912
+ shifted_mask = mask
913
+
914
+ # partition windows
915
+ x_windows = window_partition(
916
+ shifted_x, self.window_size
917
+ ) # nW*B, window_size, window_size, C
918
+ x_windows = x_windows.view(
919
+ -1, self.window_size * self.window_size, C
920
+ ) # nW*B, window_size*window_size, C
921
+ if mask is not None:
922
+ mask_windows = window_partition(shifted_mask, self.window_size)
923
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1)
924
+ else:
925
+ mask_windows = None
926
+
927
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
928
+ if self.input_resolution == x_size:
929
+ attn_windows, mask_windows = self.attn(
930
+ x_windows, mask_windows, mask=self.attn_mask
931
+ ) # nW*B, window_size*window_size, C
932
+ else:
933
+ attn_windows, mask_windows = self.attn(
934
+ x_windows,
935
+ mask_windows,
936
+ mask=self.calculate_mask(x_size).to(x.dtype).to(x.device),
937
+ ) # nW*B, window_size*window_size, C
938
+
939
+ # merge windows
940
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
941
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
942
+ if mask is not None:
943
+ mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1)
944
+ shifted_mask = window_reverse(mask_windows, self.window_size, H, W)
945
+
946
+ # reverse cyclic shift
947
+ if self.shift_size > 0:
948
+ x = torch.roll(
949
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
950
+ )
951
+ if mask is not None:
952
+ mask = torch.roll(
953
+ shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
954
+ )
955
+ else:
956
+ x = shifted_x
957
+ if mask is not None:
958
+ mask = shifted_mask
959
+ x = x.view(B, H * W, C)
960
+ if mask is not None:
961
+ mask = mask.view(B, H * W, 1)
962
+
963
+ # FFN
964
+ x = self.fuse(torch.cat([shortcut, x], dim=-1))
965
+ x = self.mlp(x)
966
+
967
+ return x, mask
968
+
969
+
970
+ class PatchMerging(nn.Module):
971
+ def __init__(self, in_channels, out_channels, down=2):
972
+ super().__init__()
973
+ self.conv = Conv2dLayerPartial(
974
+ in_channels=in_channels,
975
+ out_channels=out_channels,
976
+ kernel_size=3,
977
+ activation="lrelu",
978
+ down=down,
979
+ )
980
+ self.down = down
981
+
982
+ def forward(self, x, x_size, mask=None):
983
+ x = token2feature(x, x_size)
984
+ if mask is not None:
985
+ mask = token2feature(mask, x_size)
986
+ x, mask = self.conv(x, mask)
987
+ if self.down != 1:
988
+ ratio = 1 / self.down
989
+ x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio))
990
+ x = feature2token(x)
991
+ if mask is not None:
992
+ mask = feature2token(mask)
993
+ return x, x_size, mask
994
+
995
+
996
+ class PatchUpsampling(nn.Module):
997
+ def __init__(self, in_channels, out_channels, up=2):
998
+ super().__init__()
999
+ self.conv = Conv2dLayerPartial(
1000
+ in_channels=in_channels,
1001
+ out_channels=out_channels,
1002
+ kernel_size=3,
1003
+ activation="lrelu",
1004
+ up=up,
1005
+ )
1006
+ self.up = up
1007
+
1008
+ def forward(self, x, x_size, mask=None):
1009
+ x = token2feature(x, x_size)
1010
+ if mask is not None:
1011
+ mask = token2feature(mask, x_size)
1012
+ x, mask = self.conv(x, mask)
1013
+ if self.up != 1:
1014
+ x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up))
1015
+ x = feature2token(x)
1016
+ if mask is not None:
1017
+ mask = feature2token(mask)
1018
+ return x, x_size, mask
1019
+
1020
+
1021
+ class BasicLayer(nn.Module):
1022
+ """A basic Swin Transformer layer for one stage.
1023
+ Args:
1024
+ dim (int): Number of input channels.
1025
+ input_resolution (tuple[int]): Input resolution.
1026
+ depth (int): Number of blocks.
1027
+ num_heads (int): Number of attention heads.
1028
+ window_size (int): Local window size.
1029
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
1030
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
1031
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
1032
+ drop (float, optional): Dropout rate. Default: 0.0
1033
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
1034
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
1035
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
1036
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
1037
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
1038
+ """
1039
+
1040
+ def __init__(
1041
+ self,
1042
+ dim,
1043
+ input_resolution,
1044
+ depth,
1045
+ num_heads,
1046
+ window_size,
1047
+ down_ratio=1,
1048
+ mlp_ratio=2.0,
1049
+ qkv_bias=True,
1050
+ qk_scale=None,
1051
+ drop=0.0,
1052
+ attn_drop=0.0,
1053
+ drop_path=0.0,
1054
+ norm_layer=nn.LayerNorm,
1055
+ downsample=None,
1056
+ use_checkpoint=False,
1057
+ ):
1058
+ super().__init__()
1059
+ self.dim = dim
1060
+ self.input_resolution = input_resolution
1061
+ self.depth = depth
1062
+ self.use_checkpoint = use_checkpoint
1063
+
1064
+ # patch merging layer
1065
+ if downsample is not None:
1066
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
1067
+ self.downsample = downsample
1068
+ else:
1069
+ self.downsample = None
1070
+
1071
+ # build blocks
1072
+ self.blocks = nn.ModuleList(
1073
+ [
1074
+ SwinTransformerBlock(
1075
+ dim=dim,
1076
+ input_resolution=input_resolution,
1077
+ num_heads=num_heads,
1078
+ down_ratio=down_ratio,
1079
+ window_size=window_size,
1080
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
1081
+ mlp_ratio=mlp_ratio,
1082
+ qkv_bias=qkv_bias,
1083
+ qk_scale=qk_scale,
1084
+ drop=drop,
1085
+ attn_drop=attn_drop,
1086
+ drop_path=drop_path[i]
1087
+ if isinstance(drop_path, list)
1088
+ else drop_path,
1089
+ norm_layer=norm_layer,
1090
+ )
1091
+ for i in range(depth)
1092
+ ]
1093
+ )
1094
+
1095
+ self.conv = Conv2dLayerPartial(
1096
+ in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu"
1097
+ )
1098
+
1099
+ def forward(self, x, x_size, mask=None):
1100
+ if self.downsample is not None:
1101
+ x, x_size, mask = self.downsample(x, x_size, mask)
1102
+ identity = x
1103
+ for blk in self.blocks:
1104
+ if self.use_checkpoint:
1105
+ x, mask = checkpoint.checkpoint(blk, x, x_size, mask)
1106
+ else:
1107
+ x, mask = blk(x, x_size, mask)
1108
+ if mask is not None:
1109
+ mask = token2feature(mask, x_size)
1110
+ x, mask = self.conv(token2feature(x, x_size), mask)
1111
+ x = feature2token(x) + identity
1112
+ if mask is not None:
1113
+ mask = feature2token(mask)
1114
+ return x, x_size, mask
1115
+
1116
+
1117
+ class ToToken(nn.Module):
1118
+ def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1):
1119
+ super().__init__()
1120
+
1121
+ self.proj = Conv2dLayerPartial(
1122
+ in_channels=in_channels,
1123
+ out_channels=dim,
1124
+ kernel_size=kernel_size,
1125
+ activation="lrelu",
1126
+ )
1127
+
1128
+ def forward(self, x, mask):
1129
+ x, mask = self.proj(x, mask)
1130
+
1131
+ return x, mask
1132
+
1133
+
1134
+ class EncFromRGB(nn.Module):
1135
+ def __init__(
1136
+ self, in_channels, out_channels, activation
1137
+ ): # res = 2, ..., resolution_log2
1138
+ super().__init__()
1139
+ self.conv0 = Conv2dLayer(
1140
+ in_channels=in_channels,
1141
+ out_channels=out_channels,
1142
+ kernel_size=1,
1143
+ activation=activation,
1144
+ )
1145
+ self.conv1 = Conv2dLayer(
1146
+ in_channels=out_channels,
1147
+ out_channels=out_channels,
1148
+ kernel_size=3,
1149
+ activation=activation,
1150
+ )
1151
+
1152
+ def forward(self, x):
1153
+ x = self.conv0(x)
1154
+ x = self.conv1(x)
1155
+
1156
+ return x
1157
+
1158
+
1159
+ class ConvBlockDown(nn.Module):
1160
+ def __init__(
1161
+ self, in_channels, out_channels, activation
1162
+ ): # res = 2, ..., resolution_log
1163
+ super().__init__()
1164
+
1165
+ self.conv0 = Conv2dLayer(
1166
+ in_channels=in_channels,
1167
+ out_channels=out_channels,
1168
+ kernel_size=3,
1169
+ activation=activation,
1170
+ down=2,
1171
+ )
1172
+ self.conv1 = Conv2dLayer(
1173
+ in_channels=out_channels,
1174
+ out_channels=out_channels,
1175
+ kernel_size=3,
1176
+ activation=activation,
1177
+ )
1178
+
1179
+ def forward(self, x):
1180
+ x = self.conv0(x)
1181
+ x = self.conv1(x)
1182
+
1183
+ return x
1184
+
1185
+
1186
+ def token2feature(x, x_size):
1187
+ B, N, C = x.shape
1188
+ h, w = x_size
1189
+ x = x.permute(0, 2, 1).reshape(B, C, h, w)
1190
+ return x
1191
+
1192
+
1193
+ def feature2token(x):
1194
+ B, C, H, W = x.shape
1195
+ x = x.view(B, C, -1).transpose(1, 2)
1196
+ return x
1197
+
1198
+
1199
+ class Encoder(nn.Module):
1200
+ def __init__(
1201
+ self,
1202
+ res_log2,
1203
+ img_channels,
1204
+ activation,
1205
+ patch_size=5,
1206
+ channels=16,
1207
+ drop_path_rate=0.1,
1208
+ ):
1209
+ super().__init__()
1210
+
1211
+ self.resolution = []
1212
+
1213
+ for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
1214
+ res = 2 ** i
1215
+ self.resolution.append(res)
1216
+ if i == res_log2:
1217
+ block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
1218
+ else:
1219
+ block = ConvBlockDown(nf(i + 1), nf(i), activation)
1220
+ setattr(self, "EncConv_Block_%dx%d" % (res, res), block)
1221
+
1222
+ def forward(self, x):
1223
+ out = {}
1224
+ for res in self.resolution:
1225
+ res_log2 = int(np.log2(res))
1226
+ x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x)
1227
+ out[res_log2] = x
1228
+
1229
+ return out
1230
+
1231
+
1232
+ class ToStyle(nn.Module):
1233
+ def __init__(self, in_channels, out_channels, activation, drop_rate):
1234
+ super().__init__()
1235
+ self.conv = nn.Sequential(
1236
+ Conv2dLayer(
1237
+ in_channels=in_channels,
1238
+ out_channels=in_channels,
1239
+ kernel_size=3,
1240
+ activation=activation,
1241
+ down=2,
1242
+ ),
1243
+ Conv2dLayer(
1244
+ in_channels=in_channels,
1245
+ out_channels=in_channels,
1246
+ kernel_size=3,
1247
+ activation=activation,
1248
+ down=2,
1249
+ ),
1250
+ Conv2dLayer(
1251
+ in_channels=in_channels,
1252
+ out_channels=in_channels,
1253
+ kernel_size=3,
1254
+ activation=activation,
1255
+ down=2,
1256
+ ),
1257
+ )
1258
+
1259
+ self.pool = nn.AdaptiveAvgPool2d(1)
1260
+ self.fc = FullyConnectedLayer(
1261
+ in_features=in_channels, out_features=out_channels, activation=activation
1262
+ )
1263
+ # self.dropout = nn.Dropout(drop_rate)
1264
+
1265
+ def forward(self, x):
1266
+ x = self.conv(x)
1267
+ x = self.pool(x)
1268
+ x = self.fc(x.flatten(start_dim=1))
1269
+ # x = self.dropout(x)
1270
+
1271
+ return x
1272
+
1273
+
1274
+ class DecBlockFirstV2(nn.Module):
1275
+ def __init__(
1276
+ self,
1277
+ res,
1278
+ in_channels,
1279
+ out_channels,
1280
+ activation,
1281
+ style_dim,
1282
+ use_noise,
1283
+ demodulate,
1284
+ img_channels,
1285
+ ):
1286
+ super().__init__()
1287
+ self.res = res
1288
+
1289
+ self.conv0 = Conv2dLayer(
1290
+ in_channels=in_channels,
1291
+ out_channels=in_channels,
1292
+ kernel_size=3,
1293
+ activation=activation,
1294
+ )
1295
+ self.conv1 = StyleConv(
1296
+ in_channels=in_channels,
1297
+ out_channels=out_channels,
1298
+ style_dim=style_dim,
1299
+ resolution=2 ** res,
1300
+ kernel_size=3,
1301
+ use_noise=use_noise,
1302
+ activation=activation,
1303
+ demodulate=demodulate,
1304
+ )
1305
+ self.toRGB = ToRGB(
1306
+ in_channels=out_channels,
1307
+ out_channels=img_channels,
1308
+ style_dim=style_dim,
1309
+ kernel_size=1,
1310
+ demodulate=False,
1311
+ )
1312
+
1313
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
1314
+ # x = self.fc(x).view(x.shape[0], -1, 4, 4)
1315
+ x = self.conv0(x)
1316
+ x = x + E_features[self.res]
1317
+ style = get_style_code(ws[:, 0], gs)
1318
+ x = self.conv1(x, style, noise_mode=noise_mode)
1319
+ style = get_style_code(ws[:, 1], gs)
1320
+ img = self.toRGB(x, style, skip=None)
1321
+
1322
+ return x, img
1323
+
1324
+
1325
+ class DecBlock(nn.Module):
1326
+ def __init__(
1327
+ self,
1328
+ res,
1329
+ in_channels,
1330
+ out_channels,
1331
+ activation,
1332
+ style_dim,
1333
+ use_noise,
1334
+ demodulate,
1335
+ img_channels,
1336
+ ): # res = 4, ..., resolution_log2
1337
+ super().__init__()
1338
+ self.res = res
1339
+
1340
+ self.conv0 = StyleConv(
1341
+ in_channels=in_channels,
1342
+ out_channels=out_channels,
1343
+ style_dim=style_dim,
1344
+ resolution=2 ** res,
1345
+ kernel_size=3,
1346
+ up=2,
1347
+ use_noise=use_noise,
1348
+ activation=activation,
1349
+ demodulate=demodulate,
1350
+ )
1351
+ self.conv1 = StyleConv(
1352
+ in_channels=out_channels,
1353
+ out_channels=out_channels,
1354
+ style_dim=style_dim,
1355
+ resolution=2 ** res,
1356
+ kernel_size=3,
1357
+ use_noise=use_noise,
1358
+ activation=activation,
1359
+ demodulate=demodulate,
1360
+ )
1361
+ self.toRGB = ToRGB(
1362
+ in_channels=out_channels,
1363
+ out_channels=img_channels,
1364
+ style_dim=style_dim,
1365
+ kernel_size=1,
1366
+ demodulate=False,
1367
+ )
1368
+
1369
+ def forward(self, x, img, ws, gs, E_features, noise_mode="random"):
1370
+ style = get_style_code(ws[:, self.res * 2 - 9], gs)
1371
+ x = self.conv0(x, style, noise_mode=noise_mode)
1372
+ x = x + E_features[self.res]
1373
+ style = get_style_code(ws[:, self.res * 2 - 8], gs)
1374
+ x = self.conv1(x, style, noise_mode=noise_mode)
1375
+ style = get_style_code(ws[:, self.res * 2 - 7], gs)
1376
+ img = self.toRGB(x, style, skip=img)
1377
+
1378
+ return x, img
1379
+
1380
+
1381
+ class Decoder(nn.Module):
1382
+ def __init__(
1383
+ self, res_log2, activation, style_dim, use_noise, demodulate, img_channels
1384
+ ):
1385
+ super().__init__()
1386
+ self.Dec_16x16 = DecBlockFirstV2(
1387
+ 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels
1388
+ )
1389
+ for res in range(5, res_log2 + 1):
1390
+ setattr(
1391
+ self,
1392
+ "Dec_%dx%d" % (2 ** res, 2 ** res),
1393
+ DecBlock(
1394
+ res,
1395
+ nf(res - 1),
1396
+ nf(res),
1397
+ activation,
1398
+ style_dim,
1399
+ use_noise,
1400
+ demodulate,
1401
+ img_channels,
1402
+ ),
1403
+ )
1404
+ self.res_log2 = res_log2
1405
+
1406
+ def forward(self, x, ws, gs, E_features, noise_mode="random"):
1407
+ x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
1408
+ for res in range(5, self.res_log2 + 1):
1409
+ block = getattr(self, "Dec_%dx%d" % (2 ** res, 2 ** res))
1410
+ x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
1411
+
1412
+ return img
1413
+
1414
+
1415
+ class DecStyleBlock(nn.Module):
1416
+ def __init__(
1417
+ self,
1418
+ res,
1419
+ in_channels,
1420
+ out_channels,
1421
+ activation,
1422
+ style_dim,
1423
+ use_noise,
1424
+ demodulate,
1425
+ img_channels,
1426
+ ):
1427
+ super().__init__()
1428
+ self.res = res
1429
+
1430
+ self.conv0 = StyleConv(
1431
+ in_channels=in_channels,
1432
+ out_channels=out_channels,
1433
+ style_dim=style_dim,
1434
+ resolution=2 ** res,
1435
+ kernel_size=3,
1436
+ up=2,
1437
+ use_noise=use_noise,
1438
+ activation=activation,
1439
+ demodulate=demodulate,
1440
+ )
1441
+ self.conv1 = StyleConv(
1442
+ in_channels=out_channels,
1443
+ out_channels=out_channels,
1444
+ style_dim=style_dim,
1445
+ resolution=2 ** res,
1446
+ kernel_size=3,
1447
+ use_noise=use_noise,
1448
+ activation=activation,
1449
+ demodulate=demodulate,
1450
+ )
1451
+ self.toRGB = ToRGB(
1452
+ in_channels=out_channels,
1453
+ out_channels=img_channels,
1454
+ style_dim=style_dim,
1455
+ kernel_size=1,
1456
+ demodulate=False,
1457
+ )
1458
+
1459
+ def forward(self, x, img, style, skip, noise_mode="random"):
1460
+ x = self.conv0(x, style, noise_mode=noise_mode)
1461
+ x = x + skip
1462
+ x = self.conv1(x, style, noise_mode=noise_mode)
1463
+ img = self.toRGB(x, style, skip=img)
1464
+
1465
+ return x, img
1466
+
1467
+
1468
+ class FirstStage(nn.Module):
1469
+ def __init__(
1470
+ self,
1471
+ img_channels,
1472
+ img_resolution=256,
1473
+ dim=180,
1474
+ w_dim=512,
1475
+ use_noise=False,
1476
+ demodulate=True,
1477
+ activation="lrelu",
1478
+ ):
1479
+ super().__init__()
1480
+ res = 64
1481
+
1482
+ self.conv_first = Conv2dLayerPartial(
1483
+ in_channels=img_channels + 1,
1484
+ out_channels=dim,
1485
+ kernel_size=3,
1486
+ activation=activation,
1487
+ )
1488
+ self.enc_conv = nn.ModuleList()
1489
+ down_time = int(np.log2(img_resolution // res))
1490
+ # 根据图片尺寸构建 swim transformer 的层数
1491
+ for i in range(down_time): # from input size to 64
1492
+ self.enc_conv.append(
1493
+ Conv2dLayerPartial(
1494
+ in_channels=dim,
1495
+ out_channels=dim,
1496
+ kernel_size=3,
1497
+ down=2,
1498
+ activation=activation,
1499
+ )
1500
+ )
1501
+
1502
+ # from 64 -> 16 -> 64
1503
+ depths = [2, 3, 4, 3, 2]
1504
+ ratios = [1, 1 / 2, 1 / 2, 2, 2]
1505
+ num_heads = 6
1506
+ window_sizes = [8, 16, 16, 16, 8]
1507
+ drop_path_rate = 0.1
1508
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1509
+
1510
+ self.tran = nn.ModuleList()
1511
+ for i, depth in enumerate(depths):
1512
+ res = int(res * ratios[i])
1513
+ if ratios[i] < 1:
1514
+ merge = PatchMerging(dim, dim, down=int(1 / ratios[i]))
1515
+ elif ratios[i] > 1:
1516
+ merge = PatchUpsampling(dim, dim, up=ratios[i])
1517
+ else:
1518
+ merge = None
1519
+ self.tran.append(
1520
+ BasicLayer(
1521
+ dim=dim,
1522
+ input_resolution=[res, res],
1523
+ depth=depth,
1524
+ num_heads=num_heads,
1525
+ window_size=window_sizes[i],
1526
+ drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
1527
+ downsample=merge,
1528
+ )
1529
+ )
1530
+
1531
+ # global style
1532
+ down_conv = []
1533
+ for i in range(int(np.log2(16))):
1534
+ down_conv.append(
1535
+ Conv2dLayer(
1536
+ in_channels=dim,
1537
+ out_channels=dim,
1538
+ kernel_size=3,
1539
+ down=2,
1540
+ activation=activation,
1541
+ )
1542
+ )
1543
+ down_conv.append(nn.AdaptiveAvgPool2d((1, 1)))
1544
+ self.down_conv = nn.Sequential(*down_conv)
1545
+ self.to_style = FullyConnectedLayer(
1546
+ in_features=dim, out_features=dim * 2, activation=activation
1547
+ )
1548
+ self.ws_style = FullyConnectedLayer(
1549
+ in_features=w_dim, out_features=dim, activation=activation
1550
+ )
1551
+ self.to_square = FullyConnectedLayer(
1552
+ in_features=dim, out_features=16 * 16, activation=activation
1553
+ )
1554
+
1555
+ style_dim = dim * 3
1556
+ self.dec_conv = nn.ModuleList()
1557
+ for i in range(down_time): # from 64 to input size
1558
+ res = res * 2
1559
+ self.dec_conv.append(
1560
+ DecStyleBlock(
1561
+ res,
1562
+ dim,
1563
+ dim,
1564
+ activation,
1565
+ style_dim,
1566
+ use_noise,
1567
+ demodulate,
1568
+ img_channels,
1569
+ )
1570
+ )
1571
+
1572
+ def forward(self, images_in, masks_in, ws, noise_mode="random"):
1573
+ x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1)
1574
+
1575
+ skips = []
1576
+ x, mask = self.conv_first(x, masks_in) # input size
1577
+ skips.append(x)
1578
+ for i, block in enumerate(self.enc_conv): # input size to 64
1579
+ x, mask = block(x, mask)
1580
+ if i != len(self.enc_conv) - 1:
1581
+ skips.append(x)
1582
+
1583
+ x_size = x.size()[-2:]
1584
+ x = feature2token(x)
1585
+ mask = feature2token(mask)
1586
+ mid = len(self.tran) // 2
1587
+ for i, block in enumerate(self.tran): # 64 to 16
1588
+ if i < mid:
1589
+ x, x_size, mask = block(x, x_size, mask)
1590
+ skips.append(x)
1591
+ elif i > mid:
1592
+ x, x_size, mask = block(x, x_size, None)
1593
+ x = x + skips[mid - i]
1594
+ else:
1595
+ x, x_size, mask = block(x, x_size, None)
1596
+
1597
+ mul_map = torch.ones_like(x) * 0.5
1598
+ mul_map = F.dropout(mul_map, training=True)
1599
+ ws = self.ws_style(ws[:, -1])
1600
+ add_n = self.to_square(ws).unsqueeze(1)
1601
+ add_n = (
1602
+ F.interpolate(
1603
+ add_n, size=x.size(1), mode="linear", align_corners=False
1604
+ )
1605
+ .squeeze(1)
1606
+ .unsqueeze(-1)
1607
+ )
1608
+ x = x * mul_map + add_n * (1 - mul_map)
1609
+ gs = self.to_style(
1610
+ self.down_conv(token2feature(x, x_size)).flatten(start_dim=1)
1611
+ )
1612
+ style = torch.cat([gs, ws], dim=1)
1613
+
1614
+ x = token2feature(x, x_size).contiguous()
1615
+ img = None
1616
+ for i, block in enumerate(self.dec_conv):
1617
+ x, img = block(
1618
+ x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode
1619
+ )
1620
+
1621
+ # ensemble
1622
+ img = img * (1 - masks_in) + images_in * masks_in
1623
+
1624
+ return img
1625
+
1626
+
1627
+ class SynthesisNet(nn.Module):
1628
+ def __init__(
1629
+ self,
1630
+ w_dim, # Intermediate latent (W) dimensionality.
1631
+ img_resolution, # Output image resolution.
1632
+ img_channels=3, # Number of color channels.
1633
+ channel_base=32768, # Overall multiplier for the number of channels.
1634
+ channel_decay=1.0,
1635
+ channel_max=512, # Maximum number of channels in any layer.
1636
+ activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
1637
+ drop_rate=0.5,
1638
+ use_noise=False,
1639
+ demodulate=True,
1640
+ ):
1641
+ super().__init__()
1642
+ resolution_log2 = int(np.log2(img_resolution))
1643
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
1644
+
1645
+ self.num_layers = resolution_log2 * 2 - 3 * 2
1646
+ self.img_resolution = img_resolution
1647
+ self.resolution_log2 = resolution_log2
1648
+
1649
+ # first stage
1650
+ self.first_stage = FirstStage(
1651
+ img_channels,
1652
+ img_resolution=img_resolution,
1653
+ w_dim=w_dim,
1654
+ use_noise=False,
1655
+ demodulate=demodulate,
1656
+ )
1657
+
1658
+ # second stage
1659
+ self.enc = Encoder(
1660
+ resolution_log2, img_channels, activation, patch_size=5, channels=16
1661
+ )
1662
+ self.to_square = FullyConnectedLayer(
1663
+ in_features=w_dim, out_features=16 * 16, activation=activation
1664
+ )
1665
+ self.to_style = ToStyle(
1666
+ in_channels=nf(4),
1667
+ out_channels=nf(2) * 2,
1668
+ activation=activation,
1669
+ drop_rate=drop_rate,
1670
+ )
1671
+ style_dim = w_dim + nf(2) * 2
1672
+ self.dec = Decoder(
1673
+ resolution_log2, activation, style_dim, use_noise, demodulate, img_channels
1674
+ )
1675
+
1676
+ def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False):
1677
+ out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
1678
+
1679
+ # encoder
1680
+ x = images_in * masks_in + out_stg1 * (1 - masks_in)
1681
+ x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1)
1682
+ E_features = self.enc(x)
1683
+
1684
+ fea_16 = E_features[4]
1685
+ mul_map = torch.ones_like(fea_16) * 0.5
1686
+ mul_map = F.dropout(mul_map, training=True)
1687
+ add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1)
1688
+ add_n = F.interpolate(
1689
+ add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False
1690
+ )
1691
+ fea_16 = fea_16 * mul_map + add_n * (1 - mul_map)
1692
+ E_features[4] = fea_16
1693
+
1694
+ # style
1695
+ gs = self.to_style(fea_16)
1696
+
1697
+ # decoder
1698
+ img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode)
1699
+
1700
+ # ensemble
1701
+ img = img * (1 - masks_in) + images_in * masks_in
1702
+
1703
+ if not return_stg1:
1704
+ return img
1705
+ else:
1706
+ return img, out_stg1
1707
+
1708
+
1709
+ class Generator(nn.Module):
1710
+ def __init__(
1711
+ self,
1712
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
1713
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
1714
+ w_dim, # Intermediate latent (W) dimensionality.
1715
+ img_resolution, # resolution of generated image
1716
+ img_channels, # Number of input color channels.
1717
+ synthesis_kwargs={}, # Arguments for SynthesisNetwork.
1718
+ mapping_kwargs={}, # Arguments for MappingNetwork.
1719
+ ):
1720
+ super().__init__()
1721
+ self.z_dim = z_dim
1722
+ self.c_dim = c_dim
1723
+ self.w_dim = w_dim
1724
+ self.img_resolution = img_resolution
1725
+ self.img_channels = img_channels
1726
+
1727
+ self.synthesis = SynthesisNet(
1728
+ w_dim=w_dim,
1729
+ img_resolution=img_resolution,
1730
+ img_channels=img_channels,
1731
+ **synthesis_kwargs,
1732
+ )
1733
+ self.mapping = MappingNet(
1734
+ z_dim=z_dim,
1735
+ c_dim=c_dim,
1736
+ w_dim=w_dim,
1737
+ num_ws=self.synthesis.num_layers,
1738
+ **mapping_kwargs,
1739
+ )
1740
+
1741
+ def forward(
1742
+ self,
1743
+ images_in,
1744
+ masks_in,
1745
+ z,
1746
+ c,
1747
+ truncation_psi=1,
1748
+ truncation_cutoff=None,
1749
+ skip_w_avg_update=False,
1750
+ noise_mode="none",
1751
+ return_stg1=False,
1752
+ ):
1753
+ ws = self.mapping(
1754
+ z,
1755
+ c,
1756
+ truncation_psi=truncation_psi,
1757
+ truncation_cutoff=truncation_cutoff,
1758
+ skip_w_avg_update=skip_w_avg_update,
1759
+ )
1760
+ img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
1761
+ return img
1762
+
1763
+
1764
+ class Discriminator(torch.nn.Module):
1765
+ def __init__(
1766
+ self,
1767
+ c_dim, # Conditioning label (C) dimensionality.
1768
+ img_resolution, # Input resolution.
1769
+ img_channels, # Number of input color channels.
1770
+ channel_base=32768, # Overall multiplier for the number of channels.
1771
+ channel_max=512, # Maximum number of channels in any layer.
1772
+ channel_decay=1,
1773
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
1774
+ activation="lrelu",
1775
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
1776
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
1777
+ ):
1778
+ super().__init__()
1779
+ self.c_dim = c_dim
1780
+ self.img_resolution = img_resolution
1781
+ self.img_channels = img_channels
1782
+
1783
+ resolution_log2 = int(np.log2(img_resolution))
1784
+ assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
1785
+ self.resolution_log2 = resolution_log2
1786
+
1787
+ if cmap_dim == None:
1788
+ cmap_dim = nf(2)
1789
+ if c_dim == 0:
1790
+ cmap_dim = 0
1791
+ self.cmap_dim = cmap_dim
1792
+
1793
+ if c_dim > 0:
1794
+ self.mapping = MappingNet(
1795
+ z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None
1796
+ )
1797
+
1798
+ Dis = [DisFromRGB(img_channels + 1, nf(resolution_log2), activation)]
1799
+ for res in range(resolution_log2, 2, -1):
1800
+ Dis.append(DisBlock(nf(res), nf(res - 1), activation))
1801
+
1802
+ if mbstd_num_channels > 0:
1803
+ Dis.append(
1804
+ MinibatchStdLayer(
1805
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
1806
+ )
1807
+ )
1808
+ Dis.append(
1809
+ Conv2dLayer(
1810
+ nf(2) + mbstd_num_channels, nf(2), kernel_size=3, activation=activation
1811
+ )
1812
+ )
1813
+ self.Dis = nn.Sequential(*Dis)
1814
+
1815
+ self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
1816
+ self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
1817
+
1818
+ # for 64x64
1819
+ Dis_stg1 = [DisFromRGB(img_channels + 1, nf(resolution_log2) // 2, activation)]
1820
+ for res in range(resolution_log2, 2, -1):
1821
+ Dis_stg1.append(DisBlock(nf(res) // 2, nf(res - 1) // 2, activation))
1822
+
1823
+ if mbstd_num_channels > 0:
1824
+ Dis_stg1.append(
1825
+ MinibatchStdLayer(
1826
+ group_size=mbstd_group_size, num_channels=mbstd_num_channels
1827
+ )
1828
+ )
1829
+ Dis_stg1.append(
1830
+ Conv2dLayer(
1831
+ nf(2) // 2 + mbstd_num_channels,
1832
+ nf(2) // 2,
1833
+ kernel_size=3,
1834
+ activation=activation,
1835
+ )
1836
+ )
1837
+ self.Dis_stg1 = nn.Sequential(*Dis_stg1)
1838
+
1839
+ self.fc0_stg1 = FullyConnectedLayer(
1840
+ nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation
1841
+ )
1842
+ self.fc1_stg1 = FullyConnectedLayer(
1843
+ nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim
1844
+ )
1845
+
1846
+ def forward(self, images_in, masks_in, images_stg1, c):
1847
+ x = self.Dis(torch.cat([masks_in - 0.5, images_in], dim=1))
1848
+ x = self.fc1(self.fc0(x.flatten(start_dim=1)))
1849
+
1850
+ x_stg1 = self.Dis_stg1(torch.cat([masks_in - 0.5, images_stg1], dim=1))
1851
+ x_stg1 = self.fc1_stg1(self.fc0_stg1(x_stg1.flatten(start_dim=1)))
1852
+
1853
+ if self.c_dim > 0:
1854
+ cmap = self.mapping(None, c)
1855
+
1856
+ if self.cmap_dim > 0:
1857
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
1858
+ x_stg1 = (x_stg1 * cmap).sum(dim=1, keepdim=True) * (
1859
+ 1 / np.sqrt(self.cmap_dim)
1860
+ )
1861
+
1862
+ return x, x_stg1
1863
+
1864
+
1865
+ MAT_MODEL_URL = os.environ.get(
1866
+ "MAT_MODEL_URL",
1867
+ "https://github.com/Sanster/models/releases/download/add_mat/Places_512_FullData_G.pth",
1868
+ )
1869
+
1870
+ MAT_MODEL_MD5 = os.environ.get("MAT_MODEL_MD5", "8ca927835fa3f5e21d65ffcb165377ed")
1871
+
1872
+
1873
+ class MAT(InpaintModel):
1874
+ name = "mat"
1875
+ min_size = 512
1876
+ pad_mod = 512
1877
+ pad_to_square = True
1878
+
1879
+ def init_model(self, device, **kwargs):
1880
+ seed = 240 # pick up a random number
1881
+ set_seed(seed)
1882
+
1883
+ fp16 = not kwargs.get("no_half", False)
1884
+ use_gpu = "cuda" in str(device) and torch.cuda.is_available()
1885
+ self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
1886
+
1887
+ G = Generator(
1888
+ z_dim=512,
1889
+ c_dim=0,
1890
+ w_dim=512,
1891
+ img_resolution=512,
1892
+ img_channels=3,
1893
+ mapping_kwargs={"torch_dtype": self.torch_dtype},
1894
+ ).to(self.torch_dtype)
1895
+ # fmt: off
1896
+ self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5)
1897
+ self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device)
1898
+ self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
1899
+ # fmt: on
1900
+
1901
+ @staticmethod
1902
+ def is_downloaded() -> bool:
1903
+ return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
1904
+
1905
+ def forward(self, image, mask, config: Config):
1906
+ """Input images and output images have same size
1907
+ images: [H, W, C] RGB
1908
+ masks: [H, W] mask area == 255
1909
+ return: BGR IMAGE
1910
+ """
1911
+
1912
+ image = norm_img(image) # [0, 1]
1913
+ image = image * 2 - 1 # [0, 1] -> [-1, 1]
1914
+
1915
+ mask = (mask > 127) * 255
1916
+ mask = 255 - mask
1917
+ mask = norm_img(mask)
1918
+
1919
+ image = (
1920
+ torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device)
1921
+ )
1922
+ mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device)
1923
+
1924
+ output = self.model(
1925
+ image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"
1926
+ )
1927
+ output = (
1928
+ (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
1929
+ .round()
1930
+ .clamp(0, 255)
1931
+ .to(torch.uint8)
1932
+ )
1933
+ output = output[0].cpu().numpy()
1934
+ cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
1935
+ return cur_res
lama_cleaner/model/opencv2.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from lama_cleaner.model.base import InpaintModel
3
+ from lama_cleaner.schema import Config
4
+
5
+ flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
6
+
7
+
8
+ class OpenCV2(InpaintModel):
9
+ name = "cv2"
10
+ pad_mod = 1
11
+
12
+ @staticmethod
13
+ def is_downloaded() -> bool:
14
+ return True
15
+
16
+ def forward(self, image, mask, config: Config):
17
+ """Input image and output image have same size
18
+ image: [H, W, C] RGB
19
+ mask: [H, W, 1]
20
+ return: BGR IMAGE
21
+ """
22
+ cur_res = cv2.inpaint(
23
+ image[:, :, ::-1],
24
+ mask,
25
+ inpaintRadius=config.cv2_radius,
26
+ flags=flag_map[config.cv2_flag],
27
+ )
28
+ return cur_res
lama_cleaner/model/paint_by_example.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import PIL.Image
3
+ import cv2
4
+ import torch
5
+ from diffusers import DiffusionPipeline
6
+ from loguru import logger
7
+
8
+ from lama_cleaner.model.base import DiffusionInpaintModel
9
+ from lama_cleaner.model.utils import set_seed
10
+ from lama_cleaner.schema import Config
11
+
12
+
13
+ class PaintByExample(DiffusionInpaintModel):
14
+ name = "paint_by_example"
15
+ pad_mod = 8
16
+ min_size = 512
17
+
18
+ def init_model(self, device: torch.device, **kwargs):
19
+ fp16 = not kwargs.get('no_half', False)
20
+ use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
21
+ torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
22
+ model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
23
+
24
+ if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
25
+ logger.info("Disable Paint By Example Model NSFW checker")
26
+ model_kwargs.update(dict(
27
+ safety_checker=None,
28
+ requires_safety_checker=False
29
+ ))
30
+
31
+ self.model = DiffusionPipeline.from_pretrained(
32
+ "Fantasy-Studio/Paint-by-Example",
33
+ torch_dtype=torch_dtype,
34
+ **model_kwargs
35
+ )
36
+
37
+ self.model.enable_attention_slicing()
38
+ if kwargs.get('enable_xformers', False):
39
+ self.model.enable_xformers_memory_efficient_attention()
40
+
41
+ # TODO: gpu_id
42
+ if kwargs.get('cpu_offload', False) and use_gpu:
43
+ self.model.image_encoder = self.model.image_encoder.to(device)
44
+ self.model.enable_sequential_cpu_offload(gpu_id=0)
45
+ else:
46
+ self.model = self.model.to(device)
47
+
48
+ def forward(self, image, mask, config: Config):
49
+ """Input image and output image have same size
50
+ image: [H, W, C] RGB
51
+ mask: [H, W, 1] 255 means area to repaint
52
+ return: BGR IMAGE
53
+ """
54
+ output = self.model(
55
+ image=PIL.Image.fromarray(image),
56
+ mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
57
+ example_image=config.paint_by_example_example_image,
58
+ num_inference_steps=config.paint_by_example_steps,
59
+ output_type='np.array',
60
+ generator=torch.manual_seed(config.paint_by_example_seed)
61
+ ).images[0]
62
+
63
+ output = (output * 255).round().astype("uint8")
64
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
65
+ return output
66
+
67
+ def forward_post_process(self, result, image, mask, config):
68
+ if config.paint_by_example_match_histograms:
69
+ result = self._match_histograms(result, image[:, :, ::-1], mask)
70
+
71
+ if config.paint_by_example_mask_blur != 0:
72
+ k = 2 * config.paint_by_example_mask_blur + 1
73
+ mask = cv2.GaussianBlur(mask, (k, k), 0)
74
+ return result, image, mask
75
+
76
+ @staticmethod
77
+ def is_downloaded() -> bool:
78
+ # model will be downloaded when app start, and can't switch in frontend settings
79
+ return True
lama_cleaner/model/pipeline/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline_stable_diffusion_controlnet_inpaint import (
2
+ StableDiffusionControlNetInpaintPipeline,
3
+ )
lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
16
+
17
+ import torch
18
+ import PIL.Image
19
+ import numpy as np
20
+
21
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
22
+
23
+ EXAMPLE_DOC_STRING = """
24
+ Examples:
25
+ ```py
26
+ >>> # !pip install opencv-python transformers accelerate
27
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
28
+ >>> from diffusers.utils import load_image
29
+ >>> import numpy as np
30
+ >>> import torch
31
+
32
+ >>> import cv2
33
+ >>> from PIL import Image
34
+ >>> # download an image
35
+ >>> image = load_image(
36
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
37
+ ... )
38
+ >>> image = np.array(image)
39
+ >>> mask_image = load_image(
40
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
41
+ ... )
42
+ >>> mask_image = np.array(mask_image)
43
+ >>> # get canny image
44
+ >>> canny_image = cv2.Canny(image, 100, 200)
45
+ >>> canny_image = canny_image[:, :, None]
46
+ >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
47
+ >>> canny_image = Image.fromarray(canny_image)
48
+
49
+ >>> # load control net and stable diffusion v1-5
50
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
51
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
52
+ ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
53
+ ... )
54
+
55
+ >>> # speed up diffusion process with faster scheduler and memory optimization
56
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
57
+ >>> # remove following line if xformers is not installed
58
+ >>> pipe.enable_xformers_memory_efficient_attention()
59
+
60
+ >>> pipe.enable_model_cpu_offload()
61
+
62
+ >>> # generate image
63
+ >>> generator = torch.manual_seed(0)
64
+ >>> image = pipe(
65
+ ... "futuristic-looking doggo",
66
+ ... num_inference_steps=20,
67
+ ... generator=generator,
68
+ ... image=image,
69
+ ... control_image=canny_image,
70
+ ... mask_image=mask_image
71
+ ... ).images[0]
72
+ ```
73
+ """
74
+
75
+
76
+ def prepare_mask_and_masked_image(image, mask):
77
+ """
78
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
79
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
80
+ ``image`` and ``1`` for the ``mask``.
81
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
82
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
83
+ Args:
84
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
85
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
86
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
87
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
88
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
89
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
90
+ Raises:
91
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
92
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
93
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
94
+ (ot the other way around).
95
+ Returns:
96
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
97
+ dimensions: ``batch x channels x height x width``.
98
+ """
99
+ if isinstance(image, torch.Tensor):
100
+ if not isinstance(mask, torch.Tensor):
101
+ raise TypeError(
102
+ f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
103
+ )
104
+
105
+ # Batch single image
106
+ if image.ndim == 3:
107
+ assert (
108
+ image.shape[0] == 3
109
+ ), "Image outside a batch should be of shape (3, H, W)"
110
+ image = image.unsqueeze(0)
111
+
112
+ # Batch and add channel dim for single mask
113
+ if mask.ndim == 2:
114
+ mask = mask.unsqueeze(0).unsqueeze(0)
115
+
116
+ # Batch single mask or add channel dim
117
+ if mask.ndim == 3:
118
+ # Single batched mask, no channel dim or single mask not batched but channel dim
119
+ if mask.shape[0] == 1:
120
+ mask = mask.unsqueeze(0)
121
+
122
+ # Batched masks no channel dim
123
+ else:
124
+ mask = mask.unsqueeze(1)
125
+
126
+ assert (
127
+ image.ndim == 4 and mask.ndim == 4
128
+ ), "Image and Mask must have 4 dimensions"
129
+ assert (
130
+ image.shape[-2:] == mask.shape[-2:]
131
+ ), "Image and Mask must have the same spatial dimensions"
132
+ assert (
133
+ image.shape[0] == mask.shape[0]
134
+ ), "Image and Mask must have the same batch size"
135
+
136
+ # Check image is in [-1, 1]
137
+ if image.min() < -1 or image.max() > 1:
138
+ raise ValueError("Image should be in [-1, 1] range")
139
+
140
+ # Check mask is in [0, 1]
141
+ if mask.min() < 0 or mask.max() > 1:
142
+ raise ValueError("Mask should be in [0, 1] range")
143
+
144
+ # Binarize mask
145
+ mask[mask < 0.5] = 0
146
+ mask[mask >= 0.5] = 1
147
+
148
+ # Image as float32
149
+ image = image.to(dtype=torch.float32)
150
+ elif isinstance(mask, torch.Tensor):
151
+ raise TypeError(
152
+ f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
153
+ )
154
+ else:
155
+ # preprocess image
156
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
157
+ image = [image]
158
+
159
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
160
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
161
+ image = np.concatenate(image, axis=0)
162
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
163
+ image = np.concatenate([i[None, :] for i in image], axis=0)
164
+
165
+ image = image.transpose(0, 3, 1, 2)
166
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
167
+
168
+ # preprocess mask
169
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
170
+ mask = [mask]
171
+
172
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
173
+ mask = np.concatenate(
174
+ [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
175
+ )
176
+ mask = mask.astype(np.float32) / 255.0
177
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
178
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
179
+
180
+ mask[mask < 0.5] = 0
181
+ mask[mask >= 0.5] = 1
182
+ mask = torch.from_numpy(mask)
183
+
184
+ masked_image = image * (mask < 0.5)
185
+
186
+ return mask, masked_image
187
+
188
+
189
+ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
190
+ r"""
191
+ Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
192
+
193
+ This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
194
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
195
+
196
+ Args:
197
+ vae ([`AutoencoderKL`]):
198
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
199
+ text_encoder ([`CLIPTextModel`]):
200
+ Frozen text-encoder. Stable Diffusion uses the text portion of
201
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
202
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
203
+ tokenizer (`CLIPTokenizer`):
204
+ Tokenizer of class
205
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
206
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
207
+ controlnet ([`ControlNetModel`]):
208
+ Provides additional conditioning to the unet during the denoising process
209
+ scheduler ([`SchedulerMixin`]):
210
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
211
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
212
+ safety_checker ([`StableDiffusionSafetyChecker`]):
213
+ Classification module that estimates whether generated images could be considered offensive or harmful.
214
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
215
+ feature_extractor ([`CLIPFeatureExtractor`]):
216
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
217
+ """
218
+
219
+ def prepare_mask_latents(
220
+ self,
221
+ mask,
222
+ masked_image,
223
+ batch_size,
224
+ height,
225
+ width,
226
+ dtype,
227
+ device,
228
+ generator,
229
+ do_classifier_free_guidance,
230
+ ):
231
+ # resize the mask to latents shape as we concatenate the mask to the latents
232
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
233
+ # and half precision
234
+ mask = torch.nn.functional.interpolate(
235
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
236
+ )
237
+ mask = mask.to(device=device, dtype=dtype)
238
+
239
+ masked_image = masked_image.to(device=device, dtype=dtype)
240
+
241
+ # encode the mask image into latents space so we can concatenate it to the latents
242
+ if isinstance(generator, list):
243
+ masked_image_latents = [
244
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
245
+ generator=generator[i]
246
+ )
247
+ for i in range(batch_size)
248
+ ]
249
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
250
+ else:
251
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
252
+ generator=generator
253
+ )
254
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
255
+
256
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
257
+ if mask.shape[0] < batch_size:
258
+ if not batch_size % mask.shape[0] == 0:
259
+ raise ValueError(
260
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
261
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
262
+ " of masks that you pass is divisible by the total requested batch size."
263
+ )
264
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
265
+ if masked_image_latents.shape[0] < batch_size:
266
+ if not batch_size % masked_image_latents.shape[0] == 0:
267
+ raise ValueError(
268
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
269
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
270
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
271
+ )
272
+ masked_image_latents = masked_image_latents.repeat(
273
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
274
+ )
275
+
276
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
277
+ masked_image_latents = (
278
+ torch.cat([masked_image_latents] * 2)
279
+ if do_classifier_free_guidance
280
+ else masked_image_latents
281
+ )
282
+
283
+ # aligning device to prevent device errors when concating it with the latent model input
284
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
285
+ return mask, masked_image_latents
286
+
287
+ @torch.no_grad()
288
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
289
+ def __call__(
290
+ self,
291
+ prompt: Union[str, List[str]] = None,
292
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
293
+ control_image: Union[
294
+ torch.FloatTensor,
295
+ PIL.Image.Image,
296
+ List[torch.FloatTensor],
297
+ List[PIL.Image.Image],
298
+ ] = None,
299
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
300
+ height: Optional[int] = None,
301
+ width: Optional[int] = None,
302
+ num_inference_steps: int = 50,
303
+ guidance_scale: float = 7.5,
304
+ negative_prompt: Optional[Union[str, List[str]]] = None,
305
+ num_images_per_prompt: Optional[int] = 1,
306
+ eta: float = 0.0,
307
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
308
+ latents: Optional[torch.FloatTensor] = None,
309
+ prompt_embeds: Optional[torch.FloatTensor] = None,
310
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
311
+ output_type: Optional[str] = "pil",
312
+ return_dict: bool = True,
313
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
314
+ callback_steps: int = 1,
315
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
316
+ controlnet_conditioning_scale: float = 1.0,
317
+ ):
318
+ r"""
319
+ Function invoked when calling the pipeline for generation.
320
+ Args:
321
+ prompt (`str` or `List[str]`, *optional*):
322
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
323
+ instead.
324
+ image (`PIL.Image.Image`):
325
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
326
+ be masked out with `mask_image` and repainted according to `prompt`.
327
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
328
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
329
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
330
+ also be accepted as an image. The control image is automatically resized to fit the output image.
331
+ mask_image (`PIL.Image.Image`):
332
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
333
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
334
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
335
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
336
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
337
+ The height in pixels of the generated image.
338
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
339
+ The width in pixels of the generated image.
340
+ num_inference_steps (`int`, *optional*, defaults to 50):
341
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
342
+ expense of slower inference.
343
+ guidance_scale (`float`, *optional*, defaults to 7.5):
344
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
345
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
346
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
347
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
348
+ usually at the expense of lower image quality.
349
+ negative_prompt (`str` or `List[str]`, *optional*):
350
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
351
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
352
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
353
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
354
+ The number of images to generate per prompt.
355
+ eta (`float`, *optional*, defaults to 0.0):
356
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
357
+ [`schedulers.DDIMScheduler`], will be ignored for others.
358
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
359
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
360
+ to make generation deterministic.
361
+ latents (`torch.FloatTensor`, *optional*):
362
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
363
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
364
+ tensor will ge generated by sampling using the supplied random `generator`.
365
+ prompt_embeds (`torch.FloatTensor`, *optional*):
366
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
367
+ provided, text embeddings will be generated from `prompt` input argument.
368
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
369
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
370
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
371
+ argument.
372
+ output_type (`str`, *optional*, defaults to `"pil"`):
373
+ The output format of the generate image. Choose between
374
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
375
+ return_dict (`bool`, *optional*, defaults to `True`):
376
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
377
+ plain tuple.
378
+ callback (`Callable`, *optional*):
379
+ A function that will be called every `callback_steps` steps during inference. The function will be
380
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
381
+ callback_steps (`int`, *optional*, defaults to 1):
382
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
383
+ called at every step.
384
+ cross_attention_kwargs (`dict`, *optional*):
385
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
386
+ `self.processor` in
387
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
388
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
389
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
390
+ to the residual in the original unet.
391
+ Examples:
392
+ Returns:
393
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
394
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
395
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
396
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
397
+ (nsfw) content, according to the `safety_checker`.
398
+ """
399
+ # 0. Default height and width to unet
400
+ height, width = self._default_height_width(height, width, control_image)
401
+
402
+ # 1. Check inputs. Raise error if not correct
403
+ self.check_inputs(
404
+ prompt,
405
+ control_image,
406
+ height,
407
+ width,
408
+ callback_steps,
409
+ negative_prompt,
410
+ prompt_embeds,
411
+ negative_prompt_embeds,
412
+ )
413
+
414
+ # 2. Define call parameters
415
+ if prompt is not None and isinstance(prompt, str):
416
+ batch_size = 1
417
+ elif prompt is not None and isinstance(prompt, list):
418
+ batch_size = len(prompt)
419
+ else:
420
+ batch_size = prompt_embeds.shape[0]
421
+
422
+ device = self._execution_device
423
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
424
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
425
+ # corresponds to doing no classifier free guidance.
426
+ do_classifier_free_guidance = guidance_scale > 1.0
427
+
428
+ # 3. Encode input prompt
429
+ prompt_embeds = self._encode_prompt(
430
+ prompt,
431
+ device,
432
+ num_images_per_prompt,
433
+ do_classifier_free_guidance,
434
+ negative_prompt,
435
+ prompt_embeds=prompt_embeds,
436
+ negative_prompt_embeds=negative_prompt_embeds,
437
+ )
438
+
439
+ # 4. Prepare image
440
+ control_image = self.prepare_image(
441
+ control_image,
442
+ width,
443
+ height,
444
+ batch_size * num_images_per_prompt,
445
+ num_images_per_prompt,
446
+ device,
447
+ self.controlnet.dtype,
448
+ )
449
+
450
+ if do_classifier_free_guidance:
451
+ control_image = torch.cat([control_image] * 2)
452
+
453
+ # 5. Prepare timesteps
454
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
455
+ timesteps = self.scheduler.timesteps
456
+
457
+ # 6. Prepare latent variables
458
+ num_channels_latents = self.controlnet.config.in_channels
459
+ latents = self.prepare_latents(
460
+ batch_size * num_images_per_prompt,
461
+ num_channels_latents,
462
+ height,
463
+ width,
464
+ prompt_embeds.dtype,
465
+ device,
466
+ generator,
467
+ latents,
468
+ )
469
+
470
+ # EXTRA: prepare mask latents
471
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
472
+ mask, masked_image_latents = self.prepare_mask_latents(
473
+ mask,
474
+ masked_image,
475
+ batch_size * num_images_per_prompt,
476
+ height,
477
+ width,
478
+ prompt_embeds.dtype,
479
+ device,
480
+ generator,
481
+ do_classifier_free_guidance,
482
+ )
483
+
484
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
485
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
486
+
487
+ # 8. Denoising loop
488
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
489
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
490
+ for i, t in enumerate(timesteps):
491
+ # expand the latents if we are doing classifier free guidance
492
+ latent_model_input = (
493
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
494
+ )
495
+ latent_model_input = self.scheduler.scale_model_input(
496
+ latent_model_input, t
497
+ )
498
+
499
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
500
+ latent_model_input,
501
+ t,
502
+ encoder_hidden_states=prompt_embeds,
503
+ controlnet_cond=control_image,
504
+ return_dict=False,
505
+ )
506
+
507
+ down_block_res_samples = [
508
+ down_block_res_sample * controlnet_conditioning_scale
509
+ for down_block_res_sample in down_block_res_samples
510
+ ]
511
+ mid_block_res_sample *= controlnet_conditioning_scale
512
+
513
+ # predict the noise residual
514
+ latent_model_input = torch.cat(
515
+ [latent_model_input, mask, masked_image_latents], dim=1
516
+ )
517
+ noise_pred = self.unet(
518
+ latent_model_input,
519
+ t,
520
+ encoder_hidden_states=prompt_embeds,
521
+ cross_attention_kwargs=cross_attention_kwargs,
522
+ down_block_additional_residuals=down_block_res_samples,
523
+ mid_block_additional_residual=mid_block_res_sample,
524
+ ).sample
525
+
526
+ # perform guidance
527
+ if do_classifier_free_guidance:
528
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
529
+ noise_pred = noise_pred_uncond + guidance_scale * (
530
+ noise_pred_text - noise_pred_uncond
531
+ )
532
+
533
+ # compute the previous noisy sample x_t -> x_t-1
534
+ latents = self.scheduler.step(
535
+ noise_pred, t, latents, **extra_step_kwargs
536
+ ).prev_sample
537
+
538
+ # call the callback, if provided
539
+ if i == len(timesteps) - 1 or (
540
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
541
+ ):
542
+ progress_bar.update()
543
+ if callback is not None and i % callback_steps == 0:
544
+ callback(i, t, latents)
545
+
546
+ # If we do sequential model offloading, let's offload unet and controlnet
547
+ # manually for max memory savings
548
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
549
+ self.unet.to("cpu")
550
+ self.controlnet.to("cpu")
551
+ torch.cuda.empty_cache()
552
+
553
+ if output_type == "latent":
554
+ image = latents
555
+ has_nsfw_concept = None
556
+ elif output_type == "pil":
557
+ # 8. Post-processing
558
+ image = self.decode_latents(latents)
559
+
560
+ # 9. Run safety checker
561
+ image, has_nsfw_concept = self.run_safety_checker(
562
+ image, device, prompt_embeds.dtype
563
+ )
564
+
565
+ # 10. Convert to PIL
566
+ image = self.numpy_to_pil(image)
567
+ else:
568
+ # 8. Post-processing
569
+ image = self.decode_latents(latents)
570
+
571
+ # 9. Run safety checker
572
+ image, has_nsfw_concept = self.run_safety_checker(
573
+ image, device, prompt_embeds.dtype
574
+ )
575
+
576
+ # Offload last model to CPU
577
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
578
+ self.final_offload_hook.offload()
579
+
580
+ if not return_dict:
581
+ return (image, has_nsfw_concept)
582
+
583
+ return StableDiffusionPipelineOutput(
584
+ images=image, nsfw_content_detected=has_nsfw_concept
585
+ )
lama_cleaner/model/plms_sampler.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
2
+ import torch
3
+ import numpy as np
4
+ from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
5
+ from tqdm import tqdm
6
+
7
+
8
+ class PLMSSampler(object):
9
+ def __init__(self, model, schedule="linear", **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ self.ddpm_num_timesteps = model.num_timesteps
13
+ self.schedule = schedule
14
+
15
+ def register_buffer(self, name, attr):
16
+ setattr(self, name, attr)
17
+
18
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
19
+ if ddim_eta != 0:
20
+ raise ValueError('ddim_eta must be 0 for PLMS')
21
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
22
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
23
+ alphas_cumprod = self.model.alphas_cumprod
24
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
25
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
26
+
27
+ self.register_buffer('betas', to_torch(self.model.betas))
28
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
29
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
30
+
31
+ # calculations for diffusion q(x_t | x_{t-1}) and others
32
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
33
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
34
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
35
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
37
+
38
+ # ddim sampling parameters
39
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
40
+ ddim_timesteps=self.ddim_timesteps,
41
+ eta=ddim_eta, verbose=verbose)
42
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
43
+ self.register_buffer('ddim_alphas', ddim_alphas)
44
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
45
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
46
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
47
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
48
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
49
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
50
+
51
+ @torch.no_grad()
52
+ def sample(self,
53
+ steps,
54
+ batch_size,
55
+ shape,
56
+ conditioning=None,
57
+ callback=None,
58
+ normals_sequence=None,
59
+ img_callback=None,
60
+ quantize_x0=False,
61
+ eta=0.,
62
+ mask=None,
63
+ x0=None,
64
+ temperature=1.,
65
+ noise_dropout=0.,
66
+ score_corrector=None,
67
+ corrector_kwargs=None,
68
+ verbose=False,
69
+ x_T=None,
70
+ log_every_t=100,
71
+ unconditional_guidance_scale=1.,
72
+ unconditional_conditioning=None,
73
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
74
+ **kwargs
75
+ ):
76
+ if conditioning is not None:
77
+ if isinstance(conditioning, dict):
78
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
79
+ if cbs != batch_size:
80
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
81
+ else:
82
+ if conditioning.shape[0] != batch_size:
83
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
84
+
85
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
86
+ # sampling
87
+ C, H, W = shape
88
+ size = (batch_size, C, H, W)
89
+ print(f'Data shape for PLMS sampling is {size}')
90
+
91
+ samples = self.plms_sampling(conditioning, size,
92
+ callback=callback,
93
+ img_callback=img_callback,
94
+ quantize_denoised=quantize_x0,
95
+ mask=mask, x0=x0,
96
+ ddim_use_original_steps=False,
97
+ noise_dropout=noise_dropout,
98
+ temperature=temperature,
99
+ score_corrector=score_corrector,
100
+ corrector_kwargs=corrector_kwargs,
101
+ x_T=x_T,
102
+ log_every_t=log_every_t,
103
+ unconditional_guidance_scale=unconditional_guidance_scale,
104
+ unconditional_conditioning=unconditional_conditioning,
105
+ )
106
+ return samples
107
+
108
+ @torch.no_grad()
109
+ def plms_sampling(self, cond, shape,
110
+ x_T=None, ddim_use_original_steps=False,
111
+ callback=None, timesteps=None, quantize_denoised=False,
112
+ mask=None, x0=None, img_callback=None, log_every_t=100,
113
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
114
+ unconditional_guidance_scale=1., unconditional_conditioning=None, ):
115
+ device = self.model.betas.device
116
+ b = shape[0]
117
+ if x_T is None:
118
+ img = torch.randn(shape, device=device)
119
+ else:
120
+ img = x_T
121
+
122
+ if timesteps is None:
123
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
124
+ elif timesteps is not None and not ddim_use_original_steps:
125
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
126
+ timesteps = self.ddim_timesteps[:subset_end]
127
+
128
+ time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
129
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
130
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
131
+
132
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
133
+ old_eps = []
134
+
135
+ for i, step in enumerate(iterator):
136
+ index = total_steps - i - 1
137
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
138
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
139
+
140
+ if mask is not None:
141
+ assert x0 is not None
142
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
143
+ img = img_orig * mask + (1. - mask) * img
144
+
145
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
146
+ quantize_denoised=quantize_denoised, temperature=temperature,
147
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
148
+ corrector_kwargs=corrector_kwargs,
149
+ unconditional_guidance_scale=unconditional_guidance_scale,
150
+ unconditional_conditioning=unconditional_conditioning,
151
+ old_eps=old_eps, t_next=ts_next)
152
+ img, pred_x0, e_t = outs
153
+ old_eps.append(e_t)
154
+ if len(old_eps) >= 4:
155
+ old_eps.pop(0)
156
+ if callback: callback(i)
157
+ if img_callback: img_callback(pred_x0, i)
158
+
159
+ return img
160
+
161
+ @torch.no_grad()
162
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
163
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
164
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
165
+ b, *_, device = *x.shape, x.device
166
+
167
+ def get_model_output(x, t):
168
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
169
+ e_t = self.model.apply_model(x, t, c)
170
+ else:
171
+ x_in = torch.cat([x] * 2)
172
+ t_in = torch.cat([t] * 2)
173
+ c_in = torch.cat([unconditional_conditioning, c])
174
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
175
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
176
+
177
+ if score_corrector is not None:
178
+ assert self.model.parameterization == "eps"
179
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
180
+
181
+ return e_t
182
+
183
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187
+
188
+ def get_x_prev_and_pred_x0(e_t, index):
189
+ # select parameters corresponding to the currently considered timestep
190
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
191
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
192
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
193
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
194
+
195
+ # current prediction for x_0
196
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
197
+ if quantize_denoised:
198
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
199
+ # direction pointing to x_t
200
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
201
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
202
+ if noise_dropout > 0.:
203
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
204
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
205
+ return x_prev, pred_x0
206
+
207
+ e_t = get_model_output(x, t)
208
+ if len(old_eps) == 0:
209
+ # Pseudo Improved Euler (2nd order)
210
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
211
+ e_t_next = get_model_output(x_prev, t_next)
212
+ e_t_prime = (e_t + e_t_next) / 2
213
+ elif len(old_eps) == 1:
214
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
215
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
216
+ elif len(old_eps) == 2:
217
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
218
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
219
+ elif len(old_eps) >= 3:
220
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
221
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
222
+
223
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
224
+
225
+ return x_prev, pred_x0, e_t