LIU, Zichen commited on
Commit
1a1aace
1 Parent(s): 78fe60c

update missing files

Browse files
MagicQuill/comfy/ldm/models/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (8.43 kB). View file
 
MagicQuill/comfy/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
6
+
7
+ from comfy.ldm.util import instantiate_from_config
8
+ from comfy.ldm.modules.ema import LitEma
9
+ import comfy.ops
10
+
11
+ class DiagonalGaussianRegularizer(torch.nn.Module):
12
+ def __init__(self, sample: bool = True):
13
+ super().__init__()
14
+ self.sample = sample
15
+
16
+ def get_trainable_parameters(self) -> Any:
17
+ yield from ()
18
+
19
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
20
+ log = dict()
21
+ posterior = DiagonalGaussianDistribution(z)
22
+ if self.sample:
23
+ z = posterior.sample()
24
+ else:
25
+ z = posterior.mode()
26
+ kl_loss = posterior.kl()
27
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
28
+ log["kl_loss"] = kl_loss
29
+ return z, log
30
+
31
+
32
+ class AbstractAutoencoder(torch.nn.Module):
33
+ """
34
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
35
+ unCLIP models, etc. Hence, it is fairly general, and specific features
36
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ ema_decay: Union[None, float] = None,
42
+ monitor: Union[None, str] = None,
43
+ input_key: str = "jpg",
44
+ **kwargs,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.input_key = input_key
49
+ self.use_ema = ema_decay is not None
50
+ if monitor is not None:
51
+ self.monitor = monitor
52
+
53
+ if self.use_ema:
54
+ self.model_ema = LitEma(self, decay=ema_decay)
55
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
56
+
57
+ def get_input(self, batch) -> Any:
58
+ raise NotImplementedError()
59
+
60
+ def on_train_batch_end(self, *args, **kwargs):
61
+ # for EMA computation
62
+ if self.use_ema:
63
+ self.model_ema(self)
64
+
65
+ @contextmanager
66
+ def ema_scope(self, context=None):
67
+ if self.use_ema:
68
+ self.model_ema.store(self.parameters())
69
+ self.model_ema.copy_to(self)
70
+ if context is not None:
71
+ logpy.info(f"{context}: Switched to EMA weights")
72
+ try:
73
+ yield None
74
+ finally:
75
+ if self.use_ema:
76
+ self.model_ema.restore(self.parameters())
77
+ if context is not None:
78
+ logpy.info(f"{context}: Restored training weights")
79
+
80
+ def encode(self, *args, **kwargs) -> torch.Tensor:
81
+ raise NotImplementedError("encode()-method of abstract base class called")
82
+
83
+ def decode(self, *args, **kwargs) -> torch.Tensor:
84
+ raise NotImplementedError("decode()-method of abstract base class called")
85
+
86
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
87
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
88
+ return get_obj_from_str(cfg["target"])(
89
+ params, lr=lr, **cfg.get("params", dict())
90
+ )
91
+
92
+ def configure_optimizers(self) -> Any:
93
+ raise NotImplementedError()
94
+
95
+
96
+ class AutoencodingEngine(AbstractAutoencoder):
97
+ """
98
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
99
+ (we also restore them explicitly as special cases for legacy reasons).
100
+ Regularizations such as KL or VQ are moved to the regularizer class.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ *args,
106
+ encoder_config: Dict,
107
+ decoder_config: Dict,
108
+ regularizer_config: Dict,
109
+ **kwargs,
110
+ ):
111
+ super().__init__(*args, **kwargs)
112
+
113
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
114
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
115
+ self.regularization: AbstractRegularizer = instantiate_from_config(
116
+ regularizer_config
117
+ )
118
+
119
+ def get_last_layer(self):
120
+ return self.decoder.get_last_layer()
121
+
122
+ def encode(
123
+ self,
124
+ x: torch.Tensor,
125
+ return_reg_log: bool = False,
126
+ unregularized: bool = False,
127
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
128
+ z = self.encoder(x)
129
+ if unregularized:
130
+ return z, dict()
131
+ z, reg_log = self.regularization(z)
132
+ if return_reg_log:
133
+ return z, reg_log
134
+ return z
135
+
136
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
137
+ x = self.decoder(z, **kwargs)
138
+ return x
139
+
140
+ def forward(
141
+ self, x: torch.Tensor, **additional_decode_kwargs
142
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
143
+ z, reg_log = self.encode(x, return_reg_log=True)
144
+ dec = self.decode(z, **additional_decode_kwargs)
145
+ return z, dec, reg_log
146
+
147
+
148
+ class AutoencodingEngineLegacy(AutoencodingEngine):
149
+ def __init__(self, embed_dim: int, **kwargs):
150
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
151
+ ddconfig = kwargs.pop("ddconfig")
152
+ super().__init__(
153
+ encoder_config={
154
+ "target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
155
+ "params": ddconfig,
156
+ },
157
+ decoder_config={
158
+ "target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
159
+ "params": ddconfig,
160
+ },
161
+ **kwargs,
162
+ )
163
+ self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
164
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
165
+ (1 + ddconfig["double_z"]) * embed_dim,
166
+ 1,
167
+ )
168
+ self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
169
+ self.embed_dim = embed_dim
170
+
171
+ def get_autoencoder_params(self) -> list:
172
+ params = super().get_autoencoder_params()
173
+ return params
174
+
175
+ def encode(
176
+ self, x: torch.Tensor, return_reg_log: bool = False
177
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
178
+ if self.max_batch_size is None:
179
+ z = self.encoder(x)
180
+ z = self.quant_conv(z)
181
+ else:
182
+ N = x.shape[0]
183
+ bs = self.max_batch_size
184
+ n_batches = int(math.ceil(N / bs))
185
+ z = list()
186
+ for i_batch in range(n_batches):
187
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
188
+ z_batch = self.quant_conv(z_batch)
189
+ z.append(z_batch)
190
+ z = torch.cat(z, 0)
191
+
192
+ z, reg_log = self.regularization(z)
193
+ if return_reg_log:
194
+ return z, reg_log
195
+ return z
196
+
197
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
198
+ if self.max_batch_size is None:
199
+ dec = self.post_quant_conv(z)
200
+ dec = self.decoder(dec, **decoder_kwargs)
201
+ else:
202
+ N = z.shape[0]
203
+ bs = self.max_batch_size
204
+ n_batches = int(math.ceil(N / bs))
205
+ dec = list()
206
+ for i_batch in range(n_batches):
207
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
208
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
209
+ dec.append(dec_batch)
210
+ dec = torch.cat(dec, 0)
211
+
212
+ return dec
213
+
214
+
215
+ class AutoencoderKL(AutoencodingEngineLegacy):
216
+ def __init__(self, **kwargs):
217
+ if "lossconfig" in kwargs:
218
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
219
+ super().__init__(
220
+ regularizer_config={
221
+ "target": (
222
+ "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
223
+ )
224
+ },
225
+ **kwargs,
226
+ )