Spaces:
Runtime error
Runtime error
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .gitignore +6 -0
- .gitmodules +6 -0
- GAUSSIAN_SPLATTING_LICENSE.md +83 -0
- LICENSE.txt +21 -0
- arguments/__init__.py +258 -0
- configs/axe.yaml +76 -0
- configs/bagel.yaml +74 -0
- configs/cat_armor.yaml +74 -0
- configs/crown.yaml +74 -0
- configs/football_helmet.yaml +75 -0
- configs/hamburger.yaml +75 -0
- configs/ts_lora.yaml +76 -0
- configs/white_hair_ironman.yaml +73 -0
- configs/zombie_joker.yaml +75 -0
- environment.yml +29 -0
- example/Donut.mp4 +3 -0
- example/boots.mp4 +3 -0
- example/durian.mp4 +3 -0
- example/pillow_huskies.mp4 +3 -0
- example/wooden_car.mp4 +3 -0
- gaussian_renderer/__init__.py +168 -0
- gaussian_renderer/network_gui.py +95 -0
- gradio_demo.py +62 -0
- guidance/perpneg_utils.py +48 -0
- guidance/sd_step.py +264 -0
- guidance/sd_utils.py +487 -0
- lora_diffusion/__init__.py +5 -0
- lora_diffusion/cli_lora_add.py +187 -0
- lora_diffusion/cli_lora_pti.py +1040 -0
- lora_diffusion/cli_pt_to_safetensors.py +85 -0
- lora_diffusion/cli_svd.py +146 -0
- lora_diffusion/dataset.py +311 -0
- lora_diffusion/lora.py +1110 -0
- lora_diffusion/lora_manager.py +144 -0
- lora_diffusion/preprocess_files.py +327 -0
- lora_diffusion/safe_open.py +68 -0
- lora_diffusion/to_ckpt_v2.py +232 -0
- lora_diffusion/utils.py +214 -0
- lora_diffusion/xformers_utils.py +70 -0
- scene/__init__.py +98 -0
- scene/cameras.py +138 -0
- scene/dataset_readers.py +466 -0
- scene/gaussian_model.py +458 -0
- train.py +553 -0
- train.sh +1 -0
- utils/camera_utils.py +98 -0
- utils/general_utils.py +141 -0
- utils/graphics_utils.py +81 -0
- utils/image_utils.py +19 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example/boots.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
example/Donut.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
example/durian.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
example/pillow_huskies.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
example/wooden_car.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
.vscode
|
3 |
+
output
|
4 |
+
build
|
5 |
+
output/
|
6 |
+
point_e_model_cache/
|
.gitmodules
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "submodules/diff-gaussian-rasterization"]
|
2 |
+
path = submodules/diff-gaussian-rasterization
|
3 |
+
url = https://github.com/YixunLiang/diff-gaussian-rasterization.git
|
4 |
+
[submodule "submodules/simple-knn"]
|
5 |
+
path = submodules/simple-knn
|
6 |
+
url = https://github.com/YixunLiang/simple-knn.git
|
GAUSSIAN_SPLATTING_LICENSE.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Gaussian-Splatting License
|
2 |
+
===========================
|
3 |
+
|
4 |
+
**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
|
5 |
+
The *Software* is in the process of being registered with the Agence pour la Protection des
|
6 |
+
Programmes (APP).
|
7 |
+
|
8 |
+
The *Software* is still being developed by the *Licensor*.
|
9 |
+
|
10 |
+
*Licensor*'s goal is to allow the research community to use, test and evaluate
|
11 |
+
the *Software*.
|
12 |
+
|
13 |
+
## 1. Definitions
|
14 |
+
|
15 |
+
*Licensee* means any person or entity that uses the *Software* and distributes
|
16 |
+
its *Work*.
|
17 |
+
|
18 |
+
*Licensor* means the owners of the *Software*, i.e Inria and MPII
|
19 |
+
|
20 |
+
*Software* means the original work of authorship made available under this
|
21 |
+
License ie gaussian-splatting.
|
22 |
+
|
23 |
+
*Work* means the *Software* and any additions to or derivative works of the
|
24 |
+
*Software* that are made available under this License.
|
25 |
+
|
26 |
+
|
27 |
+
## 2. Purpose
|
28 |
+
This license is intended to define the rights granted to the *Licensee* by
|
29 |
+
Licensors under the *Software*.
|
30 |
+
|
31 |
+
## 3. Rights granted
|
32 |
+
|
33 |
+
For the above reasons Licensors have decided to distribute the *Software*.
|
34 |
+
Licensors grant non-exclusive rights to use the *Software* for research purposes
|
35 |
+
to research users (both academic and industrial), free of charge, without right
|
36 |
+
to sublicense.. The *Software* may be used "non-commercially", i.e., for research
|
37 |
+
and/or evaluation purposes only.
|
38 |
+
|
39 |
+
Subject to the terms and conditions of this License, you are granted a
|
40 |
+
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
|
41 |
+
publicly display, publicly perform and distribute its *Work* and any resulting
|
42 |
+
derivative works in any form.
|
43 |
+
|
44 |
+
## 4. Limitations
|
45 |
+
|
46 |
+
**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
|
47 |
+
so under this License, (b) you include a complete copy of this License with
|
48 |
+
your distribution, and (c) you retain without modification any copyright,
|
49 |
+
patent, trademark, or attribution notices that are present in the *Work*.
|
50 |
+
|
51 |
+
**4.2 Derivative Works.** You may specify that additional or different terms apply
|
52 |
+
to the use, reproduction, and distribution of your derivative works of the *Work*
|
53 |
+
("Your Terms") only if (a) Your Terms provide that the use limitation in
|
54 |
+
Section 2 applies to your derivative works, and (b) you identify the specific
|
55 |
+
derivative works that are subject to Your Terms. Notwithstanding Your Terms,
|
56 |
+
this License (including the redistribution requirements in Section 3.1) will
|
57 |
+
continue to apply to the *Work* itself.
|
58 |
+
|
59 |
+
**4.3** Any other use without of prior consent of Licensors is prohibited. Research
|
60 |
+
users explicitly acknowledge having received from Licensors all information
|
61 |
+
allowing to appreciate the adequacy between of the *Software* and their needs and
|
62 |
+
to undertake all necessary precautions for its execution and use.
|
63 |
+
|
64 |
+
**4.4** The *Software* is provided both as a compiled library file and as source
|
65 |
+
code. In case of using the *Software* for a publication or other results obtained
|
66 |
+
through the use of the *Software*, users are strongly encouraged to cite the
|
67 |
+
corresponding publications as explained in the documentation of the *Software*.
|
68 |
+
|
69 |
+
## 5. Disclaimer
|
70 |
+
|
71 |
+
THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
|
72 |
+
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
|
73 |
+
UNAUTHORIZED USE: [email protected] . ANY SUCH ACTION WILL
|
74 |
+
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
|
75 |
+
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
|
76 |
+
USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
|
77 |
+
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
|
78 |
+
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
79 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
80 |
+
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
|
81 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
82 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
|
83 |
+
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
|
LICENSE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 dreamgaussian
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
arguments/__init__.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from argparse import ArgumentParser, Namespace
|
13 |
+
import sys
|
14 |
+
import os
|
15 |
+
|
16 |
+
class GroupParams:
|
17 |
+
pass
|
18 |
+
|
19 |
+
class ParamGroup:
|
20 |
+
def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
|
21 |
+
group = parser.add_argument_group(name)
|
22 |
+
for key, value in vars(self).items():
|
23 |
+
shorthand = False
|
24 |
+
if key.startswith("_"):
|
25 |
+
shorthand = True
|
26 |
+
key = key[1:]
|
27 |
+
t = type(value)
|
28 |
+
value = value if not fill_none else None
|
29 |
+
if shorthand:
|
30 |
+
if t == bool:
|
31 |
+
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
|
32 |
+
else:
|
33 |
+
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
|
34 |
+
else:
|
35 |
+
if t == bool:
|
36 |
+
group.add_argument("--" + key, default=value, action="store_true")
|
37 |
+
else:
|
38 |
+
group.add_argument("--" + key, default=value, type=t)
|
39 |
+
|
40 |
+
def extract(self, args):
|
41 |
+
group = GroupParams()
|
42 |
+
for arg in vars(args).items():
|
43 |
+
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
|
44 |
+
setattr(group, arg[0], arg[1])
|
45 |
+
return group
|
46 |
+
|
47 |
+
def load_yaml(self, opts=None):
|
48 |
+
if opts is None:
|
49 |
+
return
|
50 |
+
else:
|
51 |
+
for key, value in opts.items():
|
52 |
+
try:
|
53 |
+
setattr(self, key, value)
|
54 |
+
except:
|
55 |
+
raise Exception(f'Unknown attribute {key}')
|
56 |
+
|
57 |
+
class GuidanceParams(ParamGroup):
|
58 |
+
def __init__(self, parser, opts=None):
|
59 |
+
self.guidance = "SD"
|
60 |
+
self.g_device = "cuda"
|
61 |
+
|
62 |
+
self.model_key = None
|
63 |
+
self.is_safe_tensor = False
|
64 |
+
self.base_model_key = None
|
65 |
+
|
66 |
+
self.controlnet_model_key = None
|
67 |
+
|
68 |
+
self.perpneg = True
|
69 |
+
self.negative_w = -2.
|
70 |
+
self.front_decay_factor = 2.
|
71 |
+
self.side_decay_factor = 10.
|
72 |
+
|
73 |
+
self.vram_O = False
|
74 |
+
self.fp16 = True
|
75 |
+
self.hf_key = None
|
76 |
+
self.t_range = [0.02, 0.5]
|
77 |
+
self.max_t_range = 0.98
|
78 |
+
|
79 |
+
self.scheduler_type = 'DDIM'
|
80 |
+
self.num_train_timesteps = None
|
81 |
+
|
82 |
+
self.sds = False
|
83 |
+
self.fix_noise = False
|
84 |
+
self.noise_seed = 0
|
85 |
+
|
86 |
+
self.ddim_inv = False
|
87 |
+
self.delta_t = 80
|
88 |
+
self.delta_t_start = 100
|
89 |
+
self.annealing_intervals = True
|
90 |
+
self.text = ''
|
91 |
+
self.inverse_text = ''
|
92 |
+
self.textual_inversion_path = None
|
93 |
+
self.LoRA_path = None
|
94 |
+
self.controlnet_ratio = 0.5
|
95 |
+
self.negative = ""
|
96 |
+
self.guidance_scale = 7.5
|
97 |
+
self.denoise_guidance_scale = 1.0
|
98 |
+
self.lambda_guidance = 1.
|
99 |
+
|
100 |
+
self.xs_delta_t = 200
|
101 |
+
self.xs_inv_steps = 5
|
102 |
+
self.xs_eta = 0.0
|
103 |
+
|
104 |
+
# multi-batch
|
105 |
+
self.C_batch_size = 1
|
106 |
+
|
107 |
+
self.vis_interval = 100
|
108 |
+
|
109 |
+
super().__init__(parser, "Guidance Model Parameters")
|
110 |
+
|
111 |
+
|
112 |
+
class ModelParams(ParamGroup):
|
113 |
+
def __init__(self, parser, sentinel=False, opts=None):
|
114 |
+
self.sh_degree = 0
|
115 |
+
self._source_path = ""
|
116 |
+
self._model_path = ""
|
117 |
+
self.pretrained_model_path = None
|
118 |
+
self._images = "images"
|
119 |
+
self.workspace = "debug"
|
120 |
+
self.batch = 10
|
121 |
+
self._resolution = -1
|
122 |
+
self._white_background = True
|
123 |
+
self.data_device = "cuda"
|
124 |
+
self.eval = False
|
125 |
+
self.opt_path = None
|
126 |
+
|
127 |
+
# augmentation
|
128 |
+
self.sh_deg_aug_ratio = 0.1
|
129 |
+
self.bg_aug_ratio = 0.5
|
130 |
+
self.shs_aug_ratio = 0.0
|
131 |
+
self.scale_aug_ratio = 1.0
|
132 |
+
super().__init__(parser, "Loading Parameters", sentinel)
|
133 |
+
|
134 |
+
def extract(self, args):
|
135 |
+
g = super().extract(args)
|
136 |
+
g.source_path = os.path.abspath(g.source_path)
|
137 |
+
return g
|
138 |
+
|
139 |
+
|
140 |
+
class PipelineParams(ParamGroup):
|
141 |
+
def __init__(self, parser, opts=None):
|
142 |
+
self.convert_SHs_python = False
|
143 |
+
self.compute_cov3D_python = False
|
144 |
+
self.debug = False
|
145 |
+
super().__init__(parser, "Pipeline Parameters")
|
146 |
+
|
147 |
+
|
148 |
+
class OptimizationParams(ParamGroup):
|
149 |
+
def __init__(self, parser, opts=None):
|
150 |
+
self.iterations = 5000# 10_000
|
151 |
+
self.position_lr_init = 0.00016
|
152 |
+
self.position_lr_final = 0.0000016
|
153 |
+
self.position_lr_delay_mult = 0.01
|
154 |
+
self.position_lr_max_steps = 30_000
|
155 |
+
self.feature_lr = 0.0050
|
156 |
+
self.feature_lr_final = 0.0030
|
157 |
+
|
158 |
+
self.opacity_lr = 0.05
|
159 |
+
self.scaling_lr = 0.005
|
160 |
+
self.rotation_lr = 0.001
|
161 |
+
|
162 |
+
|
163 |
+
self.geo_iter = 0
|
164 |
+
self.as_latent_ratio = 0.2
|
165 |
+
# dense
|
166 |
+
|
167 |
+
self.resnet_lr = 1e-4
|
168 |
+
self.resnet_lr_init = 2e-3
|
169 |
+
self.resnet_lr_final = 5e-5
|
170 |
+
|
171 |
+
|
172 |
+
self.scaling_lr_final = 0.001
|
173 |
+
self.rotation_lr_final = 0.0002
|
174 |
+
|
175 |
+
self.percent_dense = 0.003
|
176 |
+
self.densify_grad_threshold = 0.00075
|
177 |
+
|
178 |
+
self.lambda_tv = 1.0 # 0.1
|
179 |
+
self.lambda_bin = 10.0
|
180 |
+
self.lambda_scale = 1.0
|
181 |
+
self.lambda_sat = 1.0
|
182 |
+
self.lambda_radius = 1.0
|
183 |
+
self.densification_interval = 100
|
184 |
+
self.opacity_reset_interval = 300
|
185 |
+
self.densify_from_iter = 100
|
186 |
+
self.densify_until_iter = 30_00
|
187 |
+
|
188 |
+
self.use_control_net_iter = 10000000
|
189 |
+
self.warmup_iter = 1500
|
190 |
+
|
191 |
+
self.use_progressive = False
|
192 |
+
self.save_process = True
|
193 |
+
self.pro_frames_num = 600
|
194 |
+
self.pro_render_45 = False
|
195 |
+
self.progressive_view_iter = 500
|
196 |
+
self.progressive_view_init_ratio = 0.2
|
197 |
+
|
198 |
+
self.scale_up_cameras_iter = 500
|
199 |
+
self.scale_up_factor = 0.95
|
200 |
+
self.fovy_scale_up_factor = [0.75, 1.1]
|
201 |
+
self.phi_scale_up_factor = 1.5
|
202 |
+
super().__init__(parser, "Optimization Parameters")
|
203 |
+
|
204 |
+
|
205 |
+
class GenerateCamParams(ParamGroup):
|
206 |
+
def __init__(self, parser):
|
207 |
+
self.init_shape = 'sphere'
|
208 |
+
self.init_prompt = ''
|
209 |
+
self.use_pointe_rgb = False
|
210 |
+
self.radius_range = [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
211 |
+
self.max_radius_range = [3.5, 5.0]
|
212 |
+
self.default_radius = 3.5
|
213 |
+
self.theta_range = [45, 105]
|
214 |
+
self.max_theta_range = [45, 105]
|
215 |
+
self.phi_range = [-180, 180]
|
216 |
+
self.max_phi_range = [-180, 180]
|
217 |
+
self.fovy_range = [0.32, 0.60] #[0.3, 1.5] #[0.5, 0.8] #[10, 30]
|
218 |
+
self.max_fovy_range = [0.16, 0.60]
|
219 |
+
self.rand_cam_gamma = 1.0
|
220 |
+
self.angle_overhead = 30
|
221 |
+
self.angle_front =60
|
222 |
+
self.render_45 = True
|
223 |
+
self.uniform_sphere_rate = 0
|
224 |
+
self.image_w = 512
|
225 |
+
self.image_h = 512 # 512
|
226 |
+
self.SSAA = 1
|
227 |
+
self.init_num_pts = 100_000
|
228 |
+
self.default_polar = 90
|
229 |
+
self.default_azimuth = 0
|
230 |
+
self.default_fovy = 0.55 #20
|
231 |
+
self.jitter_pose = True
|
232 |
+
self.jitter_center = 0.05
|
233 |
+
self.jitter_target = 0.05
|
234 |
+
self.jitter_up = 0.01
|
235 |
+
self.device = "cuda"
|
236 |
+
super().__init__(parser, "Generate Cameras Parameters")
|
237 |
+
|
238 |
+
def get_combined_args(parser : ArgumentParser):
|
239 |
+
cmdlne_string = sys.argv[1:]
|
240 |
+
cfgfile_string = "Namespace()"
|
241 |
+
args_cmdline = parser.parse_args(cmdlne_string)
|
242 |
+
|
243 |
+
try:
|
244 |
+
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
|
245 |
+
print("Looking for config file in", cfgfilepath)
|
246 |
+
with open(cfgfilepath) as cfg_file:
|
247 |
+
print("Config file found: {}".format(cfgfilepath))
|
248 |
+
cfgfile_string = cfg_file.read()
|
249 |
+
except TypeError:
|
250 |
+
print("Config file not found at")
|
251 |
+
pass
|
252 |
+
args_cfgfile = eval(cfgfile_string)
|
253 |
+
|
254 |
+
merged_dict = vars(args_cfgfile).copy()
|
255 |
+
for k,v in vars(args_cmdline).items():
|
256 |
+
if v != None:
|
257 |
+
merged_dict[k] = v
|
258 |
+
return Namespace(**merged_dict)
|
configs/axe.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: viking_axe
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'Viking axe, fantasy, weapon, blender, 8k, HDR.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: false
|
18 |
+
C_batch_size: 4
|
19 |
+
|
20 |
+
t_range: [0.02, 0.5]
|
21 |
+
max_t_range: 0.98
|
22 |
+
lambda_guidance: 0.1
|
23 |
+
guidance_scale: 7.5
|
24 |
+
denoise_guidance_scale: 1.0
|
25 |
+
noise_seed: 0
|
26 |
+
|
27 |
+
ddim_inv: true
|
28 |
+
accum: false
|
29 |
+
annealing_intervals: true
|
30 |
+
|
31 |
+
xs_delta_t: 200
|
32 |
+
xs_inv_steps: 5
|
33 |
+
xs_eta: 0.0
|
34 |
+
|
35 |
+
delta_t: 25
|
36 |
+
delta_t_start: 100
|
37 |
+
|
38 |
+
GenerateCamParams:
|
39 |
+
init_shape: 'pointe'
|
40 |
+
init_prompt: 'A flag.'
|
41 |
+
use_pointe_rgb: false
|
42 |
+
init_num_pts: 100_000
|
43 |
+
phi_range: [-180, 180]
|
44 |
+
max_phi_range: [-180, 180]
|
45 |
+
rand_cam_gamma: 1.
|
46 |
+
|
47 |
+
theta_range: [45, 105]
|
48 |
+
max_theta_range: [45, 105]
|
49 |
+
|
50 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
51 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
52 |
+
default_radius: 3.5
|
53 |
+
|
54 |
+
default_fovy: 0.55
|
55 |
+
fovy_range: [0.32, 0.60]
|
56 |
+
max_fovy_range: [0.16, 0.60]
|
57 |
+
|
58 |
+
OptimizationParams:
|
59 |
+
iterations: 5000
|
60 |
+
save_process: True
|
61 |
+
pro_frames_num: 600
|
62 |
+
pro_render_45: False
|
63 |
+
warmup_iter: 1500 # 2500
|
64 |
+
|
65 |
+
as_latent_ratio : 0.2
|
66 |
+
geo_iter : 0
|
67 |
+
densify_from_iter: 100
|
68 |
+
densify_until_iter: 3000
|
69 |
+
percent_dense: 0.003
|
70 |
+
densify_grad_threshold: 0.00075
|
71 |
+
progressive_view_iter: 500 #1500
|
72 |
+
opacity_reset_interval: 300 #500
|
73 |
+
|
74 |
+
scale_up_cameras_iter: 500
|
75 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
76 |
+
phi_scale_up_factor: 1.5
|
configs/bagel.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: bagel
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'a DSLR photo of a bagel filled with cream cheese and lox.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: false
|
18 |
+
C_batch_size: 4
|
19 |
+
t_range: [0.02, 0.5]
|
20 |
+
max_t_range: 0.98
|
21 |
+
lambda_guidance: 0.1
|
22 |
+
guidance_scale: 7.5
|
23 |
+
denoise_guidance_scale: 1.0
|
24 |
+
noise_seed: 0
|
25 |
+
|
26 |
+
ddim_inv: true
|
27 |
+
annealing_intervals: true
|
28 |
+
|
29 |
+
xs_delta_t: 200
|
30 |
+
xs_inv_steps: 5
|
31 |
+
xs_eta: 0.0
|
32 |
+
|
33 |
+
delta_t: 80
|
34 |
+
delta_t_start: 100
|
35 |
+
|
36 |
+
GenerateCamParams:
|
37 |
+
init_shape: 'pointe'
|
38 |
+
init_prompt: 'a bagel.'
|
39 |
+
use_pointe_rgb: false
|
40 |
+
init_num_pts: 100_000
|
41 |
+
phi_range: [-180, 180]
|
42 |
+
max_phi_range: [-180, 180]
|
43 |
+
rand_cam_gamma: 1.
|
44 |
+
|
45 |
+
theta_range: [45, 105]
|
46 |
+
max_theta_range: [45, 105]
|
47 |
+
|
48 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
49 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
50 |
+
default_radius: 3.5
|
51 |
+
|
52 |
+
default_fovy: 0.55
|
53 |
+
fovy_range: [0.32, 0.60]
|
54 |
+
max_fovy_range: [0.16, 0.60]
|
55 |
+
|
56 |
+
OptimizationParams:
|
57 |
+
iterations: 5000
|
58 |
+
save_process: True
|
59 |
+
pro_frames_num: 600
|
60 |
+
pro_render_45: False
|
61 |
+
warmup_iter: 1500 # 2500
|
62 |
+
|
63 |
+
as_latent_ratio : 0.2
|
64 |
+
geo_iter : 0
|
65 |
+
densify_from_iter: 100
|
66 |
+
densify_until_iter: 3000
|
67 |
+
percent_dense: 0.003
|
68 |
+
densify_grad_threshold: 0.00075
|
69 |
+
progressive_view_iter: 500 #1500
|
70 |
+
opacity_reset_interval: 300 #500
|
71 |
+
|
72 |
+
scale_up_cameras_iter: 500
|
73 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
74 |
+
phi_scale_up_factor: 1.5
|
configs/cat_armor.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: cat_armor
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'a DSLR photo of a cat wearing armor.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: true
|
18 |
+
C_batch_size: 4
|
19 |
+
t_range: [0.02, 0.5]
|
20 |
+
max_t_range: 0.98
|
21 |
+
lambda_guidance: 0.1
|
22 |
+
guidance_scale: 7.5
|
23 |
+
denoise_guidance_scale: 1.0
|
24 |
+
noise_seed: 0
|
25 |
+
|
26 |
+
ddim_inv: true
|
27 |
+
annealing_intervals: true
|
28 |
+
|
29 |
+
xs_delta_t: 200
|
30 |
+
xs_inv_steps: 5
|
31 |
+
xs_eta: 0.0
|
32 |
+
|
33 |
+
delta_t: 80
|
34 |
+
delta_t_start: 100
|
35 |
+
|
36 |
+
GenerateCamParams:
|
37 |
+
init_shape: 'pointe'
|
38 |
+
init_prompt: 'a cat.'
|
39 |
+
use_pointe_rgb: false
|
40 |
+
init_num_pts: 100_000
|
41 |
+
phi_range: [-180, 180]
|
42 |
+
max_phi_range: [-180, 180]
|
43 |
+
rand_cam_gamma: 1.5
|
44 |
+
|
45 |
+
theta_range: [60, 90]
|
46 |
+
max_theta_range: [60, 90]
|
47 |
+
|
48 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
49 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
50 |
+
default_radius: 3.5
|
51 |
+
|
52 |
+
default_fovy: 0.55
|
53 |
+
fovy_range: [0.32, 0.60]
|
54 |
+
max_fovy_range: [0.16, 0.60]
|
55 |
+
|
56 |
+
OptimizationParams:
|
57 |
+
iterations: 5000
|
58 |
+
save_process: True
|
59 |
+
pro_frames_num: 600
|
60 |
+
pro_render_45: False
|
61 |
+
warmup_iter: 1500 # 2500
|
62 |
+
|
63 |
+
as_latent_ratio : 0.2
|
64 |
+
geo_iter : 0
|
65 |
+
densify_from_iter: 100
|
66 |
+
densify_until_iter: 3000
|
67 |
+
percent_dense: 0.003
|
68 |
+
densify_grad_threshold: 0.00075
|
69 |
+
progressive_view_iter: 500 #1500
|
70 |
+
opacity_reset_interval: 300 #500
|
71 |
+
|
72 |
+
scale_up_cameras_iter: 500
|
73 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
74 |
+
phi_scale_up_factor: 1.5
|
configs/crown.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: crown
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'a DSLR photo of the Imperial State Crown of England.'
|
15 |
+
negative: 'unrealistic, blurry, low quality.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: false
|
18 |
+
C_batch_size: 4
|
19 |
+
t_range: [0.02, 0.5]
|
20 |
+
max_t_range: 0.98
|
21 |
+
lambda_guidance: 0.1
|
22 |
+
guidance_scale: 7.5
|
23 |
+
denoise_guidance_scale: 1.0
|
24 |
+
noise_seed: 0
|
25 |
+
|
26 |
+
ddim_inv: true
|
27 |
+
annealing_intervals: true
|
28 |
+
|
29 |
+
xs_delta_t: 200
|
30 |
+
xs_inv_steps: 5
|
31 |
+
xs_eta: 0.0
|
32 |
+
|
33 |
+
delta_t: 80
|
34 |
+
delta_t_start: 100
|
35 |
+
|
36 |
+
GenerateCamParams:
|
37 |
+
init_shape: 'pointe'
|
38 |
+
init_prompt: 'the Imperial State Crown of England.'
|
39 |
+
use_pointe_rgb: false
|
40 |
+
init_num_pts: 100_000
|
41 |
+
phi_range: [-180, 180]
|
42 |
+
max_phi_range: [-180, 180]
|
43 |
+
rand_cam_gamma: 1.
|
44 |
+
|
45 |
+
theta_range: [45, 105]
|
46 |
+
max_theta_range: [45, 105]
|
47 |
+
|
48 |
+
radius_range: [5.2, 5.5]
|
49 |
+
max_radius_range: [3.5, 5.0]
|
50 |
+
default_radius: 3.5
|
51 |
+
|
52 |
+
default_fovy: 0.55
|
53 |
+
fovy_range: [0.32, 0.60]
|
54 |
+
max_fovy_range: [0.16, 0.60]
|
55 |
+
|
56 |
+
OptimizationParams:
|
57 |
+
iterations: 5000
|
58 |
+
save_process: True
|
59 |
+
pro_frames_num: 600
|
60 |
+
pro_render_45: False
|
61 |
+
warmup_iter: 1500 # 2500
|
62 |
+
|
63 |
+
as_latent_ratio : 0.2
|
64 |
+
geo_iter : 0
|
65 |
+
densify_from_iter: 100
|
66 |
+
densify_until_iter: 3000
|
67 |
+
percent_dense: 0.003
|
68 |
+
densify_grad_threshold: 0.00075
|
69 |
+
progressive_view_iter: 500
|
70 |
+
opacity_reset_interval: 300
|
71 |
+
|
72 |
+
scale_up_cameras_iter: 500
|
73 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
74 |
+
phi_scale_up_factor: 1.5
|
configs/football_helmet.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: football_helmet
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'a DSLR photo of a football helmet.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: false
|
18 |
+
C_batch_size: 4
|
19 |
+
t_range: [0.02, 0.5]
|
20 |
+
max_t_range: 0.98
|
21 |
+
lambda_guidance: 0.1
|
22 |
+
guidance_scale: 7.5
|
23 |
+
denoise_guidance_scale: 1.0
|
24 |
+
|
25 |
+
noise_seed: 0
|
26 |
+
|
27 |
+
ddim_inv: true
|
28 |
+
accum: false
|
29 |
+
annealing_intervals: true
|
30 |
+
|
31 |
+
xs_delta_t: 200
|
32 |
+
xs_inv_steps: 5
|
33 |
+
xs_eta: 0.0
|
34 |
+
|
35 |
+
delta_t: 50
|
36 |
+
delta_t_start: 100
|
37 |
+
|
38 |
+
GenerateCamParams:
|
39 |
+
init_shape: 'pointe'
|
40 |
+
init_prompt: 'a football helmet.'
|
41 |
+
use_pointe_rgb: false
|
42 |
+
init_num_pts: 100_000
|
43 |
+
phi_range: [-180, 180]
|
44 |
+
max_phi_range: [-180, 180]
|
45 |
+
|
46 |
+
theta_range: [45, 90]
|
47 |
+
max_theta_range: [45, 90]
|
48 |
+
|
49 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
50 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
51 |
+
default_radius: 3.5
|
52 |
+
|
53 |
+
default_fovy: 0.55
|
54 |
+
fovy_range: [0.32, 0.60]
|
55 |
+
max_fovy_range: [0.16, 0.60]
|
56 |
+
|
57 |
+
OptimizationParams:
|
58 |
+
iterations: 5000
|
59 |
+
save_process: True
|
60 |
+
pro_frames_num: 600
|
61 |
+
pro_render_45: False
|
62 |
+
warmup_iter: 1500 # 2500
|
63 |
+
|
64 |
+
as_latent_ratio : 0.2
|
65 |
+
geo_iter : 0
|
66 |
+
densify_from_iter: 100
|
67 |
+
densify_until_iter: 3000
|
68 |
+
percent_dense: 0.003
|
69 |
+
densify_grad_threshold: 0.00075
|
70 |
+
progressive_view_iter: 500 #1500
|
71 |
+
opacity_reset_interval: 300 #500
|
72 |
+
|
73 |
+
scale_up_cameras_iter: 500
|
74 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
75 |
+
phi_scale_up_factor: 1.5
|
configs/hamburger.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: hamburger
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'A delicious hamburger.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: false
|
18 |
+
C_batch_size: 4
|
19 |
+
t_range: [0.02, 0.5]
|
20 |
+
max_t_range: 0.98
|
21 |
+
lambda_guidance: 0.1
|
22 |
+
guidance_scale: 7.5
|
23 |
+
denoise_guidance_scale: 1.0
|
24 |
+
|
25 |
+
noise_seed: 0
|
26 |
+
|
27 |
+
ddim_inv: true
|
28 |
+
annealing_intervals: true
|
29 |
+
|
30 |
+
xs_delta_t: 200
|
31 |
+
xs_inv_steps: 5
|
32 |
+
xs_eta: 0.0
|
33 |
+
|
34 |
+
delta_t: 50
|
35 |
+
delta_t_start: 100
|
36 |
+
|
37 |
+
GenerateCamParams:
|
38 |
+
init_shape: 'sphere'
|
39 |
+
init_prompt: '.'
|
40 |
+
use_pointe_rgb: false
|
41 |
+
init_num_pts: 100_000
|
42 |
+
phi_range: [-180, 180]
|
43 |
+
max_phi_range: [-180, 180]
|
44 |
+
rand_cam_gamma: 1.
|
45 |
+
|
46 |
+
theta_range: [45, 105]
|
47 |
+
max_theta_range: [45, 105]
|
48 |
+
|
49 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
50 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
51 |
+
default_radius: 3.5
|
52 |
+
|
53 |
+
default_fovy: 0.55
|
54 |
+
fovy_range: [0.32, 0.60]
|
55 |
+
max_fovy_range: [0.16, 0.60]
|
56 |
+
|
57 |
+
OptimizationParams:
|
58 |
+
iterations: 5000
|
59 |
+
save_process: True
|
60 |
+
pro_frames_num: 600
|
61 |
+
pro_render_45: False
|
62 |
+
warmup_iter: 1500 # 2500
|
63 |
+
|
64 |
+
as_latent_ratio : 0.2
|
65 |
+
geo_iter : 0
|
66 |
+
densify_from_iter: 100
|
67 |
+
densify_until_iter: 3000
|
68 |
+
percent_dense: 0.003
|
69 |
+
densify_grad_threshold: 0.00075
|
70 |
+
progressive_view_iter: 500 #1500
|
71 |
+
opacity_reset_interval: 300 #500
|
72 |
+
|
73 |
+
scale_up_cameras_iter: 500
|
74 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
75 |
+
phi_scale_up_factor: 1.5
|
configs/ts_lora.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: TS_lora
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'A <Taylor_Swift> wearing sunglasses.'
|
15 |
+
LoRA_path: "./custom_example/lora/Taylor_Swift/step_inv_1000.safetensors"
|
16 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
|
17 |
+
inverse_text: ''
|
18 |
+
perpneg: true
|
19 |
+
C_batch_size: 4
|
20 |
+
t_range: [0.02, 0.5]
|
21 |
+
max_t_range: 0.98
|
22 |
+
lambda_guidance: 0.1
|
23 |
+
guidance_scale: 7.5
|
24 |
+
denoise_guidance_scale: 1.0
|
25 |
+
|
26 |
+
noise_seed: 0
|
27 |
+
|
28 |
+
ddim_inv: true
|
29 |
+
annealing_intervals: true
|
30 |
+
|
31 |
+
xs_delta_t: 200
|
32 |
+
xs_inv_steps: 5
|
33 |
+
xs_eta: 0.0
|
34 |
+
|
35 |
+
delta_t: 80
|
36 |
+
delta_t_start: 100
|
37 |
+
|
38 |
+
GenerateCamParams:
|
39 |
+
init_shape: 'pointe'
|
40 |
+
init_prompt: 'a girl head.'
|
41 |
+
use_pointe_rgb: false
|
42 |
+
init_num_pts: 100_000
|
43 |
+
phi_range: [-80, 80]
|
44 |
+
max_phi_range: [-180, 180]
|
45 |
+
rand_cam_gamma: 1.5
|
46 |
+
|
47 |
+
theta_range: [60, 120]
|
48 |
+
max_theta_range: [60, 120]
|
49 |
+
|
50 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
51 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
52 |
+
default_radius: 3.5
|
53 |
+
|
54 |
+
default_fovy: 0.55
|
55 |
+
fovy_range: [0.32, 0.60]
|
56 |
+
max_fovy_range: [0.16, 0.60]
|
57 |
+
|
58 |
+
OptimizationParams:
|
59 |
+
iterations: 5000
|
60 |
+
save_process: True
|
61 |
+
pro_frames_num: 600
|
62 |
+
pro_render_45: False
|
63 |
+
warmup_iter: 1500 # 2500
|
64 |
+
|
65 |
+
as_latent_ratio : 0.2
|
66 |
+
geo_iter : 0
|
67 |
+
densify_from_iter: 100
|
68 |
+
densify_until_iter: 3000
|
69 |
+
percent_dense: 0.003
|
70 |
+
densify_grad_threshold: 0.00075
|
71 |
+
progressive_view_iter: 500 #1500
|
72 |
+
opacity_reset_interval: 300 #500
|
73 |
+
|
74 |
+
scale_up_cameras_iter: 500
|
75 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
76 |
+
phi_scale_up_factor: 1.5
|
configs/white_hair_ironman.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: white_hair_IRONMAN
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'A portrait of IRONMAN, white hair, head, photorealistic, 8K, HDR.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: true
|
18 |
+
C_batch_size: 4
|
19 |
+
max_t_range: 0.98
|
20 |
+
lambda_guidance: 0.1
|
21 |
+
guidance_scale: 7.5
|
22 |
+
denoise_guidance_scale: 1.0
|
23 |
+
noise_seed: 0
|
24 |
+
|
25 |
+
ddim_inv: true
|
26 |
+
annealing_intervals: true
|
27 |
+
|
28 |
+
xs_delta_t: 200
|
29 |
+
xs_inv_steps: 5
|
30 |
+
xs_eta: 0.0
|
31 |
+
|
32 |
+
delta_t: 50
|
33 |
+
delta_t_start: 100
|
34 |
+
|
35 |
+
GenerateCamParams:
|
36 |
+
init_shape: 'pointe'
|
37 |
+
init_prompt: 'a man head.'
|
38 |
+
use_pointe_rgb: false
|
39 |
+
init_num_pts: 100_000
|
40 |
+
phi_range: [-80, 80]
|
41 |
+
max_phi_range: [-180, 180]
|
42 |
+
rand_cam_gamma: 1.5
|
43 |
+
|
44 |
+
theta_range: [45, 90]
|
45 |
+
max_theta_range: [45, 90]
|
46 |
+
|
47 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
48 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
49 |
+
default_radius: 3.5
|
50 |
+
|
51 |
+
default_fovy: 0.55
|
52 |
+
fovy_range: [0.32, 0.60]
|
53 |
+
max_fovy_range: [0.16, 0.60]
|
54 |
+
|
55 |
+
OptimizationParams:
|
56 |
+
iterations: 5000
|
57 |
+
save_process: True
|
58 |
+
pro_frames_num: 600
|
59 |
+
pro_render_45: False
|
60 |
+
warmup_iter: 1500 # 2500
|
61 |
+
|
62 |
+
as_latent_ratio : 0.2
|
63 |
+
geo_iter : 0
|
64 |
+
densify_from_iter: 100
|
65 |
+
densify_until_iter: 3000
|
66 |
+
percent_dense: 0.003
|
67 |
+
densify_grad_threshold: 0.00075
|
68 |
+
progressive_view_iter: 500 #1500
|
69 |
+
opacity_reset_interval: 300 #500
|
70 |
+
|
71 |
+
scale_up_cameras_iter: 500
|
72 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
73 |
+
phi_scale_up_factor: 1.5
|
configs/zombie_joker.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
port: 2355
|
2 |
+
save_video: true
|
3 |
+
seed: 0
|
4 |
+
|
5 |
+
PipelineParams:
|
6 |
+
convert_SHs_python: False #true = using direct rgb
|
7 |
+
ModelParams:
|
8 |
+
workspace: zombie_joker
|
9 |
+
sh_degree: 0
|
10 |
+
bg_aug_ratio: 0.66
|
11 |
+
|
12 |
+
GuidanceParams:
|
13 |
+
model_key: 'stabilityai/stable-diffusion-2-1-base'
|
14 |
+
text: 'Zombie JOKER, head, photorealistic, 8K, HDR.'
|
15 |
+
negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
|
16 |
+
inverse_text: ''
|
17 |
+
perpneg: true
|
18 |
+
C_batch_size: 4
|
19 |
+
|
20 |
+
t_range: [0.02, 0.5]
|
21 |
+
max_t_range: 0.98
|
22 |
+
lambda_guidance: 0.1
|
23 |
+
guidance_scale: 7.5
|
24 |
+
denoise_guidance_scale: 1.0
|
25 |
+
noise_seed: 0
|
26 |
+
|
27 |
+
ddim_inv: true
|
28 |
+
annealing_intervals: true
|
29 |
+
|
30 |
+
xs_delta_t: 200
|
31 |
+
xs_inv_steps: 5
|
32 |
+
xs_eta: 0.0
|
33 |
+
|
34 |
+
delta_t: 50
|
35 |
+
delta_t_start: 100
|
36 |
+
|
37 |
+
GenerateCamParams:
|
38 |
+
init_shape: 'pointe'
|
39 |
+
init_prompt: 'a man head.'
|
40 |
+
use_pointe_rgb: false
|
41 |
+
init_num_pts: 100_000
|
42 |
+
phi_range: [-80, 80]
|
43 |
+
max_phi_range: [-180, 180]
|
44 |
+
rand_cam_gamma: 1.5
|
45 |
+
|
46 |
+
theta_range: [45, 90]
|
47 |
+
max_theta_range: [45, 90]
|
48 |
+
|
49 |
+
radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
|
50 |
+
max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
|
51 |
+
default_radius: 3.5
|
52 |
+
|
53 |
+
default_fovy: 0.55
|
54 |
+
fovy_range: [0.32, 0.60]
|
55 |
+
max_fovy_range: [0.16, 0.60]
|
56 |
+
|
57 |
+
OptimizationParams:
|
58 |
+
iterations: 5000
|
59 |
+
save_process: True
|
60 |
+
pro_frames_num: 600
|
61 |
+
pro_render_45: False
|
62 |
+
warmup_iter: 1500 # 2500
|
63 |
+
|
64 |
+
as_latent_ratio : 0.2
|
65 |
+
geo_iter : 0
|
66 |
+
densify_from_iter: 100
|
67 |
+
densify_until_iter: 3000
|
68 |
+
percent_dense: 0.003
|
69 |
+
densify_grad_threshold: 0.00075
|
70 |
+
progressive_view_iter: 500 #1500
|
71 |
+
opacity_reset_interval: 300 #500
|
72 |
+
|
73 |
+
scale_up_cameras_iter: 500
|
74 |
+
fovy_scale_up_factor: [0.75, 1.1]
|
75 |
+
phi_scale_up_factor: 1.5
|
environment.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: LucidDreamer
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- conda-forge
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- cudatoolkit=11.6
|
8 |
+
- plyfile=0.8.1
|
9 |
+
- python=3.9
|
10 |
+
- pip=22.3.1
|
11 |
+
- pytorch=1.12.1
|
12 |
+
- torchaudio=0.12.1
|
13 |
+
- torchvision=0.15.2
|
14 |
+
- tqdm
|
15 |
+
- pip:
|
16 |
+
- mediapipe
|
17 |
+
- Pillow
|
18 |
+
- diffusers==0.18.2
|
19 |
+
- xformers==0.0.20
|
20 |
+
- transformers==4.30.2
|
21 |
+
- fire==0.5.0
|
22 |
+
- huggingface_hub==0.16.4
|
23 |
+
- imageio==2.31.1
|
24 |
+
- imageio-ffmpeg
|
25 |
+
- PyYAML
|
26 |
+
- safetensors
|
27 |
+
- wandb
|
28 |
+
- accelerate
|
29 |
+
- triton
|
example/Donut.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4633e31ff1ff161e0bd015c166c507cad140e14aef616eecef95c32da5dd1902
|
3 |
+
size 2264633
|
example/boots.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f117d721a095ae913d17072ee5ed4373c95f1a8851ca6e9e254bf5efeaf56cb
|
3 |
+
size 5358683
|
example/durian.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da35c90e1212627da08180fcb513a0d402dc189c613c00d18a3b3937c992b47d
|
3 |
+
size 9316285
|
example/pillow_huskies.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc53845fdca59e413765833aed51a9a93e2962719a63f4471e9fa7943e217cf6
|
3 |
+
size 3586741
|
example/wooden_car.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4e3b6b31a1d2c9e3791c4c2c7b278d1eb3ae209c94b5d5f3834b4ea5d6d3c16
|
3 |
+
size 1660564
|
gaussian_renderer/__init__.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
15 |
+
from scene.gaussian_model import GaussianModel
|
16 |
+
from utils.sh_utils import eval_sh, SH2RGB
|
17 |
+
from utils.graphics_utils import fov2focal
|
18 |
+
import random
|
19 |
+
|
20 |
+
|
21 |
+
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, black_video = False,
|
22 |
+
override_color = None, sh_deg_aug_ratio = 0.1, bg_aug_ratio = 0.3, shs_aug_ratio=1.0, scale_aug_ratio=1.0, test = False):
|
23 |
+
"""
|
24 |
+
Render the scene.
|
25 |
+
|
26 |
+
Background tensor (bg_color) must be on GPU!
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
30 |
+
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
|
31 |
+
try:
|
32 |
+
screenspace_points.retain_grad()
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
if black_video:
|
37 |
+
bg_color = torch.zeros_like(bg_color)
|
38 |
+
#Aug
|
39 |
+
if random.random() < sh_deg_aug_ratio and not test:
|
40 |
+
act_SH = 0
|
41 |
+
else:
|
42 |
+
act_SH = pc.active_sh_degree
|
43 |
+
|
44 |
+
if random.random() < bg_aug_ratio and not test:
|
45 |
+
if random.random() < 0.5:
|
46 |
+
bg_color = torch.rand_like(bg_color)
|
47 |
+
else:
|
48 |
+
bg_color = torch.zeros_like(bg_color)
|
49 |
+
# bg_color = torch.zeros_like(bg_color)
|
50 |
+
|
51 |
+
#bg_color = torch.zeros_like(bg_color)
|
52 |
+
# Set up rasterization configuration
|
53 |
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
54 |
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
55 |
+
try:
|
56 |
+
raster_settings = GaussianRasterizationSettings(
|
57 |
+
image_height=int(viewpoint_camera.image_height),
|
58 |
+
image_width=int(viewpoint_camera.image_width),
|
59 |
+
tanfovx=tanfovx,
|
60 |
+
tanfovy=tanfovy,
|
61 |
+
bg=bg_color,
|
62 |
+
scale_modifier=scaling_modifier,
|
63 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
64 |
+
projmatrix=viewpoint_camera.full_proj_transform,
|
65 |
+
sh_degree=act_SH,
|
66 |
+
campos=viewpoint_camera.camera_center,
|
67 |
+
prefiltered=False
|
68 |
+
)
|
69 |
+
except TypeError as e:
|
70 |
+
raster_settings = GaussianRasterizationSettings(
|
71 |
+
image_height=int(viewpoint_camera.image_height),
|
72 |
+
image_width=int(viewpoint_camera.image_width),
|
73 |
+
tanfovx=tanfovx,
|
74 |
+
tanfovy=tanfovy,
|
75 |
+
bg=bg_color,
|
76 |
+
scale_modifier=scaling_modifier,
|
77 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
78 |
+
projmatrix=viewpoint_camera.full_proj_transform,
|
79 |
+
sh_degree=act_SH,
|
80 |
+
campos=viewpoint_camera.camera_center,
|
81 |
+
prefiltered=False,
|
82 |
+
debug=False
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
87 |
+
|
88 |
+
means3D = pc.get_xyz
|
89 |
+
means2D = screenspace_points
|
90 |
+
opacity = pc.get_opacity
|
91 |
+
|
92 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
93 |
+
# scaling / rotation by the rasterizer.
|
94 |
+
scales = None
|
95 |
+
rotations = None
|
96 |
+
cov3D_precomp = None
|
97 |
+
if pipe.compute_cov3D_python:
|
98 |
+
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
99 |
+
else:
|
100 |
+
scales = pc.get_scaling
|
101 |
+
rotations = pc.get_rotation
|
102 |
+
|
103 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
104 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
105 |
+
shs = None
|
106 |
+
colors_precomp = None
|
107 |
+
if colors_precomp is None:
|
108 |
+
if pipe.convert_SHs_python:
|
109 |
+
raw_rgb = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2).squeeze()[:,:3]
|
110 |
+
rgb = torch.sigmoid(raw_rgb)
|
111 |
+
colors_precomp = rgb
|
112 |
+
else:
|
113 |
+
shs = pc.get_features
|
114 |
+
else:
|
115 |
+
colors_precomp = override_color
|
116 |
+
|
117 |
+
if random.random() < shs_aug_ratio and not test:
|
118 |
+
variance = (0.2 ** 0.5) * shs
|
119 |
+
shs = shs + (torch.randn_like(shs) * variance)
|
120 |
+
|
121 |
+
# add noise to scales
|
122 |
+
if random.random() < scale_aug_ratio and not test:
|
123 |
+
variance = (0.2 ** 0.5) * scales / 4
|
124 |
+
scales = torch.clamp(scales + (torch.randn_like(scales) * variance), 0.0)
|
125 |
+
|
126 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
127 |
+
|
128 |
+
rendered_image, radii, depth_alpha = rasterizer(
|
129 |
+
means3D = means3D,
|
130 |
+
means2D = means2D,
|
131 |
+
shs = shs,
|
132 |
+
colors_precomp = colors_precomp,
|
133 |
+
opacities = opacity,
|
134 |
+
scales = scales,
|
135 |
+
rotations = rotations,
|
136 |
+
cov3D_precomp = cov3D_precomp)
|
137 |
+
depth, alpha = torch.chunk(depth_alpha, 2)
|
138 |
+
# bg_train = pc.get_background
|
139 |
+
# rendered_image = bg_train*alpha.repeat(3,1,1) + rendered_image
|
140 |
+
# focal = 1 / (2 * math.tan(viewpoint_camera.FoVx / 2)) #torch.tan(torch.tensor(viewpoint_camera.FoVx) / 2) * (2. / 2
|
141 |
+
# disparity = focal / (depth + 1e-9)
|
142 |
+
# max_disp = torch.max(disparity)
|
143 |
+
# min_disp = torch.min(disparity[disparity > 0])
|
144 |
+
# norm_disparity = (disparity - min_disp) / (max_disp - min_disp)
|
145 |
+
# # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
146 |
+
# # They will be excluded from value updates used in the splitting criteria.
|
147 |
+
# return {"render": rendered_image,
|
148 |
+
# "depth": norm_disparity,
|
149 |
+
|
150 |
+
focal = 1 / (2 * math.tan(viewpoint_camera.FoVx / 2))
|
151 |
+
disp = focal / (depth + (alpha * 10) + 1e-5)
|
152 |
+
|
153 |
+
try:
|
154 |
+
min_d = disp[alpha <= 0.1].min()
|
155 |
+
except Exception:
|
156 |
+
min_d = disp.min()
|
157 |
+
|
158 |
+
disp = torch.clamp((disp - min_d) / (disp.max() - min_d), 0.0, 1.0)
|
159 |
+
|
160 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
161 |
+
# They will be excluded from value updates used in the splitting criteria.
|
162 |
+
return {"render": rendered_image,
|
163 |
+
"depth": disp,
|
164 |
+
"alpha": alpha,
|
165 |
+
"viewspace_points": screenspace_points,
|
166 |
+
"visibility_filter" : radii > 0,
|
167 |
+
"radii": radii,
|
168 |
+
"scales": scales}
|
gaussian_renderer/network_gui.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import traceback
|
14 |
+
import socket
|
15 |
+
import json
|
16 |
+
from scene.cameras import MiniCam
|
17 |
+
|
18 |
+
host = "127.0.0.1"
|
19 |
+
port = 6009
|
20 |
+
|
21 |
+
conn = None
|
22 |
+
addr = None
|
23 |
+
|
24 |
+
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
25 |
+
|
26 |
+
def init(wish_host, wish_port):
|
27 |
+
global host, port, listener
|
28 |
+
host = wish_host
|
29 |
+
port = wish_port
|
30 |
+
cnt = 0
|
31 |
+
while True:
|
32 |
+
try:
|
33 |
+
listener.bind((host, port))
|
34 |
+
break
|
35 |
+
except:
|
36 |
+
if cnt == 10:
|
37 |
+
break
|
38 |
+
cnt += 1
|
39 |
+
port += 1
|
40 |
+
listener.listen()
|
41 |
+
listener.settimeout(0)
|
42 |
+
|
43 |
+
def try_connect():
|
44 |
+
global conn, addr, listener
|
45 |
+
try:
|
46 |
+
conn, addr = listener.accept()
|
47 |
+
print(f"\nConnected by {addr}")
|
48 |
+
conn.settimeout(None)
|
49 |
+
except Exception as inst:
|
50 |
+
pass
|
51 |
+
|
52 |
+
def read():
|
53 |
+
global conn
|
54 |
+
messageLength = conn.recv(4)
|
55 |
+
messageLength = int.from_bytes(messageLength, 'little')
|
56 |
+
message = conn.recv(messageLength)
|
57 |
+
return json.loads(message.decode("utf-8"))
|
58 |
+
|
59 |
+
def send(message_bytes, verify):
|
60 |
+
global conn
|
61 |
+
if message_bytes != None:
|
62 |
+
conn.sendall(message_bytes)
|
63 |
+
conn.sendall(len(verify).to_bytes(4, 'little'))
|
64 |
+
conn.sendall(bytes(verify, 'ascii'))
|
65 |
+
|
66 |
+
def receive():
|
67 |
+
message = read()
|
68 |
+
|
69 |
+
width = message["resolution_x"]
|
70 |
+
height = message["resolution_y"]
|
71 |
+
|
72 |
+
if width != 0 and height != 0:
|
73 |
+
try:
|
74 |
+
do_training = bool(message["train"])
|
75 |
+
fovy = message["fov_y"]
|
76 |
+
fovx = message["fov_x"]
|
77 |
+
znear = message["z_near"]
|
78 |
+
zfar = message["z_far"]
|
79 |
+
do_shs_python = bool(message["shs_python"])
|
80 |
+
do_rot_scale_python = bool(message["rot_scale_python"])
|
81 |
+
keep_alive = bool(message["keep_alive"])
|
82 |
+
scaling_modifier = message["scaling_modifier"]
|
83 |
+
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
|
84 |
+
world_view_transform[:,1] = -world_view_transform[:,1]
|
85 |
+
world_view_transform[:,2] = -world_view_transform[:,2]
|
86 |
+
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
|
87 |
+
full_proj_transform[:,1] = -full_proj_transform[:,1]
|
88 |
+
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
|
89 |
+
except Exception as e:
|
90 |
+
print("")
|
91 |
+
traceback.print_exc()
|
92 |
+
raise e
|
93 |
+
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
|
94 |
+
else:
|
95 |
+
return None, None, None, None, None, None
|
gradio_demo.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from train import *
|
4 |
+
|
5 |
+
example_inputs = [[
|
6 |
+
"A DSLR photo of a Rugged, vintage-inspired hiking boots with a weathered leather finish, best quality, 4K, HD.",
|
7 |
+
"Rugged, vintage-inspired hiking boots with a weathered leather finish."
|
8 |
+
], [
|
9 |
+
"a DSLR photo of a Cream Cheese Donut.",
|
10 |
+
"a Donut."
|
11 |
+
], [
|
12 |
+
"A durian, 8k, HDR.",
|
13 |
+
"A durian"
|
14 |
+
], [
|
15 |
+
"A pillow with huskies printed on it",
|
16 |
+
"A pillow"
|
17 |
+
], [
|
18 |
+
"A DSLR photo of a wooden car, super detailed, best quality, 4K, HD.",
|
19 |
+
"a wooden car."
|
20 |
+
]]
|
21 |
+
example_outputs = [
|
22 |
+
gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/boots.mp4'), autoplay=True),
|
23 |
+
gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/Donut.mp4'), autoplay=True),
|
24 |
+
gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/durian.mp4'), autoplay=True),
|
25 |
+
gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/pillow_huskies.mp4'), autoplay=True),
|
26 |
+
gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/wooden_car.mp4'), autoplay=True)
|
27 |
+
]
|
28 |
+
|
29 |
+
def main(prompt, init_prompt, negative_prompt, num_iter, CFG, seed):
|
30 |
+
if [prompt, init_prompt] in example_inputs:
|
31 |
+
return example_outputs[example_inputs.index([prompt, init_prompt])]
|
32 |
+
args, lp, op, pp, gcp, gp = args_parser(default_opt=os.path.join(os.path.dirname(__file__), 'configs/white_hair_ironman.yaml'))
|
33 |
+
gp.text = prompt
|
34 |
+
gp.negative = negative_prompt
|
35 |
+
if len(init_prompt) > 1:
|
36 |
+
gcp.init_shape = 'pointe'
|
37 |
+
gcp.init_prompt = init_prompt
|
38 |
+
else:
|
39 |
+
gcp.init_shape = 'sphere'
|
40 |
+
gcp.init_prompt = '.'
|
41 |
+
op.iterations = num_iter
|
42 |
+
gp.guidance_scale = CFG
|
43 |
+
gp.noise_seed = int(seed)
|
44 |
+
lp.workspace = 'gradio_demo'
|
45 |
+
video_path = start_training(args, lp, op, pp, gcp, gp)
|
46 |
+
return gr.Video(value=video_path, autoplay=True)
|
47 |
+
|
48 |
+
with gr.Blocks() as demo:
|
49 |
+
gr.Markdown("# <center>LucidDreamer: Towards High-Fidelity Text-to-3D Generation via Interval Score Matching</center>")
|
50 |
+
gr.Markdown("<center>Yixun Liang*, Xin Yang*, Jiantao Lin, Haodong Li, Xiaogang Xu, Yingcong Chen**</center>")
|
51 |
+
gr.Markdown("<center>*: Equal contribution. **: Corresponding author.</center>")
|
52 |
+
gr.Markdown("We present a text-to-3D generation framework, named the *LucidDreamer*, to distill high-fidelity textures and shapes from pretrained 2D diffusion models.")
|
53 |
+
gr.Markdown("<details><summary><strong>CLICK for the full abstract</strong></summary>The recent advancements in text-to-3D generation mark a significant milestone in generative models, unlocking new possibilities for creating imaginative 3D assets across various real-world scenarios. While recent advancements in text-to-3D generation have shown promise, they often fall short in rendering detailed and high-quality 3D models. This problem is especially prevalent as many methods base themselves on Score Distillation Sampling (SDS). This paper identifies a notable deficiency in SDS, that it brings inconsistent and low-quality updating direction for the 3D model, causing the over-smoothing effect. To address this, we propose a novel approach called Interval Score Matching (ISM). ISM employs deterministic diffusing trajectories and utilizes interval-based score matching to counteract over-smoothing. Furthermore, we incorporate 3D Gaussian Splatting into our text-to-3D generation pipeline. Extensive experiments show that our model largely outperforms the state-of-the-art in quality and training efficiency.</details>")
|
54 |
+
gr.Interface(fn=main, inputs=[gr.Textbox(lines=2, value="A portrait of IRONMAN, white hair, head, photorealistic, 8K, HDR.", label="Your prompt"),
|
55 |
+
gr.Textbox(lines=1, value="a man head.", label="Point-E init prompt (optional)"),
|
56 |
+
gr.Textbox(lines=2, value="unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution.", label="Negative prompt (optional)"),
|
57 |
+
gr.Slider(1000, 5000, value=5000, label="Number of iterations"),
|
58 |
+
gr.Slider(7.5, 100, value=7.5, label="CFG"),
|
59 |
+
gr.Number(value=0, label="Seed")],
|
60 |
+
outputs="playable_video",
|
61 |
+
examples=example_inputs)
|
62 |
+
demo.launch()
|
guidance/perpneg_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
|
4 |
+
def get_perpendicular_component(x, y):
|
5 |
+
assert x.shape == y.shape
|
6 |
+
return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y
|
7 |
+
|
8 |
+
|
9 |
+
def batch_get_perpendicular_component(x, y):
|
10 |
+
assert x.shape == y.shape
|
11 |
+
result = []
|
12 |
+
for i in range(x.shape[0]):
|
13 |
+
result.append(get_perpendicular_component(x[i], y[i]))
|
14 |
+
return torch.stack(result)
|
15 |
+
|
16 |
+
|
17 |
+
def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
|
18 |
+
"""
|
19 |
+
Notes:
|
20 |
+
- weights: an array with the weights for combining the noise predictions
|
21 |
+
- delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
|
22 |
+
"""
|
23 |
+
delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
|
24 |
+
weights = weights.split(batch_size, dim=0) # K x [B]
|
25 |
+
# print(f"{weights[0].shape = } {weights = }")
|
26 |
+
|
27 |
+
assert torch.all(weights[0] == 1.0)
|
28 |
+
|
29 |
+
main_positive = delta_noise_preds[0] # [B, 4, 64, 64]
|
30 |
+
|
31 |
+
accumulated_output = torch.zeros_like(main_positive)
|
32 |
+
for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
|
33 |
+
# print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")
|
34 |
+
|
35 |
+
idx_non_zero = torch.abs(weights[i]) > 1e-4
|
36 |
+
|
37 |
+
# print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
|
38 |
+
# print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
|
39 |
+
# print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
|
40 |
+
# print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
|
41 |
+
if sum(idx_non_zero) == 0:
|
42 |
+
continue
|
43 |
+
accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])
|
44 |
+
|
45 |
+
#assert accumulated_output.shape == main_positive.shape,# f"{accumulated_output.shape = }, {main_positive.shape = }"
|
46 |
+
|
47 |
+
|
48 |
+
return accumulated_output + main_positive
|
guidance/sd_step.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
2 |
+
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \
|
3 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \
|
4 |
+
DDIMInverseScheduler
|
5 |
+
from diffusers.utils import BaseOutput, deprecate
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchvision.transforms as T
|
12 |
+
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from dataclasses import dataclass
|
15 |
+
|
16 |
+
from diffusers.utils import BaseOutput, randn_tensor
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
21 |
+
class DDIMSchedulerOutput(BaseOutput):
|
22 |
+
"""
|
23 |
+
Output class for the scheduler's `step` function output.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
27 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
28 |
+
denoising loop.
|
29 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
30 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
31 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
32 |
+
"""
|
33 |
+
|
34 |
+
prev_sample: torch.FloatTensor
|
35 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
36 |
+
|
37 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
38 |
+
def ddim_add_noise(
|
39 |
+
self,
|
40 |
+
original_samples: torch.FloatTensor,
|
41 |
+
noise: torch.FloatTensor,
|
42 |
+
timesteps: torch.IntTensor,
|
43 |
+
) -> torch.FloatTensor:
|
44 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
45 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
46 |
+
timesteps = timesteps.to(original_samples.device)
|
47 |
+
|
48 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
49 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
50 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
51 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
52 |
+
|
53 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
54 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
55 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
56 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
57 |
+
|
58 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
59 |
+
return noisy_samples
|
60 |
+
|
61 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.step
|
62 |
+
def ddim_step(
|
63 |
+
self,
|
64 |
+
model_output: torch.FloatTensor,
|
65 |
+
timestep: int,
|
66 |
+
sample: torch.FloatTensor,
|
67 |
+
delta_timestep: int = None,
|
68 |
+
eta: float = 0.0,
|
69 |
+
use_clipped_model_output: bool = False,
|
70 |
+
generator=None,
|
71 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
72 |
+
return_dict: bool = True,
|
73 |
+
**kwargs
|
74 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
75 |
+
"""
|
76 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
77 |
+
process from the learned model outputs (most often the predicted noise).
|
78 |
+
|
79 |
+
Args:
|
80 |
+
model_output (`torch.FloatTensor`):
|
81 |
+
The direct output from learned diffusion model.
|
82 |
+
timestep (`float`):
|
83 |
+
The current discrete timestep in the diffusion chain.
|
84 |
+
sample (`torch.FloatTensor`):
|
85 |
+
A current instance of a sample created by the diffusion process.
|
86 |
+
eta (`float`):
|
87 |
+
The weight of noise for added noise in diffusion step.
|
88 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
89 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
90 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
91 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
92 |
+
`use_clipped_model_output` has no effect.
|
93 |
+
generator (`torch.Generator`, *optional*):
|
94 |
+
A random number generator.
|
95 |
+
variance_noise (`torch.FloatTensor`):
|
96 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
97 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
98 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
99 |
+
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
|
103 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
104 |
+
tuple is returned where the first element is the sample tensor.
|
105 |
+
|
106 |
+
"""
|
107 |
+
if self.num_inference_steps is None:
|
108 |
+
raise ValueError(
|
109 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
110 |
+
)
|
111 |
+
|
112 |
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
113 |
+
# Ideally, read DDIM paper in-detail understanding
|
114 |
+
|
115 |
+
# Notation (<variable name> -> <name in paper>
|
116 |
+
# - pred_noise_t -> e_theta(x_t, t)
|
117 |
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
118 |
+
# - std_dev_t -> sigma_t
|
119 |
+
# - eta -> η
|
120 |
+
# - pred_sample_direction -> "direction pointing to x_t"
|
121 |
+
# - pred_prev_sample -> "x_t-1"
|
122 |
+
|
123 |
+
|
124 |
+
if delta_timestep is None:
|
125 |
+
# 1. get previous step value (=t+1)
|
126 |
+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
127 |
+
else:
|
128 |
+
prev_timestep = timestep - delta_timestep
|
129 |
+
|
130 |
+
# 2. compute alphas, betas
|
131 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
132 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
133 |
+
|
134 |
+
beta_prod_t = 1 - alpha_prod_t
|
135 |
+
|
136 |
+
# 3. compute predicted original sample from predicted noise also called
|
137 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
138 |
+
if self.config.prediction_type == "epsilon":
|
139 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
140 |
+
pred_epsilon = model_output
|
141 |
+
elif self.config.prediction_type == "sample":
|
142 |
+
pred_original_sample = model_output
|
143 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
144 |
+
elif self.config.prediction_type == "v_prediction":
|
145 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
146 |
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
147 |
+
else:
|
148 |
+
raise ValueError(
|
149 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
150 |
+
" `v_prediction`"
|
151 |
+
)
|
152 |
+
|
153 |
+
# 4. Clip or threshold "predicted x_0"
|
154 |
+
if self.config.thresholding:
|
155 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
156 |
+
elif self.config.clip_sample:
|
157 |
+
pred_original_sample = pred_original_sample.clamp(
|
158 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
159 |
+
)
|
160 |
+
|
161 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
162 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
163 |
+
# if prev_timestep < timestep:
|
164 |
+
# else:
|
165 |
+
# variance = abs(self._get_variance(prev_timestep, timestep))
|
166 |
+
|
167 |
+
variance = abs(self._get_variance(timestep, prev_timestep))
|
168 |
+
|
169 |
+
std_dev_t = eta * variance
|
170 |
+
std_dev_t = min((1 - alpha_prod_t_prev) / 2, std_dev_t) ** 0.5
|
171 |
+
|
172 |
+
if use_clipped_model_output:
|
173 |
+
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
174 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
175 |
+
|
176 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
177 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
178 |
+
|
179 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
180 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
181 |
+
|
182 |
+
if eta > 0:
|
183 |
+
if variance_noise is not None and generator is not None:
|
184 |
+
raise ValueError(
|
185 |
+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
186 |
+
" `variance_noise` stays `None`."
|
187 |
+
)
|
188 |
+
|
189 |
+
if variance_noise is None:
|
190 |
+
variance_noise = randn_tensor(
|
191 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
192 |
+
)
|
193 |
+
variance = std_dev_t * variance_noise
|
194 |
+
|
195 |
+
prev_sample = prev_sample + variance
|
196 |
+
|
197 |
+
prev_sample = torch.nan_to_num(prev_sample)
|
198 |
+
|
199 |
+
if not return_dict:
|
200 |
+
return (prev_sample,)
|
201 |
+
|
202 |
+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
203 |
+
|
204 |
+
def pred_original(
|
205 |
+
self,
|
206 |
+
model_output: torch.FloatTensor,
|
207 |
+
timesteps: int,
|
208 |
+
sample: torch.FloatTensor,
|
209 |
+
):
|
210 |
+
if isinstance(self, DDPMScheduler) or isinstance(self, DDIMScheduler):
|
211 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
212 |
+
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
213 |
+
timesteps = timesteps.to(sample.device)
|
214 |
+
|
215 |
+
# 1. compute alphas, betas
|
216 |
+
alpha_prod_t = alphas_cumprod[timesteps]
|
217 |
+
while len(alpha_prod_t.shape) < len(sample.shape):
|
218 |
+
alpha_prod_t = alpha_prod_t.unsqueeze(-1)
|
219 |
+
|
220 |
+
beta_prod_t = 1 - alpha_prod_t
|
221 |
+
|
222 |
+
# 2. compute predicted original sample from predicted noise also called
|
223 |
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
224 |
+
if self.config.prediction_type == "epsilon":
|
225 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
226 |
+
elif self.config.prediction_type == "sample":
|
227 |
+
pred_original_sample = model_output
|
228 |
+
elif self.config.prediction_type == "v_prediction":
|
229 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
230 |
+
else:
|
231 |
+
raise ValueError(
|
232 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
233 |
+
" `v_prediction` for the DDPMScheduler."
|
234 |
+
)
|
235 |
+
|
236 |
+
# 3. Clip or threshold "predicted x_0"
|
237 |
+
if self.config.thresholding:
|
238 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
239 |
+
elif self.config.clip_sample:
|
240 |
+
pred_original_sample = pred_original_sample.clamp(
|
241 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
242 |
+
)
|
243 |
+
elif isinstance(self, EulerAncestralDiscreteScheduler) or isinstance(self, EulerDiscreteScheduler):
|
244 |
+
timestep = timesteps.to(self.timesteps.device)
|
245 |
+
|
246 |
+
step_index = (self.timesteps == timestep).nonzero().item()
|
247 |
+
sigma = self.sigmas[step_index].to(device=sample.device, dtype=sample.dtype)
|
248 |
+
|
249 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
250 |
+
if self.config.prediction_type == "epsilon":
|
251 |
+
pred_original_sample = sample - sigma * model_output
|
252 |
+
elif self.config.prediction_type == "v_prediction":
|
253 |
+
# * c_out + input * c_skip
|
254 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
255 |
+
elif self.config.prediction_type == "sample":
|
256 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
257 |
+
else:
|
258 |
+
raise ValueError(
|
259 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
raise NotImplementedError
|
263 |
+
|
264 |
+
return pred_original_sample
|
guidance/sd_utils.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audioop import mul
|
2 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
3 |
+
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \
|
4 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \
|
5 |
+
DDIMInverseScheduler, UNet2DConditionModel
|
6 |
+
from diffusers.utils.import_utils import is_xformers_available
|
7 |
+
from os.path import isfile
|
8 |
+
from pathlib import Path
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
|
12 |
+
import torchvision.transforms as T
|
13 |
+
# suppress partial model loading warning
|
14 |
+
logging.set_verbosity_error()
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torchvision.transforms as T
|
21 |
+
from torchvision.utils import save_image
|
22 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
23 |
+
from .perpneg_utils import weighted_perpendicular_aggregator
|
24 |
+
|
25 |
+
from .sd_step import *
|
26 |
+
|
27 |
+
def rgb2sat(img, T=None):
|
28 |
+
max_ = torch.max(img, dim=1, keepdim=True).values + 1e-5
|
29 |
+
min_ = torch.min(img, dim=1, keepdim=True).values
|
30 |
+
sat = (max_ - min_) / max_
|
31 |
+
if T is not None:
|
32 |
+
sat = (1 - T) * sat
|
33 |
+
return sat
|
34 |
+
|
35 |
+
class SpecifyGradient(torch.autograd.Function):
|
36 |
+
@staticmethod
|
37 |
+
@custom_fwd
|
38 |
+
def forward(ctx, input_tensor, gt_grad):
|
39 |
+
ctx.save_for_backward(gt_grad)
|
40 |
+
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
41 |
+
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
@custom_bwd
|
45 |
+
def backward(ctx, grad_scale):
|
46 |
+
gt_grad, = ctx.saved_tensors
|
47 |
+
gt_grad = gt_grad * grad_scale
|
48 |
+
return gt_grad, None
|
49 |
+
|
50 |
+
def seed_everything(seed):
|
51 |
+
torch.manual_seed(seed)
|
52 |
+
torch.cuda.manual_seed(seed)
|
53 |
+
#torch.backends.cudnn.deterministic = True
|
54 |
+
#torch.backends.cudnn.benchmark = True
|
55 |
+
|
56 |
+
class StableDiffusion(nn.Module):
|
57 |
+
def __init__(self, device, fp16, vram_O, t_range=[0.02, 0.98], max_t_range=0.98, num_train_timesteps=None,
|
58 |
+
ddim_inv=False, use_control_net=False, textual_inversion_path = None,
|
59 |
+
LoRA_path = None, guidance_opt=None):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.device = device
|
63 |
+
self.precision_t = torch.float16 if fp16 else torch.float32
|
64 |
+
|
65 |
+
print(f'[INFO] loading stable diffusion...')
|
66 |
+
|
67 |
+
model_key = guidance_opt.model_key
|
68 |
+
assert model_key is not None
|
69 |
+
|
70 |
+
is_safe_tensor = guidance_opt.is_safe_tensor
|
71 |
+
base_model_key = "stabilityai/stable-diffusion-v1-5" if guidance_opt.base_model_key is None else guidance_opt.base_model_key # for finetuned model only
|
72 |
+
|
73 |
+
if is_safe_tensor:
|
74 |
+
pipe = StableDiffusionPipeline.from_single_file(model_key, use_safetensors=True, torch_dtype=self.precision_t, load_safety_checker=False)
|
75 |
+
else:
|
76 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t)
|
77 |
+
|
78 |
+
self.ism = not guidance_opt.sds
|
79 |
+
self.scheduler = DDIMScheduler.from_pretrained(model_key if not is_safe_tensor else base_model_key, subfolder="scheduler", torch_dtype=self.precision_t)
|
80 |
+
self.sche_func = ddim_step
|
81 |
+
|
82 |
+
if use_control_net:
|
83 |
+
controlnet_model_key = guidance_opt.controlnet_model_key
|
84 |
+
self.controlnet_depth = ControlNetModel.from_pretrained(controlnet_model_key,torch_dtype=self.precision_t).to(device)
|
85 |
+
|
86 |
+
if vram_O:
|
87 |
+
pipe.enable_sequential_cpu_offload()
|
88 |
+
pipe.enable_vae_slicing()
|
89 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
90 |
+
pipe.enable_attention_slicing(1)
|
91 |
+
pipe.enable_model_cpu_offload()
|
92 |
+
|
93 |
+
pipe.enable_xformers_memory_efficient_attention()
|
94 |
+
|
95 |
+
pipe = pipe.to(self.device)
|
96 |
+
if textual_inversion_path is not None:
|
97 |
+
pipe.load_textual_inversion(textual_inversion_path)
|
98 |
+
print("load textual inversion in:.{}".format(textual_inversion_path))
|
99 |
+
|
100 |
+
if LoRA_path is not None:
|
101 |
+
from lora_diffusion import tune_lora_scale, patch_pipe
|
102 |
+
print("load lora in:.{}".format(LoRA_path))
|
103 |
+
patch_pipe(
|
104 |
+
pipe,
|
105 |
+
LoRA_path,
|
106 |
+
patch_text=True,
|
107 |
+
patch_ti=True,
|
108 |
+
patch_unet=True,
|
109 |
+
)
|
110 |
+
tune_lora_scale(pipe.unet, 1.00)
|
111 |
+
tune_lora_scale(pipe.text_encoder, 1.00)
|
112 |
+
|
113 |
+
self.pipe = pipe
|
114 |
+
self.vae = pipe.vae
|
115 |
+
self.tokenizer = pipe.tokenizer
|
116 |
+
self.text_encoder = pipe.text_encoder
|
117 |
+
self.unet = pipe.unet
|
118 |
+
|
119 |
+
self.num_train_timesteps = num_train_timesteps if num_train_timesteps is not None else self.scheduler.config.num_train_timesteps
|
120 |
+
self.scheduler.set_timesteps(self.num_train_timesteps, device=device)
|
121 |
+
|
122 |
+
self.timesteps = torch.flip(self.scheduler.timesteps, dims=(0, ))
|
123 |
+
self.min_step = int(self.num_train_timesteps * t_range[0])
|
124 |
+
self.max_step = int(self.num_train_timesteps * t_range[1])
|
125 |
+
self.warmup_step = int(self.num_train_timesteps*(max_t_range-t_range[1]))
|
126 |
+
|
127 |
+
self.noise_temp = None
|
128 |
+
self.noise_gen = torch.Generator(self.device)
|
129 |
+
self.noise_gen.manual_seed(guidance_opt.noise_seed)
|
130 |
+
|
131 |
+
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
132 |
+
self.rgb_latent_factors = torch.tensor([
|
133 |
+
# R G B
|
134 |
+
[ 0.298, 0.207, 0.208],
|
135 |
+
[ 0.187, 0.286, 0.173],
|
136 |
+
[-0.158, 0.189, 0.264],
|
137 |
+
[-0.184, -0.271, -0.473]
|
138 |
+
], device=self.device)
|
139 |
+
|
140 |
+
|
141 |
+
print(f'[INFO] loaded stable diffusion!')
|
142 |
+
|
143 |
+
def augmentation(self, *tensors):
|
144 |
+
augs = T.Compose([
|
145 |
+
T.RandomHorizontalFlip(p=0.5),
|
146 |
+
])
|
147 |
+
|
148 |
+
channels = [ten.shape[1] for ten in tensors]
|
149 |
+
tensors_concat = torch.concat(tensors, dim=1)
|
150 |
+
tensors_concat = augs(tensors_concat)
|
151 |
+
|
152 |
+
results = []
|
153 |
+
cur_c = 0
|
154 |
+
for i in range(len(channels)):
|
155 |
+
results.append(tensors_concat[:, cur_c:cur_c + channels[i], ...])
|
156 |
+
cur_c += channels[i]
|
157 |
+
return (ten for ten in results)
|
158 |
+
|
159 |
+
def add_noise_with_cfg(self, latents, noise,
|
160 |
+
ind_t, ind_prev_t,
|
161 |
+
text_embeddings=None, cfg=1.0,
|
162 |
+
delta_t=1, inv_steps=1,
|
163 |
+
is_noisy_latent=False,
|
164 |
+
eta=0.0):
|
165 |
+
|
166 |
+
text_embeddings = text_embeddings.to(self.precision_t)
|
167 |
+
if cfg <= 1.0:
|
168 |
+
uncond_text_embedding = text_embeddings.reshape(2, -1, text_embeddings.shape[-2], text_embeddings.shape[-1])[1]
|
169 |
+
|
170 |
+
unet = self.unet
|
171 |
+
|
172 |
+
if is_noisy_latent:
|
173 |
+
prev_noisy_lat = latents
|
174 |
+
else:
|
175 |
+
prev_noisy_lat = self.scheduler.add_noise(latents, noise, self.timesteps[ind_prev_t])
|
176 |
+
|
177 |
+
cur_ind_t = ind_prev_t
|
178 |
+
cur_noisy_lat = prev_noisy_lat
|
179 |
+
|
180 |
+
pred_scores = []
|
181 |
+
|
182 |
+
for i in range(inv_steps):
|
183 |
+
# pred noise
|
184 |
+
cur_noisy_lat_ = self.scheduler.scale_model_input(cur_noisy_lat, self.timesteps[cur_ind_t]).to(self.precision_t)
|
185 |
+
|
186 |
+
if cfg > 1.0:
|
187 |
+
latent_model_input = torch.cat([cur_noisy_lat_, cur_noisy_lat_])
|
188 |
+
timestep_model_input = self.timesteps[cur_ind_t].reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
|
189 |
+
unet_output = unet(latent_model_input, timestep_model_input,
|
190 |
+
encoder_hidden_states=text_embeddings).sample
|
191 |
+
|
192 |
+
uncond, cond = torch.chunk(unet_output, chunks=2)
|
193 |
+
|
194 |
+
unet_output = cond + cfg * (uncond - cond) # reverse cfg to enhance the distillation
|
195 |
+
else:
|
196 |
+
timestep_model_input = self.timesteps[cur_ind_t].reshape(1, 1).repeat(cur_noisy_lat_.shape[0], 1).reshape(-1)
|
197 |
+
unet_output = unet(cur_noisy_lat_, timestep_model_input,
|
198 |
+
encoder_hidden_states=uncond_text_embedding).sample
|
199 |
+
|
200 |
+
pred_scores.append((cur_ind_t, unet_output))
|
201 |
+
|
202 |
+
next_ind_t = min(cur_ind_t + delta_t, ind_t)
|
203 |
+
cur_t, next_t = self.timesteps[cur_ind_t], self.timesteps[next_ind_t]
|
204 |
+
delta_t_ = next_t-cur_t if isinstance(self.scheduler, DDIMScheduler) else next_ind_t-cur_ind_t
|
205 |
+
|
206 |
+
cur_noisy_lat = self.sche_func(self.scheduler, unet_output, cur_t, cur_noisy_lat, -delta_t_, eta).prev_sample
|
207 |
+
cur_ind_t = next_ind_t
|
208 |
+
|
209 |
+
del unet_output
|
210 |
+
torch.cuda.empty_cache()
|
211 |
+
|
212 |
+
if cur_ind_t == ind_t:
|
213 |
+
break
|
214 |
+
|
215 |
+
return prev_noisy_lat, cur_noisy_lat, pred_scores[::-1]
|
216 |
+
|
217 |
+
|
218 |
+
@torch.no_grad()
|
219 |
+
def get_text_embeds(self, prompt, resolution=(512, 512)):
|
220 |
+
inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
221 |
+
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
|
222 |
+
return embeddings
|
223 |
+
|
224 |
+
def train_step_perpneg(self, text_embeddings, pred_rgb, pred_depth=None, pred_alpha=None,
|
225 |
+
grad_scale=1,use_control_net=False,
|
226 |
+
save_folder:Path=None, iteration=0, warm_up_rate = 0, weights = 0,
|
227 |
+
resolution=(512, 512), guidance_opt=None,as_latent=False, embedding_inverse = None):
|
228 |
+
|
229 |
+
|
230 |
+
# flip aug
|
231 |
+
pred_rgb, pred_depth, pred_alpha = self.augmentation(pred_rgb, pred_depth, pred_alpha)
|
232 |
+
|
233 |
+
B = pred_rgb.shape[0]
|
234 |
+
K = text_embeddings.shape[0] - 1
|
235 |
+
|
236 |
+
if as_latent:
|
237 |
+
latents,_ = self.encode_imgs(pred_depth.repeat(1,3,1,1).to(self.precision_t))
|
238 |
+
else:
|
239 |
+
latents,_ = self.encode_imgs(pred_rgb.to(self.precision_t))
|
240 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
241 |
+
|
242 |
+
weights = weights.reshape(-1)
|
243 |
+
noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)
|
244 |
+
|
245 |
+
inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1])
|
246 |
+
|
247 |
+
text_embeddings = text_embeddings.reshape(-1, text_embeddings.shape[-2], text_embeddings.shape[-1]) # make it k+1, c * t, ...
|
248 |
+
|
249 |
+
if guidance_opt.annealing_intervals:
|
250 |
+
current_delta_t = int(guidance_opt.delta_t + np.ceil((warm_up_rate)*(guidance_opt.delta_t_start - guidance_opt.delta_t)))
|
251 |
+
else:
|
252 |
+
current_delta_t = guidance_opt.delta_t
|
253 |
+
|
254 |
+
ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate), (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0]
|
255 |
+
ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0)
|
256 |
+
|
257 |
+
t = self.timesteps[ind_t]
|
258 |
+
prev_t = self.timesteps[ind_prev_t]
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
# step unroll via ddim inversion
|
262 |
+
if not self.ism:
|
263 |
+
prev_latents_noisy = self.scheduler.add_noise(latents, noise, prev_t)
|
264 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
265 |
+
target = noise
|
266 |
+
else:
|
267 |
+
# Step 1: sample x_s with larger steps
|
268 |
+
xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t
|
269 |
+
xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(np.ceil(ind_prev_t / xs_delta_t))
|
270 |
+
starting_ind = max(ind_prev_t - xs_delta_t * xs_inv_steps, torch.ones_like(ind_t) * 0)
|
271 |
+
|
272 |
+
_, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings,
|
273 |
+
guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta)
|
274 |
+
# Step 2: sample x_t
|
275 |
+
_, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings,
|
276 |
+
guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True)
|
277 |
+
|
278 |
+
pred_scores = pred_scores_xt + pred_scores_xs
|
279 |
+
target = pred_scores[0][1]
|
280 |
+
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
latent_model_input = latents_noisy[None, :, ...].repeat(1 + K, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
|
284 |
+
tt = t.reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
|
285 |
+
|
286 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, tt[0])
|
287 |
+
if use_control_net:
|
288 |
+
pred_depth_input = pred_depth_input[None, :, ...].repeat(1 + K, 1, 3, 1, 1).reshape(-1, 3, 512, 512).half()
|
289 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet_depth(
|
290 |
+
latent_model_input,
|
291 |
+
tt,
|
292 |
+
encoder_hidden_states=text_embeddings,
|
293 |
+
controlnet_cond=pred_depth_input,
|
294 |
+
return_dict=False,
|
295 |
+
)
|
296 |
+
unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings,
|
297 |
+
down_block_additional_residuals=down_block_res_samples,
|
298 |
+
mid_block_additional_residual=mid_block_res_sample).sample
|
299 |
+
else:
|
300 |
+
unet_output = self.unet(latent_model_input.to(self.precision_t), tt.to(self.precision_t), encoder_hidden_states=text_embeddings.to(self.precision_t)).sample
|
301 |
+
|
302 |
+
unet_output = unet_output.reshape(1 + K, -1, 4, resolution[0] // 8, resolution[1] // 8, )
|
303 |
+
noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
|
304 |
+
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
|
305 |
+
delta_DSD = weighted_perpendicular_aggregator(delta_noise_preds,\
|
306 |
+
weights,\
|
307 |
+
B)
|
308 |
+
|
309 |
+
pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
|
310 |
+
w = lambda alphas: (((1 - alphas) / alphas) ** 0.5)
|
311 |
+
|
312 |
+
grad = w(self.alphas[t]) * (pred_noise - target)
|
313 |
+
|
314 |
+
grad = torch.nan_to_num(grad_scale * grad)
|
315 |
+
loss = SpecifyGradient.apply(latents, grad)
|
316 |
+
|
317 |
+
if iteration % guidance_opt.vis_interval == 0:
|
318 |
+
noise_pred_post = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
|
319 |
+
lat2rgb = lambda x: torch.clip((x.permute(0,2,3,1) @ self.rgb_latent_factors.to(x.dtype)).permute(0,3,1,2), 0., 1.)
|
320 |
+
save_path_iter = os.path.join(save_folder,"iter_{}_step_{}.jpg".format(iteration,prev_t.item()))
|
321 |
+
with torch.no_grad():
|
322 |
+
pred_x0_latent_sp = pred_original(self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy)
|
323 |
+
pred_x0_latent_pos = pred_original(self.scheduler, noise_pred_post, prev_t, prev_latents_noisy)
|
324 |
+
pred_x0_pos = self.decode_latents(pred_x0_latent_pos.type(self.precision_t))
|
325 |
+
pred_x0_sp = self.decode_latents(pred_x0_latent_sp.type(self.precision_t))
|
326 |
+
|
327 |
+
grad_abs = torch.abs(grad.detach())
|
328 |
+
norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1,keepdim=True), (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1,3,1,1)
|
329 |
+
|
330 |
+
latents_rgb = F.interpolate(lat2rgb(latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
|
331 |
+
latents_sp_rgb = F.interpolate(lat2rgb(pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
|
332 |
+
|
333 |
+
viz_images = torch.cat([pred_rgb,
|
334 |
+
pred_depth.repeat(1, 3, 1, 1),
|
335 |
+
pred_alpha.repeat(1, 3, 1, 1),
|
336 |
+
rgb2sat(pred_rgb, pred_alpha).repeat(1, 3, 1, 1),
|
337 |
+
latents_rgb, latents_sp_rgb,
|
338 |
+
norm_grad,
|
339 |
+
pred_x0_sp, pred_x0_pos],dim=0)
|
340 |
+
save_image(viz_images, save_path_iter)
|
341 |
+
|
342 |
+
|
343 |
+
return loss
|
344 |
+
|
345 |
+
|
346 |
+
def train_step(self, text_embeddings, pred_rgb, pred_depth=None, pred_alpha=None,
|
347 |
+
grad_scale=1,use_control_net=False,
|
348 |
+
save_folder:Path=None, iteration=0, warm_up_rate = 0,
|
349 |
+
resolution=(512, 512), guidance_opt=None,as_latent=False, embedding_inverse = None):
|
350 |
+
|
351 |
+
pred_rgb, pred_depth, pred_alpha = self.augmentation(pred_rgb, pred_depth, pred_alpha)
|
352 |
+
|
353 |
+
B = pred_rgb.shape[0]
|
354 |
+
K = text_embeddings.shape[0] - 1
|
355 |
+
|
356 |
+
if as_latent:
|
357 |
+
latents,_ = self.encode_imgs(pred_depth.repeat(1,3,1,1).to(self.precision_t))
|
358 |
+
else:
|
359 |
+
latents,_ = self.encode_imgs(pred_rgb.to(self.precision_t))
|
360 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
361 |
+
|
362 |
+
if self.noise_temp is None:
|
363 |
+
self.noise_temp = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)
|
364 |
+
|
365 |
+
if guidance_opt.fix_noise:
|
366 |
+
noise = self.noise_temp
|
367 |
+
else:
|
368 |
+
noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)
|
369 |
+
|
370 |
+
text_embeddings = text_embeddings[:, :, ...]
|
371 |
+
text_embeddings = text_embeddings.reshape(-1, text_embeddings.shape[-2], text_embeddings.shape[-1]) # make it k+1, c * t, ...
|
372 |
+
|
373 |
+
inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1])
|
374 |
+
|
375 |
+
if guidance_opt.annealing_intervals:
|
376 |
+
current_delta_t = int(guidance_opt.delta_t + (warm_up_rate)*(guidance_opt.delta_t_start - guidance_opt.delta_t))
|
377 |
+
else:
|
378 |
+
current_delta_t = guidance_opt.delta_t
|
379 |
+
|
380 |
+
ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate), (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0]
|
381 |
+
ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0)
|
382 |
+
|
383 |
+
t = self.timesteps[ind_t]
|
384 |
+
prev_t = self.timesteps[ind_prev_t]
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
# step unroll via ddim inversion
|
388 |
+
if not self.ism:
|
389 |
+
prev_latents_noisy = self.scheduler.add_noise(latents, noise, prev_t)
|
390 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
391 |
+
target = noise
|
392 |
+
else:
|
393 |
+
# Step 1: sample x_s with larger steps
|
394 |
+
xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t
|
395 |
+
xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(np.ceil(ind_prev_t / xs_delta_t))
|
396 |
+
starting_ind = max(ind_prev_t - xs_delta_t * xs_inv_steps, torch.ones_like(ind_t) * 0)
|
397 |
+
|
398 |
+
_, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings,
|
399 |
+
guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta)
|
400 |
+
# Step 2: sample x_t
|
401 |
+
_, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings,
|
402 |
+
guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True)
|
403 |
+
|
404 |
+
pred_scores = pred_scores_xt + pred_scores_xs
|
405 |
+
target = pred_scores[0][1]
|
406 |
+
|
407 |
+
|
408 |
+
with torch.no_grad():
|
409 |
+
latent_model_input = latents_noisy[None, :, ...].repeat(2, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
|
410 |
+
tt = t.reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
|
411 |
+
|
412 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, tt[0])
|
413 |
+
if use_control_net:
|
414 |
+
pred_depth_input = pred_depth_input[None, :, ...].repeat(1 + K, 1, 3, 1, 1).reshape(-1, 3, 512, 512).half()
|
415 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet_depth(
|
416 |
+
latent_model_input,
|
417 |
+
tt,
|
418 |
+
encoder_hidden_states=text_embeddings,
|
419 |
+
controlnet_cond=pred_depth_input,
|
420 |
+
return_dict=False,
|
421 |
+
)
|
422 |
+
unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings,
|
423 |
+
down_block_additional_residuals=down_block_res_samples,
|
424 |
+
mid_block_additional_residual=mid_block_res_sample).sample
|
425 |
+
else:
|
426 |
+
unet_output = self.unet(latent_model_input.to(self.precision_t), tt.to(self.precision_t), encoder_hidden_states=text_embeddings.to(self.precision_t)).sample
|
427 |
+
|
428 |
+
unet_output = unet_output.reshape(2, -1, 4, resolution[0] // 8, resolution[1] // 8, )
|
429 |
+
noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
|
430 |
+
delta_DSD = noise_pred_text - noise_pred_uncond
|
431 |
+
|
432 |
+
pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
|
433 |
+
|
434 |
+
w = lambda alphas: (((1 - alphas) / alphas) ** 0.5)
|
435 |
+
|
436 |
+
grad = w(self.alphas[t]) * (pred_noise - target)
|
437 |
+
|
438 |
+
grad = torch.nan_to_num(grad_scale * grad)
|
439 |
+
loss = SpecifyGradient.apply(latents, grad)
|
440 |
+
|
441 |
+
if iteration % guidance_opt.vis_interval == 0:
|
442 |
+
noise_pred_post = noise_pred_uncond + 7.5* delta_DSD
|
443 |
+
lat2rgb = lambda x: torch.clip((x.permute(0,2,3,1) @ self.rgb_latent_factors.to(x.dtype)).permute(0,3,1,2), 0., 1.)
|
444 |
+
save_path_iter = os.path.join(save_folder,"iter_{}_step_{}.jpg".format(iteration,prev_t.item()))
|
445 |
+
with torch.no_grad():
|
446 |
+
pred_x0_latent_sp = pred_original(self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy)
|
447 |
+
pred_x0_latent_pos = pred_original(self.scheduler, noise_pred_post, prev_t, prev_latents_noisy)
|
448 |
+
pred_x0_pos = self.decode_latents(pred_x0_latent_pos.type(self.precision_t))
|
449 |
+
pred_x0_sp = self.decode_latents(pred_x0_latent_sp.type(self.precision_t))
|
450 |
+
# pred_x0_uncond = pred_x0_sp[:1, ...]
|
451 |
+
|
452 |
+
grad_abs = torch.abs(grad.detach())
|
453 |
+
norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1,keepdim=True), (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1,3,1,1)
|
454 |
+
|
455 |
+
latents_rgb = F.interpolate(lat2rgb(latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
|
456 |
+
latents_sp_rgb = F.interpolate(lat2rgb(pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
|
457 |
+
|
458 |
+
viz_images = torch.cat([pred_rgb,
|
459 |
+
pred_depth.repeat(1, 3, 1, 1),
|
460 |
+
pred_alpha.repeat(1, 3, 1, 1),
|
461 |
+
rgb2sat(pred_rgb, pred_alpha).repeat(1, 3, 1, 1),
|
462 |
+
latents_rgb, latents_sp_rgb, norm_grad,
|
463 |
+
pred_x0_sp, pred_x0_pos],dim=0)
|
464 |
+
save_image(viz_images, save_path_iter)
|
465 |
+
|
466 |
+
return loss
|
467 |
+
|
468 |
+
def decode_latents(self, latents):
|
469 |
+
target_dtype = latents.dtype
|
470 |
+
latents = latents / self.vae.config.scaling_factor
|
471 |
+
|
472 |
+
imgs = self.vae.decode(latents.to(self.vae.dtype)).sample
|
473 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
474 |
+
|
475 |
+
return imgs.to(target_dtype)
|
476 |
+
|
477 |
+
def encode_imgs(self, imgs):
|
478 |
+
target_dtype = imgs.dtype
|
479 |
+
# imgs: [B, 3, H, W]
|
480 |
+
imgs = 2 * imgs - 1
|
481 |
+
|
482 |
+
posterior = self.vae.encode(imgs.to(self.vae.dtype)).latent_dist
|
483 |
+
kl_divergence = posterior.kl()
|
484 |
+
|
485 |
+
latents = posterior.sample() * self.vae.config.scaling_factor
|
486 |
+
|
487 |
+
return latents.to(target_dtype), kl_divergence
|
lora_diffusion/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .lora import *
|
2 |
+
from .dataset import *
|
3 |
+
from .utils import *
|
4 |
+
from .preprocess_files import *
|
5 |
+
from .lora_manager import *
|
lora_diffusion/cli_lora_add.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Dict
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import fire
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
from safetensors.torch import safe_open, save_file
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from .lora import (
|
10 |
+
tune_lora_scale,
|
11 |
+
patch_pipe,
|
12 |
+
collapse_lora,
|
13 |
+
monkeypatch_remove_lora,
|
14 |
+
)
|
15 |
+
from .lora_manager import lora_join
|
16 |
+
from .to_ckpt_v2 import convert_to_ckpt
|
17 |
+
|
18 |
+
|
19 |
+
def _text_lora_path(path: str) -> str:
|
20 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
21 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
22 |
+
|
23 |
+
|
24 |
+
def add(
|
25 |
+
path_1: str,
|
26 |
+
path_2: str,
|
27 |
+
output_path: str,
|
28 |
+
alpha_1: float = 0.5,
|
29 |
+
alpha_2: float = 0.5,
|
30 |
+
mode: Literal[
|
31 |
+
"lpl",
|
32 |
+
"upl",
|
33 |
+
"upl-ckpt-v2",
|
34 |
+
] = "lpl",
|
35 |
+
with_text_lora: bool = False,
|
36 |
+
):
|
37 |
+
print("Lora Add, mode " + mode)
|
38 |
+
if mode == "lpl":
|
39 |
+
if path_1.endswith(".pt") and path_2.endswith(".pt"):
|
40 |
+
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
|
41 |
+
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
|
42 |
+
if with_text_lora
|
43 |
+
else []
|
44 |
+
):
|
45 |
+
print("Loading", _path_1, _path_2)
|
46 |
+
out_list = []
|
47 |
+
if opt == "text_encoder":
|
48 |
+
if not os.path.exists(_path_1):
|
49 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
50 |
+
continue
|
51 |
+
if not os.path.exists(_path_2):
|
52 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
53 |
+
continue
|
54 |
+
|
55 |
+
l1 = torch.load(_path_1)
|
56 |
+
l2 = torch.load(_path_2)
|
57 |
+
|
58 |
+
l1pairs = zip(l1[::2], l1[1::2])
|
59 |
+
l2pairs = zip(l2[::2], l2[1::2])
|
60 |
+
|
61 |
+
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
|
62 |
+
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
|
63 |
+
x1.data = alpha_1 * x1.data + alpha_2 * x2.data
|
64 |
+
y1.data = alpha_1 * y1.data + alpha_2 * y2.data
|
65 |
+
|
66 |
+
out_list.append(x1)
|
67 |
+
out_list.append(y1)
|
68 |
+
|
69 |
+
if opt == "unet":
|
70 |
+
|
71 |
+
print("Saving merged UNET to", output_path)
|
72 |
+
torch.save(out_list, output_path)
|
73 |
+
|
74 |
+
elif opt == "text_encoder":
|
75 |
+
print("Saving merged text encoder to", _text_lora_path(output_path))
|
76 |
+
torch.save(
|
77 |
+
out_list,
|
78 |
+
_text_lora_path(output_path),
|
79 |
+
)
|
80 |
+
|
81 |
+
elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
|
82 |
+
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
|
83 |
+
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
|
84 |
+
|
85 |
+
metadata = dict(safeloras_1.metadata())
|
86 |
+
metadata.update(dict(safeloras_2.metadata()))
|
87 |
+
|
88 |
+
ret_tensor = {}
|
89 |
+
|
90 |
+
for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
|
91 |
+
if keys.startswith("text_encoder") or keys.startswith("unet"):
|
92 |
+
|
93 |
+
tens1 = safeloras_1.get_tensor(keys)
|
94 |
+
tens2 = safeloras_2.get_tensor(keys)
|
95 |
+
|
96 |
+
tens = alpha_1 * tens1 + alpha_2 * tens2
|
97 |
+
ret_tensor[keys] = tens
|
98 |
+
else:
|
99 |
+
if keys in safeloras_1.keys():
|
100 |
+
|
101 |
+
tens1 = safeloras_1.get_tensor(keys)
|
102 |
+
else:
|
103 |
+
tens1 = safeloras_2.get_tensor(keys)
|
104 |
+
|
105 |
+
ret_tensor[keys] = tens1
|
106 |
+
|
107 |
+
save_file(ret_tensor, output_path, metadata)
|
108 |
+
|
109 |
+
elif mode == "upl":
|
110 |
+
|
111 |
+
print(
|
112 |
+
f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
|
113 |
+
)
|
114 |
+
|
115 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
116 |
+
path_1,
|
117 |
+
).to("cpu")
|
118 |
+
|
119 |
+
patch_pipe(loaded_pipeline, path_2)
|
120 |
+
|
121 |
+
collapse_lora(loaded_pipeline.unet, alpha_1)
|
122 |
+
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
|
123 |
+
|
124 |
+
monkeypatch_remove_lora(loaded_pipeline.unet)
|
125 |
+
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
|
126 |
+
|
127 |
+
loaded_pipeline.save_pretrained(output_path)
|
128 |
+
|
129 |
+
elif mode == "upl-ckpt-v2":
|
130 |
+
|
131 |
+
assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
|
132 |
+
name = os.path.basename(output_path)[0:-5]
|
133 |
+
|
134 |
+
print(
|
135 |
+
f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
|
136 |
+
)
|
137 |
+
|
138 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
139 |
+
path_1,
|
140 |
+
).to("cpu")
|
141 |
+
|
142 |
+
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
|
143 |
+
|
144 |
+
collapse_lora(loaded_pipeline.unet, alpha_1)
|
145 |
+
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
|
146 |
+
|
147 |
+
monkeypatch_remove_lora(loaded_pipeline.unet)
|
148 |
+
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
|
149 |
+
|
150 |
+
_tmp_output = output_path + ".tmp"
|
151 |
+
|
152 |
+
loaded_pipeline.save_pretrained(_tmp_output)
|
153 |
+
convert_to_ckpt(_tmp_output, output_path, as_half=True)
|
154 |
+
# remove the tmp_output folder
|
155 |
+
shutil.rmtree(_tmp_output)
|
156 |
+
|
157 |
+
keys = sorted(tok_dict.keys())
|
158 |
+
tok_catted = torch.stack([tok_dict[k] for k in keys])
|
159 |
+
ret = {
|
160 |
+
"string_to_token": {"*": torch.tensor(265)},
|
161 |
+
"string_to_param": {"*": tok_catted},
|
162 |
+
"name": name,
|
163 |
+
}
|
164 |
+
|
165 |
+
torch.save(ret, output_path[:-5] + ".pt")
|
166 |
+
print(
|
167 |
+
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
|
168 |
+
)
|
169 |
+
elif mode == "ljl":
|
170 |
+
print("Using Join mode : alpha will not have an effect here.")
|
171 |
+
assert path_1.endswith(".safetensors") and path_2.endswith(
|
172 |
+
".safetensors"
|
173 |
+
), "Only .safetensors files are supported"
|
174 |
+
|
175 |
+
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
|
176 |
+
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
|
177 |
+
|
178 |
+
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
|
179 |
+
save_file(total_tensor, output_path, total_metadata)
|
180 |
+
|
181 |
+
else:
|
182 |
+
print("Unknown mode", mode)
|
183 |
+
raise ValueError(f"Unknown mode {mode}")
|
184 |
+
|
185 |
+
|
186 |
+
def main():
|
187 |
+
fire.Fire(add)
|
lora_diffusion/cli_lora_pti.py
ADDED
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bootstrapped from:
|
2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import hashlib
|
6 |
+
import inspect
|
7 |
+
import itertools
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Optional, List, Literal
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.optim as optim
|
18 |
+
import torch.utils.checkpoint
|
19 |
+
from diffusers import (
|
20 |
+
AutoencoderKL,
|
21 |
+
DDPMScheduler,
|
22 |
+
StableDiffusionPipeline,
|
23 |
+
UNet2DConditionModel,
|
24 |
+
)
|
25 |
+
from diffusers.optimization import get_scheduler
|
26 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
27 |
+
from PIL import Image
|
28 |
+
from torch.utils.data import Dataset
|
29 |
+
from torchvision import transforms
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
import wandb
|
33 |
+
import fire
|
34 |
+
|
35 |
+
from lora_diffusion import (
|
36 |
+
PivotalTuningDatasetCapation,
|
37 |
+
extract_lora_ups_down,
|
38 |
+
inject_trainable_lora,
|
39 |
+
inject_trainable_lora_extended,
|
40 |
+
inspect_lora,
|
41 |
+
save_lora_weight,
|
42 |
+
save_all,
|
43 |
+
prepare_clip_model_sets,
|
44 |
+
evaluate_pipe,
|
45 |
+
UNET_EXTENDED_TARGET_REPLACE,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def get_models(
|
50 |
+
pretrained_model_name_or_path,
|
51 |
+
pretrained_vae_name_or_path,
|
52 |
+
revision,
|
53 |
+
placeholder_tokens: List[str],
|
54 |
+
initializer_tokens: List[str],
|
55 |
+
device="cuda:0",
|
56 |
+
):
|
57 |
+
|
58 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
59 |
+
pretrained_model_name_or_path,
|
60 |
+
subfolder="tokenizer",
|
61 |
+
revision=revision,
|
62 |
+
)
|
63 |
+
|
64 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
65 |
+
pretrained_model_name_or_path,
|
66 |
+
subfolder="text_encoder",
|
67 |
+
revision=revision,
|
68 |
+
)
|
69 |
+
|
70 |
+
placeholder_token_ids = []
|
71 |
+
|
72 |
+
for token, init_tok in zip(placeholder_tokens, initializer_tokens):
|
73 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
74 |
+
if num_added_tokens == 0:
|
75 |
+
raise ValueError(
|
76 |
+
f"The tokenizer already contains the token {token}. Please pass a different"
|
77 |
+
" `placeholder_token` that is not already in the tokenizer."
|
78 |
+
)
|
79 |
+
|
80 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(token)
|
81 |
+
|
82 |
+
placeholder_token_ids.append(placeholder_token_id)
|
83 |
+
|
84 |
+
# Load models and create wrapper for stable diffusion
|
85 |
+
|
86 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
87 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
88 |
+
if init_tok.startswith("<rand"):
|
89 |
+
# <rand-"sigma">, e.g. <rand-0.5>
|
90 |
+
sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0])
|
91 |
+
|
92 |
+
token_embeds[placeholder_token_id] = (
|
93 |
+
torch.randn_like(token_embeds[0]) * sigma_val
|
94 |
+
)
|
95 |
+
print(
|
96 |
+
f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}"
|
97 |
+
)
|
98 |
+
print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}")
|
99 |
+
|
100 |
+
elif init_tok == "<zero>":
|
101 |
+
token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0])
|
102 |
+
else:
|
103 |
+
token_ids = tokenizer.encode(init_tok, add_special_tokens=False)
|
104 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
105 |
+
if len(token_ids) > 1:
|
106 |
+
raise ValueError("The initializer token must be a single token.")
|
107 |
+
|
108 |
+
initializer_token_id = token_ids[0]
|
109 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
110 |
+
|
111 |
+
vae = AutoencoderKL.from_pretrained(
|
112 |
+
pretrained_vae_name_or_path or pretrained_model_name_or_path,
|
113 |
+
subfolder=None if pretrained_vae_name_or_path else "vae",
|
114 |
+
revision=None if pretrained_vae_name_or_path else revision,
|
115 |
+
)
|
116 |
+
unet = UNet2DConditionModel.from_pretrained(
|
117 |
+
pretrained_model_name_or_path,
|
118 |
+
subfolder="unet",
|
119 |
+
revision=revision,
|
120 |
+
)
|
121 |
+
|
122 |
+
return (
|
123 |
+
text_encoder.to(device),
|
124 |
+
vae.to(device),
|
125 |
+
unet.to(device),
|
126 |
+
tokenizer,
|
127 |
+
placeholder_token_ids,
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def text2img_dataloader(
|
133 |
+
train_dataset,
|
134 |
+
train_batch_size,
|
135 |
+
tokenizer,
|
136 |
+
vae,
|
137 |
+
text_encoder,
|
138 |
+
cached_latents: bool = False,
|
139 |
+
):
|
140 |
+
|
141 |
+
if cached_latents:
|
142 |
+
cached_latents_dataset = []
|
143 |
+
for idx in tqdm(range(len(train_dataset))):
|
144 |
+
batch = train_dataset[idx]
|
145 |
+
# rint(batch)
|
146 |
+
latents = vae.encode(
|
147 |
+
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
|
148 |
+
).latent_dist.sample()
|
149 |
+
latents = latents * 0.18215
|
150 |
+
batch["instance_images"] = latents.squeeze(0)
|
151 |
+
cached_latents_dataset.append(batch)
|
152 |
+
|
153 |
+
def collate_fn(examples):
|
154 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
155 |
+
pixel_values = [example["instance_images"] for example in examples]
|
156 |
+
pixel_values = torch.stack(pixel_values)
|
157 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
158 |
+
|
159 |
+
input_ids = tokenizer.pad(
|
160 |
+
{"input_ids": input_ids},
|
161 |
+
padding="max_length",
|
162 |
+
max_length=tokenizer.model_max_length,
|
163 |
+
return_tensors="pt",
|
164 |
+
).input_ids
|
165 |
+
|
166 |
+
batch = {
|
167 |
+
"input_ids": input_ids,
|
168 |
+
"pixel_values": pixel_values,
|
169 |
+
}
|
170 |
+
|
171 |
+
if examples[0].get("mask", None) is not None:
|
172 |
+
batch["mask"] = torch.stack([example["mask"] for example in examples])
|
173 |
+
|
174 |
+
return batch
|
175 |
+
|
176 |
+
if cached_latents:
|
177 |
+
|
178 |
+
train_dataloader = torch.utils.data.DataLoader(
|
179 |
+
cached_latents_dataset,
|
180 |
+
batch_size=train_batch_size,
|
181 |
+
shuffle=True,
|
182 |
+
collate_fn=collate_fn,
|
183 |
+
)
|
184 |
+
|
185 |
+
print("PTI : Using cached latent.")
|
186 |
+
|
187 |
+
else:
|
188 |
+
train_dataloader = torch.utils.data.DataLoader(
|
189 |
+
train_dataset,
|
190 |
+
batch_size=train_batch_size,
|
191 |
+
shuffle=True,
|
192 |
+
collate_fn=collate_fn,
|
193 |
+
)
|
194 |
+
|
195 |
+
return train_dataloader
|
196 |
+
|
197 |
+
|
198 |
+
def inpainting_dataloader(
|
199 |
+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
|
200 |
+
):
|
201 |
+
def collate_fn(examples):
|
202 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
203 |
+
pixel_values = [example["instance_images"] for example in examples]
|
204 |
+
mask_values = [example["instance_masks"] for example in examples]
|
205 |
+
masked_image_values = [
|
206 |
+
example["instance_masked_images"] for example in examples
|
207 |
+
]
|
208 |
+
|
209 |
+
# Concat class and instance examples for prior preservation.
|
210 |
+
# We do this to avoid doing two forward passes.
|
211 |
+
if examples[0].get("class_prompt_ids", None) is not None:
|
212 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
213 |
+
pixel_values += [example["class_images"] for example in examples]
|
214 |
+
mask_values += [example["class_masks"] for example in examples]
|
215 |
+
masked_image_values += [
|
216 |
+
example["class_masked_images"] for example in examples
|
217 |
+
]
|
218 |
+
|
219 |
+
pixel_values = (
|
220 |
+
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
221 |
+
)
|
222 |
+
mask_values = (
|
223 |
+
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
|
224 |
+
)
|
225 |
+
masked_image_values = (
|
226 |
+
torch.stack(masked_image_values)
|
227 |
+
.to(memory_format=torch.contiguous_format)
|
228 |
+
.float()
|
229 |
+
)
|
230 |
+
|
231 |
+
input_ids = tokenizer.pad(
|
232 |
+
{"input_ids": input_ids},
|
233 |
+
padding="max_length",
|
234 |
+
max_length=tokenizer.model_max_length,
|
235 |
+
return_tensors="pt",
|
236 |
+
).input_ids
|
237 |
+
|
238 |
+
batch = {
|
239 |
+
"input_ids": input_ids,
|
240 |
+
"pixel_values": pixel_values,
|
241 |
+
"mask_values": mask_values,
|
242 |
+
"masked_image_values": masked_image_values,
|
243 |
+
}
|
244 |
+
|
245 |
+
if examples[0].get("mask", None) is not None:
|
246 |
+
batch["mask"] = torch.stack([example["mask"] for example in examples])
|
247 |
+
|
248 |
+
return batch
|
249 |
+
|
250 |
+
train_dataloader = torch.utils.data.DataLoader(
|
251 |
+
train_dataset,
|
252 |
+
batch_size=train_batch_size,
|
253 |
+
shuffle=True,
|
254 |
+
collate_fn=collate_fn,
|
255 |
+
)
|
256 |
+
|
257 |
+
return train_dataloader
|
258 |
+
|
259 |
+
|
260 |
+
def loss_step(
|
261 |
+
batch,
|
262 |
+
unet,
|
263 |
+
vae,
|
264 |
+
text_encoder,
|
265 |
+
scheduler,
|
266 |
+
train_inpainting=False,
|
267 |
+
t_mutliplier=1.0,
|
268 |
+
mixed_precision=False,
|
269 |
+
mask_temperature=1.0,
|
270 |
+
cached_latents: bool = False,
|
271 |
+
):
|
272 |
+
weight_dtype = torch.float32
|
273 |
+
if not cached_latents:
|
274 |
+
latents = vae.encode(
|
275 |
+
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
|
276 |
+
).latent_dist.sample()
|
277 |
+
latents = latents * 0.18215
|
278 |
+
|
279 |
+
if train_inpainting:
|
280 |
+
masked_image_latents = vae.encode(
|
281 |
+
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
|
282 |
+
).latent_dist.sample()
|
283 |
+
masked_image_latents = masked_image_latents * 0.18215
|
284 |
+
mask = F.interpolate(
|
285 |
+
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
|
286 |
+
scale_factor=1 / 8,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
latents = batch["pixel_values"]
|
290 |
+
|
291 |
+
if train_inpainting:
|
292 |
+
masked_image_latents = batch["masked_image_latents"]
|
293 |
+
mask = batch["mask_values"]
|
294 |
+
|
295 |
+
noise = torch.randn_like(latents)
|
296 |
+
bsz = latents.shape[0]
|
297 |
+
|
298 |
+
timesteps = torch.randint(
|
299 |
+
0,
|
300 |
+
int(scheduler.config.num_train_timesteps * t_mutliplier),
|
301 |
+
(bsz,),
|
302 |
+
device=latents.device,
|
303 |
+
)
|
304 |
+
timesteps = timesteps.long()
|
305 |
+
|
306 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
307 |
+
|
308 |
+
if train_inpainting:
|
309 |
+
latent_model_input = torch.cat(
|
310 |
+
[noisy_latents, mask, masked_image_latents], dim=1
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
latent_model_input = noisy_latents
|
314 |
+
|
315 |
+
if mixed_precision:
|
316 |
+
with torch.cuda.amp.autocast():
|
317 |
+
|
318 |
+
encoder_hidden_states = text_encoder(
|
319 |
+
batch["input_ids"].to(text_encoder.device)
|
320 |
+
)[0]
|
321 |
+
|
322 |
+
model_pred = unet(
|
323 |
+
latent_model_input, timesteps, encoder_hidden_states
|
324 |
+
).sample
|
325 |
+
else:
|
326 |
+
|
327 |
+
encoder_hidden_states = text_encoder(
|
328 |
+
batch["input_ids"].to(text_encoder.device)
|
329 |
+
)[0]
|
330 |
+
|
331 |
+
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
332 |
+
|
333 |
+
if scheduler.config.prediction_type == "epsilon":
|
334 |
+
target = noise
|
335 |
+
elif scheduler.config.prediction_type == "v_prediction":
|
336 |
+
target = scheduler.get_velocity(latents, noise, timesteps)
|
337 |
+
else:
|
338 |
+
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
339 |
+
|
340 |
+
if batch.get("mask", None) is not None:
|
341 |
+
|
342 |
+
mask = (
|
343 |
+
batch["mask"]
|
344 |
+
.to(model_pred.device)
|
345 |
+
.reshape(
|
346 |
+
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
|
347 |
+
)
|
348 |
+
)
|
349 |
+
# resize to match model_pred
|
350 |
+
mask = F.interpolate(
|
351 |
+
mask.float(),
|
352 |
+
size=model_pred.shape[-2:],
|
353 |
+
mode="nearest",
|
354 |
+
)
|
355 |
+
|
356 |
+
mask = (mask + 0.01).pow(mask_temperature)
|
357 |
+
|
358 |
+
mask = mask / mask.max()
|
359 |
+
|
360 |
+
model_pred = model_pred * mask
|
361 |
+
|
362 |
+
target = target * mask
|
363 |
+
|
364 |
+
loss = (
|
365 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
366 |
+
.mean([1, 2, 3])
|
367 |
+
.mean()
|
368 |
+
)
|
369 |
+
|
370 |
+
return loss
|
371 |
+
|
372 |
+
|
373 |
+
def train_inversion(
|
374 |
+
unet,
|
375 |
+
vae,
|
376 |
+
text_encoder,
|
377 |
+
dataloader,
|
378 |
+
num_steps: int,
|
379 |
+
scheduler,
|
380 |
+
index_no_updates,
|
381 |
+
optimizer,
|
382 |
+
save_steps: int,
|
383 |
+
placeholder_token_ids,
|
384 |
+
placeholder_tokens,
|
385 |
+
save_path: str,
|
386 |
+
tokenizer,
|
387 |
+
lr_scheduler,
|
388 |
+
test_image_path: str,
|
389 |
+
cached_latents: bool,
|
390 |
+
accum_iter: int = 1,
|
391 |
+
log_wandb: bool = False,
|
392 |
+
wandb_log_prompt_cnt: int = 10,
|
393 |
+
class_token: str = "person",
|
394 |
+
train_inpainting: bool = False,
|
395 |
+
mixed_precision: bool = False,
|
396 |
+
clip_ti_decay: bool = True,
|
397 |
+
):
|
398 |
+
|
399 |
+
progress_bar = tqdm(range(num_steps))
|
400 |
+
progress_bar.set_description("Steps")
|
401 |
+
global_step = 0
|
402 |
+
|
403 |
+
# Original Emb for TI
|
404 |
+
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
|
405 |
+
|
406 |
+
if log_wandb:
|
407 |
+
preped_clip = prepare_clip_model_sets()
|
408 |
+
|
409 |
+
index_updates = ~index_no_updates
|
410 |
+
loss_sum = 0.0
|
411 |
+
|
412 |
+
for epoch in range(math.ceil(num_steps / len(dataloader))):
|
413 |
+
unet.eval()
|
414 |
+
text_encoder.train()
|
415 |
+
for batch in dataloader:
|
416 |
+
|
417 |
+
lr_scheduler.step()
|
418 |
+
|
419 |
+
with torch.set_grad_enabled(True):
|
420 |
+
loss = (
|
421 |
+
loss_step(
|
422 |
+
batch,
|
423 |
+
unet,
|
424 |
+
vae,
|
425 |
+
text_encoder,
|
426 |
+
scheduler,
|
427 |
+
train_inpainting=train_inpainting,
|
428 |
+
mixed_precision=mixed_precision,
|
429 |
+
cached_latents=cached_latents,
|
430 |
+
)
|
431 |
+
/ accum_iter
|
432 |
+
)
|
433 |
+
|
434 |
+
loss.backward()
|
435 |
+
loss_sum += loss.detach().item()
|
436 |
+
|
437 |
+
if global_step % accum_iter == 0:
|
438 |
+
# print gradient of text encoder embedding
|
439 |
+
print(
|
440 |
+
text_encoder.get_input_embeddings()
|
441 |
+
.weight.grad[index_updates, :]
|
442 |
+
.norm(dim=-1)
|
443 |
+
.mean()
|
444 |
+
)
|
445 |
+
optimizer.step()
|
446 |
+
optimizer.zero_grad()
|
447 |
+
|
448 |
+
with torch.no_grad():
|
449 |
+
|
450 |
+
# normalize embeddings
|
451 |
+
if clip_ti_decay:
|
452 |
+
pre_norm = (
|
453 |
+
text_encoder.get_input_embeddings()
|
454 |
+
.weight[index_updates, :]
|
455 |
+
.norm(dim=-1, keepdim=True)
|
456 |
+
)
|
457 |
+
|
458 |
+
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
|
459 |
+
text_encoder.get_input_embeddings().weight[
|
460 |
+
index_updates
|
461 |
+
] = F.normalize(
|
462 |
+
text_encoder.get_input_embeddings().weight[
|
463 |
+
index_updates, :
|
464 |
+
],
|
465 |
+
dim=-1,
|
466 |
+
) * (
|
467 |
+
pre_norm + lambda_ * (0.4 - pre_norm)
|
468 |
+
)
|
469 |
+
print(pre_norm)
|
470 |
+
|
471 |
+
current_norm = (
|
472 |
+
text_encoder.get_input_embeddings()
|
473 |
+
.weight[index_updates, :]
|
474 |
+
.norm(dim=-1)
|
475 |
+
)
|
476 |
+
|
477 |
+
text_encoder.get_input_embeddings().weight[
|
478 |
+
index_no_updates
|
479 |
+
] = orig_embeds_params[index_no_updates]
|
480 |
+
|
481 |
+
print(f"Current Norm : {current_norm}")
|
482 |
+
|
483 |
+
global_step += 1
|
484 |
+
progress_bar.update(1)
|
485 |
+
|
486 |
+
logs = {
|
487 |
+
"loss": loss.detach().item(),
|
488 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
489 |
+
}
|
490 |
+
progress_bar.set_postfix(**logs)
|
491 |
+
|
492 |
+
if global_step % save_steps == 0:
|
493 |
+
save_all(
|
494 |
+
unet=unet,
|
495 |
+
text_encoder=text_encoder,
|
496 |
+
placeholder_token_ids=placeholder_token_ids,
|
497 |
+
placeholder_tokens=placeholder_tokens,
|
498 |
+
save_path=os.path.join(
|
499 |
+
save_path, f"step_inv_{global_step}.safetensors"
|
500 |
+
),
|
501 |
+
save_lora=False,
|
502 |
+
)
|
503 |
+
if log_wandb:
|
504 |
+
with torch.no_grad():
|
505 |
+
pipe = StableDiffusionPipeline(
|
506 |
+
vae=vae,
|
507 |
+
text_encoder=text_encoder,
|
508 |
+
tokenizer=tokenizer,
|
509 |
+
unet=unet,
|
510 |
+
scheduler=scheduler,
|
511 |
+
safety_checker=None,
|
512 |
+
feature_extractor=None,
|
513 |
+
)
|
514 |
+
|
515 |
+
# open all images in test_image_path
|
516 |
+
images = []
|
517 |
+
for file in os.listdir(test_image_path):
|
518 |
+
if (
|
519 |
+
file.lower().endswith(".png")
|
520 |
+
or file.lower().endswith(".jpg")
|
521 |
+
or file.lower().endswith(".jpeg")
|
522 |
+
):
|
523 |
+
images.append(
|
524 |
+
Image.open(os.path.join(test_image_path, file))
|
525 |
+
)
|
526 |
+
|
527 |
+
wandb.log({"loss": loss_sum / save_steps})
|
528 |
+
loss_sum = 0.0
|
529 |
+
wandb.log(
|
530 |
+
evaluate_pipe(
|
531 |
+
pipe,
|
532 |
+
target_images=images,
|
533 |
+
class_token=class_token,
|
534 |
+
learnt_token="".join(placeholder_tokens),
|
535 |
+
n_test=wandb_log_prompt_cnt,
|
536 |
+
n_step=50,
|
537 |
+
clip_model_sets=preped_clip,
|
538 |
+
)
|
539 |
+
)
|
540 |
+
|
541 |
+
if global_step >= num_steps:
|
542 |
+
return
|
543 |
+
|
544 |
+
|
545 |
+
def perform_tuning(
|
546 |
+
unet,
|
547 |
+
vae,
|
548 |
+
text_encoder,
|
549 |
+
dataloader,
|
550 |
+
num_steps,
|
551 |
+
scheduler,
|
552 |
+
optimizer,
|
553 |
+
save_steps: int,
|
554 |
+
placeholder_token_ids,
|
555 |
+
placeholder_tokens,
|
556 |
+
save_path,
|
557 |
+
lr_scheduler_lora,
|
558 |
+
lora_unet_target_modules,
|
559 |
+
lora_clip_target_modules,
|
560 |
+
mask_temperature,
|
561 |
+
out_name: str,
|
562 |
+
tokenizer,
|
563 |
+
test_image_path: str,
|
564 |
+
cached_latents: bool,
|
565 |
+
log_wandb: bool = False,
|
566 |
+
wandb_log_prompt_cnt: int = 10,
|
567 |
+
class_token: str = "person",
|
568 |
+
train_inpainting: bool = False,
|
569 |
+
):
|
570 |
+
|
571 |
+
progress_bar = tqdm(range(num_steps))
|
572 |
+
progress_bar.set_description("Steps")
|
573 |
+
global_step = 0
|
574 |
+
|
575 |
+
weight_dtype = torch.float16
|
576 |
+
|
577 |
+
unet.train()
|
578 |
+
text_encoder.train()
|
579 |
+
|
580 |
+
if log_wandb:
|
581 |
+
preped_clip = prepare_clip_model_sets()
|
582 |
+
|
583 |
+
loss_sum = 0.0
|
584 |
+
|
585 |
+
for epoch in range(math.ceil(num_steps / len(dataloader))):
|
586 |
+
for batch in dataloader:
|
587 |
+
lr_scheduler_lora.step()
|
588 |
+
|
589 |
+
optimizer.zero_grad()
|
590 |
+
|
591 |
+
loss = loss_step(
|
592 |
+
batch,
|
593 |
+
unet,
|
594 |
+
vae,
|
595 |
+
text_encoder,
|
596 |
+
scheduler,
|
597 |
+
train_inpainting=train_inpainting,
|
598 |
+
t_mutliplier=0.8,
|
599 |
+
mixed_precision=True,
|
600 |
+
mask_temperature=mask_temperature,
|
601 |
+
cached_latents=cached_latents,
|
602 |
+
)
|
603 |
+
loss_sum += loss.detach().item()
|
604 |
+
|
605 |
+
loss.backward()
|
606 |
+
torch.nn.utils.clip_grad_norm_(
|
607 |
+
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
|
608 |
+
)
|
609 |
+
optimizer.step()
|
610 |
+
progress_bar.update(1)
|
611 |
+
logs = {
|
612 |
+
"loss": loss.detach().item(),
|
613 |
+
"lr": lr_scheduler_lora.get_last_lr()[0],
|
614 |
+
}
|
615 |
+
progress_bar.set_postfix(**logs)
|
616 |
+
|
617 |
+
global_step += 1
|
618 |
+
|
619 |
+
if global_step % save_steps == 0:
|
620 |
+
save_all(
|
621 |
+
unet,
|
622 |
+
text_encoder,
|
623 |
+
placeholder_token_ids=placeholder_token_ids,
|
624 |
+
placeholder_tokens=placeholder_tokens,
|
625 |
+
save_path=os.path.join(
|
626 |
+
save_path, f"step_{global_step}.safetensors"
|
627 |
+
),
|
628 |
+
target_replace_module_text=lora_clip_target_modules,
|
629 |
+
target_replace_module_unet=lora_unet_target_modules,
|
630 |
+
)
|
631 |
+
moved = (
|
632 |
+
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
|
633 |
+
.mean()
|
634 |
+
.item()
|
635 |
+
)
|
636 |
+
|
637 |
+
print("LORA Unet Moved", moved)
|
638 |
+
moved = (
|
639 |
+
torch.tensor(
|
640 |
+
list(itertools.chain(*inspect_lora(text_encoder).values()))
|
641 |
+
)
|
642 |
+
.mean()
|
643 |
+
.item()
|
644 |
+
)
|
645 |
+
|
646 |
+
print("LORA CLIP Moved", moved)
|
647 |
+
|
648 |
+
if log_wandb:
|
649 |
+
with torch.no_grad():
|
650 |
+
pipe = StableDiffusionPipeline(
|
651 |
+
vae=vae,
|
652 |
+
text_encoder=text_encoder,
|
653 |
+
tokenizer=tokenizer,
|
654 |
+
unet=unet,
|
655 |
+
scheduler=scheduler,
|
656 |
+
safety_checker=None,
|
657 |
+
feature_extractor=None,
|
658 |
+
)
|
659 |
+
|
660 |
+
# open all images in test_image_path
|
661 |
+
images = []
|
662 |
+
for file in os.listdir(test_image_path):
|
663 |
+
if file.endswith(".png") or file.endswith(".jpg"):
|
664 |
+
images.append(
|
665 |
+
Image.open(os.path.join(test_image_path, file))
|
666 |
+
)
|
667 |
+
|
668 |
+
wandb.log({"loss": loss_sum / save_steps})
|
669 |
+
loss_sum = 0.0
|
670 |
+
wandb.log(
|
671 |
+
evaluate_pipe(
|
672 |
+
pipe,
|
673 |
+
target_images=images,
|
674 |
+
class_token=class_token,
|
675 |
+
learnt_token="".join(placeholder_tokens),
|
676 |
+
n_test=wandb_log_prompt_cnt,
|
677 |
+
n_step=50,
|
678 |
+
clip_model_sets=preped_clip,
|
679 |
+
)
|
680 |
+
)
|
681 |
+
|
682 |
+
if global_step >= num_steps:
|
683 |
+
break
|
684 |
+
|
685 |
+
save_all(
|
686 |
+
unet,
|
687 |
+
text_encoder,
|
688 |
+
placeholder_token_ids=placeholder_token_ids,
|
689 |
+
placeholder_tokens=placeholder_tokens,
|
690 |
+
save_path=os.path.join(save_path, f"{out_name}.safetensors"),
|
691 |
+
target_replace_module_text=lora_clip_target_modules,
|
692 |
+
target_replace_module_unet=lora_unet_target_modules,
|
693 |
+
)
|
694 |
+
|
695 |
+
|
696 |
+
def train(
|
697 |
+
instance_data_dir: str,
|
698 |
+
pretrained_model_name_or_path: str,
|
699 |
+
output_dir: str,
|
700 |
+
train_text_encoder: bool = True,
|
701 |
+
pretrained_vae_name_or_path: str = None,
|
702 |
+
revision: Optional[str] = None,
|
703 |
+
perform_inversion: bool = True,
|
704 |
+
use_template: Literal[None, "object", "style"] = None,
|
705 |
+
train_inpainting: bool = False,
|
706 |
+
placeholder_tokens: str = "",
|
707 |
+
placeholder_token_at_data: Optional[str] = None,
|
708 |
+
initializer_tokens: Optional[str] = None,
|
709 |
+
seed: int = 42,
|
710 |
+
resolution: int = 512,
|
711 |
+
color_jitter: bool = True,
|
712 |
+
train_batch_size: int = 1,
|
713 |
+
sample_batch_size: int = 1,
|
714 |
+
max_train_steps_tuning: int = 1000,
|
715 |
+
max_train_steps_ti: int = 1000,
|
716 |
+
save_steps: int = 100,
|
717 |
+
gradient_accumulation_steps: int = 4,
|
718 |
+
gradient_checkpointing: bool = False,
|
719 |
+
lora_rank: int = 4,
|
720 |
+
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
|
721 |
+
lora_clip_target_modules={"CLIPAttention"},
|
722 |
+
lora_dropout_p: float = 0.0,
|
723 |
+
lora_scale: float = 1.0,
|
724 |
+
use_extended_lora: bool = False,
|
725 |
+
clip_ti_decay: bool = True,
|
726 |
+
learning_rate_unet: float = 1e-4,
|
727 |
+
learning_rate_text: float = 1e-5,
|
728 |
+
learning_rate_ti: float = 5e-4,
|
729 |
+
continue_inversion: bool = False,
|
730 |
+
continue_inversion_lr: Optional[float] = None,
|
731 |
+
use_face_segmentation_condition: bool = False,
|
732 |
+
cached_latents: bool = True,
|
733 |
+
use_mask_captioned_data: bool = False,
|
734 |
+
mask_temperature: float = 1.0,
|
735 |
+
scale_lr: bool = False,
|
736 |
+
lr_scheduler: str = "linear",
|
737 |
+
lr_warmup_steps: int = 0,
|
738 |
+
lr_scheduler_lora: str = "linear",
|
739 |
+
lr_warmup_steps_lora: int = 0,
|
740 |
+
weight_decay_ti: float = 0.00,
|
741 |
+
weight_decay_lora: float = 0.001,
|
742 |
+
use_8bit_adam: bool = False,
|
743 |
+
device="cuda:0",
|
744 |
+
extra_args: Optional[dict] = None,
|
745 |
+
log_wandb: bool = False,
|
746 |
+
wandb_log_prompt_cnt: int = 10,
|
747 |
+
wandb_project_name: str = "new_pti_project",
|
748 |
+
wandb_entity: str = "new_pti_entity",
|
749 |
+
proxy_token: str = "person",
|
750 |
+
enable_xformers_memory_efficient_attention: bool = False,
|
751 |
+
out_name: str = "final_lora",
|
752 |
+
):
|
753 |
+
torch.manual_seed(seed)
|
754 |
+
|
755 |
+
if log_wandb:
|
756 |
+
wandb.init(
|
757 |
+
project=wandb_project_name,
|
758 |
+
entity=wandb_entity,
|
759 |
+
name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
|
760 |
+
reinit=True,
|
761 |
+
config={
|
762 |
+
**(extra_args if extra_args is not None else {}),
|
763 |
+
},
|
764 |
+
)
|
765 |
+
|
766 |
+
if output_dir is not None:
|
767 |
+
os.makedirs(output_dir, exist_ok=True)
|
768 |
+
# print(placeholder_tokens, initializer_tokens)
|
769 |
+
if len(placeholder_tokens) == 0:
|
770 |
+
placeholder_tokens = []
|
771 |
+
print("PTI : Placeholder Tokens not given, using null token")
|
772 |
+
else:
|
773 |
+
placeholder_tokens = placeholder_tokens.split("|")
|
774 |
+
|
775 |
+
assert (
|
776 |
+
sorted(placeholder_tokens) == placeholder_tokens
|
777 |
+
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"
|
778 |
+
|
779 |
+
if initializer_tokens is None:
|
780 |
+
print("PTI : Initializer Tokens not given, doing random inits")
|
781 |
+
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
|
782 |
+
else:
|
783 |
+
initializer_tokens = initializer_tokens.split("|")
|
784 |
+
|
785 |
+
assert len(initializer_tokens) == len(
|
786 |
+
placeholder_tokens
|
787 |
+
), "Unequal Initializer token for Placeholder tokens."
|
788 |
+
|
789 |
+
if proxy_token is not None:
|
790 |
+
class_token = proxy_token
|
791 |
+
class_token = "".join(initializer_tokens)
|
792 |
+
|
793 |
+
if placeholder_token_at_data is not None:
|
794 |
+
tok, pat = placeholder_token_at_data.split("|")
|
795 |
+
token_map = {tok: pat}
|
796 |
+
|
797 |
+
else:
|
798 |
+
token_map = {"DUMMY": "".join(placeholder_tokens)}
|
799 |
+
|
800 |
+
print("PTI : Placeholder Tokens", placeholder_tokens)
|
801 |
+
print("PTI : Initializer Tokens", initializer_tokens)
|
802 |
+
|
803 |
+
# get the models
|
804 |
+
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
|
805 |
+
pretrained_model_name_or_path,
|
806 |
+
pretrained_vae_name_or_path,
|
807 |
+
revision,
|
808 |
+
placeholder_tokens,
|
809 |
+
initializer_tokens,
|
810 |
+
device=device,
|
811 |
+
)
|
812 |
+
|
813 |
+
noise_scheduler = DDPMScheduler.from_config(
|
814 |
+
pretrained_model_name_or_path, subfolder="scheduler"
|
815 |
+
)
|
816 |
+
|
817 |
+
if gradient_checkpointing:
|
818 |
+
unet.enable_gradient_checkpointing()
|
819 |
+
|
820 |
+
if enable_xformers_memory_efficient_attention:
|
821 |
+
from diffusers.utils.import_utils import is_xformers_available
|
822 |
+
|
823 |
+
if is_xformers_available():
|
824 |
+
unet.enable_xformers_memory_efficient_attention()
|
825 |
+
else:
|
826 |
+
raise ValueError(
|
827 |
+
"xformers is not available. Make sure it is installed correctly"
|
828 |
+
)
|
829 |
+
|
830 |
+
if scale_lr:
|
831 |
+
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
|
832 |
+
text_encoder_lr = (
|
833 |
+
learning_rate_text * gradient_accumulation_steps * train_batch_size
|
834 |
+
)
|
835 |
+
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size
|
836 |
+
else:
|
837 |
+
unet_lr = learning_rate_unet
|
838 |
+
text_encoder_lr = learning_rate_text
|
839 |
+
ti_lr = learning_rate_ti
|
840 |
+
|
841 |
+
train_dataset = PivotalTuningDatasetCapation(
|
842 |
+
instance_data_root=instance_data_dir,
|
843 |
+
token_map=token_map,
|
844 |
+
use_template=use_template,
|
845 |
+
tokenizer=tokenizer,
|
846 |
+
size=resolution,
|
847 |
+
color_jitter=color_jitter,
|
848 |
+
use_face_segmentation_condition=use_face_segmentation_condition,
|
849 |
+
use_mask_captioned_data=use_mask_captioned_data,
|
850 |
+
train_inpainting=train_inpainting,
|
851 |
+
)
|
852 |
+
|
853 |
+
train_dataset.blur_amount = 200
|
854 |
+
|
855 |
+
if train_inpainting:
|
856 |
+
assert not cached_latents, "Cached latents not supported for inpainting"
|
857 |
+
|
858 |
+
train_dataloader = inpainting_dataloader(
|
859 |
+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
|
860 |
+
)
|
861 |
+
else:
|
862 |
+
train_dataloader = text2img_dataloader(
|
863 |
+
train_dataset,
|
864 |
+
train_batch_size,
|
865 |
+
tokenizer,
|
866 |
+
vae,
|
867 |
+
text_encoder,
|
868 |
+
cached_latents=cached_latents,
|
869 |
+
)
|
870 |
+
|
871 |
+
index_no_updates = torch.arange(len(tokenizer)) != -1
|
872 |
+
|
873 |
+
for tok_id in placeholder_token_ids:
|
874 |
+
index_no_updates[tok_id] = False
|
875 |
+
|
876 |
+
unet.requires_grad_(False)
|
877 |
+
vae.requires_grad_(False)
|
878 |
+
|
879 |
+
params_to_freeze = itertools.chain(
|
880 |
+
text_encoder.text_model.encoder.parameters(),
|
881 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
882 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
883 |
+
)
|
884 |
+
for param in params_to_freeze:
|
885 |
+
param.requires_grad = False
|
886 |
+
|
887 |
+
if cached_latents:
|
888 |
+
vae = None
|
889 |
+
# STEP 1 : Perform Inversion
|
890 |
+
if perform_inversion:
|
891 |
+
ti_optimizer = optim.AdamW(
|
892 |
+
text_encoder.get_input_embeddings().parameters(),
|
893 |
+
lr=ti_lr,
|
894 |
+
betas=(0.9, 0.999),
|
895 |
+
eps=1e-08,
|
896 |
+
weight_decay=weight_decay_ti,
|
897 |
+
)
|
898 |
+
|
899 |
+
lr_scheduler = get_scheduler(
|
900 |
+
lr_scheduler,
|
901 |
+
optimizer=ti_optimizer,
|
902 |
+
num_warmup_steps=lr_warmup_steps,
|
903 |
+
num_training_steps=max_train_steps_ti,
|
904 |
+
)
|
905 |
+
|
906 |
+
train_inversion(
|
907 |
+
unet,
|
908 |
+
vae,
|
909 |
+
text_encoder,
|
910 |
+
train_dataloader,
|
911 |
+
max_train_steps_ti,
|
912 |
+
cached_latents=cached_latents,
|
913 |
+
accum_iter=gradient_accumulation_steps,
|
914 |
+
scheduler=noise_scheduler,
|
915 |
+
index_no_updates=index_no_updates,
|
916 |
+
optimizer=ti_optimizer,
|
917 |
+
lr_scheduler=lr_scheduler,
|
918 |
+
save_steps=save_steps,
|
919 |
+
placeholder_tokens=placeholder_tokens,
|
920 |
+
placeholder_token_ids=placeholder_token_ids,
|
921 |
+
save_path=output_dir,
|
922 |
+
test_image_path=instance_data_dir,
|
923 |
+
log_wandb=log_wandb,
|
924 |
+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
|
925 |
+
class_token=class_token,
|
926 |
+
train_inpainting=train_inpainting,
|
927 |
+
mixed_precision=False,
|
928 |
+
tokenizer=tokenizer,
|
929 |
+
clip_ti_decay=clip_ti_decay,
|
930 |
+
)
|
931 |
+
|
932 |
+
del ti_optimizer
|
933 |
+
|
934 |
+
# Next perform Tuning with LoRA:
|
935 |
+
if not use_extended_lora:
|
936 |
+
unet_lora_params, _ = inject_trainable_lora(
|
937 |
+
unet,
|
938 |
+
r=lora_rank,
|
939 |
+
target_replace_module=lora_unet_target_modules,
|
940 |
+
dropout_p=lora_dropout_p,
|
941 |
+
scale=lora_scale,
|
942 |
+
)
|
943 |
+
else:
|
944 |
+
print("PTI : USING EXTENDED UNET!!!")
|
945 |
+
lora_unet_target_modules = (
|
946 |
+
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
|
947 |
+
)
|
948 |
+
print("PTI : Will replace modules: ", lora_unet_target_modules)
|
949 |
+
|
950 |
+
unet_lora_params, _ = inject_trainable_lora_extended(
|
951 |
+
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
|
952 |
+
)
|
953 |
+
print(f"PTI : has {len(unet_lora_params)} lora")
|
954 |
+
|
955 |
+
print("PTI : Before training:")
|
956 |
+
inspect_lora(unet)
|
957 |
+
|
958 |
+
params_to_optimize = [
|
959 |
+
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
|
960 |
+
]
|
961 |
+
|
962 |
+
text_encoder.requires_grad_(False)
|
963 |
+
|
964 |
+
if continue_inversion:
|
965 |
+
params_to_optimize += [
|
966 |
+
{
|
967 |
+
"params": text_encoder.get_input_embeddings().parameters(),
|
968 |
+
"lr": continue_inversion_lr
|
969 |
+
if continue_inversion_lr is not None
|
970 |
+
else ti_lr,
|
971 |
+
}
|
972 |
+
]
|
973 |
+
text_encoder.requires_grad_(True)
|
974 |
+
params_to_freeze = itertools.chain(
|
975 |
+
text_encoder.text_model.encoder.parameters(),
|
976 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
977 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
978 |
+
)
|
979 |
+
for param in params_to_freeze:
|
980 |
+
param.requires_grad = False
|
981 |
+
else:
|
982 |
+
text_encoder.requires_grad_(False)
|
983 |
+
if train_text_encoder:
|
984 |
+
text_encoder_lora_params, _ = inject_trainable_lora(
|
985 |
+
text_encoder,
|
986 |
+
target_replace_module=lora_clip_target_modules,
|
987 |
+
r=lora_rank,
|
988 |
+
)
|
989 |
+
params_to_optimize += [
|
990 |
+
{
|
991 |
+
"params": itertools.chain(*text_encoder_lora_params),
|
992 |
+
"lr": text_encoder_lr,
|
993 |
+
}
|
994 |
+
]
|
995 |
+
inspect_lora(text_encoder)
|
996 |
+
|
997 |
+
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)
|
998 |
+
|
999 |
+
unet.train()
|
1000 |
+
if train_text_encoder:
|
1001 |
+
text_encoder.train()
|
1002 |
+
|
1003 |
+
train_dataset.blur_amount = 70
|
1004 |
+
|
1005 |
+
lr_scheduler_lora = get_scheduler(
|
1006 |
+
lr_scheduler_lora,
|
1007 |
+
optimizer=lora_optimizers,
|
1008 |
+
num_warmup_steps=lr_warmup_steps_lora,
|
1009 |
+
num_training_steps=max_train_steps_tuning,
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
perform_tuning(
|
1013 |
+
unet,
|
1014 |
+
vae,
|
1015 |
+
text_encoder,
|
1016 |
+
train_dataloader,
|
1017 |
+
max_train_steps_tuning,
|
1018 |
+
cached_latents=cached_latents,
|
1019 |
+
scheduler=noise_scheduler,
|
1020 |
+
optimizer=lora_optimizers,
|
1021 |
+
save_steps=save_steps,
|
1022 |
+
placeholder_tokens=placeholder_tokens,
|
1023 |
+
placeholder_token_ids=placeholder_token_ids,
|
1024 |
+
save_path=output_dir,
|
1025 |
+
lr_scheduler_lora=lr_scheduler_lora,
|
1026 |
+
lora_unet_target_modules=lora_unet_target_modules,
|
1027 |
+
lora_clip_target_modules=lora_clip_target_modules,
|
1028 |
+
mask_temperature=mask_temperature,
|
1029 |
+
tokenizer=tokenizer,
|
1030 |
+
out_name=out_name,
|
1031 |
+
test_image_path=instance_data_dir,
|
1032 |
+
log_wandb=log_wandb,
|
1033 |
+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
|
1034 |
+
class_token=class_token,
|
1035 |
+
train_inpainting=train_inpainting,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
|
1039 |
+
def main():
|
1040 |
+
fire.Fire(train)
|
lora_diffusion/cli_pt_to_safetensors.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import fire
|
4 |
+
import torch
|
5 |
+
from lora_diffusion import (
|
6 |
+
DEFAULT_TARGET_REPLACE,
|
7 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
8 |
+
UNET_DEFAULT_TARGET_REPLACE,
|
9 |
+
convert_loras_to_safeloras_with_embeds,
|
10 |
+
safetensors_available,
|
11 |
+
)
|
12 |
+
|
13 |
+
_target_by_name = {
|
14 |
+
"unet": UNET_DEFAULT_TARGET_REPLACE,
|
15 |
+
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def convert(*paths, outpath, overwrite=False, **settings):
|
20 |
+
"""
|
21 |
+
Converts one or more pytorch Lora and/or Textual Embedding pytorch files
|
22 |
+
into a safetensor file.
|
23 |
+
|
24 |
+
Pass all the input paths as arguments. Whether they are Textual Embedding
|
25 |
+
or Lora models will be auto-detected.
|
26 |
+
|
27 |
+
For Lora models, their name will be taken from the path, i.e.
|
28 |
+
"lora_weight.pt" => unet
|
29 |
+
"lora_weight.text_encoder.pt" => text_encoder
|
30 |
+
|
31 |
+
You can also set target_modules and/or rank by providing an argument prefixed
|
32 |
+
by the name.
|
33 |
+
|
34 |
+
So a complete example might be something like:
|
35 |
+
|
36 |
+
```
|
37 |
+
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
|
38 |
+
```
|
39 |
+
"""
|
40 |
+
modelmap = {}
|
41 |
+
embeds = {}
|
42 |
+
|
43 |
+
if os.path.exists(outpath) and not overwrite:
|
44 |
+
raise ValueError(
|
45 |
+
f"Output path {outpath} already exists, and overwrite is not True"
|
46 |
+
)
|
47 |
+
|
48 |
+
for path in paths:
|
49 |
+
data = torch.load(path)
|
50 |
+
|
51 |
+
if isinstance(data, dict):
|
52 |
+
print(f"Loading textual inversion embeds {data.keys()} from {path}")
|
53 |
+
embeds.update(data)
|
54 |
+
|
55 |
+
else:
|
56 |
+
name_parts = os.path.split(path)[1].split(".")
|
57 |
+
name = name_parts[-2] if len(name_parts) > 2 else "unet"
|
58 |
+
|
59 |
+
model_settings = {
|
60 |
+
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
|
61 |
+
"rank": 4,
|
62 |
+
}
|
63 |
+
|
64 |
+
prefix = f"{name}."
|
65 |
+
|
66 |
+
arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
|
67 |
+
model_settings = { **model_settings, **arg_settings }
|
68 |
+
|
69 |
+
print(f"Loading Lora for {name} from {path} with settings {model_settings}")
|
70 |
+
|
71 |
+
modelmap[name] = (
|
72 |
+
path,
|
73 |
+
model_settings["target_modules"],
|
74 |
+
model_settings["rank"],
|
75 |
+
)
|
76 |
+
|
77 |
+
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
fire.Fire(convert)
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
lora_diffusion/cli_svd.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .lora import (
|
7 |
+
save_all,
|
8 |
+
_find_modules,
|
9 |
+
LoraInjectedConv2d,
|
10 |
+
LoraInjectedLinear,
|
11 |
+
inject_trainable_lora,
|
12 |
+
inject_trainable_lora_extended,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def _iter_lora(model):
|
17 |
+
for module in model.modules():
|
18 |
+
if isinstance(module, LoraInjectedConv2d) or isinstance(
|
19 |
+
module, LoraInjectedLinear
|
20 |
+
):
|
21 |
+
yield module
|
22 |
+
|
23 |
+
|
24 |
+
def overwrite_base(base_model, tuned_model, rank, clamp_quantile):
|
25 |
+
device = base_model.device
|
26 |
+
dtype = base_model.dtype
|
27 |
+
|
28 |
+
for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)):
|
29 |
+
|
30 |
+
if isinstance(lor_base, LoraInjectedLinear):
|
31 |
+
residual = lor_tune.linear.weight.data - lor_base.linear.weight.data
|
32 |
+
# SVD on residual
|
33 |
+
print("Distill Linear shape ", residual.shape)
|
34 |
+
residual = residual.float()
|
35 |
+
U, S, Vh = torch.linalg.svd(residual)
|
36 |
+
U = U[:, :rank]
|
37 |
+
S = S[:rank]
|
38 |
+
U = U @ torch.diag(S)
|
39 |
+
|
40 |
+
Vh = Vh[:rank, :]
|
41 |
+
|
42 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
43 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
44 |
+
low_val = -hi_val
|
45 |
+
|
46 |
+
U = U.clamp(low_val, hi_val)
|
47 |
+
Vh = Vh.clamp(low_val, hi_val)
|
48 |
+
|
49 |
+
assert lor_base.lora_up.weight.shape == U.shape
|
50 |
+
assert lor_base.lora_down.weight.shape == Vh.shape
|
51 |
+
|
52 |
+
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
|
53 |
+
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
|
54 |
+
|
55 |
+
if isinstance(lor_base, LoraInjectedConv2d):
|
56 |
+
residual = lor_tune.conv.weight.data - lor_base.conv.weight.data
|
57 |
+
print("Distill Conv shape ", residual.shape)
|
58 |
+
|
59 |
+
residual = residual.float()
|
60 |
+
residual = residual.flatten(start_dim=1)
|
61 |
+
|
62 |
+
# SVD on residual
|
63 |
+
U, S, Vh = torch.linalg.svd(residual)
|
64 |
+
U = U[:, :rank]
|
65 |
+
S = S[:rank]
|
66 |
+
U = U @ torch.diag(S)
|
67 |
+
|
68 |
+
Vh = Vh[:rank, :]
|
69 |
+
|
70 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
71 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
72 |
+
low_val = -hi_val
|
73 |
+
|
74 |
+
U = U.clamp(low_val, hi_val)
|
75 |
+
Vh = Vh.clamp(low_val, hi_val)
|
76 |
+
|
77 |
+
# U is (out_channels, rank) with 1x1 conv. So,
|
78 |
+
U = U.reshape(U.shape[0], U.shape[1], 1, 1)
|
79 |
+
# V is (rank, in_channels * kernel_size1 * kernel_size2)
|
80 |
+
# now reshape:
|
81 |
+
Vh = Vh.reshape(
|
82 |
+
Vh.shape[0],
|
83 |
+
lor_base.conv.in_channels,
|
84 |
+
lor_base.conv.kernel_size[0],
|
85 |
+
lor_base.conv.kernel_size[1],
|
86 |
+
)
|
87 |
+
|
88 |
+
assert lor_base.lora_up.weight.shape == U.shape
|
89 |
+
assert lor_base.lora_down.weight.shape == Vh.shape
|
90 |
+
|
91 |
+
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
|
92 |
+
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
|
93 |
+
|
94 |
+
|
95 |
+
def svd_distill(
|
96 |
+
target_model: str,
|
97 |
+
base_model: str,
|
98 |
+
rank: int = 4,
|
99 |
+
clamp_quantile: float = 0.99,
|
100 |
+
device: str = "cuda:0",
|
101 |
+
save_path: str = "svd_distill.safetensors",
|
102 |
+
):
|
103 |
+
pipe_base = StableDiffusionPipeline.from_pretrained(
|
104 |
+
base_model, torch_dtype=torch.float16
|
105 |
+
).to(device)
|
106 |
+
|
107 |
+
pipe_tuned = StableDiffusionPipeline.from_pretrained(
|
108 |
+
target_model, torch_dtype=torch.float16
|
109 |
+
).to(device)
|
110 |
+
|
111 |
+
# Inject unet
|
112 |
+
_ = inject_trainable_lora_extended(pipe_base.unet, r=rank)
|
113 |
+
_ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank)
|
114 |
+
|
115 |
+
overwrite_base(
|
116 |
+
pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile
|
117 |
+
)
|
118 |
+
|
119 |
+
# Inject text encoder
|
120 |
+
_ = inject_trainable_lora(
|
121 |
+
pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
|
122 |
+
)
|
123 |
+
_ = inject_trainable_lora(
|
124 |
+
pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
|
125 |
+
)
|
126 |
+
|
127 |
+
overwrite_base(
|
128 |
+
pipe_base.text_encoder,
|
129 |
+
pipe_tuned.text_encoder,
|
130 |
+
rank=rank,
|
131 |
+
clamp_quantile=clamp_quantile,
|
132 |
+
)
|
133 |
+
|
134 |
+
save_all(
|
135 |
+
unet=pipe_base.unet,
|
136 |
+
text_encoder=pipe_base.text_encoder,
|
137 |
+
placeholder_token_ids=None,
|
138 |
+
placeholder_tokens=None,
|
139 |
+
save_path=save_path,
|
140 |
+
save_lora=True,
|
141 |
+
save_ti=False,
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def main():
|
146 |
+
fire.Fire(svd_distill)
|
lora_diffusion/dataset.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from torch import zeros_like
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
import glob
|
10 |
+
from .preprocess_files import face_mask_google_mediapipe
|
11 |
+
|
12 |
+
OBJECT_TEMPLATE = [
|
13 |
+
"a photo of a {}",
|
14 |
+
"a rendering of a {}",
|
15 |
+
"a cropped photo of the {}",
|
16 |
+
"the photo of a {}",
|
17 |
+
"a photo of a clean {}",
|
18 |
+
"a photo of a dirty {}",
|
19 |
+
"a dark photo of the {}",
|
20 |
+
"a photo of my {}",
|
21 |
+
"a photo of the cool {}",
|
22 |
+
"a close-up photo of a {}",
|
23 |
+
"a bright photo of the {}",
|
24 |
+
"a cropped photo of a {}",
|
25 |
+
"a photo of the {}",
|
26 |
+
"a good photo of the {}",
|
27 |
+
"a photo of one {}",
|
28 |
+
"a close-up photo of the {}",
|
29 |
+
"a rendition of the {}",
|
30 |
+
"a photo of the clean {}",
|
31 |
+
"a rendition of a {}",
|
32 |
+
"a photo of a nice {}",
|
33 |
+
"a good photo of a {}",
|
34 |
+
"a photo of the nice {}",
|
35 |
+
"a photo of the small {}",
|
36 |
+
"a photo of the weird {}",
|
37 |
+
"a photo of the large {}",
|
38 |
+
"a photo of a cool {}",
|
39 |
+
"a photo of a small {}",
|
40 |
+
]
|
41 |
+
|
42 |
+
STYLE_TEMPLATE = [
|
43 |
+
"a painting in the style of {}",
|
44 |
+
"a rendering in the style of {}",
|
45 |
+
"a cropped painting in the style of {}",
|
46 |
+
"the painting in the style of {}",
|
47 |
+
"a clean painting in the style of {}",
|
48 |
+
"a dirty painting in the style of {}",
|
49 |
+
"a dark painting in the style of {}",
|
50 |
+
"a picture in the style of {}",
|
51 |
+
"a cool painting in the style of {}",
|
52 |
+
"a close-up painting in the style of {}",
|
53 |
+
"a bright painting in the style of {}",
|
54 |
+
"a cropped painting in the style of {}",
|
55 |
+
"a good painting in the style of {}",
|
56 |
+
"a close-up painting in the style of {}",
|
57 |
+
"a rendition in the style of {}",
|
58 |
+
"a nice painting in the style of {}",
|
59 |
+
"a small painting in the style of {}",
|
60 |
+
"a weird painting in the style of {}",
|
61 |
+
"a large painting in the style of {}",
|
62 |
+
]
|
63 |
+
|
64 |
+
NULL_TEMPLATE = ["{}"]
|
65 |
+
|
66 |
+
TEMPLATE_MAP = {
|
67 |
+
"object": OBJECT_TEMPLATE,
|
68 |
+
"style": STYLE_TEMPLATE,
|
69 |
+
"null": NULL_TEMPLATE,
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def _randomset(lis):
|
74 |
+
ret = []
|
75 |
+
for i in range(len(lis)):
|
76 |
+
if random.random() < 0.5:
|
77 |
+
ret.append(lis[i])
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
def _shuffle(lis):
|
82 |
+
|
83 |
+
return random.sample(lis, len(lis))
|
84 |
+
|
85 |
+
|
86 |
+
def _get_cutout_holes(
|
87 |
+
height,
|
88 |
+
width,
|
89 |
+
min_holes=8,
|
90 |
+
max_holes=32,
|
91 |
+
min_height=16,
|
92 |
+
max_height=128,
|
93 |
+
min_width=16,
|
94 |
+
max_width=128,
|
95 |
+
):
|
96 |
+
holes = []
|
97 |
+
for _n in range(random.randint(min_holes, max_holes)):
|
98 |
+
hole_height = random.randint(min_height, max_height)
|
99 |
+
hole_width = random.randint(min_width, max_width)
|
100 |
+
y1 = random.randint(0, height - hole_height)
|
101 |
+
x1 = random.randint(0, width - hole_width)
|
102 |
+
y2 = y1 + hole_height
|
103 |
+
x2 = x1 + hole_width
|
104 |
+
holes.append((x1, y1, x2, y2))
|
105 |
+
return holes
|
106 |
+
|
107 |
+
|
108 |
+
def _generate_random_mask(image):
|
109 |
+
mask = zeros_like(image[:1])
|
110 |
+
holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
|
111 |
+
for (x1, y1, x2, y2) in holes:
|
112 |
+
mask[:, y1:y2, x1:x2] = 1.0
|
113 |
+
if random.uniform(0, 1) < 0.25:
|
114 |
+
mask.fill_(1.0)
|
115 |
+
masked_image = image * (mask < 0.5)
|
116 |
+
return mask, masked_image
|
117 |
+
|
118 |
+
|
119 |
+
class PivotalTuningDatasetCapation(Dataset):
|
120 |
+
"""
|
121 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
122 |
+
It pre-processes the images and the tokenizes prompts.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
instance_data_root,
|
128 |
+
tokenizer,
|
129 |
+
token_map: Optional[dict] = None,
|
130 |
+
use_template: Optional[str] = None,
|
131 |
+
size=512,
|
132 |
+
h_flip=True,
|
133 |
+
color_jitter=False,
|
134 |
+
resize=True,
|
135 |
+
use_mask_captioned_data=False,
|
136 |
+
use_face_segmentation_condition=False,
|
137 |
+
train_inpainting=False,
|
138 |
+
blur_amount: int = 70,
|
139 |
+
):
|
140 |
+
self.size = size
|
141 |
+
self.tokenizer = tokenizer
|
142 |
+
self.resize = resize
|
143 |
+
self.train_inpainting = train_inpainting
|
144 |
+
|
145 |
+
instance_data_root = Path(instance_data_root)
|
146 |
+
if not instance_data_root.exists():
|
147 |
+
raise ValueError("Instance images root doesn't exists.")
|
148 |
+
|
149 |
+
self.instance_images_path = []
|
150 |
+
self.mask_path = []
|
151 |
+
|
152 |
+
assert not (
|
153 |
+
use_mask_captioned_data and use_template
|
154 |
+
), "Can't use both mask caption data and template."
|
155 |
+
|
156 |
+
# Prepare the instance images
|
157 |
+
if use_mask_captioned_data:
|
158 |
+
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
|
159 |
+
for f in src_imgs:
|
160 |
+
idx = int(str(Path(f).stem).split(".")[0])
|
161 |
+
mask_path = f"{instance_data_root}/{idx}.mask.png"
|
162 |
+
|
163 |
+
if Path(mask_path).exists():
|
164 |
+
self.instance_images_path.append(f)
|
165 |
+
self.mask_path.append(mask_path)
|
166 |
+
else:
|
167 |
+
print(f"Mask not found for {f}")
|
168 |
+
|
169 |
+
self.captions = open(f"{instance_data_root}/caption.txt").readlines()
|
170 |
+
|
171 |
+
else:
|
172 |
+
possibily_src_images = (
|
173 |
+
glob.glob(str(instance_data_root) + "/*.jpg")
|
174 |
+
+ glob.glob(str(instance_data_root) + "/*.png")
|
175 |
+
+ glob.glob(str(instance_data_root) + "/*.jpeg")
|
176 |
+
)
|
177 |
+
possibily_src_images = (
|
178 |
+
set(possibily_src_images)
|
179 |
+
- set(glob.glob(str(instance_data_root) + "/*mask.png"))
|
180 |
+
- set([str(instance_data_root) + "/caption.txt"])
|
181 |
+
)
|
182 |
+
|
183 |
+
self.instance_images_path = list(set(possibily_src_images))
|
184 |
+
self.captions = [
|
185 |
+
x.split("/")[-1].split(".")[0] for x in self.instance_images_path
|
186 |
+
]
|
187 |
+
|
188 |
+
assert (
|
189 |
+
len(self.instance_images_path) > 0
|
190 |
+
), "No images found in the instance data root."
|
191 |
+
|
192 |
+
self.instance_images_path = sorted(self.instance_images_path)
|
193 |
+
|
194 |
+
self.use_mask = use_face_segmentation_condition or use_mask_captioned_data
|
195 |
+
self.use_mask_captioned_data = use_mask_captioned_data
|
196 |
+
|
197 |
+
if use_face_segmentation_condition:
|
198 |
+
|
199 |
+
for idx in range(len(self.instance_images_path)):
|
200 |
+
targ = f"{instance_data_root}/{idx}.mask.png"
|
201 |
+
# see if the mask exists
|
202 |
+
if not Path(targ).exists():
|
203 |
+
print(f"Mask not found for {targ}")
|
204 |
+
|
205 |
+
print(
|
206 |
+
"Warning : this will pre-process all the images in the instance data root."
|
207 |
+
)
|
208 |
+
|
209 |
+
if len(self.mask_path) > 0:
|
210 |
+
print(
|
211 |
+
"Warning : masks already exists, but will be overwritten."
|
212 |
+
)
|
213 |
+
|
214 |
+
masks = face_mask_google_mediapipe(
|
215 |
+
[
|
216 |
+
Image.open(f).convert("RGB")
|
217 |
+
for f in self.instance_images_path
|
218 |
+
]
|
219 |
+
)
|
220 |
+
for idx, mask in enumerate(masks):
|
221 |
+
mask.save(f"{instance_data_root}/{idx}.mask.png")
|
222 |
+
|
223 |
+
break
|
224 |
+
|
225 |
+
for idx in range(len(self.instance_images_path)):
|
226 |
+
self.mask_path.append(f"{instance_data_root}/{idx}.mask.png")
|
227 |
+
|
228 |
+
self.num_instance_images = len(self.instance_images_path)
|
229 |
+
self.token_map = token_map
|
230 |
+
|
231 |
+
self.use_template = use_template
|
232 |
+
if use_template is not None:
|
233 |
+
self.templates = TEMPLATE_MAP[use_template]
|
234 |
+
|
235 |
+
self._length = self.num_instance_images
|
236 |
+
|
237 |
+
self.h_flip = h_flip
|
238 |
+
self.image_transforms = transforms.Compose(
|
239 |
+
[
|
240 |
+
transforms.Resize(
|
241 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
242 |
+
)
|
243 |
+
if resize
|
244 |
+
else transforms.Lambda(lambda x: x),
|
245 |
+
transforms.ColorJitter(0.1, 0.1)
|
246 |
+
if color_jitter
|
247 |
+
else transforms.Lambda(lambda x: x),
|
248 |
+
transforms.CenterCrop(size),
|
249 |
+
transforms.ToTensor(),
|
250 |
+
transforms.Normalize([0.5], [0.5]),
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
self.blur_amount = blur_amount
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return self._length
|
258 |
+
|
259 |
+
def __getitem__(self, index):
|
260 |
+
example = {}
|
261 |
+
instance_image = Image.open(
|
262 |
+
self.instance_images_path[index % self.num_instance_images]
|
263 |
+
)
|
264 |
+
if not instance_image.mode == "RGB":
|
265 |
+
instance_image = instance_image.convert("RGB")
|
266 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
267 |
+
|
268 |
+
if self.train_inpainting:
|
269 |
+
(
|
270 |
+
example["instance_masks"],
|
271 |
+
example["instance_masked_images"],
|
272 |
+
) = _generate_random_mask(example["instance_images"])
|
273 |
+
|
274 |
+
if self.use_template:
|
275 |
+
assert self.token_map is not None
|
276 |
+
input_tok = list(self.token_map.values())[0]
|
277 |
+
|
278 |
+
text = random.choice(self.templates).format(input_tok)
|
279 |
+
else:
|
280 |
+
text = self.captions[index % self.num_instance_images].strip()
|
281 |
+
|
282 |
+
if self.token_map is not None:
|
283 |
+
for token, value in self.token_map.items():
|
284 |
+
text = text.replace(token, value)
|
285 |
+
|
286 |
+
print(text)
|
287 |
+
|
288 |
+
if self.use_mask:
|
289 |
+
example["mask"] = (
|
290 |
+
self.image_transforms(
|
291 |
+
Image.open(self.mask_path[index % self.num_instance_images])
|
292 |
+
)
|
293 |
+
* 0.5
|
294 |
+
+ 1.0
|
295 |
+
)
|
296 |
+
|
297 |
+
if self.h_flip and random.random() > 0.5:
|
298 |
+
hflip = transforms.RandomHorizontalFlip(p=1)
|
299 |
+
|
300 |
+
example["instance_images"] = hflip(example["instance_images"])
|
301 |
+
if self.use_mask:
|
302 |
+
example["mask"] = hflip(example["mask"])
|
303 |
+
|
304 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
305 |
+
text,
|
306 |
+
padding="do_not_pad",
|
307 |
+
truncation=True,
|
308 |
+
max_length=self.tokenizer.model_max_length,
|
309 |
+
).input_ids
|
310 |
+
|
311 |
+
return example
|
lora_diffusion/lora.py
ADDED
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
from itertools import groupby
|
4 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
try:
|
13 |
+
from safetensors.torch import safe_open
|
14 |
+
from safetensors.torch import save_file as safe_save
|
15 |
+
|
16 |
+
safetensors_available = True
|
17 |
+
except ImportError:
|
18 |
+
from .safe_open import safe_open
|
19 |
+
|
20 |
+
def safe_save(
|
21 |
+
tensors: Dict[str, torch.Tensor],
|
22 |
+
filename: str,
|
23 |
+
metadata: Optional[Dict[str, str]] = None,
|
24 |
+
) -> None:
|
25 |
+
raise EnvironmentError(
|
26 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
27 |
+
)
|
28 |
+
|
29 |
+
safetensors_available = False
|
30 |
+
|
31 |
+
|
32 |
+
class LoraInjectedLinear(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
if r > min(in_features, out_features):
|
39 |
+
raise ValueError(
|
40 |
+
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
41 |
+
)
|
42 |
+
self.r = r
|
43 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
44 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
45 |
+
self.dropout = nn.Dropout(dropout_p)
|
46 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
47 |
+
self.scale = scale
|
48 |
+
self.selector = nn.Identity()
|
49 |
+
|
50 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
51 |
+
nn.init.zeros_(self.lora_up.weight)
|
52 |
+
|
53 |
+
def forward(self, input):
|
54 |
+
return (
|
55 |
+
self.linear(input)
|
56 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
57 |
+
* self.scale
|
58 |
+
)
|
59 |
+
|
60 |
+
def realize_as_lora(self):
|
61 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
62 |
+
|
63 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
64 |
+
# diag is a 1D tensor of size (r,)
|
65 |
+
assert diag.shape == (self.r,)
|
66 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
67 |
+
self.selector.weight.data = torch.diag(diag)
|
68 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
69 |
+
self.lora_up.weight.device
|
70 |
+
).to(self.lora_up.weight.dtype)
|
71 |
+
|
72 |
+
|
73 |
+
class LoraInjectedConv2d(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
in_channels: int,
|
77 |
+
out_channels: int,
|
78 |
+
kernel_size,
|
79 |
+
stride=1,
|
80 |
+
padding=0,
|
81 |
+
dilation=1,
|
82 |
+
groups: int = 1,
|
83 |
+
bias: bool = True,
|
84 |
+
r: int = 4,
|
85 |
+
dropout_p: float = 0.1,
|
86 |
+
scale: float = 1.0,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
if r > min(in_channels, out_channels):
|
90 |
+
raise ValueError(
|
91 |
+
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
|
92 |
+
)
|
93 |
+
self.r = r
|
94 |
+
self.conv = nn.Conv2d(
|
95 |
+
in_channels=in_channels,
|
96 |
+
out_channels=out_channels,
|
97 |
+
kernel_size=kernel_size,
|
98 |
+
stride=stride,
|
99 |
+
padding=padding,
|
100 |
+
dilation=dilation,
|
101 |
+
groups=groups,
|
102 |
+
bias=bias,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.lora_down = nn.Conv2d(
|
106 |
+
in_channels=in_channels,
|
107 |
+
out_channels=r,
|
108 |
+
kernel_size=kernel_size,
|
109 |
+
stride=stride,
|
110 |
+
padding=padding,
|
111 |
+
dilation=dilation,
|
112 |
+
groups=groups,
|
113 |
+
bias=False,
|
114 |
+
)
|
115 |
+
self.dropout = nn.Dropout(dropout_p)
|
116 |
+
self.lora_up = nn.Conv2d(
|
117 |
+
in_channels=r,
|
118 |
+
out_channels=out_channels,
|
119 |
+
kernel_size=1,
|
120 |
+
stride=1,
|
121 |
+
padding=0,
|
122 |
+
bias=False,
|
123 |
+
)
|
124 |
+
self.selector = nn.Identity()
|
125 |
+
self.scale = scale
|
126 |
+
|
127 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
128 |
+
nn.init.zeros_(self.lora_up.weight)
|
129 |
+
|
130 |
+
def forward(self, input):
|
131 |
+
return (
|
132 |
+
self.conv(input)
|
133 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
134 |
+
* self.scale
|
135 |
+
)
|
136 |
+
|
137 |
+
def realize_as_lora(self):
|
138 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
139 |
+
|
140 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
141 |
+
# diag is a 1D tensor of size (r,)
|
142 |
+
assert diag.shape == (self.r,)
|
143 |
+
self.selector = nn.Conv2d(
|
144 |
+
in_channels=self.r,
|
145 |
+
out_channels=self.r,
|
146 |
+
kernel_size=1,
|
147 |
+
stride=1,
|
148 |
+
padding=0,
|
149 |
+
bias=False,
|
150 |
+
)
|
151 |
+
self.selector.weight.data = torch.diag(diag)
|
152 |
+
|
153 |
+
# same device + dtype as lora_up
|
154 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
155 |
+
self.lora_up.weight.device
|
156 |
+
).to(self.lora_up.weight.dtype)
|
157 |
+
|
158 |
+
|
159 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
160 |
+
|
161 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
|
162 |
+
|
163 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
164 |
+
|
165 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
166 |
+
|
167 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
168 |
+
|
169 |
+
EMBED_FLAG = "<embed>"
|
170 |
+
|
171 |
+
|
172 |
+
def _find_children(
|
173 |
+
model,
|
174 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
175 |
+
):
|
176 |
+
"""
|
177 |
+
Find all modules of a certain class (or union of classes).
|
178 |
+
|
179 |
+
Returns all matching modules, along with the parent of those moduless and the
|
180 |
+
names they are referenced by.
|
181 |
+
"""
|
182 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
183 |
+
for parent in model.modules():
|
184 |
+
for name, module in parent.named_children():
|
185 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
186 |
+
yield parent, name, module
|
187 |
+
|
188 |
+
|
189 |
+
def _find_modules_v2(
|
190 |
+
model,
|
191 |
+
ancestor_class: Optional[Set[str]] = None,
|
192 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
193 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
194 |
+
LoraInjectedLinear,
|
195 |
+
LoraInjectedConv2d,
|
196 |
+
],
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
200 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
201 |
+
|
202 |
+
Returns all matching modules, along with the parent of those moduless and the
|
203 |
+
names they are referenced by.
|
204 |
+
"""
|
205 |
+
|
206 |
+
# Get the targets we should replace all linears under
|
207 |
+
if ancestor_class is not None:
|
208 |
+
ancestors = (
|
209 |
+
module
|
210 |
+
for module in model.modules()
|
211 |
+
if module.__class__.__name__ in ancestor_class
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
# this, incase you want to naively iterate over all modules.
|
215 |
+
ancestors = [module for module in model.modules()]
|
216 |
+
|
217 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
218 |
+
for ancestor in ancestors:
|
219 |
+
for fullname, module in ancestor.named_modules():
|
220 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
221 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
222 |
+
*path, name = fullname.split(".")
|
223 |
+
parent = ancestor
|
224 |
+
while path:
|
225 |
+
parent = parent.get_submodule(path.pop(0))
|
226 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
227 |
+
if exclude_children_of and any(
|
228 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
229 |
+
):
|
230 |
+
continue
|
231 |
+
# Otherwise, yield it
|
232 |
+
yield parent, name, module
|
233 |
+
|
234 |
+
|
235 |
+
def _find_modules_old(
|
236 |
+
model,
|
237 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
238 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
239 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
240 |
+
):
|
241 |
+
ret = []
|
242 |
+
for _module in model.modules():
|
243 |
+
if _module.__class__.__name__ in ancestor_class:
|
244 |
+
|
245 |
+
for name, _child_module in _module.named_modules():
|
246 |
+
if _child_module.__class__ in search_class:
|
247 |
+
ret.append((_module, name, _child_module))
|
248 |
+
print(ret)
|
249 |
+
return ret
|
250 |
+
|
251 |
+
|
252 |
+
_find_modules = _find_modules_v2
|
253 |
+
|
254 |
+
|
255 |
+
def inject_trainable_lora(
|
256 |
+
model: nn.Module,
|
257 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
258 |
+
r: int = 4,
|
259 |
+
loras=None, # path to lora .pt
|
260 |
+
verbose: bool = False,
|
261 |
+
dropout_p: float = 0.0,
|
262 |
+
scale: float = 1.0,
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
inject lora into model, and returns lora parameter groups.
|
266 |
+
"""
|
267 |
+
|
268 |
+
require_grad_params = []
|
269 |
+
names = []
|
270 |
+
|
271 |
+
if loras != None:
|
272 |
+
loras = torch.load(loras)
|
273 |
+
|
274 |
+
for _module, name, _child_module in _find_modules(
|
275 |
+
model, target_replace_module, search_class=[nn.Linear]
|
276 |
+
):
|
277 |
+
weight = _child_module.weight
|
278 |
+
bias = _child_module.bias
|
279 |
+
if verbose:
|
280 |
+
print("LoRA Injection : injecting lora into ", name)
|
281 |
+
print("LoRA Injection : weight shape", weight.shape)
|
282 |
+
_tmp = LoraInjectedLinear(
|
283 |
+
_child_module.in_features,
|
284 |
+
_child_module.out_features,
|
285 |
+
_child_module.bias is not None,
|
286 |
+
r=r,
|
287 |
+
dropout_p=dropout_p,
|
288 |
+
scale=scale,
|
289 |
+
)
|
290 |
+
_tmp.linear.weight = weight
|
291 |
+
if bias is not None:
|
292 |
+
_tmp.linear.bias = bias
|
293 |
+
|
294 |
+
# switch the module
|
295 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
296 |
+
_module._modules[name] = _tmp
|
297 |
+
|
298 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
299 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
300 |
+
|
301 |
+
if loras != None:
|
302 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
303 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
304 |
+
|
305 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
306 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
307 |
+
names.append(name)
|
308 |
+
|
309 |
+
return require_grad_params, names
|
310 |
+
|
311 |
+
|
312 |
+
def inject_trainable_lora_extended(
|
313 |
+
model: nn.Module,
|
314 |
+
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
315 |
+
r: int = 4,
|
316 |
+
loras=None, # path to lora .pt
|
317 |
+
):
|
318 |
+
"""
|
319 |
+
inject lora into model, and returns lora parameter groups.
|
320 |
+
"""
|
321 |
+
|
322 |
+
require_grad_params = []
|
323 |
+
names = []
|
324 |
+
|
325 |
+
if loras != None:
|
326 |
+
loras = torch.load(loras)
|
327 |
+
|
328 |
+
for _module, name, _child_module in _find_modules(
|
329 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
|
330 |
+
):
|
331 |
+
if _child_module.__class__ == nn.Linear:
|
332 |
+
weight = _child_module.weight
|
333 |
+
bias = _child_module.bias
|
334 |
+
_tmp = LoraInjectedLinear(
|
335 |
+
_child_module.in_features,
|
336 |
+
_child_module.out_features,
|
337 |
+
_child_module.bias is not None,
|
338 |
+
r=r,
|
339 |
+
)
|
340 |
+
_tmp.linear.weight = weight
|
341 |
+
if bias is not None:
|
342 |
+
_tmp.linear.bias = bias
|
343 |
+
elif _child_module.__class__ == nn.Conv2d:
|
344 |
+
weight = _child_module.weight
|
345 |
+
bias = _child_module.bias
|
346 |
+
_tmp = LoraInjectedConv2d(
|
347 |
+
_child_module.in_channels,
|
348 |
+
_child_module.out_channels,
|
349 |
+
_child_module.kernel_size,
|
350 |
+
_child_module.stride,
|
351 |
+
_child_module.padding,
|
352 |
+
_child_module.dilation,
|
353 |
+
_child_module.groups,
|
354 |
+
_child_module.bias is not None,
|
355 |
+
r=r,
|
356 |
+
)
|
357 |
+
|
358 |
+
_tmp.conv.weight = weight
|
359 |
+
if bias is not None:
|
360 |
+
_tmp.conv.bias = bias
|
361 |
+
|
362 |
+
# switch the module
|
363 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
364 |
+
if bias is not None:
|
365 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
366 |
+
|
367 |
+
_module._modules[name] = _tmp
|
368 |
+
|
369 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
370 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
371 |
+
|
372 |
+
if loras != None:
|
373 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
374 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
375 |
+
|
376 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
377 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
378 |
+
names.append(name)
|
379 |
+
|
380 |
+
return require_grad_params, names
|
381 |
+
|
382 |
+
|
383 |
+
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
384 |
+
|
385 |
+
loras = []
|
386 |
+
|
387 |
+
for _m, _n, _child_module in _find_modules(
|
388 |
+
model,
|
389 |
+
target_replace_module,
|
390 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
391 |
+
):
|
392 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
393 |
+
|
394 |
+
if len(loras) == 0:
|
395 |
+
raise ValueError("No lora injected.")
|
396 |
+
|
397 |
+
return loras
|
398 |
+
|
399 |
+
|
400 |
+
def extract_lora_as_tensor(
|
401 |
+
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
402 |
+
):
|
403 |
+
|
404 |
+
loras = []
|
405 |
+
|
406 |
+
for _m, _n, _child_module in _find_modules(
|
407 |
+
model,
|
408 |
+
target_replace_module,
|
409 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
410 |
+
):
|
411 |
+
up, down = _child_module.realize_as_lora()
|
412 |
+
if as_fp16:
|
413 |
+
up = up.to(torch.float16)
|
414 |
+
down = down.to(torch.float16)
|
415 |
+
|
416 |
+
loras.append((up, down))
|
417 |
+
|
418 |
+
if len(loras) == 0:
|
419 |
+
raise ValueError("No lora injected.")
|
420 |
+
|
421 |
+
return loras
|
422 |
+
|
423 |
+
|
424 |
+
def save_lora_weight(
|
425 |
+
model,
|
426 |
+
path="./lora.pt",
|
427 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
428 |
+
):
|
429 |
+
weights = []
|
430 |
+
for _up, _down in extract_lora_ups_down(
|
431 |
+
model, target_replace_module=target_replace_module
|
432 |
+
):
|
433 |
+
weights.append(_up.weight.to("cpu").to(torch.float16))
|
434 |
+
weights.append(_down.weight.to("cpu").to(torch.float16))
|
435 |
+
|
436 |
+
torch.save(weights, path)
|
437 |
+
|
438 |
+
|
439 |
+
def save_lora_as_json(model, path="./lora.json"):
|
440 |
+
weights = []
|
441 |
+
for _up, _down in extract_lora_ups_down(model):
|
442 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
443 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
444 |
+
|
445 |
+
import json
|
446 |
+
|
447 |
+
with open(path, "w") as f:
|
448 |
+
json.dump(weights, f)
|
449 |
+
|
450 |
+
|
451 |
+
def save_safeloras_with_embeds(
|
452 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
453 |
+
embeds: Dict[str, torch.Tensor] = {},
|
454 |
+
outpath="./lora.safetensors",
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Saves the Lora from multiple modules in a single safetensor file.
|
458 |
+
|
459 |
+
modelmap is a dictionary of {
|
460 |
+
"module name": (module, target_replace_module)
|
461 |
+
}
|
462 |
+
"""
|
463 |
+
weights = {}
|
464 |
+
metadata = {}
|
465 |
+
|
466 |
+
for name, (model, target_replace_module) in modelmap.items():
|
467 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
468 |
+
|
469 |
+
for i, (_up, _down) in enumerate(
|
470 |
+
extract_lora_as_tensor(model, target_replace_module)
|
471 |
+
):
|
472 |
+
rank = _down.shape[0]
|
473 |
+
|
474 |
+
metadata[f"{name}:{i}:rank"] = str(rank)
|
475 |
+
weights[f"{name}:{i}:up"] = _up
|
476 |
+
weights[f"{name}:{i}:down"] = _down
|
477 |
+
|
478 |
+
for token, tensor in embeds.items():
|
479 |
+
metadata[token] = EMBED_FLAG
|
480 |
+
weights[token] = tensor
|
481 |
+
|
482 |
+
print(f"Saving weights to {outpath}")
|
483 |
+
safe_save(weights, outpath, metadata)
|
484 |
+
|
485 |
+
|
486 |
+
def save_safeloras(
|
487 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
488 |
+
outpath="./lora.safetensors",
|
489 |
+
):
|
490 |
+
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
491 |
+
|
492 |
+
|
493 |
+
def convert_loras_to_safeloras_with_embeds(
|
494 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
495 |
+
embeds: Dict[str, torch.Tensor] = {},
|
496 |
+
outpath="./lora.safetensors",
|
497 |
+
):
|
498 |
+
"""
|
499 |
+
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
500 |
+
|
501 |
+
modelmap is a dictionary of {
|
502 |
+
"module name": (pytorch_model_path, target_replace_module, rank)
|
503 |
+
}
|
504 |
+
"""
|
505 |
+
|
506 |
+
weights = {}
|
507 |
+
metadata = {}
|
508 |
+
|
509 |
+
for name, (path, target_replace_module, r) in modelmap.items():
|
510 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
511 |
+
|
512 |
+
lora = torch.load(path)
|
513 |
+
for i, weight in enumerate(lora):
|
514 |
+
is_up = i % 2 == 0
|
515 |
+
i = i // 2
|
516 |
+
|
517 |
+
if is_up:
|
518 |
+
metadata[f"{name}:{i}:rank"] = str(r)
|
519 |
+
weights[f"{name}:{i}:up"] = weight
|
520 |
+
else:
|
521 |
+
weights[f"{name}:{i}:down"] = weight
|
522 |
+
|
523 |
+
for token, tensor in embeds.items():
|
524 |
+
metadata[token] = EMBED_FLAG
|
525 |
+
weights[token] = tensor
|
526 |
+
|
527 |
+
print(f"Saving weights to {outpath}")
|
528 |
+
safe_save(weights, outpath, metadata)
|
529 |
+
|
530 |
+
|
531 |
+
def convert_loras_to_safeloras(
|
532 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
533 |
+
outpath="./lora.safetensors",
|
534 |
+
):
|
535 |
+
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
536 |
+
|
537 |
+
|
538 |
+
def parse_safeloras(
|
539 |
+
safeloras,
|
540 |
+
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
541 |
+
"""
|
542 |
+
Converts a loaded safetensor file that contains a set of module Loras
|
543 |
+
into Parameters and other information
|
544 |
+
|
545 |
+
Output is a dictionary of {
|
546 |
+
"module name": (
|
547 |
+
[list of weights],
|
548 |
+
[list of ranks],
|
549 |
+
target_replacement_modules
|
550 |
+
)
|
551 |
+
}
|
552 |
+
"""
|
553 |
+
loras = {}
|
554 |
+
metadata = safeloras.metadata()
|
555 |
+
|
556 |
+
get_name = lambda k: k.split(":")[0]
|
557 |
+
|
558 |
+
keys = list(safeloras.keys())
|
559 |
+
keys.sort(key=get_name)
|
560 |
+
|
561 |
+
for name, module_keys in groupby(keys, get_name):
|
562 |
+
info = metadata.get(name)
|
563 |
+
|
564 |
+
if not info:
|
565 |
+
raise ValueError(
|
566 |
+
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
567 |
+
)
|
568 |
+
|
569 |
+
# Skip Textual Inversion embeds
|
570 |
+
if info == EMBED_FLAG:
|
571 |
+
continue
|
572 |
+
|
573 |
+
# Handle Loras
|
574 |
+
# Extract the targets
|
575 |
+
target = json.loads(info)
|
576 |
+
|
577 |
+
# Build the result lists - Python needs us to preallocate lists to insert into them
|
578 |
+
module_keys = list(module_keys)
|
579 |
+
ranks = [4] * (len(module_keys) // 2)
|
580 |
+
weights = [None] * len(module_keys)
|
581 |
+
|
582 |
+
for key in module_keys:
|
583 |
+
# Split the model name and index out of the key
|
584 |
+
_, idx, direction = key.split(":")
|
585 |
+
idx = int(idx)
|
586 |
+
|
587 |
+
# Add the rank
|
588 |
+
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
589 |
+
|
590 |
+
# Insert the weight into the list
|
591 |
+
idx = idx * 2 + (1 if direction == "down" else 0)
|
592 |
+
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
593 |
+
|
594 |
+
loras[name] = (weights, ranks, target)
|
595 |
+
|
596 |
+
return loras
|
597 |
+
|
598 |
+
|
599 |
+
def parse_safeloras_embeds(
|
600 |
+
safeloras,
|
601 |
+
) -> Dict[str, torch.Tensor]:
|
602 |
+
"""
|
603 |
+
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
604 |
+
a dictionary of embed_token: Tensor
|
605 |
+
"""
|
606 |
+
embeds = {}
|
607 |
+
metadata = safeloras.metadata()
|
608 |
+
|
609 |
+
for key in safeloras.keys():
|
610 |
+
# Only handle Textual Inversion embeds
|
611 |
+
meta = metadata.get(key)
|
612 |
+
if not meta or meta != EMBED_FLAG:
|
613 |
+
continue
|
614 |
+
|
615 |
+
embeds[key] = safeloras.get_tensor(key)
|
616 |
+
|
617 |
+
return embeds
|
618 |
+
|
619 |
+
|
620 |
+
def load_safeloras(path, device="cpu"):
|
621 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
622 |
+
return parse_safeloras(safeloras)
|
623 |
+
|
624 |
+
|
625 |
+
def load_safeloras_embeds(path, device="cpu"):
|
626 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
627 |
+
return parse_safeloras_embeds(safeloras)
|
628 |
+
|
629 |
+
|
630 |
+
def load_safeloras_both(path, device="cpu"):
|
631 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
632 |
+
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
633 |
+
|
634 |
+
|
635 |
+
def collapse_lora(model, alpha=1.0):
|
636 |
+
|
637 |
+
for _module, name, _child_module in _find_modules(
|
638 |
+
model,
|
639 |
+
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
640 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
641 |
+
):
|
642 |
+
|
643 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
644 |
+
print("Collapsing Lin Lora in", name)
|
645 |
+
|
646 |
+
_child_module.linear.weight = nn.Parameter(
|
647 |
+
_child_module.linear.weight.data
|
648 |
+
+ alpha
|
649 |
+
* (
|
650 |
+
_child_module.lora_up.weight.data
|
651 |
+
@ _child_module.lora_down.weight.data
|
652 |
+
)
|
653 |
+
.type(_child_module.linear.weight.dtype)
|
654 |
+
.to(_child_module.linear.weight.device)
|
655 |
+
)
|
656 |
+
|
657 |
+
else:
|
658 |
+
print("Collapsing Conv Lora in", name)
|
659 |
+
_child_module.conv.weight = nn.Parameter(
|
660 |
+
_child_module.conv.weight.data
|
661 |
+
+ alpha
|
662 |
+
* (
|
663 |
+
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
664 |
+
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
665 |
+
)
|
666 |
+
.reshape(_child_module.conv.weight.data.shape)
|
667 |
+
.type(_child_module.conv.weight.dtype)
|
668 |
+
.to(_child_module.conv.weight.device)
|
669 |
+
)
|
670 |
+
|
671 |
+
|
672 |
+
def monkeypatch_or_replace_lora(
|
673 |
+
model,
|
674 |
+
loras,
|
675 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
676 |
+
r: Union[int, List[int]] = 4,
|
677 |
+
):
|
678 |
+
for _module, name, _child_module in _find_modules(
|
679 |
+
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
680 |
+
):
|
681 |
+
_source = (
|
682 |
+
_child_module.linear
|
683 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
684 |
+
else _child_module
|
685 |
+
)
|
686 |
+
|
687 |
+
weight = _source.weight
|
688 |
+
bias = _source.bias
|
689 |
+
_tmp = LoraInjectedLinear(
|
690 |
+
_source.in_features,
|
691 |
+
_source.out_features,
|
692 |
+
_source.bias is not None,
|
693 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
694 |
+
)
|
695 |
+
_tmp.linear.weight = weight
|
696 |
+
|
697 |
+
if bias is not None:
|
698 |
+
_tmp.linear.bias = bias
|
699 |
+
|
700 |
+
# switch the module
|
701 |
+
_module._modules[name] = _tmp
|
702 |
+
|
703 |
+
up_weight = loras.pop(0)
|
704 |
+
down_weight = loras.pop(0)
|
705 |
+
|
706 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
707 |
+
up_weight.type(weight.dtype)
|
708 |
+
)
|
709 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
710 |
+
down_weight.type(weight.dtype)
|
711 |
+
)
|
712 |
+
|
713 |
+
_module._modules[name].to(weight.device)
|
714 |
+
|
715 |
+
|
716 |
+
def monkeypatch_or_replace_lora_extended(
|
717 |
+
model,
|
718 |
+
loras,
|
719 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
720 |
+
r: Union[int, List[int]] = 4,
|
721 |
+
):
|
722 |
+
for _module, name, _child_module in _find_modules(
|
723 |
+
model,
|
724 |
+
target_replace_module,
|
725 |
+
search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
|
726 |
+
):
|
727 |
+
|
728 |
+
if (_child_module.__class__ == nn.Linear) or (
|
729 |
+
_child_module.__class__ == LoraInjectedLinear
|
730 |
+
):
|
731 |
+
if len(loras[0].shape) != 2:
|
732 |
+
continue
|
733 |
+
|
734 |
+
_source = (
|
735 |
+
_child_module.linear
|
736 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
737 |
+
else _child_module
|
738 |
+
)
|
739 |
+
|
740 |
+
weight = _source.weight
|
741 |
+
bias = _source.bias
|
742 |
+
_tmp = LoraInjectedLinear(
|
743 |
+
_source.in_features,
|
744 |
+
_source.out_features,
|
745 |
+
_source.bias is not None,
|
746 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
747 |
+
)
|
748 |
+
_tmp.linear.weight = weight
|
749 |
+
|
750 |
+
if bias is not None:
|
751 |
+
_tmp.linear.bias = bias
|
752 |
+
|
753 |
+
elif (_child_module.__class__ == nn.Conv2d) or (
|
754 |
+
_child_module.__class__ == LoraInjectedConv2d
|
755 |
+
):
|
756 |
+
if len(loras[0].shape) != 4:
|
757 |
+
continue
|
758 |
+
_source = (
|
759 |
+
_child_module.conv
|
760 |
+
if isinstance(_child_module, LoraInjectedConv2d)
|
761 |
+
else _child_module
|
762 |
+
)
|
763 |
+
|
764 |
+
weight = _source.weight
|
765 |
+
bias = _source.bias
|
766 |
+
_tmp = LoraInjectedConv2d(
|
767 |
+
_source.in_channels,
|
768 |
+
_source.out_channels,
|
769 |
+
_source.kernel_size,
|
770 |
+
_source.stride,
|
771 |
+
_source.padding,
|
772 |
+
_source.dilation,
|
773 |
+
_source.groups,
|
774 |
+
_source.bias is not None,
|
775 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
776 |
+
)
|
777 |
+
|
778 |
+
_tmp.conv.weight = weight
|
779 |
+
|
780 |
+
if bias is not None:
|
781 |
+
_tmp.conv.bias = bias
|
782 |
+
|
783 |
+
# switch the module
|
784 |
+
_module._modules[name] = _tmp
|
785 |
+
|
786 |
+
up_weight = loras.pop(0)
|
787 |
+
down_weight = loras.pop(0)
|
788 |
+
|
789 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
790 |
+
up_weight.type(weight.dtype)
|
791 |
+
)
|
792 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
793 |
+
down_weight.type(weight.dtype)
|
794 |
+
)
|
795 |
+
|
796 |
+
_module._modules[name].to(weight.device)
|
797 |
+
|
798 |
+
|
799 |
+
def monkeypatch_or_replace_safeloras(models, safeloras):
|
800 |
+
loras = parse_safeloras(safeloras)
|
801 |
+
|
802 |
+
for name, (lora, ranks, target) in loras.items():
|
803 |
+
model = getattr(models, name, None)
|
804 |
+
|
805 |
+
if not model:
|
806 |
+
print(f"No model provided for {name}, contained in Lora")
|
807 |
+
continue
|
808 |
+
|
809 |
+
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
810 |
+
|
811 |
+
|
812 |
+
def monkeypatch_remove_lora(model):
|
813 |
+
for _module, name, _child_module in _find_modules(
|
814 |
+
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
|
815 |
+
):
|
816 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
817 |
+
_source = _child_module.linear
|
818 |
+
weight, bias = _source.weight, _source.bias
|
819 |
+
|
820 |
+
_tmp = nn.Linear(
|
821 |
+
_source.in_features, _source.out_features, bias is not None
|
822 |
+
)
|
823 |
+
|
824 |
+
_tmp.weight = weight
|
825 |
+
if bias is not None:
|
826 |
+
_tmp.bias = bias
|
827 |
+
|
828 |
+
else:
|
829 |
+
_source = _child_module.conv
|
830 |
+
weight, bias = _source.weight, _source.bias
|
831 |
+
|
832 |
+
_tmp = nn.Conv2d(
|
833 |
+
in_channels=_source.in_channels,
|
834 |
+
out_channels=_source.out_channels,
|
835 |
+
kernel_size=_source.kernel_size,
|
836 |
+
stride=_source.stride,
|
837 |
+
padding=_source.padding,
|
838 |
+
dilation=_source.dilation,
|
839 |
+
groups=_source.groups,
|
840 |
+
bias=bias is not None,
|
841 |
+
)
|
842 |
+
|
843 |
+
_tmp.weight = weight
|
844 |
+
if bias is not None:
|
845 |
+
_tmp.bias = bias
|
846 |
+
|
847 |
+
_module._modules[name] = _tmp
|
848 |
+
|
849 |
+
|
850 |
+
def monkeypatch_add_lora(
|
851 |
+
model,
|
852 |
+
loras,
|
853 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
854 |
+
alpha: float = 1.0,
|
855 |
+
beta: float = 1.0,
|
856 |
+
):
|
857 |
+
for _module, name, _child_module in _find_modules(
|
858 |
+
model, target_replace_module, search_class=[LoraInjectedLinear]
|
859 |
+
):
|
860 |
+
weight = _child_module.linear.weight
|
861 |
+
|
862 |
+
up_weight = loras.pop(0)
|
863 |
+
down_weight = loras.pop(0)
|
864 |
+
|
865 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
866 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
867 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
868 |
+
)
|
869 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
870 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
871 |
+
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
872 |
+
)
|
873 |
+
|
874 |
+
_module._modules[name].to(weight.device)
|
875 |
+
|
876 |
+
|
877 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
878 |
+
for _module in model.modules():
|
879 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
880 |
+
_module.scale = alpha
|
881 |
+
|
882 |
+
|
883 |
+
def set_lora_diag(model, diag: torch.Tensor):
|
884 |
+
for _module in model.modules():
|
885 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
886 |
+
_module.set_selector_from_diag(diag)
|
887 |
+
|
888 |
+
|
889 |
+
def _text_lora_path(path: str) -> str:
|
890 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
891 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
892 |
+
|
893 |
+
|
894 |
+
def _ti_lora_path(path: str) -> str:
|
895 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
896 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
897 |
+
|
898 |
+
|
899 |
+
def apply_learned_embed_in_clip(
|
900 |
+
learned_embeds,
|
901 |
+
text_encoder,
|
902 |
+
tokenizer,
|
903 |
+
token: Optional[Union[str, List[str]]] = None,
|
904 |
+
idempotent=False,
|
905 |
+
):
|
906 |
+
if isinstance(token, str):
|
907 |
+
trained_tokens = [token]
|
908 |
+
elif isinstance(token, list):
|
909 |
+
assert len(learned_embeds.keys()) == len(
|
910 |
+
token
|
911 |
+
), "The number of tokens and the number of embeds should be the same"
|
912 |
+
trained_tokens = token
|
913 |
+
else:
|
914 |
+
trained_tokens = list(learned_embeds.keys())
|
915 |
+
|
916 |
+
for token in trained_tokens:
|
917 |
+
print(token)
|
918 |
+
embeds = learned_embeds[token]
|
919 |
+
|
920 |
+
# cast to dtype of text_encoder
|
921 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
922 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
923 |
+
|
924 |
+
i = 1
|
925 |
+
if not idempotent:
|
926 |
+
while num_added_tokens == 0:
|
927 |
+
print(f"The tokenizer already contains the token {token}.")
|
928 |
+
token = f"{token[:-1]}-{i}>"
|
929 |
+
print(f"Attempting to add the token {token}.")
|
930 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
931 |
+
i += 1
|
932 |
+
elif num_added_tokens == 0 and idempotent:
|
933 |
+
print(f"The tokenizer already contains the token {token}.")
|
934 |
+
print(f"Replacing {token} embedding.")
|
935 |
+
|
936 |
+
# resize the token embeddings
|
937 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
938 |
+
|
939 |
+
# get the id for the token and assign the embeds
|
940 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
941 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
942 |
+
return token
|
943 |
+
|
944 |
+
|
945 |
+
def load_learned_embed_in_clip(
|
946 |
+
learned_embeds_path,
|
947 |
+
text_encoder,
|
948 |
+
tokenizer,
|
949 |
+
token: Optional[Union[str, List[str]]] = None,
|
950 |
+
idempotent=False,
|
951 |
+
):
|
952 |
+
learned_embeds = torch.load(learned_embeds_path)
|
953 |
+
apply_learned_embed_in_clip(
|
954 |
+
learned_embeds, text_encoder, tokenizer, token, idempotent
|
955 |
+
)
|
956 |
+
|
957 |
+
|
958 |
+
def patch_pipe(
|
959 |
+
pipe,
|
960 |
+
maybe_unet_path,
|
961 |
+
token: Optional[str] = None,
|
962 |
+
r: int = 4,
|
963 |
+
patch_unet=True,
|
964 |
+
patch_text=True,
|
965 |
+
patch_ti=True,
|
966 |
+
idempotent_token=True,
|
967 |
+
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
968 |
+
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
969 |
+
):
|
970 |
+
if maybe_unet_path.endswith(".pt"):
|
971 |
+
# torch format
|
972 |
+
|
973 |
+
if maybe_unet_path.endswith(".ti.pt"):
|
974 |
+
unet_path = maybe_unet_path[:-6] + ".pt"
|
975 |
+
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
976 |
+
unet_path = maybe_unet_path[:-16] + ".pt"
|
977 |
+
else:
|
978 |
+
unet_path = maybe_unet_path
|
979 |
+
|
980 |
+
ti_path = _ti_lora_path(unet_path)
|
981 |
+
text_path = _text_lora_path(unet_path)
|
982 |
+
|
983 |
+
if patch_unet:
|
984 |
+
print("LoRA : Patching Unet")
|
985 |
+
monkeypatch_or_replace_lora(
|
986 |
+
pipe.unet,
|
987 |
+
torch.load(unet_path),
|
988 |
+
r=r,
|
989 |
+
target_replace_module=unet_target_replace_module,
|
990 |
+
)
|
991 |
+
|
992 |
+
if patch_text:
|
993 |
+
print("LoRA : Patching text encoder")
|
994 |
+
monkeypatch_or_replace_lora(
|
995 |
+
pipe.text_encoder,
|
996 |
+
torch.load(text_path),
|
997 |
+
target_replace_module=text_target_replace_module,
|
998 |
+
r=r,
|
999 |
+
)
|
1000 |
+
if patch_ti:
|
1001 |
+
print("LoRA : Patching token input")
|
1002 |
+
token = load_learned_embed_in_clip(
|
1003 |
+
ti_path,
|
1004 |
+
pipe.text_encoder,
|
1005 |
+
pipe.tokenizer,
|
1006 |
+
token=token,
|
1007 |
+
idempotent=idempotent_token,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
elif maybe_unet_path.endswith(".safetensors"):
|
1011 |
+
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
1012 |
+
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
1013 |
+
tok_dict = parse_safeloras_embeds(safeloras)
|
1014 |
+
if patch_ti:
|
1015 |
+
apply_learned_embed_in_clip(
|
1016 |
+
tok_dict,
|
1017 |
+
pipe.text_encoder,
|
1018 |
+
pipe.tokenizer,
|
1019 |
+
token=token,
|
1020 |
+
idempotent=idempotent_token,
|
1021 |
+
)
|
1022 |
+
return tok_dict
|
1023 |
+
|
1024 |
+
|
1025 |
+
@torch.no_grad()
|
1026 |
+
def inspect_lora(model):
|
1027 |
+
moved = {}
|
1028 |
+
|
1029 |
+
for name, _module in model.named_modules():
|
1030 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
1031 |
+
ups = _module.lora_up.weight.data.clone()
|
1032 |
+
downs = _module.lora_down.weight.data.clone()
|
1033 |
+
|
1034 |
+
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
1035 |
+
|
1036 |
+
dist = wght.flatten().abs().mean().item()
|
1037 |
+
if name in moved:
|
1038 |
+
moved[name].append(dist)
|
1039 |
+
else:
|
1040 |
+
moved[name] = [dist]
|
1041 |
+
|
1042 |
+
return moved
|
1043 |
+
|
1044 |
+
|
1045 |
+
def save_all(
|
1046 |
+
unet,
|
1047 |
+
text_encoder,
|
1048 |
+
save_path,
|
1049 |
+
placeholder_token_ids=None,
|
1050 |
+
placeholder_tokens=None,
|
1051 |
+
save_lora=True,
|
1052 |
+
save_ti=True,
|
1053 |
+
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
1054 |
+
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
1055 |
+
safe_form=True,
|
1056 |
+
):
|
1057 |
+
if not safe_form:
|
1058 |
+
# save ti
|
1059 |
+
if save_ti:
|
1060 |
+
ti_path = _ti_lora_path(save_path)
|
1061 |
+
learned_embeds_dict = {}
|
1062 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
1063 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
1064 |
+
print(
|
1065 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
1066 |
+
learned_embeds[:4],
|
1067 |
+
)
|
1068 |
+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
1069 |
+
|
1070 |
+
torch.save(learned_embeds_dict, ti_path)
|
1071 |
+
print("Ti saved to ", ti_path)
|
1072 |
+
|
1073 |
+
# save text encoder
|
1074 |
+
if save_lora:
|
1075 |
+
|
1076 |
+
save_lora_weight(
|
1077 |
+
unet, save_path, target_replace_module=target_replace_module_unet
|
1078 |
+
)
|
1079 |
+
print("Unet saved to ", save_path)
|
1080 |
+
|
1081 |
+
save_lora_weight(
|
1082 |
+
text_encoder,
|
1083 |
+
_text_lora_path(save_path),
|
1084 |
+
target_replace_module=target_replace_module_text,
|
1085 |
+
)
|
1086 |
+
print("Text Encoder saved to ", _text_lora_path(save_path))
|
1087 |
+
|
1088 |
+
else:
|
1089 |
+
assert save_path.endswith(
|
1090 |
+
".safetensors"
|
1091 |
+
), f"Save path : {save_path} should end with .safetensors"
|
1092 |
+
|
1093 |
+
loras = {}
|
1094 |
+
embeds = {}
|
1095 |
+
|
1096 |
+
if save_lora:
|
1097 |
+
|
1098 |
+
loras["unet"] = (unet, target_replace_module_unet)
|
1099 |
+
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
1100 |
+
|
1101 |
+
if save_ti:
|
1102 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
1103 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
1104 |
+
print(
|
1105 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
1106 |
+
learned_embeds[:4],
|
1107 |
+
)
|
1108 |
+
embeds[tok] = learned_embeds.detach().cpu()
|
1109 |
+
|
1110 |
+
save_safeloras_with_embeds(loras, embeds, save_path)
|
lora_diffusion/lora_manager.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
from safetensors import safe_open
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
from .lora import (
|
6 |
+
monkeypatch_or_replace_safeloras,
|
7 |
+
apply_learned_embed_in_clip,
|
8 |
+
set_lora_diag,
|
9 |
+
parse_safeloras_embeds,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def lora_join(lora_safetenors: list):
|
14 |
+
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
|
15 |
+
_total_metadata = {}
|
16 |
+
total_metadata = {}
|
17 |
+
total_tensor = {}
|
18 |
+
total_rank = 0
|
19 |
+
ranklist = []
|
20 |
+
for _metadata in metadatas:
|
21 |
+
rankset = []
|
22 |
+
for k, v in _metadata.items():
|
23 |
+
if k.endswith("rank"):
|
24 |
+
rankset.append(int(v))
|
25 |
+
|
26 |
+
assert len(set(rankset)) <= 1, "Rank should be the same per model"
|
27 |
+
if len(rankset) == 0:
|
28 |
+
rankset = [0]
|
29 |
+
|
30 |
+
total_rank += rankset[0]
|
31 |
+
_total_metadata.update(_metadata)
|
32 |
+
ranklist.append(rankset[0])
|
33 |
+
|
34 |
+
# remove metadata about tokens
|
35 |
+
for k, v in _total_metadata.items():
|
36 |
+
if v != "<embed>":
|
37 |
+
total_metadata[k] = v
|
38 |
+
|
39 |
+
tensorkeys = set()
|
40 |
+
for safelora in lora_safetenors:
|
41 |
+
tensorkeys.update(safelora.keys())
|
42 |
+
|
43 |
+
for keys in tensorkeys:
|
44 |
+
if keys.startswith("text_encoder") or keys.startswith("unet"):
|
45 |
+
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
|
46 |
+
|
47 |
+
is_down = keys.endswith("down")
|
48 |
+
|
49 |
+
if is_down:
|
50 |
+
_tensor = torch.cat(tensorset, dim=0)
|
51 |
+
assert _tensor.shape[0] == total_rank
|
52 |
+
else:
|
53 |
+
_tensor = torch.cat(tensorset, dim=1)
|
54 |
+
assert _tensor.shape[1] == total_rank
|
55 |
+
|
56 |
+
total_tensor[keys] = _tensor
|
57 |
+
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
|
58 |
+
total_metadata[keys_rank] = str(total_rank)
|
59 |
+
token_size_list = []
|
60 |
+
for idx, safelora in enumerate(lora_safetenors):
|
61 |
+
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
|
62 |
+
for jdx, token in enumerate(sorted(tokens)):
|
63 |
+
|
64 |
+
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
|
65 |
+
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
|
66 |
+
|
67 |
+
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
|
68 |
+
|
69 |
+
token_size_list.append(len(tokens))
|
70 |
+
|
71 |
+
return total_tensor, total_metadata, ranklist, token_size_list
|
72 |
+
|
73 |
+
|
74 |
+
class DummySafeTensorObject:
|
75 |
+
def __init__(self, tensor: dict, metadata):
|
76 |
+
self.tensor = tensor
|
77 |
+
self._metadata = metadata
|
78 |
+
|
79 |
+
def keys(self):
|
80 |
+
return self.tensor.keys()
|
81 |
+
|
82 |
+
def metadata(self):
|
83 |
+
return self._metadata
|
84 |
+
|
85 |
+
def get_tensor(self, key):
|
86 |
+
return self.tensor[key]
|
87 |
+
|
88 |
+
|
89 |
+
class LoRAManager:
|
90 |
+
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
|
91 |
+
|
92 |
+
self.lora_paths_list = lora_paths_list
|
93 |
+
self.pipe = pipe
|
94 |
+
self._setup()
|
95 |
+
|
96 |
+
def _setup(self):
|
97 |
+
|
98 |
+
self._lora_safetenors = [
|
99 |
+
safe_open(path, framework="pt", device="cpu")
|
100 |
+
for path in self.lora_paths_list
|
101 |
+
]
|
102 |
+
|
103 |
+
(
|
104 |
+
total_tensor,
|
105 |
+
total_metadata,
|
106 |
+
self.ranklist,
|
107 |
+
self.token_size_list,
|
108 |
+
) = lora_join(self._lora_safetenors)
|
109 |
+
|
110 |
+
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
|
111 |
+
|
112 |
+
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
|
113 |
+
tok_dict = parse_safeloras_embeds(self.total_safelora)
|
114 |
+
|
115 |
+
apply_learned_embed_in_clip(
|
116 |
+
tok_dict,
|
117 |
+
self.pipe.text_encoder,
|
118 |
+
self.pipe.tokenizer,
|
119 |
+
token=None,
|
120 |
+
idempotent=True,
|
121 |
+
)
|
122 |
+
|
123 |
+
def tune(self, scales):
|
124 |
+
|
125 |
+
assert len(scales) == len(
|
126 |
+
self.ranklist
|
127 |
+
), "Scale list should be the same length as ranklist"
|
128 |
+
|
129 |
+
diags = []
|
130 |
+
for scale, rank in zip(scales, self.ranklist):
|
131 |
+
diags = diags + [scale] * rank
|
132 |
+
|
133 |
+
set_lora_diag(self.pipe.unet, torch.tensor(diags))
|
134 |
+
|
135 |
+
def prompt(self, prompt):
|
136 |
+
if prompt is not None:
|
137 |
+
for idx, tok_size in enumerate(self.token_size_list):
|
138 |
+
prompt = prompt.replace(
|
139 |
+
f"<{idx + 1}>",
|
140 |
+
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
|
141 |
+
)
|
142 |
+
# TODO : Rescale LoRA + Text inputs based on prompt scale params
|
143 |
+
|
144 |
+
return prompt
|
lora_diffusion/preprocess_files.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Have SwinIR upsample
|
2 |
+
# Have BLIP auto caption
|
3 |
+
# Have CLIPSeg auto mask concept
|
4 |
+
|
5 |
+
from typing import List, Literal, Union, Optional, Tuple
|
6 |
+
import os
|
7 |
+
from PIL import Image, ImageFilter
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import fire
|
11 |
+
from tqdm import tqdm
|
12 |
+
import glob
|
13 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def swin_ir_sr(
|
18 |
+
images: List[Image.Image],
|
19 |
+
model_id: Literal[
|
20 |
+
"caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48"
|
21 |
+
] = "caidas/swin2SR-classical-sr-x2-64",
|
22 |
+
target_size: Optional[Tuple[int, int]] = None,
|
23 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
24 |
+
**kwargs,
|
25 |
+
) -> List[Image.Image]:
|
26 |
+
"""
|
27 |
+
Upscales images using SwinIR. Returns a list of PIL images.
|
28 |
+
"""
|
29 |
+
# So this is currently in main branch, so this can be used in the future I guess?
|
30 |
+
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
|
31 |
+
|
32 |
+
model = Swin2SRForImageSuperResolution.from_pretrained(
|
33 |
+
model_id,
|
34 |
+
).to(device)
|
35 |
+
processor = Swin2SRImageProcessor()
|
36 |
+
|
37 |
+
out_images = []
|
38 |
+
|
39 |
+
for image in tqdm(images):
|
40 |
+
|
41 |
+
ori_w, ori_h = image.size
|
42 |
+
if target_size is not None:
|
43 |
+
if ori_w >= target_size[0] and ori_h >= target_size[1]:
|
44 |
+
out_images.append(image)
|
45 |
+
continue
|
46 |
+
|
47 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
48 |
+
with torch.no_grad():
|
49 |
+
outputs = model(**inputs)
|
50 |
+
|
51 |
+
output = (
|
52 |
+
outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
53 |
+
)
|
54 |
+
output = np.moveaxis(output, source=0, destination=-1)
|
55 |
+
output = (output * 255.0).round().astype(np.uint8)
|
56 |
+
output = Image.fromarray(output)
|
57 |
+
|
58 |
+
out_images.append(output)
|
59 |
+
|
60 |
+
return out_images
|
61 |
+
|
62 |
+
|
63 |
+
@torch.no_grad()
|
64 |
+
def clipseg_mask_generator(
|
65 |
+
images: List[Image.Image],
|
66 |
+
target_prompts: Union[List[str], str],
|
67 |
+
model_id: Literal[
|
68 |
+
"CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16"
|
69 |
+
] = "CIDAS/clipseg-rd64-refined",
|
70 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
71 |
+
bias: float = 0.01,
|
72 |
+
temp: float = 1.0,
|
73 |
+
**kwargs,
|
74 |
+
) -> List[Image.Image]:
|
75 |
+
"""
|
76 |
+
Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
|
77 |
+
"""
|
78 |
+
|
79 |
+
if isinstance(target_prompts, str):
|
80 |
+
print(
|
81 |
+
f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
|
82 |
+
)
|
83 |
+
|
84 |
+
target_prompts = [target_prompts] * len(images)
|
85 |
+
|
86 |
+
processor = CLIPSegProcessor.from_pretrained(model_id)
|
87 |
+
model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device)
|
88 |
+
|
89 |
+
masks = []
|
90 |
+
|
91 |
+
for image, prompt in tqdm(zip(images, target_prompts)):
|
92 |
+
|
93 |
+
original_size = image.size
|
94 |
+
|
95 |
+
inputs = processor(
|
96 |
+
text=[prompt, ""],
|
97 |
+
images=[image] * 2,
|
98 |
+
padding="max_length",
|
99 |
+
truncation=True,
|
100 |
+
return_tensors="pt",
|
101 |
+
).to(device)
|
102 |
+
|
103 |
+
outputs = model(**inputs)
|
104 |
+
|
105 |
+
logits = outputs.logits
|
106 |
+
probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
|
107 |
+
probs = (probs + bias).clamp_(0, 1)
|
108 |
+
probs = 255 * probs / probs.max()
|
109 |
+
|
110 |
+
# make mask greyscale
|
111 |
+
mask = Image.fromarray(probs.cpu().numpy()).convert("L")
|
112 |
+
|
113 |
+
# resize mask to original size
|
114 |
+
mask = mask.resize(original_size)
|
115 |
+
|
116 |
+
masks.append(mask)
|
117 |
+
|
118 |
+
return masks
|
119 |
+
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def blip_captioning_dataset(
|
123 |
+
images: List[Image.Image],
|
124 |
+
text: Optional[str] = None,
|
125 |
+
model_id: Literal[
|
126 |
+
"Salesforce/blip-image-captioning-large",
|
127 |
+
"Salesforce/blip-image-captioning-base",
|
128 |
+
] = "Salesforce/blip-image-captioning-large",
|
129 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
130 |
+
**kwargs,
|
131 |
+
) -> List[str]:
|
132 |
+
"""
|
133 |
+
Returns a list of captions for the given images
|
134 |
+
"""
|
135 |
+
|
136 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
137 |
+
|
138 |
+
processor = BlipProcessor.from_pretrained(model_id)
|
139 |
+
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
|
140 |
+
captions = []
|
141 |
+
|
142 |
+
for image in tqdm(images):
|
143 |
+
inputs = processor(image, text=text, return_tensors="pt").to("cuda")
|
144 |
+
out = model.generate(
|
145 |
+
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
|
146 |
+
)
|
147 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
148 |
+
|
149 |
+
captions.append(caption)
|
150 |
+
|
151 |
+
return captions
|
152 |
+
|
153 |
+
|
154 |
+
def face_mask_google_mediapipe(
|
155 |
+
images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05
|
156 |
+
) -> List[Image.Image]:
|
157 |
+
"""
|
158 |
+
Returns a list of images with mask on the face parts.
|
159 |
+
"""
|
160 |
+
import mediapipe as mp
|
161 |
+
|
162 |
+
mp_face_detection = mp.solutions.face_detection
|
163 |
+
|
164 |
+
face_detection = mp_face_detection.FaceDetection(
|
165 |
+
model_selection=1, min_detection_confidence=0.5
|
166 |
+
)
|
167 |
+
|
168 |
+
masks = []
|
169 |
+
for image in tqdm(images):
|
170 |
+
|
171 |
+
image = np.array(image)
|
172 |
+
|
173 |
+
results = face_detection.process(image)
|
174 |
+
black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8)
|
175 |
+
|
176 |
+
if results.detections:
|
177 |
+
|
178 |
+
for detection in results.detections:
|
179 |
+
|
180 |
+
x_min = int(
|
181 |
+
detection.location_data.relative_bounding_box.xmin * image.shape[1]
|
182 |
+
)
|
183 |
+
y_min = int(
|
184 |
+
detection.location_data.relative_bounding_box.ymin * image.shape[0]
|
185 |
+
)
|
186 |
+
width = int(
|
187 |
+
detection.location_data.relative_bounding_box.width * image.shape[1]
|
188 |
+
)
|
189 |
+
height = int(
|
190 |
+
detection.location_data.relative_bounding_box.height
|
191 |
+
* image.shape[0]
|
192 |
+
)
|
193 |
+
|
194 |
+
# draw the colored rectangle
|
195 |
+
black_image[y_min : y_min + height, x_min : x_min + width] = 255
|
196 |
+
|
197 |
+
black_image = Image.fromarray(black_image)
|
198 |
+
masks.append(black_image)
|
199 |
+
|
200 |
+
return masks
|
201 |
+
|
202 |
+
|
203 |
+
def _crop_to_square(
|
204 |
+
image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
|
205 |
+
):
|
206 |
+
cx, cy = com
|
207 |
+
width, height = image.size
|
208 |
+
if width > height:
|
209 |
+
left_possible = max(cx - height / 2, 0)
|
210 |
+
left = min(left_possible, width - height)
|
211 |
+
right = left + height
|
212 |
+
top = 0
|
213 |
+
bottom = height
|
214 |
+
else:
|
215 |
+
left = 0
|
216 |
+
right = width
|
217 |
+
top_possible = max(cy - width / 2, 0)
|
218 |
+
top = min(top_possible, height - width)
|
219 |
+
bottom = top + width
|
220 |
+
|
221 |
+
image = image.crop((left, top, right, bottom))
|
222 |
+
|
223 |
+
if resize_to:
|
224 |
+
image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
|
225 |
+
|
226 |
+
return image
|
227 |
+
|
228 |
+
|
229 |
+
def _center_of_mass(mask: Image.Image):
|
230 |
+
"""
|
231 |
+
Returns the center of mass of the mask
|
232 |
+
"""
|
233 |
+
x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
|
234 |
+
|
235 |
+
x_ = x * np.array(mask)
|
236 |
+
y_ = y * np.array(mask)
|
237 |
+
|
238 |
+
x = np.sum(x_) / np.sum(mask)
|
239 |
+
y = np.sum(y_) / np.sum(mask)
|
240 |
+
|
241 |
+
return x, y
|
242 |
+
|
243 |
+
|
244 |
+
def load_and_save_masks_and_captions(
|
245 |
+
files: Union[str, List[str]],
|
246 |
+
output_dir: str,
|
247 |
+
caption_text: Optional[str] = None,
|
248 |
+
target_prompts: Optional[Union[List[str], str]] = None,
|
249 |
+
target_size: int = 512,
|
250 |
+
crop_based_on_salience: bool = True,
|
251 |
+
use_face_detection_instead: bool = False,
|
252 |
+
temp: float = 1.0,
|
253 |
+
n_length: int = -1,
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
|
257 |
+
to output dir.
|
258 |
+
"""
|
259 |
+
os.makedirs(output_dir, exist_ok=True)
|
260 |
+
|
261 |
+
# load images
|
262 |
+
if isinstance(files, str):
|
263 |
+
# check if it is a directory
|
264 |
+
if os.path.isdir(files):
|
265 |
+
# get all the .png .jpg in the directory
|
266 |
+
files = glob.glob(os.path.join(files, "*.png")) + glob.glob(
|
267 |
+
os.path.join(files, "*.jpg")
|
268 |
+
)
|
269 |
+
|
270 |
+
if len(files) == 0:
|
271 |
+
raise Exception(
|
272 |
+
f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files."
|
273 |
+
)
|
274 |
+
if n_length == -1:
|
275 |
+
n_length = len(files)
|
276 |
+
files = sorted(files)[:n_length]
|
277 |
+
|
278 |
+
images = [Image.open(file) for file in files]
|
279 |
+
|
280 |
+
# captions
|
281 |
+
print(f"Generating {len(images)} captions...")
|
282 |
+
captions = blip_captioning_dataset(images, text=caption_text)
|
283 |
+
|
284 |
+
if target_prompts is None:
|
285 |
+
target_prompts = captions
|
286 |
+
|
287 |
+
print(f"Generating {len(images)} masks...")
|
288 |
+
if not use_face_detection_instead:
|
289 |
+
seg_masks = clipseg_mask_generator(
|
290 |
+
images=images, target_prompts=target_prompts, temp=temp
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
seg_masks = face_mask_google_mediapipe(images=images)
|
294 |
+
|
295 |
+
# find the center of mass of the mask
|
296 |
+
if crop_based_on_salience:
|
297 |
+
coms = [_center_of_mass(mask) for mask in seg_masks]
|
298 |
+
else:
|
299 |
+
coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
|
300 |
+
# based on the center of mass, crop the image to a square
|
301 |
+
images = [
|
302 |
+
_crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
|
303 |
+
]
|
304 |
+
|
305 |
+
print(f"Upscaling {len(images)} images...")
|
306 |
+
# upscale images anyways
|
307 |
+
images = swin_ir_sr(images, target_size=(target_size, target_size))
|
308 |
+
images = [
|
309 |
+
image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
310 |
+
for image in images
|
311 |
+
]
|
312 |
+
|
313 |
+
seg_masks = [
|
314 |
+
_crop_to_square(mask, com, resize_to=target_size)
|
315 |
+
for mask, com in zip(seg_masks, coms)
|
316 |
+
]
|
317 |
+
with open(os.path.join(output_dir, "caption.txt"), "w") as f:
|
318 |
+
# save images and masks
|
319 |
+
for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
|
320 |
+
image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99)
|
321 |
+
mask.save(os.path.join(output_dir, f"{idx}.mask.png"))
|
322 |
+
|
323 |
+
f.write(caption + "\n")
|
324 |
+
|
325 |
+
|
326 |
+
def main():
|
327 |
+
fire.Fire(load_and_save_masks_and_captions)
|
lora_diffusion/safe_open.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pure python version of Safetensors safe_open
|
3 |
+
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import mmap
|
8 |
+
import os
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
class SafetensorsWrapper:
|
14 |
+
def __init__(self, metadata, tensors):
|
15 |
+
self._metadata = metadata
|
16 |
+
self._tensors = tensors
|
17 |
+
|
18 |
+
def metadata(self):
|
19 |
+
return self._metadata
|
20 |
+
|
21 |
+
def keys(self):
|
22 |
+
return self._tensors.keys()
|
23 |
+
|
24 |
+
def get_tensor(self, k):
|
25 |
+
return self._tensors[k]
|
26 |
+
|
27 |
+
|
28 |
+
DTYPES = {
|
29 |
+
"F32": torch.float32,
|
30 |
+
"F16": torch.float16,
|
31 |
+
"BF16": torch.bfloat16,
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def create_tensor(storage, info, offset):
|
36 |
+
dtype = DTYPES[info["dtype"]]
|
37 |
+
shape = info["shape"]
|
38 |
+
start, stop = info["data_offsets"]
|
39 |
+
return (
|
40 |
+
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
|
41 |
+
.view(dtype=dtype)
|
42 |
+
.reshape(shape)
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def safe_open(filename, framework="pt", device="cpu"):
|
47 |
+
if framework != "pt":
|
48 |
+
raise ValueError("`framework` must be 'pt'")
|
49 |
+
|
50 |
+
with open(filename, mode="r", encoding="utf8") as file_obj:
|
51 |
+
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
|
52 |
+
header = m.read(8)
|
53 |
+
n = int.from_bytes(header, "little")
|
54 |
+
metadata_bytes = m.read(n)
|
55 |
+
metadata = json.loads(metadata_bytes)
|
56 |
+
|
57 |
+
size = os.stat(filename).st_size
|
58 |
+
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
|
59 |
+
offset = n + 8
|
60 |
+
|
61 |
+
return SafetensorsWrapper(
|
62 |
+
metadata=metadata.get("__metadata__", {}),
|
63 |
+
tensors={
|
64 |
+
name: create_tensor(storage, info, offset).to(device)
|
65 |
+
for name, info in metadata.items()
|
66 |
+
if name != "__metadata__"
|
67 |
+
},
|
68 |
+
)
|
lora_diffusion/to_ckpt_v2.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
|
2 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
3 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
4 |
+
# Does not convert optimizer state or any other thing.
|
5 |
+
# Written by jachiam
|
6 |
+
import argparse
|
7 |
+
import os.path as osp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
# =================#
|
13 |
+
# UNet Conversion #
|
14 |
+
# =================#
|
15 |
+
|
16 |
+
unet_conversion_map = [
|
17 |
+
# (stable-diffusion, HF Diffusers)
|
18 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
19 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
20 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
21 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
22 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
23 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
24 |
+
("out.0.weight", "conv_norm_out.weight"),
|
25 |
+
("out.0.bias", "conv_norm_out.bias"),
|
26 |
+
("out.2.weight", "conv_out.weight"),
|
27 |
+
("out.2.bias", "conv_out.bias"),
|
28 |
+
]
|
29 |
+
|
30 |
+
unet_conversion_map_resnet = [
|
31 |
+
# (stable-diffusion, HF Diffusers)
|
32 |
+
("in_layers.0", "norm1"),
|
33 |
+
("in_layers.2", "conv1"),
|
34 |
+
("out_layers.0", "norm2"),
|
35 |
+
("out_layers.3", "conv2"),
|
36 |
+
("emb_layers.1", "time_emb_proj"),
|
37 |
+
("skip_connection", "conv_shortcut"),
|
38 |
+
]
|
39 |
+
|
40 |
+
unet_conversion_map_layer = []
|
41 |
+
# hardcoded number of downblocks and resnets/attentions...
|
42 |
+
# would need smarter logic for other networks.
|
43 |
+
for i in range(4):
|
44 |
+
# loop over downblocks/upblocks
|
45 |
+
|
46 |
+
for j in range(2):
|
47 |
+
# loop over resnets/attentions for downblocks
|
48 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
49 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
50 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
51 |
+
|
52 |
+
if i < 3:
|
53 |
+
# no attention layers in down_blocks.3
|
54 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
55 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
56 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
57 |
+
|
58 |
+
for j in range(3):
|
59 |
+
# loop over resnets/attentions for upblocks
|
60 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
61 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
62 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
63 |
+
|
64 |
+
if i > 0:
|
65 |
+
# no attention layers in up_blocks.0
|
66 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
67 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
68 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
69 |
+
|
70 |
+
if i < 3:
|
71 |
+
# no downsample in down_blocks.3
|
72 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
73 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
74 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
75 |
+
|
76 |
+
# no upsample in up_blocks.3
|
77 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
78 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
79 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
80 |
+
|
81 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
82 |
+
sd_mid_atn_prefix = "middle_block.1."
|
83 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
84 |
+
|
85 |
+
for j in range(2):
|
86 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
87 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
88 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
89 |
+
|
90 |
+
|
91 |
+
def convert_unet_state_dict(unet_state_dict):
|
92 |
+
# buyer beware: this is a *brittle* function,
|
93 |
+
# and correct output requires that all of these pieces interact in
|
94 |
+
# the exact order in which I have arranged them.
|
95 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
96 |
+
for sd_name, hf_name in unet_conversion_map:
|
97 |
+
mapping[hf_name] = sd_name
|
98 |
+
for k, v in mapping.items():
|
99 |
+
if "resnets" in k:
|
100 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
101 |
+
v = v.replace(hf_part, sd_part)
|
102 |
+
mapping[k] = v
|
103 |
+
for k, v in mapping.items():
|
104 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
105 |
+
v = v.replace(hf_part, sd_part)
|
106 |
+
mapping[k] = v
|
107 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
108 |
+
return new_state_dict
|
109 |
+
|
110 |
+
|
111 |
+
# ================#
|
112 |
+
# VAE Conversion #
|
113 |
+
# ================#
|
114 |
+
|
115 |
+
vae_conversion_map = [
|
116 |
+
# (stable-diffusion, HF Diffusers)
|
117 |
+
("nin_shortcut", "conv_shortcut"),
|
118 |
+
("norm_out", "conv_norm_out"),
|
119 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
120 |
+
]
|
121 |
+
|
122 |
+
for i in range(4):
|
123 |
+
# down_blocks have two resnets
|
124 |
+
for j in range(2):
|
125 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
126 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
127 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
128 |
+
|
129 |
+
if i < 3:
|
130 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
131 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
132 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
133 |
+
|
134 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
135 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
136 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
137 |
+
|
138 |
+
# up_blocks have three resnets
|
139 |
+
# also, up blocks in hf are numbered in reverse from sd
|
140 |
+
for j in range(3):
|
141 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
142 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
143 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
144 |
+
|
145 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
146 |
+
for i in range(2):
|
147 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
148 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
149 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
150 |
+
|
151 |
+
|
152 |
+
vae_conversion_map_attn = [
|
153 |
+
# (stable-diffusion, HF Diffusers)
|
154 |
+
("norm.", "group_norm."),
|
155 |
+
("q.", "query."),
|
156 |
+
("k.", "key."),
|
157 |
+
("v.", "value."),
|
158 |
+
("proj_out.", "proj_attn."),
|
159 |
+
]
|
160 |
+
|
161 |
+
|
162 |
+
def reshape_weight_for_sd(w):
|
163 |
+
# convert HF linear weights to SD conv2d weights
|
164 |
+
return w.reshape(*w.shape, 1, 1)
|
165 |
+
|
166 |
+
|
167 |
+
def convert_vae_state_dict(vae_state_dict):
|
168 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
169 |
+
for k, v in mapping.items():
|
170 |
+
for sd_part, hf_part in vae_conversion_map:
|
171 |
+
v = v.replace(hf_part, sd_part)
|
172 |
+
mapping[k] = v
|
173 |
+
for k, v in mapping.items():
|
174 |
+
if "attentions" in k:
|
175 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
176 |
+
v = v.replace(hf_part, sd_part)
|
177 |
+
mapping[k] = v
|
178 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
179 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
180 |
+
for k, v in new_state_dict.items():
|
181 |
+
for weight_name in weights_to_convert:
|
182 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
183 |
+
print(f"Reshaping {k} for SD format")
|
184 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
185 |
+
return new_state_dict
|
186 |
+
|
187 |
+
|
188 |
+
# =========================#
|
189 |
+
# Text Encoder Conversion #
|
190 |
+
# =========================#
|
191 |
+
# pretty much a no-op
|
192 |
+
|
193 |
+
|
194 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
195 |
+
return text_enc_dict
|
196 |
+
|
197 |
+
|
198 |
+
def convert_to_ckpt(model_path, checkpoint_path, as_half):
|
199 |
+
|
200 |
+
assert model_path is not None, "Must provide a model path!"
|
201 |
+
|
202 |
+
assert checkpoint_path is not None, "Must provide a checkpoint path!"
|
203 |
+
|
204 |
+
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
205 |
+
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
206 |
+
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
207 |
+
|
208 |
+
# Convert the UNet model
|
209 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
210 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
211 |
+
unet_state_dict = {
|
212 |
+
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
|
213 |
+
}
|
214 |
+
|
215 |
+
# Convert the VAE model
|
216 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
217 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
218 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
219 |
+
|
220 |
+
# Convert the text encoder model
|
221 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
222 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
223 |
+
text_enc_dict = {
|
224 |
+
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
|
225 |
+
}
|
226 |
+
|
227 |
+
# Put together new checkpoint
|
228 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
229 |
+
if as_half:
|
230 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
231 |
+
state_dict = {"state_dict": state_dict}
|
232 |
+
torch.save(state_dict, checkpoint_path)
|
lora_diffusion/utils.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import (
|
6 |
+
CLIPProcessor,
|
7 |
+
CLIPTextModelWithProjection,
|
8 |
+
CLIPTokenizer,
|
9 |
+
CLIPVisionModelWithProjection,
|
10 |
+
)
|
11 |
+
|
12 |
+
from diffusers import StableDiffusionPipeline
|
13 |
+
from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
|
14 |
+
import os
|
15 |
+
import glob
|
16 |
+
import math
|
17 |
+
|
18 |
+
EXAMPLE_PROMPTS = [
|
19 |
+
"<obj> swimming in a pool",
|
20 |
+
"<obj> at a beach with a view of seashore",
|
21 |
+
"<obj> in times square",
|
22 |
+
"<obj> wearing sunglasses",
|
23 |
+
"<obj> in a construction outfit",
|
24 |
+
"<obj> playing with a ball",
|
25 |
+
"<obj> wearing headphones",
|
26 |
+
"<obj> oil painting ghibli inspired",
|
27 |
+
"<obj> working on the laptop",
|
28 |
+
"<obj> with mountains and sunset in background",
|
29 |
+
"Painting of <obj> at a beach by artist claude monet",
|
30 |
+
"<obj> digital painting 3d render geometric style",
|
31 |
+
"A screaming <obj>",
|
32 |
+
"A depressed <obj>",
|
33 |
+
"A sleeping <obj>",
|
34 |
+
"A sad <obj>",
|
35 |
+
"A joyous <obj>",
|
36 |
+
"A frowning <obj>",
|
37 |
+
"A sculpture of <obj>",
|
38 |
+
"<obj> near a pool",
|
39 |
+
"<obj> at a beach with a view of seashore",
|
40 |
+
"<obj> in a garden",
|
41 |
+
"<obj> in grand canyon",
|
42 |
+
"<obj> floating in ocean",
|
43 |
+
"<obj> and an armchair",
|
44 |
+
"A maple tree on the side of <obj>",
|
45 |
+
"<obj> and an orange sofa",
|
46 |
+
"<obj> with chocolate cake on it",
|
47 |
+
"<obj> with a vase of rose flowers on it",
|
48 |
+
"A digital illustration of <obj>",
|
49 |
+
"Georgia O'Keeffe style <obj> painting",
|
50 |
+
"A watercolor painting of <obj> on a beach",
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
def image_grid(_imgs, rows=None, cols=None):
|
55 |
+
|
56 |
+
if rows is None and cols is None:
|
57 |
+
rows = cols = math.ceil(len(_imgs) ** 0.5)
|
58 |
+
|
59 |
+
if rows is None:
|
60 |
+
rows = math.ceil(len(_imgs) / cols)
|
61 |
+
if cols is None:
|
62 |
+
cols = math.ceil(len(_imgs) / rows)
|
63 |
+
|
64 |
+
w, h = _imgs[0].size
|
65 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
66 |
+
grid_w, grid_h = grid.size
|
67 |
+
|
68 |
+
for i, img in enumerate(_imgs):
|
69 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
70 |
+
return grid
|
71 |
+
|
72 |
+
|
73 |
+
def text_img_alignment(img_embeds, text_embeds, target_img_embeds):
|
74 |
+
# evaluation inspired from textual inversion paper
|
75 |
+
# https://arxiv.org/abs/2208.01618
|
76 |
+
|
77 |
+
# text alignment
|
78 |
+
assert img_embeds.shape[0] == text_embeds.shape[0]
|
79 |
+
text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / (
|
80 |
+
img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1)
|
81 |
+
)
|
82 |
+
|
83 |
+
# image alignment
|
84 |
+
img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True)
|
85 |
+
|
86 |
+
avg_target_img_embed = (
|
87 |
+
(target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True))
|
88 |
+
.mean(dim=0)
|
89 |
+
.unsqueeze(0)
|
90 |
+
.repeat(img_embeds.shape[0], 1)
|
91 |
+
)
|
92 |
+
|
93 |
+
img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1)
|
94 |
+
|
95 |
+
return {
|
96 |
+
"text_alignment_avg": text_img_sim.mean().item(),
|
97 |
+
"image_alignment_avg": img_img_sim.mean().item(),
|
98 |
+
"text_alignment_all": text_img_sim.tolist(),
|
99 |
+
"image_alignment_all": img_img_sim.tolist(),
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"):
|
104 |
+
text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id)
|
105 |
+
tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id)
|
106 |
+
vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id)
|
107 |
+
processor = CLIPProcessor.from_pretrained(eval_clip_id)
|
108 |
+
|
109 |
+
return text_model, tokenizer, vis_model, processor
|
110 |
+
|
111 |
+
|
112 |
+
def evaluate_pipe(
|
113 |
+
pipe,
|
114 |
+
target_images: List[Image.Image],
|
115 |
+
class_token: str = "",
|
116 |
+
learnt_token: str = "",
|
117 |
+
guidance_scale: float = 5.0,
|
118 |
+
seed=0,
|
119 |
+
clip_model_sets=None,
|
120 |
+
eval_clip_id: str = "openai/clip-vit-large-patch14",
|
121 |
+
n_test: int = 10,
|
122 |
+
n_step: int = 50,
|
123 |
+
):
|
124 |
+
|
125 |
+
if clip_model_sets is not None:
|
126 |
+
text_model, tokenizer, vis_model, processor = clip_model_sets
|
127 |
+
else:
|
128 |
+
text_model, tokenizer, vis_model, processor = prepare_clip_model_sets(
|
129 |
+
eval_clip_id
|
130 |
+
)
|
131 |
+
|
132 |
+
images = []
|
133 |
+
img_embeds = []
|
134 |
+
text_embeds = []
|
135 |
+
for prompt in EXAMPLE_PROMPTS[:n_test]:
|
136 |
+
prompt = prompt.replace("<obj>", learnt_token)
|
137 |
+
torch.manual_seed(seed)
|
138 |
+
with torch.autocast("cuda"):
|
139 |
+
img = pipe(
|
140 |
+
prompt, num_inference_steps=n_step, guidance_scale=guidance_scale
|
141 |
+
).images[0]
|
142 |
+
images.append(img)
|
143 |
+
|
144 |
+
# image
|
145 |
+
inputs = processor(images=img, return_tensors="pt")
|
146 |
+
img_embed = vis_model(**inputs).image_embeds
|
147 |
+
img_embeds.append(img_embed)
|
148 |
+
|
149 |
+
prompt = prompt.replace(learnt_token, class_token)
|
150 |
+
# prompts
|
151 |
+
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
|
152 |
+
outputs = text_model(**inputs)
|
153 |
+
text_embed = outputs.text_embeds
|
154 |
+
text_embeds.append(text_embed)
|
155 |
+
|
156 |
+
# target images
|
157 |
+
inputs = processor(images=target_images, return_tensors="pt")
|
158 |
+
target_img_embeds = vis_model(**inputs).image_embeds
|
159 |
+
|
160 |
+
img_embeds = torch.cat(img_embeds, dim=0)
|
161 |
+
text_embeds = torch.cat(text_embeds, dim=0)
|
162 |
+
|
163 |
+
return text_img_alignment(img_embeds, text_embeds, target_img_embeds)
|
164 |
+
|
165 |
+
|
166 |
+
def visualize_progress(
|
167 |
+
path_alls: Union[str, List[str]],
|
168 |
+
prompt: str,
|
169 |
+
model_id: str = "runwayml/stable-diffusion-v1-5",
|
170 |
+
device="cuda:0",
|
171 |
+
patch_unet=True,
|
172 |
+
patch_text=True,
|
173 |
+
patch_ti=True,
|
174 |
+
unet_scale=1.0,
|
175 |
+
text_sclae=1.0,
|
176 |
+
num_inference_steps=50,
|
177 |
+
guidance_scale=5.0,
|
178 |
+
offset: int = 0,
|
179 |
+
limit: int = 10,
|
180 |
+
seed: int = 0,
|
181 |
+
):
|
182 |
+
|
183 |
+
imgs = []
|
184 |
+
if isinstance(path_alls, str):
|
185 |
+
alls = list(set(glob.glob(path_alls)))
|
186 |
+
|
187 |
+
alls.sort(key=os.path.getmtime)
|
188 |
+
else:
|
189 |
+
alls = path_alls
|
190 |
+
|
191 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
192 |
+
model_id, torch_dtype=torch.float16
|
193 |
+
).to(device)
|
194 |
+
|
195 |
+
print(f"Found {len(alls)} checkpoints")
|
196 |
+
for path in alls[offset:limit]:
|
197 |
+
print(path)
|
198 |
+
|
199 |
+
patch_pipe(
|
200 |
+
pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti
|
201 |
+
)
|
202 |
+
|
203 |
+
tune_lora_scale(pipe.unet, unet_scale)
|
204 |
+
tune_lora_scale(pipe.text_encoder, text_sclae)
|
205 |
+
|
206 |
+
torch.manual_seed(seed)
|
207 |
+
image = pipe(
|
208 |
+
prompt,
|
209 |
+
num_inference_steps=num_inference_steps,
|
210 |
+
guidance_scale=guidance_scale,
|
211 |
+
).images[0]
|
212 |
+
imgs.append(image)
|
213 |
+
|
214 |
+
return imgs
|
lora_diffusion/xformers_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.models.attention import BasicTransformerBlock
|
5 |
+
from diffusers.utils.import_utils import is_xformers_available
|
6 |
+
|
7 |
+
from .lora import LoraInjectedLinear
|
8 |
+
|
9 |
+
if is_xformers_available():
|
10 |
+
import xformers
|
11 |
+
import xformers.ops
|
12 |
+
else:
|
13 |
+
xformers = None
|
14 |
+
|
15 |
+
|
16 |
+
@functools.cache
|
17 |
+
def test_xformers_backwards(size):
|
18 |
+
@torch.enable_grad()
|
19 |
+
def _grad(size):
|
20 |
+
q = torch.randn((1, 4, size), device="cuda")
|
21 |
+
k = torch.randn((1, 4, size), device="cuda")
|
22 |
+
v = torch.randn((1, 4, size), device="cuda")
|
23 |
+
|
24 |
+
q = q.detach().requires_grad_()
|
25 |
+
k = k.detach().requires_grad_()
|
26 |
+
v = v.detach().requires_grad_()
|
27 |
+
|
28 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
29 |
+
loss = out.sum(2).mean(0).sum()
|
30 |
+
|
31 |
+
return torch.autograd.grad(loss, v)
|
32 |
+
|
33 |
+
try:
|
34 |
+
_grad(size)
|
35 |
+
print(size, "pass")
|
36 |
+
return True
|
37 |
+
except Exception as e:
|
38 |
+
print(size, "fail")
|
39 |
+
return False
|
40 |
+
|
41 |
+
|
42 |
+
def set_use_memory_efficient_attention_xformers(
|
43 |
+
module: torch.nn.Module, valid: bool
|
44 |
+
) -> None:
|
45 |
+
def fn_test_dim_head(module: torch.nn.Module):
|
46 |
+
if isinstance(module, BasicTransformerBlock):
|
47 |
+
# dim_head isn't stored anywhere, so back-calculate
|
48 |
+
source = module.attn1.to_v
|
49 |
+
if isinstance(source, LoraInjectedLinear):
|
50 |
+
source = source.linear
|
51 |
+
|
52 |
+
dim_head = source.out_features // module.attn1.heads
|
53 |
+
|
54 |
+
result = test_xformers_backwards(dim_head)
|
55 |
+
|
56 |
+
# If dim_head > dim_head_max, turn xformers off
|
57 |
+
if not result:
|
58 |
+
module.set_use_memory_efficient_attention_xformers(False)
|
59 |
+
|
60 |
+
for child in module.children():
|
61 |
+
fn_test_dim_head(child)
|
62 |
+
|
63 |
+
if not is_xformers_available() and valid:
|
64 |
+
print("XFormers is not available. Skipping.")
|
65 |
+
return
|
66 |
+
|
67 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
68 |
+
|
69 |
+
if valid:
|
70 |
+
fn_test_dim_head(module)
|
scene/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import json
|
15 |
+
from utils.system_utils import searchForMaxIteration
|
16 |
+
from scene.dataset_readers import sceneLoadTypeCallbacks,GenerateRandomCameras,GeneratePurnCameras,GenerateCircleCameras
|
17 |
+
from scene.gaussian_model import GaussianModel
|
18 |
+
from arguments import ModelParams, GenerateCamParams
|
19 |
+
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, cameraList_from_RcamInfos
|
20 |
+
|
21 |
+
class Scene:
|
22 |
+
|
23 |
+
gaussians : GaussianModel
|
24 |
+
|
25 |
+
def __init__(self, args : ModelParams, pose_args : GenerateCamParams, gaussians : GaussianModel, load_iteration=None, shuffle=False, resolution_scales=[1.0]):
|
26 |
+
"""b
|
27 |
+
:param path: Path to colmap scene main folder.
|
28 |
+
"""
|
29 |
+
self.model_path = args._model_path
|
30 |
+
self.pretrained_model_path = args.pretrained_model_path
|
31 |
+
self.loaded_iter = None
|
32 |
+
self.gaussians = gaussians
|
33 |
+
self.resolution_scales = resolution_scales
|
34 |
+
self.pose_args = pose_args
|
35 |
+
self.args = args
|
36 |
+
if load_iteration:
|
37 |
+
if load_iteration == -1:
|
38 |
+
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
|
39 |
+
else:
|
40 |
+
self.loaded_iter = load_iteration
|
41 |
+
print("Loading trained model at iteration {}".format(self.loaded_iter))
|
42 |
+
|
43 |
+
self.test_cameras = {}
|
44 |
+
scene_info = sceneLoadTypeCallbacks["RandomCam"](self.model_path ,pose_args)
|
45 |
+
|
46 |
+
json_cams = []
|
47 |
+
camlist = []
|
48 |
+
if scene_info.test_cameras:
|
49 |
+
camlist.extend(scene_info.test_cameras)
|
50 |
+
for id, cam in enumerate(camlist):
|
51 |
+
json_cams.append(camera_to_JSON(id, cam))
|
52 |
+
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
|
53 |
+
json.dump(json_cams, file)
|
54 |
+
|
55 |
+
if shuffle:
|
56 |
+
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
|
57 |
+
self.cameras_extent = pose_args.default_radius # scene_info.nerf_normalization["radius"]
|
58 |
+
for resolution_scale in resolution_scales:
|
59 |
+
self.test_cameras[resolution_scale] = cameraList_from_RcamInfos(scene_info.test_cameras, resolution_scale, self.pose_args)
|
60 |
+
if self.loaded_iter:
|
61 |
+
self.gaussians.load_ply(os.path.join(self.model_path,
|
62 |
+
"point_cloud",
|
63 |
+
"iteration_" + str(self.loaded_iter),
|
64 |
+
"point_cloud.ply"))
|
65 |
+
elif self.pretrained_model_path is not None:
|
66 |
+
self.gaussians.load_ply(self.pretrained_model_path)
|
67 |
+
else:
|
68 |
+
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
69 |
+
|
70 |
+
def save(self, iteration):
|
71 |
+
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
|
72 |
+
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
|
73 |
+
|
74 |
+
def getRandTrainCameras(self, scale=1.0):
|
75 |
+
rand_train_cameras = GenerateRandomCameras(self.pose_args, self.args.batch, SSAA=True)
|
76 |
+
train_cameras = {}
|
77 |
+
for resolution_scale in self.resolution_scales:
|
78 |
+
train_cameras[resolution_scale] = cameraList_from_RcamInfos(rand_train_cameras, resolution_scale, self.pose_args, SSAA=True)
|
79 |
+
return train_cameras[scale]
|
80 |
+
|
81 |
+
|
82 |
+
def getPurnTrainCameras(self, scale=1.0):
|
83 |
+
rand_train_cameras = GeneratePurnCameras(self.pose_args)
|
84 |
+
train_cameras = {}
|
85 |
+
for resolution_scale in self.resolution_scales:
|
86 |
+
train_cameras[resolution_scale] = cameraList_from_RcamInfos(rand_train_cameras, resolution_scale, self.pose_args)
|
87 |
+
return train_cameras[scale]
|
88 |
+
|
89 |
+
|
90 |
+
def getTestCameras(self, scale=1.0):
|
91 |
+
return self.test_cameras[scale]
|
92 |
+
|
93 |
+
def getCircleVideoCameras(self, scale=1.0,batch_size=120, render45 = True):
|
94 |
+
video_circle_cameras = GenerateCircleCameras(self.pose_args,batch_size,render45)
|
95 |
+
video_cameras = {}
|
96 |
+
for resolution_scale in self.resolution_scales:
|
97 |
+
video_cameras[resolution_scale] = cameraList_from_RcamInfos(video_circle_cameras, resolution_scale, self.pose_args)
|
98 |
+
return video_cameras[scale]
|
scene/cameras.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import numpy as np
|
15 |
+
from utils.graphics_utils import getWorld2View2, getProjectionMatrix, fov2focal
|
16 |
+
|
17 |
+
def get_rays_torch(focal, c2w, H=64,W=64):
|
18 |
+
"""Computes rays using a General Pinhole Camera Model
|
19 |
+
Assumes self.h, self.w, self.focal, and self.cam_to_world exist
|
20 |
+
"""
|
21 |
+
x, y = torch.meshgrid(
|
22 |
+
torch.arange(W), # X-Axis (columns)
|
23 |
+
torch.arange(H), # Y-Axis (rows)
|
24 |
+
indexing='xy')
|
25 |
+
camera_directions = torch.stack(
|
26 |
+
[(x - W * 0.5 + 0.5) / focal,
|
27 |
+
-(y - H * 0.5 + 0.5) / focal,
|
28 |
+
-torch.ones_like(x)],
|
29 |
+
dim=-1).to(c2w)
|
30 |
+
|
31 |
+
# Rotate ray directions from camera frame to the world frame
|
32 |
+
directions = ((camera_directions[ None,..., None, :] * c2w[None,None, None, :3, :3]).sum(axis=-1)) # Translate camera frame's origin to the world frame
|
33 |
+
origins = torch.broadcast_to(c2w[ None,None, None, :3, -1], directions.shape)
|
34 |
+
viewdirs = directions / torch.linalg.norm(directions, axis=-1, keepdims=True)
|
35 |
+
|
36 |
+
return torch.cat((origins,viewdirs),dim=-1)
|
37 |
+
|
38 |
+
|
39 |
+
class Camera(nn.Module):
|
40 |
+
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
41 |
+
image_name, uid,
|
42 |
+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
43 |
+
):
|
44 |
+
super(Camera, self).__init__()
|
45 |
+
|
46 |
+
self.uid = uid
|
47 |
+
self.colmap_id = colmap_id
|
48 |
+
self.R = R
|
49 |
+
self.T = T
|
50 |
+
self.FoVx = FoVx
|
51 |
+
self.FoVy = FoVy
|
52 |
+
self.image_name = image_name
|
53 |
+
|
54 |
+
try:
|
55 |
+
self.data_device = torch.device(data_device)
|
56 |
+
except Exception as e:
|
57 |
+
print(e)
|
58 |
+
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
59 |
+
self.data_device = torch.device("cuda")
|
60 |
+
|
61 |
+
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
62 |
+
self.image_width = self.original_image.shape[2]
|
63 |
+
self.image_height = self.original_image.shape[1]
|
64 |
+
|
65 |
+
if gt_alpha_mask is not None:
|
66 |
+
self.original_image *= gt_alpha_mask.to(self.data_device)
|
67 |
+
else:
|
68 |
+
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
69 |
+
|
70 |
+
self.zfar = 100.0
|
71 |
+
self.znear = 0.01
|
72 |
+
|
73 |
+
self.trans = trans
|
74 |
+
self.scale = scale
|
75 |
+
|
76 |
+
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
77 |
+
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
78 |
+
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
79 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
80 |
+
|
81 |
+
|
82 |
+
class RCamera(nn.Module):
|
83 |
+
def __init__(self, colmap_id, R, T, FoVx, FoVy, uid, delta_polar, delta_azimuth, delta_radius, opt,
|
84 |
+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", SSAA=False
|
85 |
+
):
|
86 |
+
super(RCamera, self).__init__()
|
87 |
+
|
88 |
+
self.uid = uid
|
89 |
+
self.colmap_id = colmap_id
|
90 |
+
self.R = R
|
91 |
+
self.T = T
|
92 |
+
self.FoVx = FoVx
|
93 |
+
self.FoVy = FoVy
|
94 |
+
self.delta_polar = delta_polar
|
95 |
+
self.delta_azimuth = delta_azimuth
|
96 |
+
self.delta_radius = delta_radius
|
97 |
+
try:
|
98 |
+
self.data_device = torch.device(data_device)
|
99 |
+
except Exception as e:
|
100 |
+
print(e)
|
101 |
+
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
102 |
+
self.data_device = torch.device("cuda")
|
103 |
+
|
104 |
+
self.zfar = 100.0
|
105 |
+
self.znear = 0.01
|
106 |
+
|
107 |
+
if SSAA:
|
108 |
+
ssaa = opt.SSAA
|
109 |
+
else:
|
110 |
+
ssaa = 1
|
111 |
+
|
112 |
+
self.image_width = opt.image_w * ssaa
|
113 |
+
self.image_height = opt.image_h * ssaa
|
114 |
+
|
115 |
+
self.trans = trans
|
116 |
+
self.scale = scale
|
117 |
+
|
118 |
+
RT = torch.tensor(getWorld2View2(R, T, trans, scale))
|
119 |
+
self.world_view_transform = RT.transpose(0, 1).cuda()
|
120 |
+
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
121 |
+
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
122 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
123 |
+
# self.rays = get_rays_torch(fov2focal(FoVx, 64), RT).cuda()
|
124 |
+
self.rays = get_rays_torch(fov2focal(FoVx, self.image_width//8), RT, H=self.image_height//8, W=self.image_width//8).cuda()
|
125 |
+
|
126 |
+
class MiniCam:
|
127 |
+
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
128 |
+
self.image_width = width
|
129 |
+
self.image_height = height
|
130 |
+
self.FoVy = fovy
|
131 |
+
self.FoVx = fovx
|
132 |
+
self.znear = znear
|
133 |
+
self.zfar = zfar
|
134 |
+
self.world_view_transform = world_view_transform
|
135 |
+
self.full_proj_transform = full_proj_transform
|
136 |
+
view_inv = torch.inverse(self.world_view_transform)
|
137 |
+
self.camera_center = view_inv[3][:3]
|
138 |
+
|
scene/dataset_readers.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
import torch
|
15 |
+
import random
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from PIL import Image
|
18 |
+
from typing import NamedTuple
|
19 |
+
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
|
20 |
+
import numpy as np
|
21 |
+
import json
|
22 |
+
from pathlib import Path
|
23 |
+
from utils.pointe_utils import init_from_pointe
|
24 |
+
from plyfile import PlyData, PlyElement
|
25 |
+
from utils.sh_utils import SH2RGB
|
26 |
+
from utils.general_utils import inverse_sigmoid_np
|
27 |
+
from scene.gaussian_model import BasicPointCloud
|
28 |
+
|
29 |
+
|
30 |
+
class RandCameraInfo(NamedTuple):
|
31 |
+
uid: int
|
32 |
+
R: np.array
|
33 |
+
T: np.array
|
34 |
+
FovY: np.array
|
35 |
+
FovX: np.array
|
36 |
+
width: int
|
37 |
+
height: int
|
38 |
+
delta_polar : np.array
|
39 |
+
delta_azimuth : np.array
|
40 |
+
delta_radius : np.array
|
41 |
+
|
42 |
+
|
43 |
+
class SceneInfo(NamedTuple):
|
44 |
+
point_cloud: BasicPointCloud
|
45 |
+
train_cameras: list
|
46 |
+
test_cameras: list
|
47 |
+
nerf_normalization: dict
|
48 |
+
ply_path: str
|
49 |
+
|
50 |
+
|
51 |
+
class RSceneInfo(NamedTuple):
|
52 |
+
point_cloud: BasicPointCloud
|
53 |
+
test_cameras: list
|
54 |
+
ply_path: str
|
55 |
+
|
56 |
+
# def getNerfppNorm(cam_info):
|
57 |
+
# def get_center_and_diag(cam_centers):
|
58 |
+
# cam_centers = np.hstack(cam_centers)
|
59 |
+
# avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
|
60 |
+
# center = avg_cam_center
|
61 |
+
# dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
|
62 |
+
# diagonal = np.max(dist)
|
63 |
+
# return center.flatten(), diagonal
|
64 |
+
|
65 |
+
# cam_centers = []
|
66 |
+
|
67 |
+
# for cam in cam_info:
|
68 |
+
# W2C = getWorld2View2(cam.R, cam.T)
|
69 |
+
# C2W = np.linalg.inv(W2C)
|
70 |
+
# cam_centers.append(C2W[:3, 3:4])
|
71 |
+
|
72 |
+
# center, diagonal = get_center_and_diag(cam_centers)
|
73 |
+
# radius = diagonal * 1.1
|
74 |
+
|
75 |
+
# translate = -center
|
76 |
+
|
77 |
+
# return {"translate": translate, "radius": radius}
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
def fetchPly(path):
|
82 |
+
plydata = PlyData.read(path)
|
83 |
+
vertices = plydata['vertex']
|
84 |
+
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
|
85 |
+
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
|
86 |
+
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
|
87 |
+
return BasicPointCloud(points=positions, colors=colors, normals=normals)
|
88 |
+
|
89 |
+
def storePly(path, xyz, rgb):
|
90 |
+
# Define the dtype for the structured array
|
91 |
+
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
92 |
+
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
93 |
+
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
94 |
+
|
95 |
+
normals = np.zeros_like(xyz)
|
96 |
+
|
97 |
+
elements = np.empty(xyz.shape[0], dtype=dtype)
|
98 |
+
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
99 |
+
elements[:] = list(map(tuple, attributes))
|
100 |
+
|
101 |
+
# Create the PlyData object and write to file
|
102 |
+
vertex_element = PlyElement.describe(elements, 'vertex')
|
103 |
+
ply_data = PlyData([vertex_element])
|
104 |
+
ply_data.write(path)
|
105 |
+
|
106 |
+
#only test_camera
|
107 |
+
def readCircleCamInfo(path,opt):
|
108 |
+
print("Reading Test Transforms")
|
109 |
+
test_cam_infos = GenerateCircleCameras(opt,render45 = opt.render_45)
|
110 |
+
ply_path = os.path.join(path, "init_points3d.ply")
|
111 |
+
if not os.path.exists(ply_path):
|
112 |
+
# Since this data set has no colmap data, we start with random points
|
113 |
+
num_pts = opt.init_num_pts
|
114 |
+
if opt.init_shape == 'sphere':
|
115 |
+
thetas = np.random.rand(num_pts)*np.pi
|
116 |
+
phis = np.random.rand(num_pts)*2*np.pi
|
117 |
+
radius = np.random.rand(num_pts)*0.5
|
118 |
+
# We create random points inside the bounds of sphere
|
119 |
+
xyz = np.stack([
|
120 |
+
radius * np.sin(thetas) * np.sin(phis),
|
121 |
+
radius * np.sin(thetas) * np.cos(phis),
|
122 |
+
radius * np.cos(thetas),
|
123 |
+
], axis=-1) # [B, 3]
|
124 |
+
elif opt.init_shape == 'box':
|
125 |
+
xyz = np.random.random((num_pts, 3)) * 1.0 - 0.5
|
126 |
+
elif opt.init_shape == 'rectangle_x':
|
127 |
+
xyz = np.random.random((num_pts, 3))
|
128 |
+
xyz[:, 0] = xyz[:, 0] * 0.6 - 0.3
|
129 |
+
xyz[:, 1] = xyz[:, 1] * 1.2 - 0.6
|
130 |
+
xyz[:, 2] = xyz[:, 2] * 0.5 - 0.25
|
131 |
+
elif opt.init_shape == 'rectangle_z':
|
132 |
+
xyz = np.random.random((num_pts, 3))
|
133 |
+
xyz[:, 0] = xyz[:, 0] * 0.8 - 0.4
|
134 |
+
xyz[:, 1] = xyz[:, 1] * 0.6 - 0.3
|
135 |
+
xyz[:, 2] = xyz[:, 2] * 1.2 - 0.6
|
136 |
+
elif opt.init_shape == 'pointe':
|
137 |
+
num_pts = int(num_pts/5000)
|
138 |
+
xyz,rgb = init_from_pointe(opt.init_prompt)
|
139 |
+
xyz[:,1] = - xyz[:,1]
|
140 |
+
xyz[:,2] = xyz[:,2] + 0.15
|
141 |
+
thetas = np.random.rand(num_pts)*np.pi
|
142 |
+
phis = np.random.rand(num_pts)*2*np.pi
|
143 |
+
radius = np.random.rand(num_pts)*0.05
|
144 |
+
# We create random points inside the bounds of sphere
|
145 |
+
xyz_ball = np.stack([
|
146 |
+
radius * np.sin(thetas) * np.sin(phis),
|
147 |
+
radius * np.sin(thetas) * np.cos(phis),
|
148 |
+
radius * np.cos(thetas),
|
149 |
+
], axis=-1) # [B, 3]expend_dims
|
150 |
+
rgb_ball = np.random.random((4096, num_pts, 3))*0.0001
|
151 |
+
rgb = (np.expand_dims(rgb,axis=1)+rgb_ball).reshape(-1,3)
|
152 |
+
xyz = (np.expand_dims(xyz,axis=1)+np.expand_dims(xyz_ball,axis=0)).reshape(-1,3)
|
153 |
+
xyz = xyz * 1.
|
154 |
+
num_pts = xyz.shape[0]
|
155 |
+
elif opt.init_shape == 'scene':
|
156 |
+
thetas = np.random.rand(num_pts)*np.pi
|
157 |
+
phis = np.random.rand(num_pts)*2*np.pi
|
158 |
+
radius = np.random.rand(num_pts) + opt.radius_range[-1]*3
|
159 |
+
# We create random points inside the bounds of sphere
|
160 |
+
xyz = np.stack([
|
161 |
+
radius * np.sin(thetas) * np.sin(phis),
|
162 |
+
radius * np.sin(thetas) * np.cos(phis),
|
163 |
+
radius * np.cos(thetas),
|
164 |
+
], axis=-1) # [B, 3]
|
165 |
+
else:
|
166 |
+
raise NotImplementedError()
|
167 |
+
print(f"Generating random point cloud ({num_pts})...")
|
168 |
+
|
169 |
+
shs = np.random.random((num_pts, 3)) / 255.0
|
170 |
+
|
171 |
+
if opt.init_shape == 'pointe' and opt.use_pointe_rgb:
|
172 |
+
pcd = BasicPointCloud(points=xyz, colors=rgb, normals=np.zeros((num_pts, 3)))
|
173 |
+
storePly(ply_path, xyz, rgb * 255)
|
174 |
+
else:
|
175 |
+
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
|
176 |
+
storePly(ply_path, xyz, SH2RGB(shs) * 255)
|
177 |
+
try:
|
178 |
+
pcd = fetchPly(ply_path)
|
179 |
+
except:
|
180 |
+
pcd = None
|
181 |
+
|
182 |
+
scene_info = RSceneInfo(point_cloud=pcd,
|
183 |
+
test_cameras=test_cam_infos,
|
184 |
+
ply_path=ply_path)
|
185 |
+
return scene_info
|
186 |
+
#borrow from https://github.com/ashawkey/stable-dreamfusion
|
187 |
+
|
188 |
+
def safe_normalize(x, eps=1e-20):
|
189 |
+
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
|
190 |
+
|
191 |
+
# def circle_poses(radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), angle_overhead=30, angle_front=60):
|
192 |
+
|
193 |
+
# theta = theta / 180 * np.pi
|
194 |
+
# phi = phi / 180 * np.pi
|
195 |
+
# angle_overhead = angle_overhead / 180 * np.pi
|
196 |
+
# angle_front = angle_front / 180 * np.pi
|
197 |
+
|
198 |
+
# centers = torch.stack([
|
199 |
+
# radius * torch.sin(theta) * torch.sin(phi),
|
200 |
+
# radius * torch.cos(theta),
|
201 |
+
# radius * torch.sin(theta) * torch.cos(phi),
|
202 |
+
# ], dim=-1) # [B, 3]
|
203 |
+
|
204 |
+
# # lookat
|
205 |
+
# forward_vector = safe_normalize(centers)
|
206 |
+
# up_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(len(centers), 1)
|
207 |
+
# right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
208 |
+
# up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
|
209 |
+
|
210 |
+
# poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(len(centers), 1, 1)
|
211 |
+
# poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
212 |
+
# poses[:, :3, 3] = centers
|
213 |
+
|
214 |
+
# return poses.numpy()
|
215 |
+
|
216 |
+
def circle_poses(radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), angle_overhead=30, angle_front=60):
|
217 |
+
|
218 |
+
theta = theta / 180 * np.pi
|
219 |
+
phi = phi / 180 * np.pi
|
220 |
+
angle_overhead = angle_overhead / 180 * np.pi
|
221 |
+
angle_front = angle_front / 180 * np.pi
|
222 |
+
|
223 |
+
centers = torch.stack([
|
224 |
+
radius * torch.sin(theta) * torch.sin(phi),
|
225 |
+
radius * torch.sin(theta) * torch.cos(phi),
|
226 |
+
radius * torch.cos(theta),
|
227 |
+
], dim=-1) # [B, 3]
|
228 |
+
|
229 |
+
# lookat
|
230 |
+
forward_vector = safe_normalize(centers)
|
231 |
+
up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(len(centers), 1)
|
232 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
233 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
|
234 |
+
|
235 |
+
poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(len(centers), 1, 1)
|
236 |
+
poses[:, :3, :3] = torch.stack((-right_vector, up_vector, forward_vector), dim=-1)
|
237 |
+
poses[:, :3, 3] = centers
|
238 |
+
|
239 |
+
return poses.numpy()
|
240 |
+
|
241 |
+
def gen_random_pos(size, param_range, gamma=1):
|
242 |
+
lower, higher = param_range[0], param_range[1]
|
243 |
+
|
244 |
+
mid = lower + (higher - lower) * 0.5
|
245 |
+
radius = (higher - lower) * 0.5
|
246 |
+
|
247 |
+
rand_ = torch.rand(size) # 0, 1
|
248 |
+
sign = torch.where(torch.rand(size) > 0.5, torch.ones(size) * -1., torch.ones(size))
|
249 |
+
rand_ = sign * (rand_ ** gamma)
|
250 |
+
|
251 |
+
return (rand_ * radius) + mid
|
252 |
+
|
253 |
+
|
254 |
+
def rand_poses(size, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5, rand_cam_gamma=1):
|
255 |
+
''' generate random poses from an orbit camera
|
256 |
+
Args:
|
257 |
+
size: batch size of generated poses.
|
258 |
+
device: where to allocate the output.
|
259 |
+
radius: camera radius
|
260 |
+
theta_range: [min, max], should be in [0, pi]
|
261 |
+
phi_range: [min, max], should be in [0, 2 * pi]
|
262 |
+
Return:
|
263 |
+
poses: [size, 4, 4]
|
264 |
+
'''
|
265 |
+
|
266 |
+
theta_range = np.array(theta_range) / 180 * np.pi
|
267 |
+
phi_range = np.array(phi_range) / 180 * np.pi
|
268 |
+
angle_overhead = angle_overhead / 180 * np.pi
|
269 |
+
angle_front = angle_front / 180 * np.pi
|
270 |
+
|
271 |
+
# radius = torch.rand(size) * (radius_range[1] - radius_range[0]) + radius_range[0]
|
272 |
+
radius = gen_random_pos(size, radius_range)
|
273 |
+
|
274 |
+
if random.random() < uniform_sphere_rate:
|
275 |
+
unit_centers = F.normalize(
|
276 |
+
torch.stack([
|
277 |
+
torch.randn(size),
|
278 |
+
torch.abs(torch.randn(size)),
|
279 |
+
torch.randn(size),
|
280 |
+
], dim=-1), p=2, dim=1
|
281 |
+
)
|
282 |
+
thetas = torch.acos(unit_centers[:,1])
|
283 |
+
phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
|
284 |
+
phis[phis < 0] += 2 * np.pi
|
285 |
+
centers = unit_centers * radius.unsqueeze(-1)
|
286 |
+
else:
|
287 |
+
# thetas = torch.rand(size) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
288 |
+
# phis = torch.rand(size) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
289 |
+
# phis[phis < 0] += 2 * np.pi
|
290 |
+
|
291 |
+
# centers = torch.stack([
|
292 |
+
# radius * torch.sin(thetas) * torch.sin(phis),
|
293 |
+
# radius * torch.cos(thetas),
|
294 |
+
# radius * torch.sin(thetas) * torch.cos(phis),
|
295 |
+
# ], dim=-1) # [B, 3]
|
296 |
+
# thetas = torch.rand(size) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
297 |
+
# phis = torch.rand(size) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
298 |
+
thetas = gen_random_pos(size, theta_range, rand_cam_gamma)
|
299 |
+
phis = gen_random_pos(size, phi_range, rand_cam_gamma)
|
300 |
+
phis[phis < 0] += 2 * np.pi
|
301 |
+
|
302 |
+
centers = torch.stack([
|
303 |
+
radius * torch.sin(thetas) * torch.sin(phis),
|
304 |
+
radius * torch.sin(thetas) * torch.cos(phis),
|
305 |
+
radius * torch.cos(thetas),
|
306 |
+
], dim=-1) # [B, 3]
|
307 |
+
|
308 |
+
targets = 0
|
309 |
+
|
310 |
+
# jitters
|
311 |
+
if opt.jitter_pose:
|
312 |
+
jit_center = opt.jitter_center # 0.015 # was 0.2
|
313 |
+
jit_target = opt.jitter_target
|
314 |
+
centers += torch.rand_like(centers) * jit_center - jit_center/2.0
|
315 |
+
targets += torch.randn_like(centers) * jit_target
|
316 |
+
|
317 |
+
# lookat
|
318 |
+
forward_vector = safe_normalize(centers - targets)
|
319 |
+
up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1)
|
320 |
+
#up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1)
|
321 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
322 |
+
|
323 |
+
if opt.jitter_pose:
|
324 |
+
up_noise = torch.randn_like(up_vector) * opt.jitter_up
|
325 |
+
else:
|
326 |
+
up_noise = 0
|
327 |
+
|
328 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) #forward_vector
|
329 |
+
|
330 |
+
poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(size, 1, 1)
|
331 |
+
poses[:, :3, :3] = torch.stack((-right_vector, up_vector, forward_vector), dim=-1) #up_vector
|
332 |
+
poses[:, :3, 3] = centers
|
333 |
+
|
334 |
+
|
335 |
+
# back to degree
|
336 |
+
thetas = thetas / np.pi * 180
|
337 |
+
phis = phis / np.pi * 180
|
338 |
+
|
339 |
+
return poses.numpy(), thetas.numpy(), phis.numpy(), radius.numpy()
|
340 |
+
|
341 |
+
def GenerateCircleCameras(opt, size=8, render45 = False):
|
342 |
+
# random focal
|
343 |
+
fov = opt.default_fovy
|
344 |
+
cam_infos = []
|
345 |
+
#generate specific data structure
|
346 |
+
for idx in range(size):
|
347 |
+
thetas = torch.FloatTensor([opt.default_polar])
|
348 |
+
phis = torch.FloatTensor([(idx / size) * 360])
|
349 |
+
radius = torch.FloatTensor([opt.default_radius])
|
350 |
+
# random pose on the fly
|
351 |
+
poses = circle_poses(radius=radius, theta=thetas, phi=phis, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front)
|
352 |
+
matrix = np.linalg.inv(poses[0])
|
353 |
+
R = -np.transpose(matrix[:3,:3])
|
354 |
+
R[:,0] = -R[:,0]
|
355 |
+
T = -matrix[:3, 3]
|
356 |
+
fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w)
|
357 |
+
FovY = fovy
|
358 |
+
FovX = fov
|
359 |
+
|
360 |
+
# delta polar/azimuth/radius to default view
|
361 |
+
delta_polar = thetas - opt.default_polar
|
362 |
+
delta_azimuth = phis - opt.default_azimuth
|
363 |
+
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
|
364 |
+
delta_radius = radius - opt.default_radius
|
365 |
+
cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w,
|
366 |
+
height = opt.image_h, delta_polar = delta_polar,delta_azimuth = delta_azimuth, delta_radius = delta_radius))
|
367 |
+
if render45:
|
368 |
+
for idx in range(size):
|
369 |
+
thetas = torch.FloatTensor([opt.default_polar*2//3])
|
370 |
+
phis = torch.FloatTensor([(idx / size) * 360])
|
371 |
+
radius = torch.FloatTensor([opt.default_radius])
|
372 |
+
# random pose on the fly
|
373 |
+
poses = circle_poses(radius=radius, theta=thetas, phi=phis, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front)
|
374 |
+
matrix = np.linalg.inv(poses[0])
|
375 |
+
R = -np.transpose(matrix[:3,:3])
|
376 |
+
R[:,0] = -R[:,0]
|
377 |
+
T = -matrix[:3, 3]
|
378 |
+
fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w)
|
379 |
+
FovY = fovy
|
380 |
+
FovX = fov
|
381 |
+
|
382 |
+
# delta polar/azimuth/radius to default view
|
383 |
+
delta_polar = thetas - opt.default_polar
|
384 |
+
delta_azimuth = phis - opt.default_azimuth
|
385 |
+
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
|
386 |
+
delta_radius = radius - opt.default_radius
|
387 |
+
cam_infos.append(RandCameraInfo(uid=idx+size, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w,
|
388 |
+
height = opt.image_h, delta_polar = delta_polar,delta_azimuth = delta_azimuth, delta_radius = delta_radius))
|
389 |
+
return cam_infos
|
390 |
+
|
391 |
+
|
392 |
+
def GenerateRandomCameras(opt, size=2000, SSAA=True):
|
393 |
+
# random pose on the fly
|
394 |
+
poses, thetas, phis, radius = rand_poses(size, opt, radius_range=opt.radius_range, theta_range=opt.theta_range, phi_range=opt.phi_range,
|
395 |
+
angle_overhead=opt.angle_overhead, angle_front=opt.angle_front, uniform_sphere_rate=opt.uniform_sphere_rate,
|
396 |
+
rand_cam_gamma=opt.rand_cam_gamma)
|
397 |
+
# delta polar/azimuth/radius to default view
|
398 |
+
delta_polar = thetas - opt.default_polar
|
399 |
+
delta_azimuth = phis - opt.default_azimuth
|
400 |
+
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
|
401 |
+
delta_radius = radius - opt.default_radius
|
402 |
+
# random focal
|
403 |
+
fov = random.random() * (opt.fovy_range[1] - opt.fovy_range[0]) + opt.fovy_range[0]
|
404 |
+
|
405 |
+
cam_infos = []
|
406 |
+
|
407 |
+
if SSAA:
|
408 |
+
ssaa = opt.SSAA
|
409 |
+
else:
|
410 |
+
ssaa = 1
|
411 |
+
|
412 |
+
image_h = opt.image_h * ssaa
|
413 |
+
image_w = opt.image_w * ssaa
|
414 |
+
|
415 |
+
#generate specific data structure
|
416 |
+
for idx in range(size):
|
417 |
+
matrix = np.linalg.inv(poses[idx])
|
418 |
+
R = -np.transpose(matrix[:3,:3])
|
419 |
+
R[:,0] = -R[:,0]
|
420 |
+
T = -matrix[:3, 3]
|
421 |
+
# matrix = poses[idx]
|
422 |
+
# R = matrix[:3,:3]
|
423 |
+
# T = matrix[:3, 3]
|
424 |
+
fovy = focal2fov(fov2focal(fov, image_h), image_w)
|
425 |
+
FovY = fovy
|
426 |
+
FovX = fov
|
427 |
+
|
428 |
+
cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=image_w,
|
429 |
+
height=image_h, delta_polar = delta_polar[idx],
|
430 |
+
delta_azimuth = delta_azimuth[idx], delta_radius = delta_radius[idx]))
|
431 |
+
return cam_infos
|
432 |
+
|
433 |
+
def GeneratePurnCameras(opt, size=300):
|
434 |
+
# random pose on the fly
|
435 |
+
poses, thetas, phis, radius = rand_poses(size, opt, radius_range=[opt.default_radius,opt.default_radius+0.1], theta_range=opt.theta_range, phi_range=opt.phi_range, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front, uniform_sphere_rate=opt.uniform_sphere_rate)
|
436 |
+
# delta polar/azimuth/radius to default view
|
437 |
+
delta_polar = thetas - opt.default_polar
|
438 |
+
delta_azimuth = phis - opt.default_azimuth
|
439 |
+
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
|
440 |
+
delta_radius = radius - opt.default_radius
|
441 |
+
# random focal
|
442 |
+
#fov = random.random() * (opt.fovy_range[1] - opt.fovy_range[0]) + opt.fovy_range[0]
|
443 |
+
fov = opt.default_fovy
|
444 |
+
cam_infos = []
|
445 |
+
#generate specific data structure
|
446 |
+
for idx in range(size):
|
447 |
+
matrix = np.linalg.inv(poses[idx])
|
448 |
+
R = -np.transpose(matrix[:3,:3])
|
449 |
+
R[:,0] = -R[:,0]
|
450 |
+
T = -matrix[:3, 3]
|
451 |
+
# matrix = poses[idx]
|
452 |
+
# R = matrix[:3,:3]
|
453 |
+
# T = matrix[:3, 3]
|
454 |
+
fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w)
|
455 |
+
FovY = fovy
|
456 |
+
FovX = fov
|
457 |
+
|
458 |
+
cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w,
|
459 |
+
height = opt.image_h, delta_polar = delta_polar[idx],delta_azimuth = delta_azimuth[idx], delta_radius = delta_radius[idx]))
|
460 |
+
return cam_infos
|
461 |
+
|
462 |
+
sceneLoadTypeCallbacks = {
|
463 |
+
# "Colmap": readColmapSceneInfo,
|
464 |
+
# "Blender" : readNerfSyntheticInfo,
|
465 |
+
"RandomCam" : readCircleCamInfo
|
466 |
+
}
|
scene/gaussian_model.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
15 |
+
from torch import nn
|
16 |
+
import os
|
17 |
+
from utils.system_utils import mkdir_p
|
18 |
+
from plyfile import PlyData, PlyElement
|
19 |
+
from utils.sh_utils import RGB2SH,SH2RGB
|
20 |
+
from simple_knn._C import distCUDA2
|
21 |
+
from utils.graphics_utils import BasicPointCloud
|
22 |
+
from utils.general_utils import strip_symmetric, build_scaling_rotation
|
23 |
+
# from .resnet import *
|
24 |
+
|
25 |
+
class GaussianModel:
|
26 |
+
|
27 |
+
def setup_functions(self):
|
28 |
+
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
29 |
+
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
30 |
+
actual_covariance = L @ L.transpose(1, 2)
|
31 |
+
symm = strip_symmetric(actual_covariance)
|
32 |
+
return symm
|
33 |
+
|
34 |
+
self.scaling_activation = torch.exp
|
35 |
+
self.scaling_inverse_activation = torch.log
|
36 |
+
|
37 |
+
self.covariance_activation = build_covariance_from_scaling_rotation
|
38 |
+
|
39 |
+
self.opacity_activation = torch.sigmoid
|
40 |
+
self.inverse_opacity_activation = inverse_sigmoid
|
41 |
+
|
42 |
+
self.rotation_activation = torch.nn.functional.normalize
|
43 |
+
|
44 |
+
|
45 |
+
def __init__(self, sh_degree : int):
|
46 |
+
self.active_sh_degree = 0
|
47 |
+
self.max_sh_degree = sh_degree
|
48 |
+
self._xyz = torch.empty(0)
|
49 |
+
self._features_dc = torch.empty(0)
|
50 |
+
self._features_rest = torch.empty(0)
|
51 |
+
self._scaling = torch.empty(0)
|
52 |
+
self._rotation = torch.empty(0)
|
53 |
+
self._opacity = torch.empty(0)
|
54 |
+
self._background = torch.empty(0)
|
55 |
+
self.max_radii2D = torch.empty(0)
|
56 |
+
self.xyz_gradient_accum = torch.empty(0)
|
57 |
+
self.denom = torch.empty(0)
|
58 |
+
self.optimizer = None
|
59 |
+
self.percent_dense = 0
|
60 |
+
self.spatial_lr_scale = 0
|
61 |
+
self.setup_functions()
|
62 |
+
|
63 |
+
def capture(self):
|
64 |
+
return (
|
65 |
+
self.active_sh_degree,
|
66 |
+
self._xyz,
|
67 |
+
self._features_dc,
|
68 |
+
self._features_rest,
|
69 |
+
self._scaling,
|
70 |
+
self._rotation,
|
71 |
+
self._opacity,
|
72 |
+
self.max_radii2D,
|
73 |
+
self.xyz_gradient_accum,
|
74 |
+
self.denom,
|
75 |
+
self.optimizer.state_dict(),
|
76 |
+
self.spatial_lr_scale,
|
77 |
+
)
|
78 |
+
|
79 |
+
def restore(self, model_args, training_args):
|
80 |
+
(self.active_sh_degree,
|
81 |
+
self._xyz,
|
82 |
+
self._features_dc,
|
83 |
+
self._features_rest,
|
84 |
+
self._scaling,
|
85 |
+
self._rotation,
|
86 |
+
self._opacity,
|
87 |
+
self.max_radii2D,
|
88 |
+
xyz_gradient_accum,
|
89 |
+
denom,
|
90 |
+
opt_dict,
|
91 |
+
self.spatial_lr_scale) = model_args
|
92 |
+
self.training_setup(training_args)
|
93 |
+
self.xyz_gradient_accum = xyz_gradient_accum
|
94 |
+
self.denom = denom
|
95 |
+
self.optimizer.load_state_dict(opt_dict)
|
96 |
+
|
97 |
+
@property
|
98 |
+
def get_scaling(self):
|
99 |
+
return self.scaling_activation(self._scaling)
|
100 |
+
|
101 |
+
@property
|
102 |
+
def get_rotation(self):
|
103 |
+
return self.rotation_activation(self._rotation)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def get_xyz(self):
|
107 |
+
return self._xyz
|
108 |
+
|
109 |
+
@property
|
110 |
+
def get_background(self):
|
111 |
+
return torch.sigmoid(self._background)
|
112 |
+
|
113 |
+
@property
|
114 |
+
def get_features(self):
|
115 |
+
features_dc = self._features_dc
|
116 |
+
features_rest = self._features_rest
|
117 |
+
return torch.cat((features_dc, features_rest), dim=1)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def get_opacity(self):
|
121 |
+
return self.opacity_activation(self._opacity)
|
122 |
+
|
123 |
+
def get_covariance(self, scaling_modifier = 1):
|
124 |
+
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
125 |
+
|
126 |
+
def oneupSHdegree(self):
|
127 |
+
if self.active_sh_degree < self.max_sh_degree:
|
128 |
+
self.active_sh_degree += 1
|
129 |
+
|
130 |
+
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
|
131 |
+
self.spatial_lr_scale = spatial_lr_scale
|
132 |
+
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
133 |
+
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors))).float().cuda() #RGB2SH(
|
134 |
+
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
135 |
+
features[:, :3, 0 ] = fused_color
|
136 |
+
features[:, 3:, 1:] = 0.0
|
137 |
+
|
138 |
+
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
139 |
+
|
140 |
+
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
141 |
+
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
142 |
+
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
143 |
+
rots[:, 0] = 1
|
144 |
+
|
145 |
+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
146 |
+
|
147 |
+
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
148 |
+
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
149 |
+
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
150 |
+
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
151 |
+
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
152 |
+
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
153 |
+
self._background = nn.Parameter(torch.zeros((3,1,1), device="cuda").requires_grad_(True))
|
154 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
155 |
+
|
156 |
+
def training_setup(self, training_args):
|
157 |
+
self.percent_dense = training_args.percent_dense
|
158 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
159 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
160 |
+
|
161 |
+
l = [
|
162 |
+
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
|
163 |
+
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
|
164 |
+
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
|
165 |
+
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
|
166 |
+
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
|
167 |
+
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
|
168 |
+
{'params': [self._background], 'lr': training_args.feature_lr, "name": "background"},
|
169 |
+
]
|
170 |
+
|
171 |
+
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
172 |
+
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
173 |
+
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
174 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
175 |
+
max_steps=training_args.iterations)
|
176 |
+
|
177 |
+
|
178 |
+
self.rotation_scheduler_args = get_expon_lr_func(lr_init=training_args.rotation_lr,
|
179 |
+
lr_final=training_args.rotation_lr_final,
|
180 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
181 |
+
max_steps=training_args.iterations)
|
182 |
+
|
183 |
+
self.scaling_scheduler_args = get_expon_lr_func(lr_init=training_args.scaling_lr,
|
184 |
+
lr_final=training_args.scaling_lr_final,
|
185 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
186 |
+
max_steps=training_args.iterations)
|
187 |
+
|
188 |
+
self.feature_scheduler_args = get_expon_lr_func(lr_init=training_args.feature_lr,
|
189 |
+
lr_final=training_args.feature_lr_final,
|
190 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
191 |
+
max_steps=training_args.iterations)
|
192 |
+
def update_learning_rate(self, iteration):
|
193 |
+
''' Learning rate scheduling per step '''
|
194 |
+
for param_group in self.optimizer.param_groups:
|
195 |
+
if param_group["name"] == "xyz":
|
196 |
+
lr = self.xyz_scheduler_args(iteration)
|
197 |
+
param_group['lr'] = lr
|
198 |
+
return lr
|
199 |
+
|
200 |
+
def update_feature_learning_rate(self, iteration):
|
201 |
+
''' Learning rate scheduling per step '''
|
202 |
+
for param_group in self.optimizer.param_groups:
|
203 |
+
if param_group["name"] == "f_dc":
|
204 |
+
lr = self.feature_scheduler_args(iteration)
|
205 |
+
param_group['lr'] = lr
|
206 |
+
return lr
|
207 |
+
|
208 |
+
def update_rotation_learning_rate(self, iteration):
|
209 |
+
''' Learning rate scheduling per step '''
|
210 |
+
for param_group in self.optimizer.param_groups:
|
211 |
+
if param_group["name"] == "rotation":
|
212 |
+
lr = self.rotation_scheduler_args(iteration)
|
213 |
+
param_group['lr'] = lr
|
214 |
+
return lr
|
215 |
+
|
216 |
+
def update_scaling_learning_rate(self, iteration):
|
217 |
+
''' Learning rate scheduling per step '''
|
218 |
+
for param_group in self.optimizer.param_groups:
|
219 |
+
if param_group["name"] == "scaling":
|
220 |
+
lr = self.scaling_scheduler_args(iteration)
|
221 |
+
param_group['lr'] = lr
|
222 |
+
return lr
|
223 |
+
|
224 |
+
|
225 |
+
def construct_list_of_attributes(self):
|
226 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
227 |
+
# All channels except the 3 DC
|
228 |
+
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
229 |
+
l.append('f_dc_{}'.format(i))
|
230 |
+
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
|
231 |
+
l.append('f_rest_{}'.format(i))
|
232 |
+
l.append('opacity')
|
233 |
+
for i in range(self._scaling.shape[1]):
|
234 |
+
l.append('scale_{}'.format(i))
|
235 |
+
for i in range(self._rotation.shape[1]):
|
236 |
+
l.append('rot_{}'.format(i))
|
237 |
+
return l
|
238 |
+
|
239 |
+
def save_ply(self, path):
|
240 |
+
mkdir_p(os.path.dirname(path))
|
241 |
+
|
242 |
+
xyz = self._xyz.detach().cpu().numpy()
|
243 |
+
normals = np.zeros_like(xyz)
|
244 |
+
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
245 |
+
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
246 |
+
opacities = self._opacity.detach().cpu().numpy()
|
247 |
+
scale = self._scaling.detach().cpu().numpy()
|
248 |
+
rotation = self._rotation.detach().cpu().numpy()
|
249 |
+
|
250 |
+
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
251 |
+
|
252 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
253 |
+
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
254 |
+
elements[:] = list(map(tuple, attributes))
|
255 |
+
el = PlyElement.describe(elements, 'vertex')
|
256 |
+
PlyData([el]).write(path)
|
257 |
+
np.savetxt(os.path.join(os.path.split(path)[0],"point_cloud_rgb.txt"),np.concatenate((xyz, SH2RGB(f_dc)), axis=1))
|
258 |
+
|
259 |
+
def reset_opacity(self):
|
260 |
+
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
261 |
+
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
262 |
+
self._opacity = optimizable_tensors["opacity"]
|
263 |
+
|
264 |
+
def load_ply(self, path):
|
265 |
+
plydata = PlyData.read(path)
|
266 |
+
|
267 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
268 |
+
np.asarray(plydata.elements[0]["y"]),
|
269 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
270 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
271 |
+
|
272 |
+
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
273 |
+
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
274 |
+
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
275 |
+
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
276 |
+
|
277 |
+
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
278 |
+
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
279 |
+
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
280 |
+
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
281 |
+
for idx, attr_name in enumerate(extra_f_names):
|
282 |
+
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
283 |
+
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
284 |
+
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
285 |
+
|
286 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
287 |
+
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
288 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
289 |
+
for idx, attr_name in enumerate(scale_names):
|
290 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
291 |
+
|
292 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
293 |
+
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
294 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
295 |
+
for idx, attr_name in enumerate(rot_names):
|
296 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
297 |
+
|
298 |
+
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
299 |
+
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
300 |
+
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
301 |
+
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
302 |
+
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
303 |
+
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
304 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
305 |
+
self.active_sh_degree = self.max_sh_degree
|
306 |
+
|
307 |
+
def replace_tensor_to_optimizer(self, tensor, name):
|
308 |
+
optimizable_tensors = {}
|
309 |
+
for group in self.optimizer.param_groups:
|
310 |
+
if group["name"] not in ['background']:
|
311 |
+
if group["name"] == name:
|
312 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
313 |
+
stored_state["exp_avg"] = torch.zeros_like(tensor)
|
314 |
+
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
|
315 |
+
|
316 |
+
del self.optimizer.state[group['params'][0]]
|
317 |
+
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
|
318 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
319 |
+
|
320 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
321 |
+
return optimizable_tensors
|
322 |
+
|
323 |
+
def _prune_optimizer(self, mask):
|
324 |
+
optimizable_tensors = {}
|
325 |
+
for group in self.optimizer.param_groups:
|
326 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
327 |
+
if group["name"] not in ['background']:
|
328 |
+
if stored_state is not None:
|
329 |
+
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
330 |
+
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
331 |
+
|
332 |
+
del self.optimizer.state[group['params'][0]]
|
333 |
+
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
334 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
335 |
+
|
336 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
337 |
+
else:
|
338 |
+
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
339 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
340 |
+
return optimizable_tensors
|
341 |
+
|
342 |
+
def prune_points(self, mask):
|
343 |
+
valid_points_mask = ~mask
|
344 |
+
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
345 |
+
|
346 |
+
self._xyz = optimizable_tensors["xyz"]
|
347 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
348 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
349 |
+
self._opacity = optimizable_tensors["opacity"]
|
350 |
+
self._scaling = optimizable_tensors["scaling"]
|
351 |
+
self._rotation = optimizable_tensors["rotation"]
|
352 |
+
|
353 |
+
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
|
354 |
+
|
355 |
+
self.denom = self.denom[valid_points_mask]
|
356 |
+
self.max_radii2D = self.max_radii2D[valid_points_mask]
|
357 |
+
|
358 |
+
def cat_tensors_to_optimizer(self, tensors_dict):
|
359 |
+
optimizable_tensors = {}
|
360 |
+
for group in self.optimizer.param_groups:
|
361 |
+
if group["name"] not in ['background']:
|
362 |
+
assert len(group["params"]) == 1
|
363 |
+
extension_tensor = tensors_dict[group["name"]]
|
364 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
365 |
+
if stored_state is not None:
|
366 |
+
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
367 |
+
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
368 |
+
|
369 |
+
del self.optimizer.state[group['params'][0]]
|
370 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
371 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
372 |
+
|
373 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
374 |
+
else:
|
375 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
376 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
377 |
+
|
378 |
+
return optimizable_tensors
|
379 |
+
|
380 |
+
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
|
381 |
+
d = {"xyz": new_xyz,
|
382 |
+
"f_dc": new_features_dc,
|
383 |
+
"f_rest": new_features_rest,
|
384 |
+
"opacity": new_opacities,
|
385 |
+
"scaling" : new_scaling,
|
386 |
+
"rotation" : new_rotation}
|
387 |
+
|
388 |
+
optimizable_tensors = self.cat_tensors_to_optimizer(d)
|
389 |
+
self._xyz = optimizable_tensors["xyz"]
|
390 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
391 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
392 |
+
self._opacity = optimizable_tensors["opacity"]
|
393 |
+
self._scaling = optimizable_tensors["scaling"]
|
394 |
+
self._rotation = optimizable_tensors["rotation"]
|
395 |
+
|
396 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
397 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
398 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
399 |
+
|
400 |
+
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
|
401 |
+
n_init_points = self.get_xyz.shape[0]
|
402 |
+
# Extract points that satisfy the gradient condition
|
403 |
+
padded_grad = torch.zeros((n_init_points), device="cuda")
|
404 |
+
padded_grad[:grads.shape[0]] = grads.squeeze()
|
405 |
+
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
|
406 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
407 |
+
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
|
408 |
+
|
409 |
+
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
|
410 |
+
means =torch.zeros((stds.size(0), 3),device="cuda")
|
411 |
+
samples = torch.normal(mean=means, std=stds)
|
412 |
+
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
|
413 |
+
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
|
414 |
+
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
|
415 |
+
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
|
416 |
+
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
|
417 |
+
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
|
418 |
+
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
|
419 |
+
|
420 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
|
421 |
+
|
422 |
+
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
|
423 |
+
self.prune_points(prune_filter)
|
424 |
+
|
425 |
+
def densify_and_clone(self, grads, grad_threshold, scene_extent):
|
426 |
+
# Extract points that satisfy the gradient condition
|
427 |
+
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
|
428 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
429 |
+
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
|
430 |
+
|
431 |
+
new_xyz = self._xyz[selected_pts_mask]
|
432 |
+
new_features_dc = self._features_dc[selected_pts_mask]
|
433 |
+
new_features_rest = self._features_rest[selected_pts_mask]
|
434 |
+
new_opacities = self._opacity[selected_pts_mask]
|
435 |
+
new_scaling = self._scaling[selected_pts_mask]
|
436 |
+
new_rotation = self._rotation[selected_pts_mask]
|
437 |
+
|
438 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
|
439 |
+
|
440 |
+
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
|
441 |
+
grads = self.xyz_gradient_accum / self.denom
|
442 |
+
grads[grads.isnan()] = 0.0
|
443 |
+
|
444 |
+
self.densify_and_clone(grads, max_grad, extent)
|
445 |
+
self.densify_and_split(grads, max_grad, extent)
|
446 |
+
|
447 |
+
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
448 |
+
if max_screen_size:
|
449 |
+
big_points_vs = self.max_radii2D > max_screen_size
|
450 |
+
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
451 |
+
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
452 |
+
self.prune_points(prune_mask)
|
453 |
+
|
454 |
+
torch.cuda.empty_cache()
|
455 |
+
|
456 |
+
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
457 |
+
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
458 |
+
self.denom[update_filter] += 1
|
train.py
ADDED
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import random
|
13 |
+
import imageio
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from random import randint
|
18 |
+
from utils.loss_utils import l1_loss, ssim, tv_loss
|
19 |
+
from gaussian_renderer import render, network_gui
|
20 |
+
import sys
|
21 |
+
from scene import Scene, GaussianModel
|
22 |
+
from utils.general_utils import safe_state
|
23 |
+
import uuid
|
24 |
+
from tqdm import tqdm
|
25 |
+
from utils.image_utils import psnr
|
26 |
+
from argparse import ArgumentParser, Namespace
|
27 |
+
from arguments import ModelParams, PipelineParams, OptimizationParams, GenerateCamParams, GuidanceParams
|
28 |
+
import math
|
29 |
+
import yaml
|
30 |
+
from torchvision.utils import save_image
|
31 |
+
import torchvision.transforms as T
|
32 |
+
|
33 |
+
try:
|
34 |
+
from torch.utils.tensorboard import SummaryWriter
|
35 |
+
TENSORBOARD_FOUND = True
|
36 |
+
except ImportError:
|
37 |
+
TENSORBOARD_FOUND = False
|
38 |
+
|
39 |
+
sys.path.append('/root/yangxin/codebase/3D_Playground/GSDF')
|
40 |
+
|
41 |
+
|
42 |
+
def adjust_text_embeddings(embeddings, azimuth, guidance_opt):
|
43 |
+
#TODO: add prenerg functions
|
44 |
+
text_z_list = []
|
45 |
+
weights_list = []
|
46 |
+
K = 0
|
47 |
+
#for b in range(azimuth):
|
48 |
+
text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth, guidance_opt)
|
49 |
+
K = max(K, weights_.shape[0])
|
50 |
+
text_z_list.append(text_z_)
|
51 |
+
weights_list.append(weights_)
|
52 |
+
|
53 |
+
# Interleave text_embeddings from different dirs to form a batch
|
54 |
+
text_embeddings = []
|
55 |
+
for i in range(K):
|
56 |
+
for text_z in text_z_list:
|
57 |
+
# if uneven length, pad with the first embedding
|
58 |
+
text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0])
|
59 |
+
text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768]
|
60 |
+
|
61 |
+
# Interleave weights from different dirs to form a batch
|
62 |
+
weights = []
|
63 |
+
for i in range(K):
|
64 |
+
for weights_ in weights_list:
|
65 |
+
weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0]))
|
66 |
+
weights = torch.stack(weights, dim=0) # [B * K]
|
67 |
+
return text_embeddings, weights
|
68 |
+
|
69 |
+
def get_pos_neg_text_embeddings(embeddings, azimuth_val, opt):
|
70 |
+
if azimuth_val >= -90 and azimuth_val < 90:
|
71 |
+
if azimuth_val >= 0:
|
72 |
+
r = 1 - azimuth_val / 90
|
73 |
+
else:
|
74 |
+
r = 1 + azimuth_val / 90
|
75 |
+
start_z = embeddings['front']
|
76 |
+
end_z = embeddings['side']
|
77 |
+
# if random.random() < 0.3:
|
78 |
+
# r = r + random.gauss(0, 0.08)
|
79 |
+
pos_z = r * start_z + (1 - r) * end_z
|
80 |
+
text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0)
|
81 |
+
if r > 0.8:
|
82 |
+
front_neg_w = 0.0
|
83 |
+
else:
|
84 |
+
front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w
|
85 |
+
if r < 0.2:
|
86 |
+
side_neg_w = 0.0
|
87 |
+
else:
|
88 |
+
side_neg_w = math.exp(-(1-r) * opt.side_decay_factor) * opt.negative_w
|
89 |
+
|
90 |
+
weights = torch.tensor([1.0, front_neg_w, side_neg_w])
|
91 |
+
else:
|
92 |
+
if azimuth_val >= 0:
|
93 |
+
r = 1 - (azimuth_val - 90) / 90
|
94 |
+
else:
|
95 |
+
r = 1 + (azimuth_val + 90) / 90
|
96 |
+
start_z = embeddings['side']
|
97 |
+
end_z = embeddings['back']
|
98 |
+
# if random.random() < 0.3:
|
99 |
+
# r = r + random.gauss(0, 0.08)
|
100 |
+
pos_z = r * start_z + (1 - r) * end_z
|
101 |
+
text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0)
|
102 |
+
front_neg_w = opt.negative_w
|
103 |
+
if r > 0.8:
|
104 |
+
side_neg_w = 0.0
|
105 |
+
else:
|
106 |
+
side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2
|
107 |
+
|
108 |
+
weights = torch.tensor([1.0, side_neg_w, front_neg_w])
|
109 |
+
return text_z, weights.to(text_z.device)
|
110 |
+
|
111 |
+
def prepare_embeddings(guidance_opt, guidance):
|
112 |
+
embeddings = {}
|
113 |
+
# text embeddings (stable-diffusion) and (IF)
|
114 |
+
embeddings['default'] = guidance.get_text_embeds([guidance_opt.text])
|
115 |
+
embeddings['uncond'] = guidance.get_text_embeds([guidance_opt.negative])
|
116 |
+
|
117 |
+
for d in ['front', 'side', 'back']:
|
118 |
+
embeddings[d] = guidance.get_text_embeds([f"{guidance_opt.text}, {d} view"])
|
119 |
+
embeddings['inverse_text'] = guidance.get_text_embeds(guidance_opt.inverse_text)
|
120 |
+
return embeddings
|
121 |
+
|
122 |
+
def guidance_setup(guidance_opt):
|
123 |
+
if guidance_opt.guidance=="SD":
|
124 |
+
from guidance.sd_utils import StableDiffusion
|
125 |
+
guidance = StableDiffusion(guidance_opt.g_device, guidance_opt.fp16, guidance_opt.vram_O,
|
126 |
+
guidance_opt.t_range, guidance_opt.max_t_range,
|
127 |
+
num_train_timesteps=guidance_opt.num_train_timesteps,
|
128 |
+
ddim_inv=guidance_opt.ddim_inv,
|
129 |
+
textual_inversion_path = guidance_opt.textual_inversion_path,
|
130 |
+
LoRA_path = guidance_opt.LoRA_path,
|
131 |
+
guidance_opt=guidance_opt)
|
132 |
+
else:
|
133 |
+
raise ValueError(f'{guidance_opt.guidance} not supported.')
|
134 |
+
if guidance is not None:
|
135 |
+
for p in guidance.parameters():
|
136 |
+
p.requires_grad = False
|
137 |
+
embeddings = prepare_embeddings(guidance_opt, guidance)
|
138 |
+
return guidance, embeddings
|
139 |
+
|
140 |
+
|
141 |
+
def training(dataset, opt, pipe, gcams, guidance_opt, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, save_video):
|
142 |
+
first_iter = 0
|
143 |
+
tb_writer = prepare_output_and_logger(dataset)
|
144 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
145 |
+
scene = Scene(dataset, gcams, gaussians)
|
146 |
+
gaussians.training_setup(opt)
|
147 |
+
if checkpoint:
|
148 |
+
(model_params, first_iter) = torch.load(checkpoint)
|
149 |
+
gaussians.restore(model_params, opt)
|
150 |
+
|
151 |
+
bg_color = [1, 1, 1] if dataset._white_background else [0, 0, 0]
|
152 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device=dataset.data_device)
|
153 |
+
iter_start = torch.cuda.Event(enable_timing = True)
|
154 |
+
iter_end = torch.cuda.Event(enable_timing = True)
|
155 |
+
|
156 |
+
#
|
157 |
+
save_folder = os.path.join(dataset._model_path,"train_process/")
|
158 |
+
if not os.path.exists(save_folder):
|
159 |
+
os.makedirs(save_folder) # makedirs
|
160 |
+
print('train_process is in :', save_folder)
|
161 |
+
#controlnet
|
162 |
+
use_control_net = False
|
163 |
+
#set up pretrain diffusion models and text_embedings
|
164 |
+
guidance, embeddings = guidance_setup(guidance_opt)
|
165 |
+
viewpoint_stack = None
|
166 |
+
viewpoint_stack_around = None
|
167 |
+
ema_loss_for_log = 0.0
|
168 |
+
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
169 |
+
first_iter += 1
|
170 |
+
|
171 |
+
if opt.save_process:
|
172 |
+
save_folder_proc = os.path.join(scene.args._model_path,"process_videos/")
|
173 |
+
if not os.path.exists(save_folder_proc):
|
174 |
+
os.makedirs(save_folder_proc) # makedirs
|
175 |
+
process_view_points = scene.getCircleVideoCameras(batch_size=opt.pro_frames_num,render45=opt.pro_render_45).copy()
|
176 |
+
save_process_iter = opt.iterations // len(process_view_points)
|
177 |
+
pro_img_frames = []
|
178 |
+
|
179 |
+
for iteration in range(first_iter, opt.iterations + 1):
|
180 |
+
#TODO: DEBUG NETWORK_GUI
|
181 |
+
if network_gui.conn == None:
|
182 |
+
network_gui.try_connect()
|
183 |
+
while network_gui.conn != None:
|
184 |
+
try:
|
185 |
+
net_image_bytes = None
|
186 |
+
custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
|
187 |
+
if custom_cam != None:
|
188 |
+
net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
|
189 |
+
net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
|
190 |
+
network_gui.send(net_image_bytes, guidance_opt.text)
|
191 |
+
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
|
192 |
+
break
|
193 |
+
except Exception as e:
|
194 |
+
network_gui.conn = None
|
195 |
+
|
196 |
+
iter_start.record()
|
197 |
+
|
198 |
+
gaussians.update_learning_rate(iteration)
|
199 |
+
gaussians.update_feature_learning_rate(iteration)
|
200 |
+
gaussians.update_rotation_learning_rate(iteration)
|
201 |
+
gaussians.update_scaling_learning_rate(iteration)
|
202 |
+
# Every 500 its we increase the levels of SH up to a maximum degree
|
203 |
+
if iteration % 500 == 0:
|
204 |
+
gaussians.oneupSHdegree()
|
205 |
+
|
206 |
+
# progressively relaxing view range
|
207 |
+
if not opt.use_progressive:
|
208 |
+
if iteration >= opt.progressive_view_iter and iteration % opt.scale_up_cameras_iter == 0:
|
209 |
+
scene.pose_args.fovy_range[0] = max(scene.pose_args.max_fovy_range[0], scene.pose_args.fovy_range[0] * opt.fovy_scale_up_factor[0])
|
210 |
+
scene.pose_args.fovy_range[1] = min(scene.pose_args.max_fovy_range[1], scene.pose_args.fovy_range[1] * opt.fovy_scale_up_factor[1])
|
211 |
+
|
212 |
+
scene.pose_args.radius_range[1] = max(scene.pose_args.max_radius_range[1], scene.pose_args.radius_range[1] * opt.scale_up_factor)
|
213 |
+
scene.pose_args.radius_range[0] = max(scene.pose_args.max_radius_range[0], scene.pose_args.radius_range[0] * opt.scale_up_factor)
|
214 |
+
|
215 |
+
scene.pose_args.theta_range[1] = min(scene.pose_args.max_theta_range[1], scene.pose_args.theta_range[1] * opt.phi_scale_up_factor)
|
216 |
+
scene.pose_args.theta_range[0] = max(scene.pose_args.max_theta_range[0], scene.pose_args.theta_range[0] * 1/opt.phi_scale_up_factor)
|
217 |
+
|
218 |
+
# opt.reset_resnet_iter = max(500, opt.reset_resnet_iter // 1.25)
|
219 |
+
scene.pose_args.phi_range[0] = max(scene.pose_args.max_phi_range[0] , scene.pose_args.phi_range[0] * opt.phi_scale_up_factor)
|
220 |
+
scene.pose_args.phi_range[1] = min(scene.pose_args.max_phi_range[1], scene.pose_args.phi_range[1] * opt.phi_scale_up_factor)
|
221 |
+
|
222 |
+
print('scale up theta_range to:', scene.pose_args.theta_range)
|
223 |
+
print('scale up radius_range to:', scene.pose_args.radius_range)
|
224 |
+
print('scale up phi_range to:', scene.pose_args.phi_range)
|
225 |
+
print('scale up fovy_range to:', scene.pose_args.fovy_range)
|
226 |
+
|
227 |
+
# Pick a random Camera
|
228 |
+
if not viewpoint_stack:
|
229 |
+
viewpoint_stack = scene.getRandTrainCameras().copy()
|
230 |
+
|
231 |
+
C_batch_size = guidance_opt.C_batch_size
|
232 |
+
viewpoint_cams = []
|
233 |
+
images = []
|
234 |
+
text_z_ = []
|
235 |
+
weights_ = []
|
236 |
+
depths = []
|
237 |
+
alphas = []
|
238 |
+
scales = []
|
239 |
+
|
240 |
+
text_z_inverse =torch.cat([embeddings['uncond'],embeddings['inverse_text']], dim=0)
|
241 |
+
|
242 |
+
for i in range(C_batch_size):
|
243 |
+
try:
|
244 |
+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
245 |
+
except:
|
246 |
+
viewpoint_stack = scene.getRandTrainCameras().copy()
|
247 |
+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
248 |
+
|
249 |
+
#pred text_z
|
250 |
+
azimuth = viewpoint_cam.delta_azimuth
|
251 |
+
text_z = [embeddings['uncond']]
|
252 |
+
|
253 |
+
|
254 |
+
if guidance_opt.perpneg:
|
255 |
+
text_z_comp, weights = adjust_text_embeddings(embeddings, azimuth, guidance_opt)
|
256 |
+
text_z.append(text_z_comp)
|
257 |
+
weights_.append(weights)
|
258 |
+
|
259 |
+
else:
|
260 |
+
if azimuth >= -90 and azimuth < 90:
|
261 |
+
if azimuth >= 0:
|
262 |
+
r = 1 - azimuth / 90
|
263 |
+
else:
|
264 |
+
r = 1 + azimuth / 90
|
265 |
+
start_z = embeddings['front']
|
266 |
+
end_z = embeddings['side']
|
267 |
+
else:
|
268 |
+
if azimuth >= 0:
|
269 |
+
r = 1 - (azimuth - 90) / 90
|
270 |
+
else:
|
271 |
+
r = 1 + (azimuth + 90) / 90
|
272 |
+
start_z = embeddings['side']
|
273 |
+
end_z = embeddings['back']
|
274 |
+
text_z.append(r * start_z + (1 - r) * end_z)
|
275 |
+
|
276 |
+
text_z = torch.cat(text_z, dim=0)
|
277 |
+
text_z_.append(text_z)
|
278 |
+
|
279 |
+
# Render
|
280 |
+
if (iteration - 1) == debug_from:
|
281 |
+
pipe.debug = True
|
282 |
+
render_pkg = render(viewpoint_cam, gaussians, pipe, background,
|
283 |
+
sh_deg_aug_ratio = dataset.sh_deg_aug_ratio,
|
284 |
+
bg_aug_ratio = dataset.bg_aug_ratio,
|
285 |
+
shs_aug_ratio = dataset.shs_aug_ratio,
|
286 |
+
scale_aug_ratio = dataset.scale_aug_ratio)
|
287 |
+
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
288 |
+
depth, alpha = render_pkg["depth"], render_pkg["alpha"]
|
289 |
+
|
290 |
+
scales.append(render_pkg["scales"])
|
291 |
+
images.append(image)
|
292 |
+
depths.append(depth)
|
293 |
+
alphas.append(alpha)
|
294 |
+
viewpoint_cams.append(viewpoint_cams)
|
295 |
+
|
296 |
+
images = torch.stack(images, dim=0)
|
297 |
+
depths = torch.stack(depths, dim=0)
|
298 |
+
alphas = torch.stack(alphas, dim=0)
|
299 |
+
|
300 |
+
# Loss
|
301 |
+
warm_up_rate = 1. - min(iteration/opt.warmup_iter,1.)
|
302 |
+
guidance_scale = guidance_opt.guidance_scale
|
303 |
+
_aslatent = False
|
304 |
+
if iteration < opt.geo_iter or random.random()< opt.as_latent_ratio:
|
305 |
+
_aslatent=True
|
306 |
+
if iteration > opt.use_control_net_iter and (random.random() < guidance_opt.controlnet_ratio):
|
307 |
+
use_control_net = True
|
308 |
+
if guidance_opt.perpneg:
|
309 |
+
loss = guidance.train_step_perpneg(torch.stack(text_z_, dim=1), images,
|
310 |
+
pred_depth=depths, pred_alpha=alphas,
|
311 |
+
grad_scale=guidance_opt.lambda_guidance,
|
312 |
+
use_control_net = use_control_net ,save_folder = save_folder, iteration = iteration, warm_up_rate=warm_up_rate,
|
313 |
+
weights = torch.stack(weights_, dim=1), resolution=(gcams.image_h, gcams.image_w),
|
314 |
+
guidance_opt=guidance_opt,as_latent=_aslatent, embedding_inverse = text_z_inverse)
|
315 |
+
else:
|
316 |
+
loss = guidance.train_step(torch.stack(text_z_, dim=1), images,
|
317 |
+
pred_depth=depths, pred_alpha=alphas,
|
318 |
+
grad_scale=guidance_opt.lambda_guidance,
|
319 |
+
use_control_net = use_control_net ,save_folder = save_folder, iteration = iteration, warm_up_rate=warm_up_rate,
|
320 |
+
resolution=(gcams.image_h, gcams.image_w),
|
321 |
+
guidance_opt=guidance_opt,as_latent=_aslatent, embedding_inverse = text_z_inverse)
|
322 |
+
#raise ValueError(f'original version not supported.')
|
323 |
+
scales = torch.stack(scales, dim=0)
|
324 |
+
|
325 |
+
loss_scale = torch.mean(scales,dim=-1).mean()
|
326 |
+
loss_tv = tv_loss(images) + tv_loss(depths)
|
327 |
+
# loss_bin = torch.mean(torch.min(alphas - 0.0001, 1 - alphas))
|
328 |
+
|
329 |
+
loss = loss + opt.lambda_tv * loss_tv + opt.lambda_scale * loss_scale #opt.lambda_tv * loss_tv + opt.lambda_bin * loss_bin + opt.lambda_scale * loss_scale +
|
330 |
+
loss.backward()
|
331 |
+
iter_end.record()
|
332 |
+
|
333 |
+
with torch.no_grad():
|
334 |
+
# Progress bar
|
335 |
+
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
336 |
+
if opt.save_process:
|
337 |
+
if iteration % save_process_iter == 0 and len(process_view_points) > 0:
|
338 |
+
viewpoint_cam_p = process_view_points.pop(0)
|
339 |
+
render_p = render(viewpoint_cam_p, gaussians, pipe, background, test=True)
|
340 |
+
img_p = torch.clamp(render_p["render"], 0.0, 1.0)
|
341 |
+
img_p = img_p.detach().cpu().permute(1,2,0).numpy()
|
342 |
+
img_p = (img_p * 255).round().astype('uint8')
|
343 |
+
pro_img_frames.append(img_p)
|
344 |
+
|
345 |
+
if iteration % 10 == 0:
|
346 |
+
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
|
347 |
+
progress_bar.update(10)
|
348 |
+
if iteration == opt.iterations:
|
349 |
+
progress_bar.close()
|
350 |
+
|
351 |
+
# Log and save
|
352 |
+
training_report(tb_writer, iteration, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
353 |
+
if (iteration in testing_iterations):
|
354 |
+
if save_video:
|
355 |
+
video_path = video_inference(iteration, scene, render, (pipe, background))
|
356 |
+
|
357 |
+
if (iteration in saving_iterations):
|
358 |
+
print("\n[ITER {}] Saving Gaussians".format(iteration))
|
359 |
+
scene.save(iteration)
|
360 |
+
|
361 |
+
# Densification
|
362 |
+
if iteration < opt.densify_until_iter:
|
363 |
+
# Keep track of max radii in image-space for pruning
|
364 |
+
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
365 |
+
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
366 |
+
|
367 |
+
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
368 |
+
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
369 |
+
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
370 |
+
|
371 |
+
if iteration % opt.opacity_reset_interval == 0: #or (dataset._white_background and iteration == opt.densify_from_iter)
|
372 |
+
gaussians.reset_opacity()
|
373 |
+
|
374 |
+
# Optimizer step
|
375 |
+
if iteration < opt.iterations:
|
376 |
+
gaussians.optimizer.step()
|
377 |
+
gaussians.optimizer.zero_grad(set_to_none = True)
|
378 |
+
|
379 |
+
if (iteration in checkpoint_iterations):
|
380 |
+
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
381 |
+
torch.save((gaussians.capture(), iteration), scene._model_path + "/chkpnt" + str(iteration) + ".pth")
|
382 |
+
|
383 |
+
if opt.save_process:
|
384 |
+
imageio.mimwrite(os.path.join(save_folder_proc, "video_rgb.mp4"), pro_img_frames, fps=30, quality=8)
|
385 |
+
return video_path
|
386 |
+
|
387 |
+
|
388 |
+
def prepare_output_and_logger(args):
|
389 |
+
if not args._model_path:
|
390 |
+
if os.getenv('OAR_JOB_ID'):
|
391 |
+
unique_str=os.getenv('OAR_JOB_ID')
|
392 |
+
else:
|
393 |
+
unique_str = str(uuid.uuid4())
|
394 |
+
args._model_path = os.path.join("./output/", args.workspace)
|
395 |
+
|
396 |
+
# Set up output folder
|
397 |
+
print("Output folder: {}".format(args._model_path))
|
398 |
+
os.makedirs(args._model_path, exist_ok = True)
|
399 |
+
|
400 |
+
# copy configs
|
401 |
+
if args.opt_path is not None:
|
402 |
+
os.system(' '.join(['cp', args.opt_path, os.path.join(args._model_path, 'config.yaml')]))
|
403 |
+
|
404 |
+
with open(os.path.join(args._model_path, "cfg_args"), 'w') as cfg_log_f:
|
405 |
+
cfg_log_f.write(str(Namespace(**vars(args))))
|
406 |
+
|
407 |
+
# Create Tensorboard writer
|
408 |
+
tb_writer = None
|
409 |
+
if TENSORBOARD_FOUND:
|
410 |
+
tb_writer = SummaryWriter(args._model_path)
|
411 |
+
else:
|
412 |
+
print("Tensorboard not available: not logging progress")
|
413 |
+
return tb_writer
|
414 |
+
|
415 |
+
def training_report(tb_writer, iteration, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
|
416 |
+
if tb_writer:
|
417 |
+
tb_writer.add_scalar('iter_time', elapsed, iteration)
|
418 |
+
# Report test and samples of training set
|
419 |
+
if iteration in testing_iterations:
|
420 |
+
save_folder = os.path.join(scene.args._model_path,"test_six_views/{}_iteration".format(iteration))
|
421 |
+
if not os.path.exists(save_folder):
|
422 |
+
os.makedirs(save_folder) # makedirs 创建文件时如果路径不存在会创建这个路径
|
423 |
+
print('test views is in :', save_folder)
|
424 |
+
torch.cuda.empty_cache()
|
425 |
+
config = ({'name': 'test', 'cameras' : scene.getTestCameras()})
|
426 |
+
if config['cameras'] and len(config['cameras']) > 0:
|
427 |
+
for idx, viewpoint in enumerate(config['cameras']):
|
428 |
+
render_out = renderFunc(viewpoint, scene.gaussians, *renderArgs, test=True)
|
429 |
+
rgb, depth = render_out["render"],render_out["depth"]
|
430 |
+
if depth is not None:
|
431 |
+
depth_norm = depth/depth.max()
|
432 |
+
save_image(depth_norm,os.path.join(save_folder,"render_depth_{}.png".format(viewpoint.uid)))
|
433 |
+
|
434 |
+
image = torch.clamp(rgb, 0.0, 1.0)
|
435 |
+
save_image(image,os.path.join(save_folder,"render_view_{}.png".format(viewpoint.uid)))
|
436 |
+
if tb_writer:
|
437 |
+
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.uid), image[None], global_step=iteration)
|
438 |
+
print("\n[ITER {}] Eval Done!".format(iteration))
|
439 |
+
if tb_writer:
|
440 |
+
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
|
441 |
+
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
|
442 |
+
torch.cuda.empty_cache()
|
443 |
+
|
444 |
+
def video_inference(iteration, scene : Scene, renderFunc, renderArgs):
|
445 |
+
sharp = T.RandomAdjustSharpness(3, p=1.0)
|
446 |
+
|
447 |
+
save_folder = os.path.join(scene.args._model_path,"videos/{}_iteration".format(iteration))
|
448 |
+
if not os.path.exists(save_folder):
|
449 |
+
os.makedirs(save_folder) # makedirs
|
450 |
+
print('videos is in :', save_folder)
|
451 |
+
torch.cuda.empty_cache()
|
452 |
+
config = ({'name': 'test', 'cameras' : scene.getCircleVideoCameras()})
|
453 |
+
if config['cameras'] and len(config['cameras']) > 0:
|
454 |
+
img_frames = []
|
455 |
+
depth_frames = []
|
456 |
+
print("Generating Video using", len(config['cameras']), "different view points")
|
457 |
+
for idx, viewpoint in enumerate(config['cameras']):
|
458 |
+
render_out = renderFunc(viewpoint, scene.gaussians, *renderArgs, test=True)
|
459 |
+
rgb,depth = render_out["render"],render_out["depth"]
|
460 |
+
if depth is not None:
|
461 |
+
depth_norm = depth/depth.max()
|
462 |
+
depths = torch.clamp(depth_norm, 0.0, 1.0)
|
463 |
+
depths = depths.detach().cpu().permute(1,2,0).numpy()
|
464 |
+
depths = (depths * 255).round().astype('uint8')
|
465 |
+
depth_frames.append(depths)
|
466 |
+
|
467 |
+
image = torch.clamp(rgb, 0.0, 1.0)
|
468 |
+
image = image.detach().cpu().permute(1,2,0).numpy()
|
469 |
+
image = (image * 255).round().astype('uint8')
|
470 |
+
img_frames.append(image)
|
471 |
+
#save_image(image,os.path.join(save_folder,"lora_view_{}.jpg".format(viewpoint.uid)))
|
472 |
+
# Img to Numpy
|
473 |
+
imageio.mimwrite(os.path.join(save_folder, "video_rgb_{}.mp4".format(iteration)), img_frames, fps=30, quality=8)
|
474 |
+
if len(depth_frames) > 0:
|
475 |
+
imageio.mimwrite(os.path.join(save_folder, "video_depth_{}.mp4".format(iteration)), depth_frames, fps=30, quality=8)
|
476 |
+
print("\n[ITER {}] Video Save Done!".format(iteration))
|
477 |
+
torch.cuda.empty_cache()
|
478 |
+
return os.path.join(save_folder, "video_rgb_{}.mp4".format(iteration))
|
479 |
+
|
480 |
+
def args_parser(default_opt=None):
|
481 |
+
# Set up command line argument parser
|
482 |
+
parser = ArgumentParser(description="Training script parameters")
|
483 |
+
|
484 |
+
parser.add_argument('--opt', type=str, default=default_opt)
|
485 |
+
parser.add_argument('--ip', type=str, default="127.0.0.1")
|
486 |
+
parser.add_argument('--port', type=int, default=6009)
|
487 |
+
parser.add_argument('--debug_from', type=int, default=-1)
|
488 |
+
parser.add_argument('--seed', type=int, default=0)
|
489 |
+
parser.add_argument('--detect_anomaly', action='store_true', default=False)
|
490 |
+
parser.add_argument("--test_ratio", type=int, default=5) # [2500,5000,7500,10000,12000]
|
491 |
+
parser.add_argument("--save_ratio", type=int, default=2) # [10000,12000]
|
492 |
+
parser.add_argument("--save_video", type=bool, default=False)
|
493 |
+
parser.add_argument("--quiet", action="store_true")
|
494 |
+
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
495 |
+
parser.add_argument("--start_checkpoint", type=str, default = None)
|
496 |
+
parser.add_argument("--cuda", type=str, default='0')
|
497 |
+
|
498 |
+
lp = ModelParams(parser)
|
499 |
+
op = OptimizationParams(parser)
|
500 |
+
pp = PipelineParams(parser)
|
501 |
+
gcp = GenerateCamParams(parser)
|
502 |
+
gp = GuidanceParams(parser)
|
503 |
+
|
504 |
+
args = parser.parse_args(sys.argv[1:])
|
505 |
+
|
506 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
|
507 |
+
if args.opt is not None:
|
508 |
+
with open(args.opt) as f:
|
509 |
+
opts = yaml.load(f, Loader=yaml.FullLoader)
|
510 |
+
lp.load_yaml(opts.get('ModelParams', None))
|
511 |
+
op.load_yaml(opts.get('OptimizationParams', None))
|
512 |
+
pp.load_yaml(opts.get('PipelineParams', None))
|
513 |
+
gcp.load_yaml(opts.get('GenerateCamParams', None))
|
514 |
+
gp.load_yaml(opts.get('GuidanceParams', None))
|
515 |
+
|
516 |
+
lp.opt_path = args.opt
|
517 |
+
args.port = opts['port']
|
518 |
+
args.save_video = opts.get('save_video', True)
|
519 |
+
args.seed = opts.get('seed', 0)
|
520 |
+
args.device = opts.get('device', 'cuda')
|
521 |
+
|
522 |
+
# override device
|
523 |
+
gp.g_device = args.device
|
524 |
+
lp.data_device = args.device
|
525 |
+
gcp.device = args.device
|
526 |
+
return args, lp, op, pp, gcp, gp
|
527 |
+
|
528 |
+
def start_training(args, lp, op, pp, gcp, gp):
|
529 |
+
# save iterations
|
530 |
+
test_iter = [1] + [k * op.iterations // args.test_ratio for k in range(1, args.test_ratio)] + [op.iterations]
|
531 |
+
args.test_iterations = test_iter
|
532 |
+
|
533 |
+
save_iter = [k * op.iterations // args.save_ratio for k in range(1, args.save_ratio)] + [op.iterations]
|
534 |
+
args.save_iterations = save_iter
|
535 |
+
|
536 |
+
print('Test iter:', args.test_iterations)
|
537 |
+
print('Save iter:', args.save_iterations)
|
538 |
+
|
539 |
+
print("Optimizing " + lp._model_path)
|
540 |
+
|
541 |
+
# Initialize system state (RNG)
|
542 |
+
safe_state(args.quiet, seed=args.seed)
|
543 |
+
# Start GUI server, configure and run training
|
544 |
+
network_gui.init(args.ip, args.port)
|
545 |
+
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
546 |
+
video_path = training(lp, op, pp, gcp, gp, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.save_video)
|
547 |
+
# All done
|
548 |
+
print("\nTraining complete.")
|
549 |
+
return video_path
|
550 |
+
|
551 |
+
if __name__ == "__main__":
|
552 |
+
args, lp, op, pp, gcp, gp = args_parser()
|
553 |
+
start_training(args, lp, op, pp, gcp, gp)
|
train.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python train.py --opt 'configs/bagel.yaml' --cuda 4
|
utils/camera_utils.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from scene.cameras import Camera, RCamera
|
13 |
+
import numpy as np
|
14 |
+
from utils.general_utils import PILtoTorch
|
15 |
+
from utils.graphics_utils import fov2focal
|
16 |
+
|
17 |
+
WARNED = False
|
18 |
+
|
19 |
+
def loadCam(args, id, cam_info, resolution_scale):
|
20 |
+
orig_w, orig_h = cam_info.image.size
|
21 |
+
|
22 |
+
if args.resolution in [1, 2, 4, 8]:
|
23 |
+
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
24 |
+
else: # should be a type that converts to float
|
25 |
+
if args.resolution == -1:
|
26 |
+
if orig_w > 1600:
|
27 |
+
global WARNED
|
28 |
+
if not WARNED:
|
29 |
+
print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
|
30 |
+
"If this is not desired, please explicitly specify '--resolution/-r' as 1")
|
31 |
+
WARNED = True
|
32 |
+
global_down = orig_w / 1600
|
33 |
+
else:
|
34 |
+
global_down = 1
|
35 |
+
else:
|
36 |
+
global_down = orig_w / args.resolution
|
37 |
+
|
38 |
+
scale = float(global_down) * float(resolution_scale)
|
39 |
+
resolution = (int(orig_w / scale), int(orig_h / scale))
|
40 |
+
|
41 |
+
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
42 |
+
|
43 |
+
gt_image = resized_image_rgb[:3, ...]
|
44 |
+
loaded_mask = None
|
45 |
+
|
46 |
+
if resized_image_rgb.shape[1] == 4:
|
47 |
+
loaded_mask = resized_image_rgb[3:4, ...]
|
48 |
+
|
49 |
+
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
50 |
+
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
51 |
+
image=gt_image, gt_alpha_mask=loaded_mask,
|
52 |
+
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
53 |
+
|
54 |
+
|
55 |
+
def loadRandomCam(opt, id, cam_info, resolution_scale, SSAA=False):
|
56 |
+
return RCamera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
57 |
+
FoVx=cam_info.FovX, FoVy=cam_info.FovY, delta_polar=cam_info.delta_polar,
|
58 |
+
delta_azimuth=cam_info.delta_azimuth , delta_radius=cam_info.delta_radius, opt=opt,
|
59 |
+
uid=id, data_device=opt.device, SSAA=SSAA)
|
60 |
+
|
61 |
+
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
62 |
+
camera_list = []
|
63 |
+
|
64 |
+
for id, c in enumerate(cam_infos):
|
65 |
+
camera_list.append(loadCam(args, id, c, resolution_scale))
|
66 |
+
|
67 |
+
return camera_list
|
68 |
+
|
69 |
+
|
70 |
+
def cameraList_from_RcamInfos(cam_infos, resolution_scale, opt, SSAA=False):
|
71 |
+
camera_list = []
|
72 |
+
|
73 |
+
for id, c in enumerate(cam_infos):
|
74 |
+
camera_list.append(loadRandomCam(opt, id, c, resolution_scale, SSAA=SSAA))
|
75 |
+
|
76 |
+
return camera_list
|
77 |
+
|
78 |
+
def camera_to_JSON(id, camera : Camera):
|
79 |
+
Rt = np.zeros((4, 4))
|
80 |
+
Rt[:3, :3] = camera.R.transpose()
|
81 |
+
Rt[:3, 3] = camera.T
|
82 |
+
Rt[3, 3] = 1.0
|
83 |
+
|
84 |
+
W2C = np.linalg.inv(Rt)
|
85 |
+
pos = W2C[:3, 3]
|
86 |
+
rot = W2C[:3, :3]
|
87 |
+
serializable_array_2d = [x.tolist() for x in rot]
|
88 |
+
camera_entry = {
|
89 |
+
'id' : id,
|
90 |
+
'img_name' : id,
|
91 |
+
'width' : camera.width,
|
92 |
+
'height' : camera.height,
|
93 |
+
'position': pos.tolist(),
|
94 |
+
'rotation': serializable_array_2d,
|
95 |
+
'fy' : fov2focal(camera.FovY, camera.height),
|
96 |
+
'fx' : fov2focal(camera.FovX, camera.width)
|
97 |
+
}
|
98 |
+
return camera_entry
|
utils/general_utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import sys
|
14 |
+
from datetime import datetime
|
15 |
+
import numpy as np
|
16 |
+
import random
|
17 |
+
|
18 |
+
def inverse_sigmoid(x):
|
19 |
+
return torch.log(x/(1-x))
|
20 |
+
|
21 |
+
def inverse_sigmoid_np(x):
|
22 |
+
return np.log(x/(1-x))
|
23 |
+
|
24 |
+
def PILtoTorch(pil_image, resolution):
|
25 |
+
resized_image_PIL = pil_image.resize(resolution)
|
26 |
+
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
27 |
+
if len(resized_image.shape) == 3:
|
28 |
+
return resized_image.permute(2, 0, 1)
|
29 |
+
else:
|
30 |
+
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
31 |
+
|
32 |
+
def get_expon_lr_func(
|
33 |
+
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Copied from Plenoxels
|
37 |
+
|
38 |
+
Continuous learning rate decay function. Adapted from JaxNeRF
|
39 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
40 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
41 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
42 |
+
function of lr_delay_mult, such that the initial learning rate is
|
43 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
44 |
+
to the normal learning rate when steps>lr_delay_steps.
|
45 |
+
:param conf: config subtree 'lr' or similar
|
46 |
+
:param max_steps: int, the number of steps during optimization.
|
47 |
+
:return HoF which takes step as input
|
48 |
+
"""
|
49 |
+
|
50 |
+
def helper(step):
|
51 |
+
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
52 |
+
# Disable this parameter
|
53 |
+
return 0.0
|
54 |
+
if lr_delay_steps > 0:
|
55 |
+
# A kind of reverse cosine decay.
|
56 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
57 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
delay_rate = 1.0
|
61 |
+
t = np.clip(step / max_steps, 0, 1)
|
62 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
63 |
+
return delay_rate * log_lerp
|
64 |
+
|
65 |
+
return helper
|
66 |
+
|
67 |
+
def strip_lowerdiag(L):
|
68 |
+
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
69 |
+
|
70 |
+
uncertainty[:, 0] = L[:, 0, 0]
|
71 |
+
uncertainty[:, 1] = L[:, 0, 1]
|
72 |
+
uncertainty[:, 2] = L[:, 0, 2]
|
73 |
+
uncertainty[:, 3] = L[:, 1, 1]
|
74 |
+
uncertainty[:, 4] = L[:, 1, 2]
|
75 |
+
uncertainty[:, 5] = L[:, 2, 2]
|
76 |
+
return uncertainty
|
77 |
+
|
78 |
+
def strip_symmetric(sym):
|
79 |
+
return strip_lowerdiag(sym)
|
80 |
+
|
81 |
+
def build_rotation(r):
|
82 |
+
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
83 |
+
|
84 |
+
q = r / norm[:, None]
|
85 |
+
|
86 |
+
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
87 |
+
|
88 |
+
r = q[:, 0]
|
89 |
+
x = q[:, 1]
|
90 |
+
y = q[:, 2]
|
91 |
+
z = q[:, 3]
|
92 |
+
|
93 |
+
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
94 |
+
R[:, 0, 1] = 2 * (x*y - r*z)
|
95 |
+
R[:, 0, 2] = 2 * (x*z + r*y)
|
96 |
+
R[:, 1, 0] = 2 * (x*y + r*z)
|
97 |
+
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
98 |
+
R[:, 1, 2] = 2 * (y*z - r*x)
|
99 |
+
R[:, 2, 0] = 2 * (x*z - r*y)
|
100 |
+
R[:, 2, 1] = 2 * (y*z + r*x)
|
101 |
+
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
102 |
+
return R
|
103 |
+
|
104 |
+
def build_scaling_rotation(s, r):
|
105 |
+
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
106 |
+
R = build_rotation(r)
|
107 |
+
|
108 |
+
L[:,0,0] = s[:,0]
|
109 |
+
L[:,1,1] = s[:,1]
|
110 |
+
L[:,2,2] = s[:,2]
|
111 |
+
|
112 |
+
L = R @ L
|
113 |
+
return L
|
114 |
+
|
115 |
+
def safe_state(silent, seed=0):
|
116 |
+
old_f = sys.stdout
|
117 |
+
class F:
|
118 |
+
def __init__(self, silent):
|
119 |
+
self.silent = silent
|
120 |
+
|
121 |
+
def write(self, x):
|
122 |
+
if not self.silent:
|
123 |
+
if x.endswith("\n"):
|
124 |
+
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
|
125 |
+
else:
|
126 |
+
old_f.write(x)
|
127 |
+
|
128 |
+
def flush(self):
|
129 |
+
old_f.flush()
|
130 |
+
|
131 |
+
sys.stdout = F(silent)
|
132 |
+
random.seed(seed)
|
133 |
+
np.random.seed(seed)
|
134 |
+
torch.manual_seed(seed)
|
135 |
+
torch.cuda.manual_seed_all(seed)
|
136 |
+
|
137 |
+
# if seed == 0:
|
138 |
+
torch.backends.cudnn.deterministic = True
|
139 |
+
torch.backends.cudnn.benchmark = False
|
140 |
+
|
141 |
+
# torch.cuda.set_device(torch.device("cuda:0"))
|
utils/graphics_utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
from typing import NamedTuple
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.functional import norm
|
19 |
+
|
20 |
+
|
21 |
+
class BasicPointCloud(NamedTuple):
|
22 |
+
points : np.array
|
23 |
+
colors : np.array
|
24 |
+
normals : np.array
|
25 |
+
|
26 |
+
def geom_transform_points(points, transf_matrix):
|
27 |
+
P, _ = points.shape
|
28 |
+
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
|
29 |
+
points_hom = torch.cat([points, ones], dim=1)
|
30 |
+
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
|
31 |
+
|
32 |
+
denom = points_out[..., 3:] + 0.0000001
|
33 |
+
return (points_out[..., :3] / denom).squeeze(dim=0)
|
34 |
+
|
35 |
+
def getWorld2View(R, t):
|
36 |
+
Rt = np.zeros((4, 4))
|
37 |
+
Rt[:3, :3] = R.transpose()
|
38 |
+
Rt[:3, 3] = t
|
39 |
+
Rt[3, 3] = 1.0
|
40 |
+
return np.float32(Rt)
|
41 |
+
|
42 |
+
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
43 |
+
Rt = np.zeros((4, 4))
|
44 |
+
Rt[:3, :3] = R.transpose()
|
45 |
+
Rt[:3, 3] = t
|
46 |
+
Rt[3, 3] = 1.0
|
47 |
+
|
48 |
+
C2W = np.linalg.inv(Rt)
|
49 |
+
cam_center = C2W[:3, 3]
|
50 |
+
cam_center = (cam_center + translate) * scale
|
51 |
+
C2W[:3, 3] = cam_center
|
52 |
+
Rt = np.linalg.inv(C2W)
|
53 |
+
return np.float32(Rt)
|
54 |
+
|
55 |
+
def getProjectionMatrix(znear, zfar, fovX, fovY):
|
56 |
+
tanHalfFovY = math.tan((fovY / 2))
|
57 |
+
tanHalfFovX = math.tan((fovX / 2))
|
58 |
+
|
59 |
+
top = tanHalfFovY * znear
|
60 |
+
bottom = -top
|
61 |
+
right = tanHalfFovX * znear
|
62 |
+
left = -right
|
63 |
+
|
64 |
+
P = torch.zeros(4, 4)
|
65 |
+
|
66 |
+
z_sign = 1.0
|
67 |
+
|
68 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
69 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
70 |
+
P[0, 2] = (right + left) / (right - left)
|
71 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
72 |
+
P[3, 2] = z_sign
|
73 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
74 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
75 |
+
return P
|
76 |
+
|
77 |
+
def fov2focal(fov, pixels):
|
78 |
+
return pixels / (2 * math.tan(fov / 2))
|
79 |
+
|
80 |
+
def focal2fov(focal, pixels):
|
81 |
+
return 2*math.atan(pixels/(2*focal))
|
utils/image_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def mse(img1, img2):
|
15 |
+
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
16 |
+
|
17 |
+
def psnr(img1, img2):
|
18 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
19 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|