abc commited on
Commit
8999a58
·
1 Parent(s): 497a25a

Delete locon

Browse files
locon/__init__.py DELETED
File without changes
locon/kohya_model_utils.py DELETED
@@ -1,1184 +0,0 @@
1
- '''
2
- https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
3
- '''
4
- # v1: split from train_db_fixed.py.
5
- # v2: support safetensors
6
-
7
- import math
8
- import os
9
- import torch
10
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
11
- from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
12
- from safetensors.torch import load_file, save_file
13
-
14
- # DiffUsers版StableDiffusionのモデルパラメータ
15
- NUM_TRAIN_TIMESTEPS = 1000
16
- BETA_START = 0.00085
17
- BETA_END = 0.0120
18
-
19
- UNET_PARAMS_MODEL_CHANNELS = 320
20
- UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
21
- UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
22
- UNET_PARAMS_IMAGE_SIZE = 32 # unused
23
- UNET_PARAMS_IN_CHANNELS = 4
24
- UNET_PARAMS_OUT_CHANNELS = 4
25
- UNET_PARAMS_NUM_RES_BLOCKS = 2
26
- UNET_PARAMS_CONTEXT_DIM = 768
27
- UNET_PARAMS_NUM_HEADS = 8
28
-
29
- VAE_PARAMS_Z_CHANNELS = 4
30
- VAE_PARAMS_RESOLUTION = 256
31
- VAE_PARAMS_IN_CHANNELS = 3
32
- VAE_PARAMS_OUT_CH = 3
33
- VAE_PARAMS_CH = 128
34
- VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
35
- VAE_PARAMS_NUM_RES_BLOCKS = 2
36
-
37
- # V2
38
- V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
39
- V2_UNET_PARAMS_CONTEXT_DIM = 1024
40
-
41
- # Diffusersの設定を読み込むための参照モデル
42
- DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
43
- DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
44
-
45
-
46
- # region StableDiffusion->Diffusersの変換コード
47
- # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
48
-
49
-
50
- def shave_segments(path, n_shave_prefix_segments=1):
51
- """
52
- Removes segments. Positive values shave the first segments, negative shave the last segments.
53
- """
54
- if n_shave_prefix_segments >= 0:
55
- return ".".join(path.split(".")[n_shave_prefix_segments:])
56
- else:
57
- return ".".join(path.split(".")[:n_shave_prefix_segments])
58
-
59
-
60
- def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
61
- """
62
- Updates paths inside resnets to the new naming scheme (local renaming)
63
- """
64
- mapping = []
65
- for old_item in old_list:
66
- new_item = old_item.replace("in_layers.0", "norm1")
67
- new_item = new_item.replace("in_layers.2", "conv1")
68
-
69
- new_item = new_item.replace("out_layers.0", "norm2")
70
- new_item = new_item.replace("out_layers.3", "conv2")
71
-
72
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
73
- new_item = new_item.replace("skip_connection", "conv_shortcut")
74
-
75
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
76
-
77
- mapping.append({"old": old_item, "new": new_item})
78
-
79
- return mapping
80
-
81
-
82
- def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
83
- """
84
- Updates paths inside resnets to the new naming scheme (local renaming)
85
- """
86
- mapping = []
87
- for old_item in old_list:
88
- new_item = old_item
89
-
90
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
91
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
92
-
93
- mapping.append({"old": old_item, "new": new_item})
94
-
95
- return mapping
96
-
97
-
98
- def renew_attention_paths(old_list, n_shave_prefix_segments=0):
99
- """
100
- Updates paths inside attentions to the new naming scheme (local renaming)
101
- """
102
- mapping = []
103
- for old_item in old_list:
104
- new_item = old_item
105
-
106
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
107
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
108
-
109
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
110
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
111
-
112
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
-
114
- mapping.append({"old": old_item, "new": new_item})
115
-
116
- return mapping
117
-
118
-
119
- def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
- """
121
- Updates paths inside attentions to the new naming scheme (local renaming)
122
- """
123
- mapping = []
124
- for old_item in old_list:
125
- new_item = old_item
126
-
127
- new_item = new_item.replace("norm.weight", "group_norm.weight")
128
- new_item = new_item.replace("norm.bias", "group_norm.bias")
129
-
130
- new_item = new_item.replace("q.weight", "query.weight")
131
- new_item = new_item.replace("q.bias", "query.bias")
132
-
133
- new_item = new_item.replace("k.weight", "key.weight")
134
- new_item = new_item.replace("k.bias", "key.bias")
135
-
136
- new_item = new_item.replace("v.weight", "value.weight")
137
- new_item = new_item.replace("v.bias", "value.bias")
138
-
139
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
-
142
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
-
144
- mapping.append({"old": old_item, "new": new_item})
145
-
146
- return mapping
147
-
148
-
149
- def assign_to_checkpoint(
150
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
- ):
152
- """
153
- This does the final conversion step: take locally converted weights and apply a global renaming
154
- to them. It splits attention layers, and takes into account additional replacements
155
- that may arise.
156
-
157
- Assigns the weights to the new checkpoint.
158
- """
159
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
-
161
- # Splits the attention layers into three variables.
162
- if attention_paths_to_split is not None:
163
- for path, path_map in attention_paths_to_split.items():
164
- old_tensor = old_checkpoint[path]
165
- channels = old_tensor.shape[0] // 3
166
-
167
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
-
169
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
-
171
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
-
174
- checkpoint[path_map["query"]] = query.reshape(target_shape)
175
- checkpoint[path_map["key"]] = key.reshape(target_shape)
176
- checkpoint[path_map["value"]] = value.reshape(target_shape)
177
-
178
- for path in paths:
179
- new_path = path["new"]
180
-
181
- # These have already been assigned
182
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
- continue
184
-
185
- # Global renaming happens here
186
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
-
190
- if additional_replacements is not None:
191
- for replacement in additional_replacements:
192
- new_path = new_path.replace(replacement["old"], replacement["new"])
193
-
194
- # proj_attn.weight has to be converted from conv 1D to linear
195
- if "proj_attn.weight" in new_path:
196
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
- else:
198
- checkpoint[new_path] = old_checkpoint[path["old"]]
199
-
200
-
201
- def conv_attn_to_linear(checkpoint):
202
- keys = list(checkpoint.keys())
203
- attn_keys = ["query.weight", "key.weight", "value.weight"]
204
- for key in keys:
205
- if ".".join(key.split(".")[-2:]) in attn_keys:
206
- if checkpoint[key].ndim > 2:
207
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
208
- elif "proj_attn.weight" in key:
209
- if checkpoint[key].ndim > 2:
210
- checkpoint[key] = checkpoint[key][:, :, 0]
211
-
212
-
213
- def linear_transformer_to_conv(checkpoint):
214
- keys = list(checkpoint.keys())
215
- tf_keys = ["proj_in.weight", "proj_out.weight"]
216
- for key in keys:
217
- if ".".join(key.split(".")[-2:]) in tf_keys:
218
- if checkpoint[key].ndim == 2:
219
- checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
220
-
221
-
222
- def convert_ldm_unet_checkpoint(v2, checkpoint, config):
223
- """
224
- Takes a state dict and a config, and returns a converted checkpoint.
225
- """
226
-
227
- # extract state_dict for UNet
228
- unet_state_dict = {}
229
- unet_key = "model.diffusion_model."
230
- keys = list(checkpoint.keys())
231
- for key in keys:
232
- if key.startswith(unet_key):
233
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
234
-
235
- new_checkpoint = {}
236
-
237
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
238
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
239
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
240
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
241
-
242
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
243
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
244
-
245
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
246
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
247
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
248
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
249
-
250
- # Retrieves the keys for the input blocks only
251
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
252
- input_blocks = {
253
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
254
- for layer_id in range(num_input_blocks)
255
- }
256
-
257
- # Retrieves the keys for the middle blocks only
258
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
259
- middle_blocks = {
260
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
261
- for layer_id in range(num_middle_blocks)
262
- }
263
-
264
- # Retrieves the keys for the output blocks only
265
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
266
- output_blocks = {
267
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
268
- for layer_id in range(num_output_blocks)
269
- }
270
-
271
- for i in range(1, num_input_blocks):
272
- block_id = (i - 1) // (config["layers_per_block"] + 1)
273
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
274
-
275
- resnets = [
276
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
277
- ]
278
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
279
-
280
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
281
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
282
- f"input_blocks.{i}.0.op.weight"
283
- )
284
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
285
- f"input_blocks.{i}.0.op.bias"
286
- )
287
-
288
- paths = renew_resnet_paths(resnets)
289
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
290
- assign_to_checkpoint(
291
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
292
- )
293
-
294
- if len(attentions):
295
- paths = renew_attention_paths(attentions)
296
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
297
- assign_to_checkpoint(
298
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
299
- )
300
-
301
- resnet_0 = middle_blocks[0]
302
- attentions = middle_blocks[1]
303
- resnet_1 = middle_blocks[2]
304
-
305
- resnet_0_paths = renew_resnet_paths(resnet_0)
306
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
307
-
308
- resnet_1_paths = renew_resnet_paths(resnet_1)
309
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
310
-
311
- attentions_paths = renew_attention_paths(attentions)
312
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
313
- assign_to_checkpoint(
314
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
315
- )
316
-
317
- for i in range(num_output_blocks):
318
- block_id = i // (config["layers_per_block"] + 1)
319
- layer_in_block_id = i % (config["layers_per_block"] + 1)
320
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
321
- output_block_list = {}
322
-
323
- for layer in output_block_layers:
324
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
325
- if layer_id in output_block_list:
326
- output_block_list[layer_id].append(layer_name)
327
- else:
328
- output_block_list[layer_id] = [layer_name]
329
-
330
- if len(output_block_list) > 1:
331
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
332
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
333
-
334
- resnet_0_paths = renew_resnet_paths(resnets)
335
- paths = renew_resnet_paths(resnets)
336
-
337
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
338
- assign_to_checkpoint(
339
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
340
- )
341
-
342
- # オリジナル:
343
- # if ["conv.weight", "conv.bias"] in output_block_list.values():
344
- # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
345
-
346
- # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
347
- for l in output_block_list.values():
348
- l.sort()
349
-
350
- if ["conv.bias", "conv.weight"] in output_block_list.values():
351
- index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
352
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
353
- f"output_blocks.{i}.{index}.conv.bias"
354
- ]
355
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
356
- f"output_blocks.{i}.{index}.conv.weight"
357
- ]
358
-
359
- # Clear attentions as they have been attributed above.
360
- if len(attentions) == 2:
361
- attentions = []
362
-
363
- if len(attentions):
364
- paths = renew_attention_paths(attentions)
365
- meta_path = {
366
- "old": f"output_blocks.{i}.1",
367
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
368
- }
369
- assign_to_checkpoint(
370
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
371
- )
372
- else:
373
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
374
- for path in resnet_0_paths:
375
- old_path = ".".join(["output_blocks", str(i), path["old"]])
376
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
377
-
378
- new_checkpoint[new_path] = unet_state_dict[old_path]
379
-
380
- # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
381
- if v2:
382
- linear_transformer_to_conv(new_checkpoint)
383
-
384
- return new_checkpoint
385
-
386
-
387
- def convert_ldm_vae_checkpoint(checkpoint, config):
388
- # extract state dict for VAE
389
- vae_state_dict = {}
390
- vae_key = "first_stage_model."
391
- keys = list(checkpoint.keys())
392
- for key in keys:
393
- if key.startswith(vae_key):
394
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
395
- # if len(vae_state_dict) == 0:
396
- # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
397
- # vae_state_dict = checkpoint
398
-
399
- new_checkpoint = {}
400
-
401
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
402
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
403
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
404
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
405
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
406
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
407
-
408
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
409
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
410
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
411
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
412
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
413
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
414
-
415
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
416
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
417
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
418
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
419
-
420
- # Retrieves the keys for the encoder down blocks only
421
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
422
- down_blocks = {
423
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
424
- }
425
-
426
- # Retrieves the keys for the decoder up blocks only
427
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
428
- up_blocks = {
429
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
430
- }
431
-
432
- for i in range(num_down_blocks):
433
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
434
-
435
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
436
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
437
- f"encoder.down.{i}.downsample.conv.weight"
438
- )
439
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
440
- f"encoder.down.{i}.downsample.conv.bias"
441
- )
442
-
443
- paths = renew_vae_resnet_paths(resnets)
444
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
445
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
446
-
447
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
448
- num_mid_res_blocks = 2
449
- for i in range(1, num_mid_res_blocks + 1):
450
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
451
-
452
- paths = renew_vae_resnet_paths(resnets)
453
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
454
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
455
-
456
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
457
- paths = renew_vae_attention_paths(mid_attentions)
458
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
459
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
- conv_attn_to_linear(new_checkpoint)
461
-
462
- for i in range(num_up_blocks):
463
- block_id = num_up_blocks - 1 - i
464
- resnets = [
465
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
466
- ]
467
-
468
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
469
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
470
- f"decoder.up.{block_id}.upsample.conv.weight"
471
- ]
472
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
473
- f"decoder.up.{block_id}.upsample.conv.bias"
474
- ]
475
-
476
- paths = renew_vae_resnet_paths(resnets)
477
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
478
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
479
-
480
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
481
- num_mid_res_blocks = 2
482
- for i in range(1, num_mid_res_blocks + 1):
483
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
484
-
485
- paths = renew_vae_resnet_paths(resnets)
486
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
487
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
488
-
489
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
490
- paths = renew_vae_attention_paths(mid_attentions)
491
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
492
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
493
- conv_attn_to_linear(new_checkpoint)
494
- return new_checkpoint
495
-
496
-
497
- def create_unet_diffusers_config(v2):
498
- """
499
- Creates a config for the diffusers based on the config of the LDM model.
500
- """
501
- # unet_params = original_config.model.params.unet_config.params
502
-
503
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
504
-
505
- down_block_types = []
506
- resolution = 1
507
- for i in range(len(block_out_channels)):
508
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
509
- down_block_types.append(block_type)
510
- if i != len(block_out_channels) - 1:
511
- resolution *= 2
512
-
513
- up_block_types = []
514
- for i in range(len(block_out_channels)):
515
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
516
- up_block_types.append(block_type)
517
- resolution //= 2
518
-
519
- config = dict(
520
- sample_size=UNET_PARAMS_IMAGE_SIZE,
521
- in_channels=UNET_PARAMS_IN_CHANNELS,
522
- out_channels=UNET_PARAMS_OUT_CHANNELS,
523
- down_block_types=tuple(down_block_types),
524
- up_block_types=tuple(up_block_types),
525
- block_out_channels=tuple(block_out_channels),
526
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
527
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
528
- attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
529
- )
530
-
531
- return config
532
-
533
-
534
- def create_vae_diffusers_config():
535
- """
536
- Creates a config for the diffusers based on the config of the LDM model.
537
- """
538
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
539
- # _ = original_config.model.params.first_stage_config.params.embed_dim
540
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
541
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
542
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
543
-
544
- config = dict(
545
- sample_size=VAE_PARAMS_RESOLUTION,
546
- in_channels=VAE_PARAMS_IN_CHANNELS,
547
- out_channels=VAE_PARAMS_OUT_CH,
548
- down_block_types=tuple(down_block_types),
549
- up_block_types=tuple(up_block_types),
550
- block_out_channels=tuple(block_out_channels),
551
- latent_channels=VAE_PARAMS_Z_CHANNELS,
552
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
553
- )
554
- return config
555
-
556
-
557
- def convert_ldm_clip_checkpoint_v1(checkpoint):
558
- keys = list(checkpoint.keys())
559
- text_model_dict = {}
560
- for key in keys:
561
- if key.startswith("cond_stage_model.transformer"):
562
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
563
- return text_model_dict
564
-
565
-
566
- def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
567
- # 嫌になるくらい違うぞ!
568
- def convert_key(key):
569
- if not key.startswith("cond_stage_model"):
570
- return None
571
-
572
- # common conversion
573
- key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
574
- key = key.replace("cond_stage_model.model.", "text_model.")
575
-
576
- if "resblocks" in key:
577
- # resblocks conversion
578
- key = key.replace(".resblocks.", ".layers.")
579
- if ".ln_" in key:
580
- key = key.replace(".ln_", ".layer_norm")
581
- elif ".mlp." in key:
582
- key = key.replace(".c_fc.", ".fc1.")
583
- key = key.replace(".c_proj.", ".fc2.")
584
- elif '.attn.out_proj' in key:
585
- key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
586
- elif '.attn.in_proj' in key:
587
- key = None # 特殊なので後で処理する
588
- else:
589
- raise ValueError(f"unexpected key in SD: {key}")
590
- elif '.positional_embedding' in key:
591
- key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
592
- elif '.text_projection' in key:
593
- key = None # 使われない???
594
- elif '.logit_scale' in key:
595
- key = None # 使われない???
596
- elif '.token_embedding' in key:
597
- key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
598
- elif '.ln_final' in key:
599
- key = key.replace(".ln_final", ".final_layer_norm")
600
- return key
601
-
602
- keys = list(checkpoint.keys())
603
- new_sd = {}
604
- for key in keys:
605
- # remove resblocks 23
606
- if '.resblocks.23.' in key:
607
- continue
608
- new_key = convert_key(key)
609
- if new_key is None:
610
- continue
611
- new_sd[new_key] = checkpoint[key]
612
-
613
- # attnの変換
614
- for key in keys:
615
- if '.resblocks.23.' in key:
616
- continue
617
- if '.resblocks' in key and '.attn.in_proj_' in key:
618
- # 三つに分割
619
- values = torch.chunk(checkpoint[key], 3)
620
-
621
- key_suffix = ".weight" if "weight" in key else ".bias"
622
- key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
623
- key_pfx = key_pfx.replace("_weight", "")
624
- key_pfx = key_pfx.replace("_bias", "")
625
- key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
626
- new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
627
- new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
628
- new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
629
-
630
- # rename or add position_ids
631
- ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
632
- if ANOTHER_POSITION_IDS_KEY in new_sd:
633
- # waifu diffusion v1.4
634
- position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
635
- del new_sd[ANOTHER_POSITION_IDS_KEY]
636
- else:
637
- position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
638
-
639
- new_sd["text_model.embeddings.position_ids"] = position_ids
640
- return new_sd
641
-
642
- # endregion
643
-
644
-
645
- # region Diffusers->StableDiffusion の変換コード
646
- # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
647
-
648
- def conv_transformer_to_linear(checkpoint):
649
- keys = list(checkpoint.keys())
650
- tf_keys = ["proj_in.weight", "proj_out.weight"]
651
- for key in keys:
652
- if ".".join(key.split(".")[-2:]) in tf_keys:
653
- if checkpoint[key].ndim > 2:
654
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
655
-
656
-
657
- def convert_unet_state_dict_to_sd(v2, unet_state_dict):
658
- unet_conversion_map = [
659
- # (stable-diffusion, HF Diffusers)
660
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
661
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
662
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
663
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
664
- ("input_blocks.0.0.weight", "conv_in.weight"),
665
- ("input_blocks.0.0.bias", "conv_in.bias"),
666
- ("out.0.weight", "conv_norm_out.weight"),
667
- ("out.0.bias", "conv_norm_out.bias"),
668
- ("out.2.weight", "conv_out.weight"),
669
- ("out.2.bias", "conv_out.bias"),
670
- ]
671
-
672
- unet_conversion_map_resnet = [
673
- # (stable-diffusion, HF Diffusers)
674
- ("in_layers.0", "norm1"),
675
- ("in_layers.2", "conv1"),
676
- ("out_layers.0", "norm2"),
677
- ("out_layers.3", "conv2"),
678
- ("emb_layers.1", "time_emb_proj"),
679
- ("skip_connection", "conv_shortcut"),
680
- ]
681
-
682
- unet_conversion_map_layer = []
683
- for i in range(4):
684
- # loop over downblocks/upblocks
685
-
686
- for j in range(2):
687
- # loop over resnets/attentions for downblocks
688
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
689
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
690
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
691
-
692
- if i < 3:
693
- # no attention layers in down_blocks.3
694
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
695
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
696
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
697
-
698
- for j in range(3):
699
- # loop over resnets/attentions for upblocks
700
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
701
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
702
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
703
-
704
- if i > 0:
705
- # no attention layers in up_blocks.0
706
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
707
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
708
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
709
-
710
- if i < 3:
711
- # no downsample in down_blocks.3
712
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
713
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
714
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
715
-
716
- # no upsample in up_blocks.3
717
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
718
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
719
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
720
-
721
- hf_mid_atn_prefix = "mid_block.attentions.0."
722
- sd_mid_atn_prefix = "middle_block.1."
723
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
724
-
725
- for j in range(2):
726
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
727
- sd_mid_res_prefix = f"middle_block.{2*j}."
728
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
729
-
730
- # buyer beware: this is a *brittle* function,
731
- # and correct output requires that all of these pieces interact in
732
- # the exact order in which I have arranged them.
733
- mapping = {k: k for k in unet_state_dict.keys()}
734
- for sd_name, hf_name in unet_conversion_map:
735
- mapping[hf_name] = sd_name
736
- for k, v in mapping.items():
737
- if "resnets" in k:
738
- for sd_part, hf_part in unet_conversion_map_resnet:
739
- v = v.replace(hf_part, sd_part)
740
- mapping[k] = v
741
- for k, v in mapping.items():
742
- for sd_part, hf_part in unet_conversion_map_layer:
743
- v = v.replace(hf_part, sd_part)
744
- mapping[k] = v
745
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
746
-
747
- if v2:
748
- conv_transformer_to_linear(new_state_dict)
749
-
750
- return new_state_dict
751
-
752
-
753
- # ================#
754
- # VAE Conversion #
755
- # ================#
756
-
757
- def reshape_weight_for_sd(w):
758
- # convert HF linear weights to SD conv2d weights
759
- return w.reshape(*w.shape, 1, 1)
760
-
761
-
762
- def convert_vae_state_dict(vae_state_dict):
763
- vae_conversion_map = [
764
- # (stable-diffusion, HF Diffusers)
765
- ("nin_shortcut", "conv_shortcut"),
766
- ("norm_out", "conv_norm_out"),
767
- ("mid.attn_1.", "mid_block.attentions.0."),
768
- ]
769
-
770
- for i in range(4):
771
- # down_blocks have two resnets
772
- for j in range(2):
773
- hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
774
- sd_down_prefix = f"encoder.down.{i}.block.{j}."
775
- vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
776
-
777
- if i < 3:
778
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
779
- sd_downsample_prefix = f"down.{i}.downsample."
780
- vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
781
-
782
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
783
- sd_upsample_prefix = f"up.{3-i}.upsample."
784
- vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
785
-
786
- # up_blocks have three resnets
787
- # also, up blocks in hf are numbered in reverse from sd
788
- for j in range(3):
789
- hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
790
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
791
- vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
792
-
793
- # this part accounts for mid blocks in both the encoder and the decoder
794
- for i in range(2):
795
- hf_mid_res_prefix = f"mid_block.resnets.{i}."
796
- sd_mid_res_prefix = f"mid.block_{i+1}."
797
- vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
798
-
799
- vae_conversion_map_attn = [
800
- # (stable-diffusion, HF Diffusers)
801
- ("norm.", "group_norm."),
802
- ("q.", "query."),
803
- ("k.", "key."),
804
- ("v.", "value."),
805
- ("proj_out.", "proj_attn."),
806
- ]
807
-
808
- mapping = {k: k for k in vae_state_dict.keys()}
809
- for k, v in mapping.items():
810
- for sd_part, hf_part in vae_conversion_map:
811
- v = v.replace(hf_part, sd_part)
812
- mapping[k] = v
813
- for k, v in mapping.items():
814
- if "attentions" in k:
815
- for sd_part, hf_part in vae_conversion_map_attn:
816
- v = v.replace(hf_part, sd_part)
817
- mapping[k] = v
818
- new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
819
- weights_to_convert = ["q", "k", "v", "proj_out"]
820
- for k, v in new_state_dict.items():
821
- for weight_name in weights_to_convert:
822
- if f"mid.attn_1.{weight_name}.weight" in k:
823
- # print(f"Reshaping {k} for SD format")
824
- new_state_dict[k] = reshape_weight_for_sd(v)
825
-
826
- return new_state_dict
827
-
828
-
829
- # endregion
830
-
831
- # region 自作のモデル読み書きなど
832
-
833
- def is_safetensors(path):
834
- return os.path.splitext(path)[1].lower() == '.safetensors'
835
-
836
-
837
- def load_checkpoint_with_text_encoder_conversion(ckpt_path):
838
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
839
- TEXT_ENCODER_KEY_REPLACEMENTS = [
840
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
841
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
842
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
843
- ]
844
-
845
- if is_safetensors(ckpt_path):
846
- checkpoint = None
847
- state_dict = load_file(ckpt_path, "cpu")
848
- else:
849
- checkpoint = torch.load(ckpt_path, map_location="cpu")
850
- if "state_dict" in checkpoint:
851
- state_dict = checkpoint["state_dict"]
852
- else:
853
- state_dict = checkpoint
854
- checkpoint = None
855
-
856
- key_reps = []
857
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
858
- for key in state_dict.keys():
859
- if key.startswith(rep_from):
860
- new_key = rep_to + key[len(rep_from):]
861
- key_reps.append((key, new_key))
862
-
863
- for key, new_key in key_reps:
864
- state_dict[new_key] = state_dict[key]
865
- del state_dict[key]
866
-
867
- return checkpoint, state_dict
868
-
869
-
870
- # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
871
- def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
872
- _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
873
- if dtype is not None:
874
- for k, v in state_dict.items():
875
- if type(v) is torch.Tensor:
876
- state_dict[k] = v.to(dtype)
877
-
878
- # Convert the UNet2DConditionModel model.
879
- unet_config = create_unet_diffusers_config(v2)
880
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
881
-
882
- unet = UNet2DConditionModel(**unet_config)
883
- info = unet.load_state_dict(converted_unet_checkpoint)
884
- print("loading u-net:", info)
885
-
886
- # Convert the VAE model.
887
- vae_config = create_vae_diffusers_config()
888
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
889
-
890
- vae = AutoencoderKL(**vae_config)
891
- info = vae.load_state_dict(converted_vae_checkpoint)
892
- print("loading vae:", info)
893
-
894
- # convert text_model
895
- if v2:
896
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
897
- cfg = CLIPTextConfig(
898
- vocab_size=49408,
899
- hidden_size=1024,
900
- intermediate_size=4096,
901
- num_hidden_layers=23,
902
- num_attention_heads=16,
903
- max_position_embeddings=77,
904
- hidden_act="gelu",
905
- layer_norm_eps=1e-05,
906
- dropout=0.0,
907
- attention_dropout=0.0,
908
- initializer_range=0.02,
909
- initializer_factor=1.0,
910
- pad_token_id=1,
911
- bos_token_id=0,
912
- eos_token_id=2,
913
- model_type="clip_text_model",
914
- projection_dim=512,
915
- torch_dtype="float32",
916
- transformers_version="4.25.0.dev0",
917
- )
918
- text_model = CLIPTextModel._from_config(cfg)
919
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
920
- else:
921
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
922
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
923
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
924
- print("loading text encoder:", info)
925
-
926
- return text_model, vae, unet
927
-
928
-
929
- def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
930
- def convert_key(key):
931
- # position_idsの除去
932
- if ".position_ids" in key:
933
- return None
934
-
935
- # common
936
- key = key.replace("text_model.encoder.", "transformer.")
937
- key = key.replace("text_model.", "")
938
- if "layers" in key:
939
- # resblocks conversion
940
- key = key.replace(".layers.", ".resblocks.")
941
- if ".layer_norm" in key:
942
- key = key.replace(".layer_norm", ".ln_")
943
- elif ".mlp." in key:
944
- key = key.replace(".fc1.", ".c_fc.")
945
- key = key.replace(".fc2.", ".c_proj.")
946
- elif '.self_attn.out_proj' in key:
947
- key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
948
- elif '.self_attn.' in key:
949
- key = None # 特殊なので後で処理する
950
- else:
951
- raise ValueError(f"unexpected key in DiffUsers model: {key}")
952
- elif '.position_embedding' in key:
953
- key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
954
- elif '.token_embedding' in key:
955
- key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
956
- elif 'final_layer_norm' in key:
957
- key = key.replace("final_layer_norm", "ln_final")
958
- return key
959
-
960
- keys = list(checkpoint.keys())
961
- new_sd = {}
962
- for key in keys:
963
- new_key = convert_key(key)
964
- if new_key is None:
965
- continue
966
- new_sd[new_key] = checkpoint[key]
967
-
968
- # attnの変換
969
- for key in keys:
970
- if 'layers' in key and 'q_proj' in key:
971
- # 三つを結合
972
- key_q = key
973
- key_k = key.replace("q_proj", "k_proj")
974
- key_v = key.replace("q_proj", "v_proj")
975
-
976
- value_q = checkpoint[key_q]
977
- value_k = checkpoint[key_k]
978
- value_v = checkpoint[key_v]
979
- value = torch.cat([value_q, value_k, value_v])
980
-
981
- new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
982
- new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
983
- new_sd[new_key] = value
984
-
985
- # 最後の層などを捏造するか
986
- if make_dummy_weights:
987
- print("make dummy weights for resblock.23, text_projection and logit scale.")
988
- keys = list(new_sd.keys())
989
- for key in keys:
990
- if key.startswith("transformer.resblocks.22."):
991
- new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
992
-
993
- # Diffusersに含まれない重みを作っておく
994
- new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
995
- new_sd['logit_scale'] = torch.tensor(1)
996
-
997
- return new_sd
998
-
999
-
1000
- def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
1001
- if ckpt_path is not None:
1002
- # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1003
- checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1004
- if checkpoint is None: # safetensors または state_dictのckpt
1005
- checkpoint = {}
1006
- strict = False
1007
- else:
1008
- strict = True
1009
- if "state_dict" in state_dict:
1010
- del state_dict["state_dict"]
1011
- else:
1012
- # 新しく作る
1013
- assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1014
- checkpoint = {}
1015
- state_dict = {}
1016
- strict = False
1017
-
1018
- def update_sd(prefix, sd):
1019
- for k, v in sd.items():
1020
- key = prefix + k
1021
- assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1022
- if save_dtype is not None:
1023
- v = v.detach().clone().to("cpu").to(save_dtype)
1024
- state_dict[key] = v
1025
-
1026
- # Convert the UNet model
1027
- unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1028
- update_sd("model.diffusion_model.", unet_state_dict)
1029
-
1030
- # Convert the text encoder model
1031
- if v2:
1032
- make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1033
- text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1034
- update_sd("cond_stage_model.model.", text_enc_dict)
1035
- else:
1036
- text_enc_dict = text_encoder.state_dict()
1037
- update_sd("cond_stage_model.transformer.", text_enc_dict)
1038
-
1039
- # Convert the VAE
1040
- if vae is not None:
1041
- vae_dict = convert_vae_state_dict(vae.state_dict())
1042
- update_sd("first_stage_model.", vae_dict)
1043
-
1044
- # Put together new checkpoint
1045
- key_count = len(state_dict.keys())
1046
- new_ckpt = {'state_dict': state_dict}
1047
-
1048
- if 'epoch' in checkpoint:
1049
- epochs += checkpoint['epoch']
1050
- if 'global_step' in checkpoint:
1051
- steps += checkpoint['global_step']
1052
-
1053
- new_ckpt['epoch'] = epochs
1054
- new_ckpt['global_step'] = steps
1055
-
1056
- if is_safetensors(output_file):
1057
- # TODO Tensor以外のdictの値を削除したほうがいいか
1058
- save_file(state_dict, output_file)
1059
- else:
1060
- torch.save(new_ckpt, output_file)
1061
-
1062
- return key_count
1063
-
1064
-
1065
- def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1066
- if pretrained_model_name_or_path is None:
1067
- # load default settings for v1/v2
1068
- if v2:
1069
- pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1070
- else:
1071
- pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1072
-
1073
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1074
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1075
- if vae is None:
1076
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1077
-
1078
- pipeline = StableDiffusionPipeline(
1079
- unet=unet,
1080
- text_encoder=text_encoder,
1081
- vae=vae,
1082
- scheduler=scheduler,
1083
- tokenizer=tokenizer,
1084
- safety_checker=None,
1085
- feature_extractor=None,
1086
- requires_safety_checker=None,
1087
- )
1088
- pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1089
-
1090
-
1091
- VAE_PREFIX = "first_stage_model."
1092
-
1093
-
1094
- def load_vae(vae_id, dtype):
1095
- print(f"load VAE: {vae_id}")
1096
- if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1097
- # Diffusers local/remote
1098
- try:
1099
- vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1100
- except EnvironmentError as e:
1101
- print(f"exception occurs in loading vae: {e}")
1102
- print("retry with subfolder='vae'")
1103
- vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1104
- return vae
1105
-
1106
- # local
1107
- vae_config = create_vae_diffusers_config()
1108
-
1109
- if vae_id.endswith(".bin"):
1110
- # SD 1.5 VAE on Huggingface
1111
- converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1112
- else:
1113
- # StableDiffusion
1114
- vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1115
- else torch.load(vae_id, map_location="cpu"))
1116
- vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1117
-
1118
- # vae only or full model
1119
- full_model = False
1120
- for vae_key in vae_sd:
1121
- if vae_key.startswith(VAE_PREFIX):
1122
- full_model = True
1123
- break
1124
- if not full_model:
1125
- sd = {}
1126
- for key, value in vae_sd.items():
1127
- sd[VAE_PREFIX + key] = value
1128
- vae_sd = sd
1129
- del sd
1130
-
1131
- # Convert the VAE model.
1132
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1133
-
1134
- vae = AutoencoderKL(**vae_config)
1135
- vae.load_state_dict(converted_vae_checkpoint)
1136
- return vae
1137
-
1138
- # endregion
1139
-
1140
-
1141
- def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1142
- max_width, max_height = max_reso
1143
- max_area = (max_width // divisible) * (max_height // divisible)
1144
-
1145
- resos = set()
1146
-
1147
- size = int(math.sqrt(max_area)) * divisible
1148
- resos.add((size, size))
1149
-
1150
- size = min_size
1151
- while size <= max_size:
1152
- width = size
1153
- height = min(max_size, (max_area // (width // divisible)) * divisible)
1154
- resos.add((width, height))
1155
- resos.add((height, width))
1156
-
1157
- # # make additional resos
1158
- # if width >= height and width - divisible >= min_size:
1159
- # resos.add((width - divisible, height))
1160
- # resos.add((height, width - divisible))
1161
- # if height >= width and height - divisible >= min_size:
1162
- # resos.add((width, height - divisible))
1163
- # resos.add((height - divisible, width))
1164
-
1165
- size += divisible
1166
-
1167
- resos = list(resos)
1168
- resos.sort()
1169
-
1170
- aspect_ratios = [w / h for w, h in resos]
1171
- return resos, aspect_ratios
1172
-
1173
-
1174
- if __name__ == '__main__':
1175
- resos, aspect_ratios = make_bucket_resolutions((512, 768))
1176
- print(len(resos))
1177
- print(resos)
1178
- print(aspect_ratios)
1179
-
1180
- ars = set()
1181
- for ar in aspect_ratios:
1182
- if ar in ars:
1183
- print("error! duplicate ar:", ar)
1184
- ars.add(ar)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
locon/kohya_utils.py DELETED
@@ -1,48 +0,0 @@
1
- # part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
2
-
3
- import hashlib
4
- import safetensors
5
- from io import BytesIO
6
-
7
-
8
- def addnet_hash_legacy(b):
9
- """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
10
- m = hashlib.sha256()
11
-
12
- b.seek(0x100000)
13
- m.update(b.read(0x10000))
14
- return m.hexdigest()[0:8]
15
-
16
-
17
- def addnet_hash_safetensors(b):
18
- """New model hash used by sd-webui-additional-networks for .safetensors format files"""
19
- hash_sha256 = hashlib.sha256()
20
- blksize = 1024 * 1024
21
-
22
- b.seek(0)
23
- header = b.read(8)
24
- n = int.from_bytes(header, "little")
25
-
26
- offset = n + 8
27
- b.seek(offset)
28
- for chunk in iter(lambda: b.read(blksize), b""):
29
- hash_sha256.update(chunk)
30
-
31
- return hash_sha256.hexdigest()
32
-
33
-
34
- def precalculate_safetensors_hashes(tensors, metadata):
35
- """Precalculate the model hashes needed by sd-webui-additional-networks to
36
- save time on indexing the model later."""
37
-
38
- # Because writing user metadata to the file can change the result of
39
- # sd_models.model_hash(), only retain the training metadata for purposes of
40
- # calculating the hash, as they are meant to be immutable
41
- metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
42
-
43
- bytes = safetensors.torch.save(tensors, metadata)
44
- b = BytesIO(bytes)
45
-
46
- model_hash = addnet_hash_safetensors(b)
47
- legacy_hash = addnet_hash_legacy(b)
48
- return model_hash, legacy_hash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
locon/locon.py DELETED
@@ -1,53 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
-
8
- class LoConModule(nn.Module):
9
- """
10
- modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
11
- """
12
-
13
- def __init__(self, lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
14
- """ if alpha == 0 or None, alpha is rank (no scaling). """
15
- super().__init__()
16
- self.lora_name = lora_name
17
- self.lora_dim = lora_dim
18
-
19
- if org_module.__class__.__name__ == 'Conv2d':
20
- # For general LoCon
21
- in_dim = org_module.in_channels
22
- k_size = org_module.kernel_size
23
- stride = org_module.stride
24
- padding = org_module.padding
25
- out_dim = org_module.out_channels
26
- self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
27
- self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
28
- else:
29
- in_dim = org_module.in_features
30
- out_dim = org_module.out_features
31
- self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
32
- self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
33
-
34
- if type(alpha) == torch.Tensor:
35
- alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
36
- alpha = lora_dim if alpha is None or alpha == 0 else alpha
37
- self.scale = alpha / self.lora_dim
38
- self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
39
-
40
- # same as microsoft's
41
- torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
42
- torch.nn.init.zeros_(self.lora_up.weight)
43
-
44
- self.multiplier = multiplier
45
- self.org_module = org_module # remove in applying
46
-
47
- def apply_to(self):
48
- self.org_forward = self.org_module.forward
49
- self.org_module.forward = self.forward
50
- del self.org_module
51
-
52
- def forward(self, x):
53
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
locon/locon_kohya.py DELETED
@@ -1,243 +0,0 @@
1
- # LoCon network module
2
- # reference:
3
- # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
- # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
- # https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
6
-
7
- import math
8
- import os
9
- from typing import List
10
- import torch
11
-
12
- from .kohya_utils import *
13
- from .locon import LoConModule
14
-
15
-
16
- def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
17
- if network_dim is None:
18
- network_dim = 4 # default
19
- conv_dim = kwargs.get('conv_dim', network_dim)
20
- conv_alpha = kwargs.get('conv_alpha', network_alpha)
21
- network = LoRANetwork(
22
- text_encoder, unet,
23
- multiplier=multiplier,
24
- lora_dim=network_dim, conv_lora_dim=conv_dim,
25
- alpha=network_alpha, conv_alpha=conv_alpha
26
- )
27
- return network
28
-
29
-
30
- def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
31
- if os.path.splitext(file)[1] == '.safetensors':
32
- from safetensors.torch import load_file, safe_open
33
- weights_sd = load_file(file)
34
- else:
35
- weights_sd = torch.load(file, map_location='cpu')
36
-
37
- # get dim (rank)
38
- network_alpha = None
39
- network_dim = None
40
- for key, value in weights_sd.items():
41
- if network_alpha is None and 'alpha' in key:
42
- network_alpha = value
43
- if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
44
- network_dim = value.size()[0]
45
-
46
- if network_alpha is None:
47
- network_alpha = network_dim
48
-
49
- network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
50
- network.weights_sd = weights_sd
51
- return network
52
-
53
- torch.nn.Conv2d
54
- class LoRANetwork(torch.nn.Module):
55
- '''
56
- LoRA + LoCon
57
- '''
58
- # Ignore proj_in or proj_out, their channels is only a few.
59
- UNET_TARGET_REPLACE_MODULE = [
60
- "Transformer2DModel",
61
- "Attention",
62
- "ResnetBlock2D",
63
- "Downsample2D",
64
- "Upsample2D"
65
- ]
66
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
67
- LORA_PREFIX_UNET = 'lora_unet'
68
- LORA_PREFIX_TEXT_ENCODER = 'lora_te'
69
-
70
- def __init__(
71
- self,
72
- text_encoder, unet,
73
- multiplier=1.0,
74
- lora_dim=4, conv_lora_dim=4,
75
- alpha=1, conv_alpha=1
76
- ) -> None:
77
- super().__init__()
78
- self.multiplier = multiplier
79
- self.lora_dim = lora_dim
80
- self.conv_lora_dim = int(conv_lora_dim)
81
- if self.conv_lora_dim != self.lora_dim:
82
- print('Apply different lora dim for conv layer')
83
- print(f'LoCon Dim: {conv_lora_dim}, LoRA Dim: {lora_dim}')
84
- self.alpha = alpha
85
- self.conv_alpha = float(conv_alpha)
86
- if self.alpha != self.conv_alpha:
87
- print('Apply different alpha value for conv layer')
88
- print(f'LoCon alpha: {conv_alpha}, LoRA alpha: {alpha}')
89
-
90
- # create module instances
91
- def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoConModule]:
92
- print('Create LoCon Module')
93
- loras = []
94
- for name, module in root_module.named_modules():
95
- if module.__class__.__name__ in target_replace_modules:
96
- for child_name, child_module in module.named_modules():
97
- lora_name = prefix + '.' + name + '.' + child_name
98
- lora_name = lora_name.replace('.', '_')
99
- if child_module.__class__.__name__ == 'Linear':
100
- lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
101
- elif child_module.__class__.__name__ == 'Conv2d':
102
- k_size, *_ = child_module.kernel_size
103
- if k_size==1:
104
- lora = LoConModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
105
- else:
106
- lora = LoConModule(lora_name, child_module, self.multiplier, self.conv_lora_dim, self.conv_alpha)
107
- else:
108
- continue
109
- loras.append(lora)
110
- return loras
111
-
112
- self.text_encoder_loras = create_modules(
113
- LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
114
- text_encoder,
115
- LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
116
- )
117
- print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
118
-
119
- self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
120
- print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
121
-
122
- self.weights_sd = None
123
-
124
- # assertion
125
- names = set()
126
- for lora in self.text_encoder_loras + self.unet_loras:
127
- assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
128
- names.add(lora.lora_name)
129
-
130
- def set_multiplier(self, multiplier):
131
- self.multiplier = multiplier
132
- for lora in self.text_encoder_loras + self.unet_loras:
133
- lora.multiplier = self.multiplier
134
-
135
- def load_weights(self, file):
136
- if os.path.splitext(file)[1] == '.safetensors':
137
- from safetensors.torch import load_file, safe_open
138
- self.weights_sd = load_file(file)
139
- else:
140
- self.weights_sd = torch.load(file, map_location='cpu')
141
-
142
- def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
143
- if self.weights_sd:
144
- weights_has_text_encoder = weights_has_unet = False
145
- for key in self.weights_sd.keys():
146
- if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
147
- weights_has_text_encoder = True
148
- elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
149
- weights_has_unet = True
150
-
151
- if apply_text_encoder is None:
152
- apply_text_encoder = weights_has_text_encoder
153
- else:
154
- assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
155
-
156
- if apply_unet is None:
157
- apply_unet = weights_has_unet
158
- else:
159
- assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
160
- else:
161
- assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
162
-
163
- if apply_text_encoder:
164
- print("enable LoRA for text encoder")
165
- else:
166
- self.text_encoder_loras = []
167
-
168
- if apply_unet:
169
- print("enable LoRA for U-Net")
170
- else:
171
- self.unet_loras = []
172
-
173
- for lora in self.text_encoder_loras + self.unet_loras:
174
- lora.apply_to()
175
- self.add_module(lora.lora_name, lora)
176
-
177
- if self.weights_sd:
178
- # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
179
- info = self.load_state_dict(self.weights_sd, False)
180
- print(f"weights are loaded: {info}")
181
-
182
- def enable_gradient_checkpointing(self):
183
- # not supported
184
- pass
185
-
186
- def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
187
- def enumerate_params(loras):
188
- params = []
189
- for lora in loras:
190
- params.extend(lora.parameters())
191
- return params
192
-
193
- self.requires_grad_(True)
194
- all_params = []
195
-
196
- if self.text_encoder_loras:
197
- param_data = {'params': enumerate_params(self.text_encoder_loras)}
198
- if text_encoder_lr is not None:
199
- param_data['lr'] = text_encoder_lr
200
- all_params.append(param_data)
201
-
202
- if self.unet_loras:
203
- param_data = {'params': enumerate_params(self.unet_loras)}
204
- if unet_lr is not None:
205
- param_data['lr'] = unet_lr
206
- all_params.append(param_data)
207
-
208
- return all_params
209
-
210
- def prepare_grad_etc(self, text_encoder, unet):
211
- self.requires_grad_(True)
212
-
213
- def on_epoch_start(self, text_encoder, unet):
214
- self.train()
215
-
216
- def get_trainable_params(self):
217
- return self.parameters()
218
-
219
- def save_weights(self, file, dtype, metadata):
220
- if metadata is not None and len(metadata) == 0:
221
- metadata = None
222
-
223
- state_dict = self.state_dict()
224
-
225
- if dtype is not None:
226
- for key in list(state_dict.keys()):
227
- v = state_dict[key]
228
- v = v.detach().clone().to("cpu").to(dtype)
229
- state_dict[key] = v
230
-
231
- if os.path.splitext(file)[1] == '.safetensors':
232
- from safetensors.torch import save_file
233
-
234
- # Precalculate model hashes to save time on indexing
235
- if metadata is None:
236
- metadata = {}
237
- model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
238
- metadata["sshs_model_hash"] = model_hash
239
- metadata["sshs_legacy_hash"] = legacy_hash
240
-
241
- save_file(state_dict, file, metadata)
242
- else:
243
- torch.save(state_dict, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
locon/utils.py DELETED
@@ -1,148 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- import torch.linalg as linalg
6
-
7
- from tqdm import tqdm
8
-
9
-
10
- def extract_conv(
11
- weight: nn.Parameter|torch.Tensor,
12
- lora_rank = 8
13
- ) -> tuple[nn.Parameter, nn.Parameter]:
14
- out_ch, in_ch, kernel_size, _ = weight.shape
15
- lora_rank = min(out_ch, in_ch, lora_rank)
16
-
17
- U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
18
-
19
- U = U[:, :lora_rank]
20
- S = S[:lora_rank]
21
- U = U @ torch.diag(S)
22
- Vh = Vh[:lora_rank, :]
23
-
24
- extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).cpu()
25
- extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).cpu()
26
- del U, S, Vh, weight
27
- return extract_weight_A, extract_weight_B
28
-
29
-
30
- def merge_conv(
31
- weight_a: nn.Parameter|torch.Tensor,
32
- weight_b: nn.Parameter|torch.Tensor,
33
- ):
34
- rank, in_ch, kernel_size, k_ = weight_a.shape
35
- out_ch, rank_, _, _ = weight_b.shape
36
-
37
- assert rank == rank_ and kernel_size == k_
38
-
39
- merged = weight_b.reshape(out_ch, -1) @ weight_a.reshape(rank, -1)
40
- weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
41
- return weight
42
-
43
-
44
- def extract_linear(
45
- weight: nn.Parameter|torch.Tensor,
46
- lora_rank = 8
47
- ) -> tuple[nn.Parameter, nn.Parameter]:
48
- out_ch, in_ch = weight.shape
49
- lora_rank = min(out_ch, in_ch, lora_rank)
50
-
51
- U, S, Vh = linalg.svd(weight)
52
-
53
- U = U[:, :lora_rank]
54
- S = S[:lora_rank]
55
- U = U @ torch.diag(S)
56
- Vh = Vh[:lora_rank, :]
57
-
58
- extract_weight_A = Vh.reshape(lora_rank, in_ch).cpu()
59
- extract_weight_B = U.reshape(out_ch, lora_rank).cpu()
60
- del U, S, Vh, weight
61
- return extract_weight_A, extract_weight_B
62
-
63
-
64
- def merge_linear(
65
- weight_a: nn.Parameter|torch.Tensor,
66
- weight_b: nn.Parameter|torch.Tensor,
67
- ):
68
- rank, in_ch = weight_a.shape
69
- out_ch, rank_ = weight_b.shape
70
-
71
- assert rank == rank_
72
-
73
- weight = weight_b @ weight_a
74
- return weight
75
-
76
-
77
- def extract_diff(
78
- base_model,
79
- db_model,
80
- lora_dim=4,
81
- conv_lora_dim=4,
82
- extract_device = 'cuda',
83
- ):
84
- UNET_TARGET_REPLACE_MODULE = [
85
- "Transformer2DModel",
86
- "Attention",
87
- "ResnetBlock2D",
88
- "Downsample2D",
89
- "Upsample2D"
90
- ]
91
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
92
- LORA_PREFIX_UNET = 'lora_unet'
93
- LORA_PREFIX_TEXT_ENCODER = 'lora_te'
94
- def make_state_dict(
95
- prefix,
96
- root_module: torch.nn.Module,
97
- target_module: torch.nn.Module,
98
- target_replace_modules
99
- ):
100
- loras = {}
101
- temp = {}
102
-
103
- for name, module in root_module.named_modules():
104
- if module.__class__.__name__ in target_replace_modules:
105
- temp[name] = {}
106
- for child_name, child_module in module.named_modules():
107
- if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
108
- continue
109
- temp[name][child_name] = child_module.weight
110
-
111
- for name, module in tqdm(list(target_module.named_modules())):
112
- if name in temp:
113
- weights = temp[name]
114
- for child_name, child_module in module.named_modules():
115
- lora_name = prefix + '.' + name + '.' + child_name
116
- lora_name = lora_name.replace('.', '_')
117
- if child_module.__class__.__name__ == 'Linear':
118
- extract_a, extract_b = extract_linear(
119
- (child_module.weight - weights[child_name]),
120
- lora_dim
121
- )
122
- elif child_module.__class__.__name__ == 'Conv2d':
123
- extract_a, extract_b = extract_conv(
124
- (child_module.weight - weights[child_name]),
125
- conv_lora_dim
126
- )
127
- else:
128
- continue
129
-
130
- loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().half()
131
- loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().half()
132
- loras[f'{lora_name}.alpha'] = torch.Tensor([int(extract_a.shape[0])]).detach().cpu().half()
133
- del extract_a, extract_b
134
- return loras
135
-
136
- text_encoder_loras = make_state_dict(
137
- LORA_PREFIX_TEXT_ENCODER,
138
- base_model[0], db_model[0],
139
- TEXT_ENCODER_TARGET_REPLACE_MODULE
140
- )
141
-
142
- unet_loras = make_state_dict(
143
- LORA_PREFIX_UNET,
144
- base_model[2], db_model[2],
145
- UNET_TARGET_REPLACE_MODULE
146
- )
147
- print(len(text_encoder_loras), len(unet_loras))
148
- return text_encoder_loras|unet_loras