abc commited on
Commit
ac2ea1d
·
1 Parent(s): 0380cfd

Upload 10 files

Browse files
lycoris/dylora.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ from collections import OrderedDict, abc as container_abcs
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class DyLoraModule(nn.Module):
12
+ """
13
+ Hadamard product Implementaion for Dynamic Low Rank adaptation
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ lora_name,
19
+ org_module: nn.Module,
20
+ multiplier=1.0,
21
+ lora_dim=4, alpha=1,
22
+ dropout=0.,
23
+ use_cp=False,
24
+ block_size=1,
25
+ **kwargs,
26
+ ):
27
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
28
+ super().__init__()
29
+ self.lora_name = lora_name
30
+ self.lora_dim = lora_dim
31
+ assert lora_dim % block_size == 0, 'lora_dim must be a multiple of block_size'
32
+ self.block_count = lora_dim//block_size
33
+ self.block_size = block_size
34
+
35
+ self.shape = org_module.weight.shape
36
+ if org_module.__class__.__name__ == 'Conv2d':
37
+ in_dim = org_module.in_channels
38
+ k_size = org_module.kernel_size
39
+ out_dim = org_module.out_channels
40
+ shape = (out_dim, in_dim*k_size[0]*k_size[1])
41
+ self.op = F.conv2d
42
+ self.extra_args = {
43
+ "stride": org_module.stride,
44
+ "padding": org_module.padding,
45
+ "dilation": org_module.dilation,
46
+ "groups": org_module.groups
47
+ }
48
+ else:
49
+ in_dim = org_module.in_features
50
+ out_dim = org_module.out_features
51
+ shape = (out_dim, in_dim)
52
+ self.op = F.linear
53
+ self.extra_args = {}
54
+
55
+ self.lora_dim = lora_dim
56
+ self.up_list = nn.ParameterList([
57
+ torch.empty(shape[0], 1)
58
+ for i in range(lora_dim)
59
+ ])
60
+ self.up_list.requires_grad_(False)
61
+ self.up_update = [
62
+ torch.zeros_like(self.up_list[i])
63
+ for i in range(lora_dim)
64
+ ]
65
+
66
+ self.down_list = nn.ParameterList([
67
+ torch.empty(1, shape[1])
68
+ for i in range(lora_dim)
69
+ ])
70
+ self.down_list.requires_grad_(False)
71
+ self.down_update = [
72
+ torch.zeros_like(self.down_list[i])
73
+ for i in range(lora_dim)
74
+ ]
75
+
76
+ self.index = 0
77
+
78
+ if type(alpha) == torch.Tensor:
79
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
80
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
81
+ self.scale = alpha / self.lora_dim
82
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
83
+
84
+ # Need more experiences on init method
85
+
86
+ for v in self.down_list:
87
+ torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5))
88
+ for v in self.up_list:
89
+ torch.nn.init.zeros_(v)
90
+ for i, v in enumerate(self.up_update):
91
+ v.copy_(self.up_list[i])
92
+ for i, v in enumerate(self.down_update):
93
+ v.copy_(self.down_list[i])
94
+
95
+ self.multiplier = multiplier
96
+ self.org_module = [org_module] # remove in applying
97
+ self.grad_ckpt = False
98
+
99
+ self.apply_train(0)
100
+
101
+ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
102
+ # TODO: Remove `args` and the parsing logic when BC allows.
103
+ if len(args) > 0:
104
+ if destination is None:
105
+ destination = args[0]
106
+ if len(args) > 1 and prefix == '':
107
+ prefix = args[1]
108
+ if len(args) > 2 and keep_vars is False:
109
+ keep_vars = args[2]
110
+ # DeprecationWarning is ignored by default
111
+
112
+ if destination is None:
113
+ destination = OrderedDict()
114
+ destination._metadata = OrderedDict()
115
+
116
+ local_metadata = dict(version=self._version)
117
+ if hasattr(destination, "_metadata"):
118
+ destination._metadata[prefix[:-1]] = local_metadata
119
+
120
+ destination[f'{prefix}alpha'] = self.alpha
121
+ destination[f'{prefix}lora_up.weight'] = nn.Parameter(
122
+ torch.concat(self.up_update, dim=1)
123
+ )
124
+ destination[f'{prefix}lora_down.weight'] = nn.Parameter(
125
+ torch.concat(self.down_update)
126
+ )
127
+ return destination
128
+
129
+ def apply_to(self):
130
+ self.org_module[0].forward = self.forward
131
+
132
+ def apply_train(self, b:int):
133
+ self.up_list.requires_grad_(False)
134
+ self.down_list.requires_grad_(False)
135
+
136
+ for i in range(self.index*self.block_size, (self.index+1)*self.block_size):
137
+ self.up_update[i].copy_(self.up_list[i])
138
+ self.down_update[i].copy_(self.down_list[i])
139
+
140
+ for i in range(b*self.block_size, (b+1)*self.block_size):
141
+ self.up_list[i].copy_(self.up_update[i])
142
+ self.down_list[i].copy_(self.down_update[i])
143
+
144
+ self.up_list.requires_grad_(True)
145
+ self.down_list.requires_grad_(True)
146
+ self.index = b
147
+
148
+ @torch.enable_grad()
149
+ def forward(self, x):
150
+ b = random.randint(0, self.block_count-1)
151
+ if self.up_update[b].device != self.up_list[b].device:
152
+ device = self.up_list[b].device
153
+ for i in range(self.lora_dim):
154
+ self.up_update[i] = self.up_update[i].to(device)
155
+ self.down_update[i] = self.down_update[i].to(device)
156
+
157
+ if self.training:
158
+ self.apply_train(b)
159
+ down = torch.concat(
160
+ list(self.down_update[:b*self.block_size])
161
+ + list(self.down_list[b*self.block_size:(b+1)*self.block_size])
162
+ )
163
+ up = torch.concat(
164
+ list(self.up_update[:b*self.block_size])
165
+ + list(self.up_list[b*self.block_size:(b+1)*self.block_size]),
166
+ dim=1
167
+ )
168
+
169
+ bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
170
+ return self.op(
171
+ x,
172
+ self.org_module[0].weight + (up@down).view(self.shape) * self.alpha/(b+1),
173
+ bias,
174
+ **self.extra_args
175
+ )
lycoris/ia3.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class IA3Module(nn.Module):
9
+ """
10
+ Hadamard product Implementaion for Low Rank Adaptation
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ lora_name,
16
+ org_module: nn.Module,
17
+ multiplier=1.0,
18
+ train_on_input=False,
19
+ **kwargs
20
+ ):
21
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
22
+ super().__init__()
23
+ self.lora_name = lora_name
24
+ self.cp=False
25
+
26
+ self.shape = org_module.weight.shape
27
+ if org_module.__class__.__name__ == 'Conv2d':
28
+ in_dim = org_module.in_channels
29
+ out_dim = org_module.out_channels
30
+ if train_on_input:
31
+ train_dim = in_dim
32
+ else:
33
+ train_dim = out_dim
34
+ self.weight = nn.Parameter(torch.empty(1, train_dim, 1, 1))
35
+ else:
36
+ in_dim = org_module.in_features
37
+ out_dim = org_module.out_features
38
+ if train_on_input:
39
+ train_dim = in_dim
40
+ else:
41
+ train_dim = out_dim
42
+
43
+ self.weight = nn.Parameter(torch.empty(train_dim))
44
+
45
+ # Need more experiences on init method
46
+ torch.nn.init.constant_(self.weight, 0)
47
+
48
+ self.multiplier = multiplier
49
+ self.org_forward = None
50
+ self.org_module = [org_module] # remove in applying
51
+ self.grad_ckpt = False
52
+ self.train_input = train_on_input
53
+ self.register_buffer('on_input', torch.tensor(int(train_on_input)))
54
+
55
+ def apply_to(self):
56
+ self.org_forward = self.org_module[0].forward
57
+ self.org_module[0].forward = self.forward
58
+
59
+ @torch.enable_grad()
60
+ def forward(self, x):
61
+ if self.train_input:
62
+ x = x * (1 + self.weight * self.multiplier)
63
+ out = self.org_forward(x)
64
+ dtype = out.dtype
65
+ if not self.train_input:
66
+ out = out * (1 + self.weight * self.multiplier)
67
+ out = out.to(dtype)
68
+ return out
lycoris/kohya.py CHANGED
@@ -13,6 +13,9 @@ import torch
13
  from .kohya_utils import *
14
  from .locon import LoConModule
15
  from .loha import LohaModule
 
 
 
16
 
17
 
18
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
@@ -21,39 +24,55 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
21
  conv_dim = int(kwargs.get('conv_dim', network_dim))
22
  conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
23
  dropout = float(kwargs.get('dropout', 0.))
24
- algo = kwargs.get('algo', 'lora')
25
- disable_cp = kwargs.get('disable_conv_cp', False)
 
 
26
  network_module = {
27
  'lora': LoConModule,
 
28
  'loha': LohaModule,
 
 
 
29
  }[algo]
30
 
31
  print(f'Using rank adaptation algo: {algo}')
32
 
33
- if (algo == 'loha'
34
  and not kwargs.get('no_dim_warn', False)
35
  and (network_dim>64 or conv_dim>64)):
36
  print('='*20 + 'WARNING' + '='*20)
37
- warn(
38
- (
39
- "You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)"
40
- "in Hadamard Product representation!\n"
41
- "Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n"
42
- "If you just want to use high dim loha, please consider use lower lr."
43
- ),
44
- stacklevel=2,
45
- )
 
46
  print('='*20 + 'WARNING' + '='*20)
47
 
48
- network = LycorisNetwork(
49
- text_encoder, unet,
50
- multiplier=multiplier,
51
- lora_dim=network_dim, conv_lora_dim=conv_dim,
52
- alpha=network_alpha, conv_alpha=conv_alpha,
53
- dropout=dropout,
54
- use_cp=(not bool(disable_cp)),
55
- network_module=network_module
56
- )
 
 
 
 
 
 
 
 
 
57
 
58
  return network
59
 
@@ -86,8 +105,9 @@ class LycorisNetwork(torch.nn.Module):
86
  multiplier=1.0,
87
  lora_dim=4, conv_lora_dim=4,
88
  alpha=1, conv_alpha=1,
89
- use_cp = True,
90
  dropout = 0, network_module = LoConModule,
 
91
  ) -> None:
92
  super().__init__()
93
  self.multiplier = multiplier
@@ -124,19 +144,25 @@ class LycorisNetwork(torch.nn.Module):
124
  if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
125
  lora = network_module(
126
  lora_name, child_module, self.multiplier,
127
- self.lora_dim, self.alpha, self.dropout, use_cp
 
 
128
  )
129
  elif child_module.__class__.__name__ == 'Conv2d':
130
  k_size, *_ = child_module.kernel_size
131
  if k_size==1 and lora_dim>0:
132
  lora = network_module(
133
  lora_name, child_module, self.multiplier,
134
- self.lora_dim, self.alpha, self.dropout, use_cp
 
 
135
  )
136
  elif conv_lora_dim>0:
137
  lora = network_module(
138
  lora_name, child_module, self.multiplier,
139
- self.conv_lora_dim, self.conv_alpha, self.dropout, use_cp
 
 
140
  )
141
  else:
142
  continue
@@ -149,19 +175,25 @@ class LycorisNetwork(torch.nn.Module):
149
  if module.__class__.__name__ == 'Linear' and lora_dim>0:
150
  lora = network_module(
151
  lora_name, module, self.multiplier,
152
- self.lora_dim, self.alpha, self.dropout, use_cp
 
 
153
  )
154
  elif module.__class__.__name__ == 'Conv2d':
155
  k_size, *_ = module.kernel_size
156
  if k_size==1 and lora_dim>0:
157
  lora = network_module(
158
  lora_name, module, self.multiplier,
159
- self.lora_dim, self.alpha, self.dropout, use_cp
 
 
160
  )
161
  elif conv_lora_dim>0:
162
  lora = network_module(
163
  lora_name, module, self.multiplier,
164
- self.conv_lora_dim, self.conv_alpha, self.dropout, use_cp
 
 
165
  )
166
  else:
167
  continue
@@ -306,3 +338,205 @@ class LycorisNetwork(torch.nn.Module):
306
  save_file(state_dict, file, metadata)
307
  else:
308
  torch.save(state_dict, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from .kohya_utils import *
14
  from .locon import LoConModule
15
  from .loha import LohaModule
16
+ from .ia3 import IA3Module
17
+ from .lokr import LokrModule
18
+ from .dylora import DyLoraModule
19
 
20
 
21
  def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
 
24
  conv_dim = int(kwargs.get('conv_dim', network_dim))
25
  conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
26
  dropout = float(kwargs.get('dropout', 0.))
27
+ algo = kwargs.get('algo', 'lora').lower()
28
+ use_cp = (not kwargs.get('disable_conv_cp', True)
29
+ or kwargs.get('use_conv_cp', False))
30
+ block_size = int(kwargs.get('block_size', 4))
31
  network_module = {
32
  'lora': LoConModule,
33
+ 'locon': LoConModule,
34
  'loha': LohaModule,
35
+ 'ia3': IA3Module,
36
+ 'lokr': LokrModule,
37
+ 'dylora': DyLoraModule,
38
  }[algo]
39
 
40
  print(f'Using rank adaptation algo: {algo}')
41
 
42
+ if ((algo == 'loha' or algo == 'lokr')
43
  and not kwargs.get('no_dim_warn', False)
44
  and (network_dim>64 or conv_dim>64)):
45
  print('='*20 + 'WARNING' + '='*20)
46
+ warning_type ={
47
+ 'loha': "Hadamard Product representation",
48
+ 'lokr': "Kronecker Product representation"
49
+ }
50
+ warning_msg = f"""You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)\n
51
+ in {warning_type[algo]}!\n
52
+ Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n
53
+ If you just want to use high dim {algo}, please consider use lower lr.
54
+ """
55
+ warn(warning_msg, stacklevel=2)
56
  print('='*20 + 'WARNING' + '='*20)
57
 
58
+ if algo == 'ia3':
59
+ network = IA3Network(
60
+ text_encoder, unet,
61
+ multiplier = multiplier,
62
+ )
63
+ else:
64
+ network = LycorisNetwork(
65
+ text_encoder, unet,
66
+ multiplier=multiplier,
67
+ lora_dim=network_dim, conv_lora_dim=conv_dim,
68
+ alpha=network_alpha, conv_alpha=conv_alpha,
69
+ dropout=dropout,
70
+ use_cp=use_cp,
71
+ network_module=network_module,
72
+ decompose_both=kwargs.get('decompose_both', False),
73
+ factor=kwargs.get('factor', -1),
74
+ block_size = block_size
75
+ )
76
 
77
  return network
78
 
 
105
  multiplier=1.0,
106
  lora_dim=4, conv_lora_dim=4,
107
  alpha=1, conv_alpha=1,
108
+ use_cp = False,
109
  dropout = 0, network_module = LoConModule,
110
+ **kwargs,
111
  ) -> None:
112
  super().__init__()
113
  self.multiplier = multiplier
 
144
  if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
145
  lora = network_module(
146
  lora_name, child_module, self.multiplier,
147
+ self.lora_dim, self.alpha,
148
+ self.dropout, use_cp,
149
+ **kwargs
150
  )
151
  elif child_module.__class__.__name__ == 'Conv2d':
152
  k_size, *_ = child_module.kernel_size
153
  if k_size==1 and lora_dim>0:
154
  lora = network_module(
155
  lora_name, child_module, self.multiplier,
156
+ self.lora_dim, self.alpha,
157
+ self.dropout, use_cp,
158
+ **kwargs
159
  )
160
  elif conv_lora_dim>0:
161
  lora = network_module(
162
  lora_name, child_module, self.multiplier,
163
+ self.conv_lora_dim, self.conv_alpha,
164
+ self.dropout, use_cp,
165
+ **kwargs
166
  )
167
  else:
168
  continue
 
175
  if module.__class__.__name__ == 'Linear' and lora_dim>0:
176
  lora = network_module(
177
  lora_name, module, self.multiplier,
178
+ self.lora_dim, self.alpha,
179
+ self.dropout, use_cp,
180
+ **kwargs
181
  )
182
  elif module.__class__.__name__ == 'Conv2d':
183
  k_size, *_ = module.kernel_size
184
  if k_size==1 and lora_dim>0:
185
  lora = network_module(
186
  lora_name, module, self.multiplier,
187
+ self.lora_dim, self.alpha,
188
+ self.dropout, use_cp,
189
+ **kwargs
190
  )
191
  elif conv_lora_dim>0:
192
  lora = network_module(
193
  lora_name, module, self.multiplier,
194
+ self.conv_lora_dim, self.conv_alpha,
195
+ self.dropout, use_cp,
196
+ **kwargs
197
  )
198
  else:
199
  continue
 
338
  save_file(state_dict, file, metadata)
339
  else:
340
  torch.save(state_dict, file)
341
+
342
+
343
+ class IA3Network(torch.nn.Module):
344
+ '''
345
+ IA3 network
346
+ '''
347
+ # Ignore proj_in or proj_out, their channels is only a few.
348
+ UNET_TARGET_REPLACE_MODULE = []
349
+ UNET_TARGET_REPLACE_NAME = ["to_k", "to_v", "ff.net.2"]
350
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = []
351
+ TEXT_ENCODER_TARGET_REPLACE_NAME= ["k_proj", "v_proj", "mlp.fc2"]
352
+ TRAIN_INPUT = ["mlp.fc2", "ff.net.2"]
353
+ LORA_PREFIX_UNET = 'lora_unet'
354
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
355
+
356
+ def __init__(
357
+ self,
358
+ text_encoder, unet,
359
+ multiplier=1.0,
360
+ **kwargs,
361
+ ) -> None:
362
+ super().__init__()
363
+ self.multiplier = multiplier
364
+
365
+ # create module instances
366
+ def create_modules(
367
+ prefix,
368
+ root_module: torch.nn.Module,
369
+ target_replace_modules,
370
+ target_replace_names = [],
371
+ target_train_input = []
372
+ ) -> List[IA3Module]:
373
+ print('Create LyCORIS Module')
374
+ loras = []
375
+ for name, module in root_module.named_modules():
376
+ if module.__class__.__name__ in target_replace_modules:
377
+ for child_name, child_module in module.named_modules():
378
+ lora_name = prefix + '.' + name + '.' + child_name
379
+ lora_name = lora_name.replace('.', '_')
380
+ if child_module.__class__.__name__ in {'Linear', 'Conv2d'}:
381
+ lora = IA3Module(
382
+ lora_name, child_module, self.multiplier,
383
+ name in target_train_input,
384
+ **kwargs,
385
+ )
386
+ loras.append(lora)
387
+ elif any(i in name for i in target_replace_names):
388
+ lora_name = prefix + '.' + name
389
+ lora_name = lora_name.replace('.', '_')
390
+ if module.__class__.__name__ in {'Linear', 'Conv2d'}:
391
+ lora = IA3Module(
392
+ lora_name, module, self.multiplier,
393
+ name in target_train_input,
394
+ **kwargs,
395
+ )
396
+ loras.append(lora)
397
+ return loras
398
+
399
+ self.text_encoder_loras = create_modules(
400
+ IA3Network.LORA_PREFIX_TEXT_ENCODER,
401
+ text_encoder,
402
+ IA3Network.TEXT_ENCODER_TARGET_REPLACE_MODULE,
403
+ IA3Network.TEXT_ENCODER_TARGET_REPLACE_NAME,
404
+ IA3Network.TRAIN_INPUT
405
+ )
406
+ print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
407
+
408
+ self.unet_loras = create_modules(
409
+ IA3Network.LORA_PREFIX_UNET,
410
+ unet,
411
+ IA3Network.UNET_TARGET_REPLACE_MODULE,
412
+ IA3Network.UNET_TARGET_REPLACE_NAME,
413
+ IA3Network.TRAIN_INPUT
414
+ )
415
+ print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
416
+
417
+ self.weights_sd = None
418
+
419
+ # assertion
420
+ names = set()
421
+ for lora in self.text_encoder_loras + self.unet_loras:
422
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
423
+ names.add(lora.lora_name)
424
+
425
+ def set_multiplier(self, multiplier):
426
+ self.multiplier = multiplier
427
+ for lora in self.text_encoder_loras + self.unet_loras:
428
+ lora.multiplier = self.multiplier
429
+
430
+ def load_weights(self, file):
431
+ if os.path.splitext(file)[1] == '.safetensors':
432
+ from safetensors.torch import load_file, safe_open
433
+ self.weights_sd = load_file(file)
434
+ else:
435
+ self.weights_sd = torch.load(file, map_location='cpu')
436
+
437
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
438
+ if self.weights_sd:
439
+ weights_has_text_encoder = weights_has_unet = False
440
+ for key in self.weights_sd.keys():
441
+ if key.startswith(LycorisNetwork.LORA_PREFIX_TEXT_ENCODER):
442
+ weights_has_text_encoder = True
443
+ elif key.startswith(LycorisNetwork.LORA_PREFIX_UNET):
444
+ weights_has_unet = True
445
+
446
+ if apply_text_encoder is None:
447
+ apply_text_encoder = weights_has_text_encoder
448
+ else:
449
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
450
+
451
+ if apply_unet is None:
452
+ apply_unet = weights_has_unet
453
+ else:
454
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
455
+ else:
456
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
457
+
458
+ if apply_text_encoder:
459
+ print("enable LyCORIS for text encoder")
460
+ else:
461
+ self.text_encoder_loras = []
462
+
463
+ if apply_unet:
464
+ print("enable LyCORIS for U-Net")
465
+ else:
466
+ self.unet_loras = []
467
+
468
+ for lora in self.text_encoder_loras + self.unet_loras:
469
+ lora.apply_to()
470
+ self.add_module(lora.lora_name, lora)
471
+
472
+ if self.weights_sd:
473
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
474
+ info = self.load_state_dict(self.weights_sd, False)
475
+ print(f"weights are loaded: {info}")
476
+
477
+ def enable_gradient_checkpointing(self):
478
+ # not supported
479
+ def make_ckpt(module):
480
+ if isinstance(module, torch.nn.Module):
481
+ module.grad_ckpt = True
482
+ self.apply(make_ckpt)
483
+ pass
484
+
485
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
486
+ def enumerate_params(loras):
487
+ params = []
488
+ for lora in loras:
489
+ params.extend(lora.parameters())
490
+ return params
491
+
492
+ self.requires_grad_(True)
493
+ all_params = []
494
+
495
+ if self.text_encoder_loras:
496
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
497
+ if text_encoder_lr is not None:
498
+ param_data['lr'] = text_encoder_lr
499
+ all_params.append(param_data)
500
+
501
+ if self.unet_loras:
502
+ param_data = {'params': enumerate_params(self.unet_loras)}
503
+ if unet_lr is not None:
504
+ param_data['lr'] = unet_lr
505
+ all_params.append(param_data)
506
+
507
+ return all_params
508
+
509
+ def prepare_grad_etc(self, text_encoder, unet):
510
+ self.requires_grad_(True)
511
+
512
+ def on_epoch_start(self, text_encoder, unet):
513
+ self.train()
514
+
515
+ def get_trainable_params(self):
516
+ return self.parameters()
517
+
518
+ def save_weights(self, file, dtype, metadata):
519
+ if metadata is not None and len(metadata) == 0:
520
+ metadata = None
521
+
522
+ state_dict = self.state_dict()
523
+
524
+ if dtype is not None:
525
+ for key in list(state_dict.keys()):
526
+ v = state_dict[key]
527
+ v = v.detach().clone().to("cpu").to(dtype)
528
+ state_dict[key] = v
529
+
530
+ if os.path.splitext(file)[1] == '.safetensors':
531
+ from safetensors.torch import save_file
532
+
533
+ # Precalculate model hashes to save time on indexing
534
+ if metadata is None:
535
+ metadata = {}
536
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
537
+ metadata["sshs_model_hash"] = model_hash
538
+ metadata["sshs_legacy_hash"] = legacy_hash
539
+
540
+ save_file(state_dict, file, metadata)
541
+ else:
542
+ torch.save(state_dict, file)
lycoris/kohya_model_utils.py CHANGED
@@ -1,13 +1,10 @@
1
- '''
2
- https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
3
- '''
4
  # v1: split from train_db_fixed.py.
5
  # v2: support safetensors
6
 
7
  import math
8
  import os
9
  import torch
10
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
11
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
12
  from safetensors.torch import load_file, save_file
13
 
@@ -19,7 +16,7 @@ BETA_END = 0.0120
19
  UNET_PARAMS_MODEL_CHANNELS = 320
20
  UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
21
  UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
22
- UNET_PARAMS_IMAGE_SIZE = 32 # unused
23
  UNET_PARAMS_IN_CHANNELS = 4
24
  UNET_PARAMS_OUT_CHANNELS = 4
25
  UNET_PARAMS_NUM_RES_BLOCKS = 2
@@ -48,596 +45,574 @@ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
48
 
49
 
50
  def shave_segments(path, n_shave_prefix_segments=1):
51
- """
52
- Removes segments. Positive values shave the first segments, negative shave the last segments.
53
- """
54
- if n_shave_prefix_segments >= 0:
55
- return ".".join(path.split(".")[n_shave_prefix_segments:])
56
- else:
57
- return ".".join(path.split(".")[:n_shave_prefix_segments])
58
 
59
 
60
  def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
61
- """
62
- Updates paths inside resnets to the new naming scheme (local renaming)
63
- """
64
- mapping = []
65
- for old_item in old_list:
66
- new_item = old_item.replace("in_layers.0", "norm1")
67
- new_item = new_item.replace("in_layers.2", "conv1")
68
 
69
- new_item = new_item.replace("out_layers.0", "norm2")
70
- new_item = new_item.replace("out_layers.3", "conv2")
71
 
72
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
73
- new_item = new_item.replace("skip_connection", "conv_shortcut")
74
 
75
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
76
 
77
- mapping.append({"old": old_item, "new": new_item})
78
 
79
- return mapping
80
 
81
 
82
  def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
83
- """
84
- Updates paths inside resnets to the new naming scheme (local renaming)
85
- """
86
- mapping = []
87
- for old_item in old_list:
88
- new_item = old_item
89
 
90
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
91
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
92
 
93
- mapping.append({"old": old_item, "new": new_item})
94
 
95
- return mapping
96
 
97
 
98
  def renew_attention_paths(old_list, n_shave_prefix_segments=0):
99
- """
100
- Updates paths inside attentions to the new naming scheme (local renaming)
101
- """
102
- mapping = []
103
- for old_item in old_list:
104
- new_item = old_item
105
 
106
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
107
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
108
 
109
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
110
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
111
 
112
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
 
114
- mapping.append({"old": old_item, "new": new_item})
115
 
116
- return mapping
117
 
118
 
119
  def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
- """
121
- Updates paths inside attentions to the new naming scheme (local renaming)
122
- """
123
- mapping = []
124
- for old_item in old_list:
125
- new_item = old_item
126
 
127
- new_item = new_item.replace("norm.weight", "group_norm.weight")
128
- new_item = new_item.replace("norm.bias", "group_norm.bias")
129
 
130
- new_item = new_item.replace("q.weight", "query.weight")
131
- new_item = new_item.replace("q.bias", "query.bias")
132
 
133
- new_item = new_item.replace("k.weight", "key.weight")
134
- new_item = new_item.replace("k.bias", "key.bias")
135
 
136
- new_item = new_item.replace("v.weight", "value.weight")
137
- new_item = new_item.replace("v.bias", "value.bias")
138
 
139
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
 
142
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
 
144
- mapping.append({"old": old_item, "new": new_item})
145
 
146
- return mapping
147
 
148
 
149
  def assign_to_checkpoint(
150
  paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
  ):
152
- """
153
- This does the final conversion step: take locally converted weights and apply a global renaming
154
- to them. It splits attention layers, and takes into account additional replacements
155
- that may arise.
156
 
157
- Assigns the weights to the new checkpoint.
158
- """
159
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
160
 
161
- # Splits the attention layers into three variables.
162
- if attention_paths_to_split is not None:
163
- for path, path_map in attention_paths_to_split.items():
164
- old_tensor = old_checkpoint[path]
165
- channels = old_tensor.shape[0] // 3
166
 
167
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
168
 
169
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
170
 
171
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
172
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
173
 
174
- checkpoint[path_map["query"]] = query.reshape(target_shape)
175
- checkpoint[path_map["key"]] = key.reshape(target_shape)
176
- checkpoint[path_map["value"]] = value.reshape(target_shape)
177
 
178
- for path in paths:
179
- new_path = path["new"]
180
 
181
- # These have already been assigned
182
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
183
- continue
184
 
185
- # Global renaming happens here
186
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
187
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
188
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
189
 
190
- if additional_replacements is not None:
191
- for replacement in additional_replacements:
192
- new_path = new_path.replace(replacement["old"], replacement["new"])
193
 
194
- # proj_attn.weight has to be converted from conv 1D to linear
195
- if "proj_attn.weight" in new_path:
196
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
197
- else:
198
- checkpoint[new_path] = old_checkpoint[path["old"]]
199
 
200
 
201
  def conv_attn_to_linear(checkpoint):
202
- keys = list(checkpoint.keys())
203
- attn_keys = ["query.weight", "key.weight", "value.weight"]
204
- for key in keys:
205
- if ".".join(key.split(".")[-2:]) in attn_keys:
206
- if checkpoint[key].ndim > 2:
207
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
208
- elif "proj_attn.weight" in key:
209
- if checkpoint[key].ndim > 2:
210
- checkpoint[key] = checkpoint[key][:, :, 0]
211
 
212
 
213
  def linear_transformer_to_conv(checkpoint):
214
- keys = list(checkpoint.keys())
215
- tf_keys = ["proj_in.weight", "proj_out.weight"]
216
- for key in keys:
217
- if ".".join(key.split(".")[-2:]) in tf_keys:
218
- if checkpoint[key].ndim == 2:
219
- checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
220
 
221
 
222
  def convert_ldm_unet_checkpoint(v2, checkpoint, config):
223
- """
224
- Takes a state dict and a config, and returns a converted checkpoint.
225
- """
226
-
227
- # extract state_dict for UNet
228
- unet_state_dict = {}
229
- unet_key = "model.diffusion_model."
230
- keys = list(checkpoint.keys())
231
- for key in keys:
232
- if key.startswith(unet_key):
233
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
234
-
235
- new_checkpoint = {}
236
-
237
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
238
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
239
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
240
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
241
-
242
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
243
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
244
-
245
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
246
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
247
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
248
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
249
-
250
- # Retrieves the keys for the input blocks only
251
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
252
- input_blocks = {
253
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
254
- for layer_id in range(num_input_blocks)
255
- }
256
-
257
- # Retrieves the keys for the middle blocks only
258
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
259
- middle_blocks = {
260
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
261
- for layer_id in range(num_middle_blocks)
262
- }
263
-
264
- # Retrieves the keys for the output blocks only
265
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
266
- output_blocks = {
267
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
268
- for layer_id in range(num_output_blocks)
269
- }
270
-
271
- for i in range(1, num_input_blocks):
272
- block_id = (i - 1) // (config["layers_per_block"] + 1)
273
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
274
-
275
- resnets = [
276
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
277
- ]
278
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
279
-
280
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
281
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
282
- f"input_blocks.{i}.0.op.weight"
283
- )
284
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
285
- f"input_blocks.{i}.0.op.bias"
286
- )
287
-
288
- paths = renew_resnet_paths(resnets)
289
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
290
- assign_to_checkpoint(
291
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
292
- )
293
-
294
- if len(attentions):
295
- paths = renew_attention_paths(attentions)
296
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
297
- assign_to_checkpoint(
298
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
299
- )
300
-
301
- resnet_0 = middle_blocks[0]
302
- attentions = middle_blocks[1]
303
- resnet_1 = middle_blocks[2]
304
-
305
- resnet_0_paths = renew_resnet_paths(resnet_0)
306
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
307
-
308
- resnet_1_paths = renew_resnet_paths(resnet_1)
309
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
310
-
311
- attentions_paths = renew_attention_paths(attentions)
312
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
313
- assign_to_checkpoint(
314
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
315
- )
316
-
317
- for i in range(num_output_blocks):
318
- block_id = i // (config["layers_per_block"] + 1)
319
- layer_in_block_id = i % (config["layers_per_block"] + 1)
320
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
321
- output_block_list = {}
322
-
323
- for layer in output_block_layers:
324
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
325
- if layer_id in output_block_list:
326
- output_block_list[layer_id].append(layer_name)
327
- else:
328
- output_block_list[layer_id] = [layer_name]
329
-
330
- if len(output_block_list) > 1:
331
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
332
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
333
-
334
- resnet_0_paths = renew_resnet_paths(resnets)
335
- paths = renew_resnet_paths(resnets)
336
-
337
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
338
- assign_to_checkpoint(
339
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
340
- )
341
-
342
- # オリジナル:
343
- # if ["conv.weight", "conv.bias"] in output_block_list.values():
344
- # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
345
-
346
- # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
347
- for l in output_block_list.values():
348
- l.sort()
349
-
350
- if ["conv.bias", "conv.weight"] in output_block_list.values():
351
- index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
352
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
353
- f"output_blocks.{i}.{index}.conv.bias"
354
- ]
355
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
356
- f"output_blocks.{i}.{index}.conv.weight"
357
- ]
358
-
359
- # Clear attentions as they have been attributed above.
360
- if len(attentions) == 2:
361
- attentions = []
362
-
363
- if len(attentions):
364
- paths = renew_attention_paths(attentions)
365
- meta_path = {
366
- "old": f"output_blocks.{i}.1",
367
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
368
- }
369
- assign_to_checkpoint(
370
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
371
- )
372
- else:
373
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
374
- for path in resnet_0_paths:
375
- old_path = ".".join(["output_blocks", str(i), path["old"]])
376
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
377
-
378
- new_checkpoint[new_path] = unet_state_dict[old_path]
379
-
380
- # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
381
- if v2:
382
- linear_transformer_to_conv(new_checkpoint)
383
 
384
- return new_checkpoint
385
 
386
 
387
  def convert_ldm_vae_checkpoint(checkpoint, config):
388
- # extract state dict for VAE
389
- vae_state_dict = {}
390
- vae_key = "first_stage_model."
391
- keys = list(checkpoint.keys())
392
- for key in keys:
393
- if key.startswith(vae_key):
394
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
395
- # if len(vae_state_dict) == 0:
396
- # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
397
- # vae_state_dict = checkpoint
398
-
399
- new_checkpoint = {}
400
-
401
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
402
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
403
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
404
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
405
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
406
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
407
-
408
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
409
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
410
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
411
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
412
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
413
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
414
-
415
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
416
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
417
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
418
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
419
-
420
- # Retrieves the keys for the encoder down blocks only
421
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
422
- down_blocks = {
423
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
424
- }
425
-
426
- # Retrieves the keys for the decoder up blocks only
427
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
428
- up_blocks = {
429
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
430
- }
431
-
432
- for i in range(num_down_blocks):
433
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
434
-
435
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
436
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
437
- f"encoder.down.{i}.downsample.conv.weight"
438
- )
439
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
440
- f"encoder.down.{i}.downsample.conv.bias"
441
- )
442
-
443
- paths = renew_vae_resnet_paths(resnets)
444
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
445
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
446
-
447
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
448
- num_mid_res_blocks = 2
449
- for i in range(1, num_mid_res_blocks + 1):
450
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
451
-
452
- paths = renew_vae_resnet_paths(resnets)
453
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
 
454
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
455
-
456
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
457
- paths = renew_vae_attention_paths(mid_attentions)
458
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
459
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
460
- conv_attn_to_linear(new_checkpoint)
461
-
462
- for i in range(num_up_blocks):
463
- block_id = num_up_blocks - 1 - i
464
- resnets = [
465
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
466
- ]
467
-
468
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
469
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
470
- f"decoder.up.{block_id}.upsample.conv.weight"
471
- ]
472
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
473
- f"decoder.up.{block_id}.upsample.conv.bias"
474
- ]
475
-
476
- paths = renew_vae_resnet_paths(resnets)
477
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
 
 
 
 
 
 
 
478
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
479
-
480
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
481
- num_mid_res_blocks = 2
482
- for i in range(1, num_mid_res_blocks + 1):
483
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
484
-
485
- paths = renew_vae_resnet_paths(resnets)
486
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
487
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
488
-
489
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
490
- paths = renew_vae_attention_paths(mid_attentions)
491
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
492
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
493
- conv_attn_to_linear(new_checkpoint)
494
- return new_checkpoint
495
 
496
 
497
  def create_unet_diffusers_config(v2):
498
- """
499
- Creates a config for the diffusers based on the config of the LDM model.
500
- """
501
- # unet_params = original_config.model.params.unet_config.params
502
-
503
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
504
-
505
- down_block_types = []
506
- resolution = 1
507
- for i in range(len(block_out_channels)):
508
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
509
- down_block_types.append(block_type)
510
- if i != len(block_out_channels) - 1:
511
- resolution *= 2
512
-
513
- up_block_types = []
514
- for i in range(len(block_out_channels)):
515
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
516
- up_block_types.append(block_type)
517
- resolution //= 2
518
-
519
- config = dict(
520
- sample_size=UNET_PARAMS_IMAGE_SIZE,
521
- in_channels=UNET_PARAMS_IN_CHANNELS,
522
- out_channels=UNET_PARAMS_OUT_CHANNELS,
523
- down_block_types=tuple(down_block_types),
524
- up_block_types=tuple(up_block_types),
525
- block_out_channels=tuple(block_out_channels),
526
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
527
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
528
- attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
529
- )
530
-
531
- return config
532
 
533
 
534
  def create_vae_diffusers_config():
535
- """
536
- Creates a config for the diffusers based on the config of the LDM model.
537
- """
538
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
539
- # _ = original_config.model.params.first_stage_config.params.embed_dim
540
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
541
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
542
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
543
-
544
- config = dict(
545
- sample_size=VAE_PARAMS_RESOLUTION,
546
- in_channels=VAE_PARAMS_IN_CHANNELS,
547
- out_channels=VAE_PARAMS_OUT_CH,
548
- down_block_types=tuple(down_block_types),
549
- up_block_types=tuple(up_block_types),
550
- block_out_channels=tuple(block_out_channels),
551
- latent_channels=VAE_PARAMS_Z_CHANNELS,
552
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
553
- )
554
- return config
555
 
556
 
557
  def convert_ldm_clip_checkpoint_v1(checkpoint):
558
- keys = list(checkpoint.keys())
559
- text_model_dict = {}
560
- for key in keys:
561
- if key.startswith("cond_stage_model.transformer"):
562
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
563
- return text_model_dict
564
 
565
 
566
  def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
567
- # 嫌になるくらい違うぞ!
568
- def convert_key(key):
569
- if not key.startswith("cond_stage_model"):
570
- return None
571
-
572
- # common conversion
573
- key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
574
- key = key.replace("cond_stage_model.model.", "text_model.")
575
-
576
- if "resblocks" in key:
577
- # resblocks conversion
578
- key = key.replace(".resblocks.", ".layers.")
579
- if ".ln_" in key:
580
- key = key.replace(".ln_", ".layer_norm")
581
- elif ".mlp." in key:
582
- key = key.replace(".c_fc.", ".fc1.")
583
- key = key.replace(".c_proj.", ".fc2.")
584
- elif '.attn.out_proj' in key:
585
- key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
586
- elif '.attn.in_proj' in key:
587
- key = None # 特殊なので後で処理する
588
- else:
589
- raise ValueError(f"unexpected key in SD: {key}")
590
- elif '.positional_embedding' in key:
591
- key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
592
- elif '.text_projection' in key:
593
- key = None # 使われない???
594
- elif '.logit_scale' in key:
595
- key = None # 使われない???
596
- elif '.token_embedding' in key:
597
- key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
598
- elif '.ln_final' in key:
599
- key = key.replace(".ln_final", ".final_layer_norm")
600
- return key
601
-
602
- keys = list(checkpoint.keys())
603
- new_sd = {}
604
- for key in keys:
605
- # remove resblocks 23
606
- if '.resblocks.23.' in key:
607
- continue
608
- new_key = convert_key(key)
609
- if new_key is None:
610
- continue
611
- new_sd[new_key] = checkpoint[key]
612
-
613
- # attnの変換
614
- for key in keys:
615
- if '.resblocks.23.' in key:
616
- continue
617
- if '.resblocks' in key and '.attn.in_proj_' in key:
618
- # 三つに分割
619
- values = torch.chunk(checkpoint[key], 3)
620
-
621
- key_suffix = ".weight" if "weight" in key else ".bias"
622
- key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
623
- key_pfx = key_pfx.replace("_weight", "")
624
- key_pfx = key_pfx.replace("_bias", "")
625
- key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
626
- new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
627
- new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
628
- new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
629
-
630
- # rename or add position_ids
631
- ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
632
- if ANOTHER_POSITION_IDS_KEY in new_sd:
633
- # waifu diffusion v1.4
634
- position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
635
- del new_sd[ANOTHER_POSITION_IDS_KEY]
636
- else:
637
- position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
638
-
639
- new_sd["text_model.embeddings.position_ids"] = position_ids
640
- return new_sd
 
641
 
642
  # endregion
643
 
@@ -645,540 +620,546 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
645
  # region Diffusers->StableDiffusion の変換コード
646
  # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
647
 
 
648
  def conv_transformer_to_linear(checkpoint):
649
- keys = list(checkpoint.keys())
650
- tf_keys = ["proj_in.weight", "proj_out.weight"]
651
- for key in keys:
652
- if ".".join(key.split(".")[-2:]) in tf_keys:
653
- if checkpoint[key].ndim > 2:
654
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
655
 
656
 
657
  def convert_unet_state_dict_to_sd(v2, unet_state_dict):
658
- unet_conversion_map = [
659
- # (stable-diffusion, HF Diffusers)
660
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
661
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
662
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
663
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
664
- ("input_blocks.0.0.weight", "conv_in.weight"),
665
- ("input_blocks.0.0.bias", "conv_in.bias"),
666
- ("out.0.weight", "conv_norm_out.weight"),
667
- ("out.0.bias", "conv_norm_out.bias"),
668
- ("out.2.weight", "conv_out.weight"),
669
- ("out.2.bias", "conv_out.bias"),
670
- ]
671
-
672
- unet_conversion_map_resnet = [
673
- # (stable-diffusion, HF Diffusers)
674
- ("in_layers.0", "norm1"),
675
- ("in_layers.2", "conv1"),
676
- ("out_layers.0", "norm2"),
677
- ("out_layers.3", "conv2"),
678
- ("emb_layers.1", "time_emb_proj"),
679
- ("skip_connection", "conv_shortcut"),
680
- ]
681
-
682
- unet_conversion_map_layer = []
683
- for i in range(4):
684
- # loop over downblocks/upblocks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
  for j in range(2):
687
- # loop over resnets/attentions for downblocks
688
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
689
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
690
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
691
-
692
- if i < 3:
693
- # no attention layers in down_blocks.3
694
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
695
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
696
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
697
-
698
- for j in range(3):
699
- # loop over resnets/attentions for upblocks
700
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
701
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
702
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
703
-
704
- if i > 0:
705
- # no attention layers in up_blocks.0
706
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
707
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
708
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
709
-
710
- if i < 3:
711
- # no downsample in down_blocks.3
712
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
713
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
714
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
715
-
716
- # no upsample in up_blocks.3
717
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
718
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
719
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
720
-
721
- hf_mid_atn_prefix = "mid_block.attentions.0."
722
- sd_mid_atn_prefix = "middle_block.1."
723
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
724
-
725
- for j in range(2):
726
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
727
- sd_mid_res_prefix = f"middle_block.{2*j}."
728
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
729
-
730
- # buyer beware: this is a *brittle* function,
731
- # and correct output requires that all of these pieces interact in
732
- # the exact order in which I have arranged them.
733
- mapping = {k: k for k in unet_state_dict.keys()}
734
- for sd_name, hf_name in unet_conversion_map:
735
- mapping[hf_name] = sd_name
736
- for k, v in mapping.items():
737
- if "resnets" in k:
738
- for sd_part, hf_part in unet_conversion_map_resnet:
739
- v = v.replace(hf_part, sd_part)
740
- mapping[k] = v
741
- for k, v in mapping.items():
742
- for sd_part, hf_part in unet_conversion_map_layer:
743
- v = v.replace(hf_part, sd_part)
744
- mapping[k] = v
745
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
746
-
747
- if v2:
748
- conv_transformer_to_linear(new_state_dict)
749
-
750
- return new_state_dict
751
 
752
 
753
  # ================#
754
  # VAE Conversion #
755
  # ================#
756
 
 
757
  def reshape_weight_for_sd(w):
758
  # convert HF linear weights to SD conv2d weights
759
- return w.reshape(*w.shape, 1, 1)
760
 
761
 
762
  def convert_vae_state_dict(vae_state_dict):
763
- vae_conversion_map = [
764
- # (stable-diffusion, HF Diffusers)
765
- ("nin_shortcut", "conv_shortcut"),
766
- ("norm_out", "conv_norm_out"),
767
- ("mid.attn_1.", "mid_block.attentions.0."),
768
- ]
769
-
770
- for i in range(4):
771
- # down_blocks have two resnets
772
- for j in range(2):
773
- hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
774
- sd_down_prefix = f"encoder.down.{i}.block.{j}."
775
- vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
776
-
777
- if i < 3:
778
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
779
- sd_downsample_prefix = f"down.{i}.downsample."
780
- vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
781
-
782
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
783
- sd_upsample_prefix = f"up.{3-i}.upsample."
784
- vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
785
-
786
- # up_blocks have three resnets
787
- # also, up blocks in hf are numbered in reverse from sd
788
- for j in range(3):
789
- hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
790
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
791
- vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
792
-
793
- # this part accounts for mid blocks in both the encoder and the decoder
794
- for i in range(2):
795
- hf_mid_res_prefix = f"mid_block.resnets.{i}."
796
- sd_mid_res_prefix = f"mid.block_{i+1}."
797
- vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
798
-
799
- vae_conversion_map_attn = [
800
- # (stable-diffusion, HF Diffusers)
801
- ("norm.", "group_norm."),
802
- ("q.", "query."),
803
- ("k.", "key."),
804
- ("v.", "value."),
805
- ("proj_out.", "proj_attn."),
806
- ]
807
-
808
- mapping = {k: k for k in vae_state_dict.keys()}
809
- for k, v in mapping.items():
810
- for sd_part, hf_part in vae_conversion_map:
811
- v = v.replace(hf_part, sd_part)
812
- mapping[k] = v
813
- for k, v in mapping.items():
814
- if "attentions" in k:
815
- for sd_part, hf_part in vae_conversion_map_attn:
816
- v = v.replace(hf_part, sd_part)
817
- mapping[k] = v
818
- new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
819
- weights_to_convert = ["q", "k", "v", "proj_out"]
820
- for k, v in new_state_dict.items():
821
- for weight_name in weights_to_convert:
822
- if f"mid.attn_1.{weight_name}.weight" in k:
823
- # print(f"Reshaping {k} for SD format")
824
- new_state_dict[k] = reshape_weight_for_sd(v)
825
-
826
- return new_state_dict
827
 
828
 
829
  # endregion
830
 
831
  # region 自作のモデル読み書きなど
832
 
 
833
  def is_safetensors(path):
834
- return os.path.splitext(path)[1].lower() == '.safetensors'
835
-
836
-
837
- def load_checkpoint_with_text_encoder_conversion(ckpt_path):
838
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
839
- TEXT_ENCODER_KEY_REPLACEMENTS = [
840
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
841
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
842
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
843
- ]
844
-
845
- if is_safetensors(ckpt_path):
846
- checkpoint = None
847
- state_dict = load_file(ckpt_path, "cpu")
848
- else:
849
- checkpoint = torch.load(ckpt_path, map_location="cpu")
850
- if "state_dict" in checkpoint:
851
- state_dict = checkpoint["state_dict"]
852
  else:
853
- state_dict = checkpoint
854
- checkpoint = None
 
 
 
 
855
 
856
- key_reps = []
857
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
858
- for key in state_dict.keys():
859
- if key.startswith(rep_from):
860
- new_key = rep_to + key[len(rep_from):]
861
- key_reps.append((key, new_key))
862
 
863
- for key, new_key in key_reps:
864
- state_dict[new_key] = state_dict[key]
865
- del state_dict[key]
866
 
867
- return checkpoint, state_dict
868
 
869
 
870
  # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
871
- def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
872
- _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
873
- if dtype is not None:
874
- for k, v in state_dict.items():
875
- if type(v) is torch.Tensor:
876
- state_dict[k] = v.to(dtype)
877
-
878
- # Convert the UNet2DConditionModel model.
879
- unet_config = create_unet_diffusers_config(v2)
880
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
881
-
882
- unet = UNet2DConditionModel(**unet_config)
883
- info = unet.load_state_dict(converted_unet_checkpoint)
884
- print("loading u-net:", info)
885
-
886
- # Convert the VAE model.
887
- vae_config = create_vae_diffusers_config()
888
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
889
-
890
- vae = AutoencoderKL(**vae_config)
891
- info = vae.load_state_dict(converted_vae_checkpoint)
892
- print("loading vae:", info)
893
-
894
- # convert text_model
895
- if v2:
896
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
897
- cfg = CLIPTextConfig(
898
- vocab_size=49408,
899
- hidden_size=1024,
900
- intermediate_size=4096,
901
- num_hidden_layers=23,
902
- num_attention_heads=16,
903
- max_position_embeddings=77,
904
- hidden_act="gelu",
905
- layer_norm_eps=1e-05,
906
- dropout=0.0,
907
- attention_dropout=0.0,
908
- initializer_range=0.02,
909
- initializer_factor=1.0,
910
- pad_token_id=1,
911
- bos_token_id=0,
912
- eos_token_id=2,
913
- model_type="clip_text_model",
914
- projection_dim=512,
915
- torch_dtype="float32",
916
- transformers_version="4.25.0.dev0",
917
- )
918
- text_model = CLIPTextModel._from_config(cfg)
919
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
920
- else:
921
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
922
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
923
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
924
- print("loading text encoder:", info)
925
 
926
- return text_model, vae, unet
 
 
927
 
 
 
 
928
 
929
- def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
930
- def convert_key(key):
931
- # position_idsの除去
932
- if ".position_ids" in key:
933
- return None
934
-
935
- # common
936
- key = key.replace("text_model.encoder.", "transformer.")
937
- key = key.replace("text_model.", "")
938
- if "layers" in key:
939
- # resblocks conversion
940
- key = key.replace(".layers.", ".resblocks.")
941
- if ".layer_norm" in key:
942
- key = key.replace(".layer_norm", ".ln_")
943
- elif ".mlp." in key:
944
- key = key.replace(".fc1.", ".c_fc.")
945
- key = key.replace(".fc2.", ".c_proj.")
946
- elif '.self_attn.out_proj' in key:
947
- key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
948
- elif '.self_attn.' in key:
949
- key = None # 特殊なので後で処理する
950
- else:
951
- raise ValueError(f"unexpected key in DiffUsers model: {key}")
952
- elif '.position_embedding' in key:
953
- key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
954
- elif '.token_embedding' in key:
955
- key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
956
- elif 'final_layer_norm' in key:
957
- key = key.replace("final_layer_norm", "ln_final")
958
- return key
959
-
960
- keys = list(checkpoint.keys())
961
- new_sd = {}
962
- for key in keys:
963
- new_key = convert_key(key)
964
- if new_key is None:
965
- continue
966
- new_sd[new_key] = checkpoint[key]
967
-
968
- # attnの変換
969
- for key in keys:
970
- if 'layers' in key and 'q_proj' in key:
971
- # 三つを結合
972
- key_q = key
973
- key_k = key.replace("q_proj", "k_proj")
974
- key_v = key.replace("q_proj", "v_proj")
975
-
976
- value_q = checkpoint[key_q]
977
- value_k = checkpoint[key_k]
978
- value_v = checkpoint[key_v]
979
- value = torch.cat([value_q, value_k, value_v])
980
-
981
- new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
982
- new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
983
- new_sd[new_key] = value
984
-
985
- # 最後の層などを捏造するか
986
- if make_dummy_weights:
987
- print("make dummy weights for resblock.23, text_projection and logit scale.")
988
- keys = list(new_sd.keys())
989
- for key in keys:
990
- if key.startswith("transformer.resblocks.22."):
991
- new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
992
 
993
- # Diffusersに含まれない重みを作っておく
994
- new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
995
- new_sd['logit_scale'] = torch.tensor(1)
996
 
997
- return new_sd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998
 
 
 
 
999
 
1000
- def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
1001
- if ckpt_path is not None:
1002
- # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1003
- checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1004
- if checkpoint is None: # safetensors または state_dictのckpt
1005
- checkpoint = {}
1006
- strict = False
1007
- else:
1008
- strict = True
1009
- if "state_dict" in state_dict:
1010
- del state_dict["state_dict"]
1011
- else:
1012
- # 新しく作る
1013
- assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1014
- checkpoint = {}
1015
- state_dict = {}
1016
- strict = False
1017
-
1018
- def update_sd(prefix, sd):
1019
- for k, v in sd.items():
1020
- key = prefix + k
1021
- assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1022
- if save_dtype is not None:
1023
- v = v.detach().clone().to("cpu").to(save_dtype)
1024
- state_dict[key] = v
1025
-
1026
- # Convert the UNet model
1027
- unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1028
- update_sd("model.diffusion_model.", unet_state_dict)
1029
-
1030
- # Convert the text encoder model
1031
- if v2:
1032
- make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1033
- text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1034
- update_sd("cond_stage_model.model.", text_enc_dict)
1035
- else:
1036
- text_enc_dict = text_encoder.state_dict()
1037
- update_sd("cond_stage_model.transformer.", text_enc_dict)
1038
-
1039
- # Convert the VAE
1040
- if vae is not None:
1041
- vae_dict = convert_vae_state_dict(vae.state_dict())
1042
- update_sd("first_stage_model.", vae_dict)
1043
-
1044
- # Put together new checkpoint
1045
- key_count = len(state_dict.keys())
1046
- new_ckpt = {'state_dict': state_dict}
1047
-
1048
- if 'epoch' in checkpoint:
1049
- epochs += checkpoint['epoch']
1050
- if 'global_step' in checkpoint:
1051
- steps += checkpoint['global_step']
1052
-
1053
- new_ckpt['epoch'] = epochs
1054
- new_ckpt['global_step'] = steps
1055
-
1056
- if is_safetensors(output_file):
1057
- # TODO Tensor以外のdictの値を削除したほうがいいか
1058
- save_file(state_dict, output_file)
1059
- else:
1060
- torch.save(new_ckpt, output_file)
1061
-
1062
- return key_count
1063
 
 
1064
 
1065
- def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1066
- if pretrained_model_name_or_path is None:
1067
- # load default settings for v1/v2
1068
- if v2:
1069
- pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1070
- else:
1071
- pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1072
-
1073
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1074
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1075
- if vae is None:
1076
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1077
-
1078
- pipeline = StableDiffusionPipeline(
1079
- unet=unet,
1080
- text_encoder=text_encoder,
1081
- vae=vae,
1082
- scheduler=scheduler,
1083
- tokenizer=tokenizer,
1084
- safety_checker=None,
1085
- feature_extractor=None,
1086
- requires_safety_checker=None,
1087
- )
1088
- pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1090
 
1091
- VAE_PREFIX = "first_stage_model."
 
 
 
 
 
 
1092
 
 
 
 
 
1093
 
1094
- def load_vae(vae_id, dtype):
1095
- print(f"load VAE: {vae_id}")
1096
- if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1097
- # Diffusers local/remote
1098
- try:
1099
- vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1100
- except EnvironmentError as e:
1101
- print(f"exception occurs in loading vae: {e}")
1102
- print("retry with subfolder='vae'")
1103
- vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1104
- return vae
1105
 
1106
- # local
1107
- vae_config = create_vae_diffusers_config()
1108
-
1109
- if vae_id.endswith(".bin"):
1110
- # SD 1.5 VAE on Huggingface
1111
- converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1112
- else:
1113
- # StableDiffusion
1114
- vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
1115
- else torch.load(vae_id, map_location="cpu"))
1116
- vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
1117
-
1118
- # vae only or full model
1119
- full_model = False
1120
- for vae_key in vae_sd:
1121
- if vae_key.startswith(VAE_PREFIX):
1122
- full_model = True
1123
- break
1124
- if not full_model:
1125
- sd = {}
1126
- for key, value in vae_sd.items():
1127
- sd[VAE_PREFIX + key] = value
1128
- vae_sd = sd
1129
- del sd
1130
 
1131
- # Convert the VAE model.
1132
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
 
1133
 
1134
- vae = AutoencoderKL(**vae_config)
1135
- vae.load_state_dict(converted_vae_checkpoint)
1136
- return vae
1137
 
1138
- # endregion
1139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1140
 
1141
- def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1142
- max_width, max_height = max_reso
1143
- max_area = (max_width // divisible) * (max_height // divisible)
 
1144
 
1145
- resos = set()
 
 
1146
 
1147
- size = int(math.sqrt(max_area)) * divisible
1148
- resos.add((size, size))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1149
 
1150
- size = min_size
1151
- while size <= max_size:
1152
- width = size
1153
- height = min(max_size, (max_area // (width // divisible)) * divisible)
1154
- resos.add((width, height))
1155
- resos.add((height, width))
1156
 
1157
- # # make additional resos
1158
- # if width >= height and width - divisible >= min_size:
1159
- # resos.add((width - divisible, height))
1160
- # resos.add((height, width - divisible))
1161
- # if height >= width and height - divisible >= min_size:
1162
- # resos.add((width, height - divisible))
1163
- # resos.add((height - divisible, width))
1164
 
1165
- size += divisible
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1166
 
1167
- resos = list(resos)
1168
- resos.sort()
1169
 
1170
- aspect_ratios = [w / h for w, h in resos]
1171
- return resos, aspect_ratios
1172
 
1173
 
1174
- if __name__ == '__main__':
1175
- resos, aspect_ratios = make_bucket_resolutions((512, 768))
1176
- print(len(resos))
1177
- print(resos)
1178
- print(aspect_ratios)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1179
 
1180
- ars = set()
1181
- for ar in aspect_ratios:
1182
- if ar in ars:
1183
- print("error! duplicate ar:", ar)
1184
- ars.add(ar)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # v1: split from train_db_fixed.py.
2
  # v2: support safetensors
3
 
4
  import math
5
  import os
6
  import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
8
  from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
9
  from safetensors.torch import load_file, save_file
10
 
 
16
  UNET_PARAMS_MODEL_CHANNELS = 320
17
  UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
18
  UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
19
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
20
  UNET_PARAMS_IN_CHANNELS = 4
21
  UNET_PARAMS_OUT_CHANNELS = 4
22
  UNET_PARAMS_NUM_RES_BLOCKS = 2
 
45
 
46
 
47
  def shave_segments(path, n_shave_prefix_segments=1):
48
+ """
49
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
50
+ """
51
+ if n_shave_prefix_segments >= 0:
52
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
53
+ else:
54
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
55
 
56
 
57
  def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
 
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
 
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
 
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
 
74
+ mapping.append({"old": old_item, "new": new_item})
75
 
76
+ return mapping
77
 
78
 
79
  def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
 
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
 
90
+ mapping.append({"old": old_item, "new": new_item})
91
 
92
+ return mapping
93
 
94
 
95
  def renew_attention_paths(old_list, n_shave_prefix_segments=0):
96
+ """
97
+ Updates paths inside attentions to the new naming scheme (local renaming)
98
+ """
99
+ mapping = []
100
+ for old_item in old_list:
101
+ new_item = old_item
102
 
103
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
104
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
105
 
106
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
107
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
108
 
109
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
110
 
111
+ mapping.append({"old": old_item, "new": new_item})
112
 
113
+ return mapping
114
 
115
 
116
  def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
117
+ """
118
+ Updates paths inside attentions to the new naming scheme (local renaming)
119
+ """
120
+ mapping = []
121
+ for old_item in old_list:
122
+ new_item = old_item
123
 
124
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
125
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
126
 
127
+ new_item = new_item.replace("q.weight", "query.weight")
128
+ new_item = new_item.replace("q.bias", "query.bias")
129
 
130
+ new_item = new_item.replace("k.weight", "key.weight")
131
+ new_item = new_item.replace("k.bias", "key.bias")
132
 
133
+ new_item = new_item.replace("v.weight", "value.weight")
134
+ new_item = new_item.replace("v.bias", "value.bias")
135
 
136
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
137
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
138
 
139
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
140
 
141
+ mapping.append({"old": old_item, "new": new_item})
142
 
143
+ return mapping
144
 
145
 
146
  def assign_to_checkpoint(
147
  paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
148
  ):
149
+ """
150
+ This does the final conversion step: take locally converted weights and apply a global renaming
151
+ to them. It splits attention layers, and takes into account additional replacements
152
+ that may arise.
153
 
154
+ Assigns the weights to the new checkpoint.
155
+ """
156
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
157
 
158
+ # Splits the attention layers into three variables.
159
+ if attention_paths_to_split is not None:
160
+ for path, path_map in attention_paths_to_split.items():
161
+ old_tensor = old_checkpoint[path]
162
+ channels = old_tensor.shape[0] // 3
163
 
164
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
165
 
166
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
167
 
168
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
169
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
170
 
171
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
172
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
173
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
174
 
175
+ for path in paths:
176
+ new_path = path["new"]
177
 
178
+ # These have already been assigned
179
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
180
+ continue
181
 
182
+ # Global renaming happens here
183
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
184
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
185
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
186
 
187
+ if additional_replacements is not None:
188
+ for replacement in additional_replacements:
189
+ new_path = new_path.replace(replacement["old"], replacement["new"])
190
 
191
+ # proj_attn.weight has to be converted from conv 1D to linear
192
+ if "proj_attn.weight" in new_path:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
194
+ else:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]]
196
 
197
 
198
  def conv_attn_to_linear(checkpoint):
199
+ keys = list(checkpoint.keys())
200
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
201
+ for key in keys:
202
+ if ".".join(key.split(".")[-2:]) in attn_keys:
203
+ if checkpoint[key].ndim > 2:
204
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
205
+ elif "proj_attn.weight" in key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0]
208
 
209
 
210
  def linear_transformer_to_conv(checkpoint):
211
+ keys = list(checkpoint.keys())
212
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
213
+ for key in keys:
214
+ if ".".join(key.split(".")[-2:]) in tf_keys:
215
+ if checkpoint[key].ndim == 2:
216
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
217
 
218
 
219
  def convert_ldm_unet_checkpoint(v2, checkpoint, config):
220
+ """
221
+ Takes a state dict and a config, and returns a converted checkpoint.
222
+ """
223
+
224
+ # extract state_dict for UNet
225
+ unet_state_dict = {}
226
+ unet_key = "model.diffusion_model."
227
+ keys = list(checkpoint.keys())
228
+ for key in keys:
229
+ if key.startswith(unet_key):
230
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
231
+
232
+ new_checkpoint = {}
233
+
234
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
235
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
236
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
237
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
238
+
239
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
240
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
251
+ }
252
+
253
+ # Retrieves the keys for the middle blocks only
254
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
255
+ middle_blocks = {
256
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
257
+ }
258
+
259
+ # Retrieves the keys for the output blocks only
260
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
261
+ output_blocks = {
262
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
263
+ }
264
+
265
+ for i in range(1, num_input_blocks):
266
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
267
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
268
+
269
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
270
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
271
+
272
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
273
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
274
+ f"input_blocks.{i}.0.op.weight"
275
+ )
276
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
277
+
278
+ paths = renew_resnet_paths(resnets)
279
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
280
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
281
+
282
+ if len(attentions):
283
+ paths = renew_attention_paths(attentions)
284
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
285
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
286
+
287
+ resnet_0 = middle_blocks[0]
288
+ attentions = middle_blocks[1]
289
+ resnet_1 = middle_blocks[2]
290
+
291
+ resnet_0_paths = renew_resnet_paths(resnet_0)
292
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
293
+
294
+ resnet_1_paths = renew_resnet_paths(resnet_1)
295
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
296
+
297
+ attentions_paths = renew_attention_paths(attentions)
298
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
299
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
300
+
301
+ for i in range(num_output_blocks):
302
+ block_id = i // (config["layers_per_block"] + 1)
303
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
304
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
305
+ output_block_list = {}
306
+
307
+ for layer in output_block_layers:
308
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
309
+ if layer_id in output_block_list:
310
+ output_block_list[layer_id].append(layer_name)
311
+ else:
312
+ output_block_list[layer_id] = [layer_name]
313
+
314
+ if len(output_block_list) > 1:
315
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
316
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
317
+
318
+ resnet_0_paths = renew_resnet_paths(resnets)
319
+ paths = renew_resnet_paths(resnets)
320
+
321
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
322
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
323
+
324
+ # オリジナル:
325
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
326
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
327
+
328
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
329
+ for l in output_block_list.values():
330
+ l.sort()
331
+
332
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
333
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
334
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
335
+ f"output_blocks.{i}.{index}.conv.bias"
336
+ ]
337
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
338
+ f"output_blocks.{i}.{index}.conv.weight"
339
+ ]
340
+
341
+ # Clear attentions as they have been attributed above.
342
+ if len(attentions) == 2:
343
+ attentions = []
344
+
345
+ if len(attentions):
346
+ paths = renew_attention_paths(attentions)
347
+ meta_path = {
348
+ "old": f"output_blocks.{i}.1",
349
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
350
+ }
351
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
352
+ else:
353
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
354
+ for path in resnet_0_paths:
355
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
356
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
357
+
358
+ new_checkpoint[new_path] = unet_state_dict[old_path]
359
+
360
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
361
+ if v2:
362
+ linear_transformer_to_conv(new_checkpoint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ return new_checkpoint
365
 
366
 
367
  def convert_ldm_vae_checkpoint(checkpoint, config):
368
+ # extract state dict for VAE
369
+ vae_state_dict = {}
370
+ vae_key = "first_stage_model."
371
+ keys = list(checkpoint.keys())
372
+ for key in keys:
373
+ if key.startswith(vae_key):
374
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
375
+ # if len(vae_state_dict) == 0:
376
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
377
+ # vae_state_dict = checkpoint
378
+
379
+ new_checkpoint = {}
380
+
381
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
382
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
383
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
384
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
385
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
386
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
387
+
388
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
389
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
390
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
391
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
392
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
393
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
394
+
395
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
396
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
397
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
398
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
399
+
400
+ # Retrieves the keys for the encoder down blocks only
401
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
402
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
403
+
404
+ # Retrieves the keys for the decoder up blocks only
405
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
406
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
407
+
408
+ for i in range(num_down_blocks):
409
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
410
+
411
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
412
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
413
+ f"encoder.down.{i}.downsample.conv.weight"
414
+ )
415
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
416
+ f"encoder.down.{i}.downsample.conv.bias"
417
+ )
418
+
419
+ paths = renew_vae_resnet_paths(resnets)
420
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
421
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
422
+
423
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
424
+ num_mid_res_blocks = 2
425
+ for i in range(1, num_mid_res_blocks + 1):
426
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
427
+
428
+ paths = renew_vae_resnet_paths(resnets)
429
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
430
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
431
+
432
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
433
+ paths = renew_vae_attention_paths(mid_attentions)
434
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
435
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
436
+ conv_attn_to_linear(new_checkpoint)
437
+
438
+ for i in range(num_up_blocks):
439
+ block_id = num_up_blocks - 1 - i
440
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
441
+
442
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
443
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
444
+ f"decoder.up.{block_id}.upsample.conv.weight"
445
+ ]
446
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
447
+ f"decoder.up.{block_id}.upsample.conv.bias"
448
+ ]
449
+
450
+ paths = renew_vae_resnet_paths(resnets)
451
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
452
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
453
+
454
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
455
+ num_mid_res_blocks = 2
456
+ for i in range(1, num_mid_res_blocks + 1):
457
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
458
+
459
+ paths = renew_vae_resnet_paths(resnets)
460
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
461
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
462
+
463
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
464
+ paths = renew_vae_attention_paths(mid_attentions)
465
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
466
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
467
+ conv_attn_to_linear(new_checkpoint)
468
+ return new_checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
 
471
  def create_unet_diffusers_config(v2):
472
+ """
473
+ Creates a config for the diffusers based on the config of the LDM model.
474
+ """
475
+ # unet_params = original_config.model.params.unet_config.params
476
+
477
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
478
+
479
+ down_block_types = []
480
+ resolution = 1
481
+ for i in range(len(block_out_channels)):
482
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
483
+ down_block_types.append(block_type)
484
+ if i != len(block_out_channels) - 1:
485
+ resolution *= 2
486
+
487
+ up_block_types = []
488
+ for i in range(len(block_out_channels)):
489
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
490
+ up_block_types.append(block_type)
491
+ resolution //= 2
492
+
493
+ config = dict(
494
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
495
+ in_channels=UNET_PARAMS_IN_CHANNELS,
496
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
497
+ down_block_types=tuple(down_block_types),
498
+ up_block_types=tuple(up_block_types),
499
+ block_out_channels=tuple(block_out_channels),
500
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
501
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
502
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
503
+ )
504
+
505
+ return config
506
 
507
 
508
  def create_vae_diffusers_config():
509
+ """
510
+ Creates a config for the diffusers based on the config of the LDM model.
511
+ """
512
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
513
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
514
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
515
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
516
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
517
+
518
+ config = dict(
519
+ sample_size=VAE_PARAMS_RESOLUTION,
520
+ in_channels=VAE_PARAMS_IN_CHANNELS,
521
+ out_channels=VAE_PARAMS_OUT_CH,
522
+ down_block_types=tuple(down_block_types),
523
+ up_block_types=tuple(up_block_types),
524
+ block_out_channels=tuple(block_out_channels),
525
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
526
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
527
+ )
528
+ return config
529
 
530
 
531
  def convert_ldm_clip_checkpoint_v1(checkpoint):
532
+ keys = list(checkpoint.keys())
533
+ text_model_dict = {}
534
+ for key in keys:
535
+ if key.startswith("cond_stage_model.transformer"):
536
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
537
+ return text_model_dict
538
 
539
 
540
  def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
541
+ # 嫌になるくらい違うぞ!
542
+ def convert_key(key):
543
+ if not key.startswith("cond_stage_model"):
544
+ return None
545
+
546
+ # common conversion
547
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
548
+ key = key.replace("cond_stage_model.model.", "text_model.")
549
+
550
+ if "resblocks" in key:
551
+ # resblocks conversion
552
+ key = key.replace(".resblocks.", ".layers.")
553
+ if ".ln_" in key:
554
+ key = key.replace(".ln_", ".layer_norm")
555
+ elif ".mlp." in key:
556
+ key = key.replace(".c_fc.", ".fc1.")
557
+ key = key.replace(".c_proj.", ".fc2.")
558
+ elif ".attn.out_proj" in key:
559
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
560
+ elif ".attn.in_proj" in key:
561
+ key = None # 特殊なので後で処理する
562
+ else:
563
+ raise ValueError(f"unexpected key in SD: {key}")
564
+ elif ".positional_embedding" in key:
565
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
566
+ elif ".text_projection" in key:
567
+ key = None # 使われない???
568
+ elif ".logit_scale" in key:
569
+ key = None # 使われない???
570
+ elif ".token_embedding" in key:
571
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
572
+ elif ".ln_final" in key:
573
+ key = key.replace(".ln_final", ".final_layer_norm")
574
+ return key
575
+
576
+ keys = list(checkpoint.keys())
577
+ new_sd = {}
578
+ for key in keys:
579
+ # remove resblocks 23
580
+ if ".resblocks.23." in key:
581
+ continue
582
+ new_key = convert_key(key)
583
+ if new_key is None:
584
+ continue
585
+ new_sd[new_key] = checkpoint[key]
586
+
587
+ # attnの変換
588
+ for key in keys:
589
+ if ".resblocks.23." in key:
590
+ continue
591
+ if ".resblocks" in key and ".attn.in_proj_" in key:
592
+ # 三つに分割
593
+ values = torch.chunk(checkpoint[key], 3)
594
+
595
+ key_suffix = ".weight" if "weight" in key else ".bias"
596
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
597
+ key_pfx = key_pfx.replace("_weight", "")
598
+ key_pfx = key_pfx.replace("_bias", "")
599
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
600
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
601
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
602
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
603
+
604
+ # rename or add position_ids
605
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
606
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
607
+ # waifu diffusion v1.4
608
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
609
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
610
+ else:
611
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
612
+
613
+ new_sd["text_model.embeddings.position_ids"] = position_ids
614
+ return new_sd
615
+
616
 
617
  # endregion
618
 
 
620
  # region Diffusers->StableDiffusion の変換コード
621
  # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
622
 
623
+
624
  def conv_transformer_to_linear(checkpoint):
625
+ keys = list(checkpoint.keys())
626
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
627
+ for key in keys:
628
+ if ".".join(key.split(".")[-2:]) in tf_keys:
629
+ if checkpoint[key].ndim > 2:
630
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
631
 
632
 
633
  def convert_unet_state_dict_to_sd(v2, unet_state_dict):
634
+ unet_conversion_map = [
635
+ # (stable-diffusion, HF Diffusers)
636
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
637
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
638
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
639
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
640
+ ("input_blocks.0.0.weight", "conv_in.weight"),
641
+ ("input_blocks.0.0.bias", "conv_in.bias"),
642
+ ("out.0.weight", "conv_norm_out.weight"),
643
+ ("out.0.bias", "conv_norm_out.bias"),
644
+ ("out.2.weight", "conv_out.weight"),
645
+ ("out.2.bias", "conv_out.bias"),
646
+ ]
647
+
648
+ unet_conversion_map_resnet = [
649
+ # (stable-diffusion, HF Diffusers)
650
+ ("in_layers.0", "norm1"),
651
+ ("in_layers.2", "conv1"),
652
+ ("out_layers.0", "norm2"),
653
+ ("out_layers.3", "conv2"),
654
+ ("emb_layers.1", "time_emb_proj"),
655
+ ("skip_connection", "conv_shortcut"),
656
+ ]
657
+
658
+ unet_conversion_map_layer = []
659
+ for i in range(4):
660
+ # loop over downblocks/upblocks
661
+
662
+ for j in range(2):
663
+ # loop over resnets/attentions for downblocks
664
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
665
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
666
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
667
+
668
+ if i < 3:
669
+ # no attention layers in down_blocks.3
670
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
671
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
672
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
673
+
674
+ for j in range(3):
675
+ # loop over resnets/attentions for upblocks
676
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
677
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
678
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
679
+
680
+ if i > 0:
681
+ # no attention layers in up_blocks.0
682
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
683
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
684
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
685
+
686
+ if i < 3:
687
+ # no downsample in down_blocks.3
688
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
689
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
690
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
691
+
692
+ # no upsample in up_blocks.3
693
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
694
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
695
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
696
+
697
+ hf_mid_atn_prefix = "mid_block.attentions.0."
698
+ sd_mid_atn_prefix = "middle_block.1."
699
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
700
 
701
  for j in range(2):
702
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
703
+ sd_mid_res_prefix = f"middle_block.{2*j}."
704
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
705
+
706
+ # buyer beware: this is a *brittle* function,
707
+ # and correct output requires that all of these pieces interact in
708
+ # the exact order in which I have arranged them.
709
+ mapping = {k: k for k in unet_state_dict.keys()}
710
+ for sd_name, hf_name in unet_conversion_map:
711
+ mapping[hf_name] = sd_name
712
+ for k, v in mapping.items():
713
+ if "resnets" in k:
714
+ for sd_part, hf_part in unet_conversion_map_resnet:
715
+ v = v.replace(hf_part, sd_part)
716
+ mapping[k] = v
717
+ for k, v in mapping.items():
718
+ for sd_part, hf_part in unet_conversion_map_layer:
719
+ v = v.replace(hf_part, sd_part)
720
+ mapping[k] = v
721
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
722
+
723
+ if v2:
724
+ conv_transformer_to_linear(new_state_dict)
725
+
726
+ return new_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
 
729
  # ================#
730
  # VAE Conversion #
731
  # ================#
732
 
733
+
734
  def reshape_weight_for_sd(w):
735
  # convert HF linear weights to SD conv2d weights
736
+ return w.reshape(*w.shape, 1, 1)
737
 
738
 
739
  def convert_vae_state_dict(vae_state_dict):
740
+ vae_conversion_map = [
741
+ # (stable-diffusion, HF Diffusers)
742
+ ("nin_shortcut", "conv_shortcut"),
743
+ ("norm_out", "conv_norm_out"),
744
+ ("mid.attn_1.", "mid_block.attentions.0."),
745
+ ]
746
+
747
+ for i in range(4):
748
+ # down_blocks have two resnets
749
+ for j in range(2):
750
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
751
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
752
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
753
+
754
+ if i < 3:
755
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
756
+ sd_downsample_prefix = f"down.{i}.downsample."
757
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
758
+
759
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
760
+ sd_upsample_prefix = f"up.{3-i}.upsample."
761
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
762
+
763
+ # up_blocks have three resnets
764
+ # also, up blocks in hf are numbered in reverse from sd
765
+ for j in range(3):
766
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
767
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
768
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
769
+
770
+ # this part accounts for mid blocks in both the encoder and the decoder
771
+ for i in range(2):
772
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
773
+ sd_mid_res_prefix = f"mid.block_{i+1}."
774
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
775
+
776
+ vae_conversion_map_attn = [
777
+ # (stable-diffusion, HF Diffusers)
778
+ ("norm.", "group_norm."),
779
+ ("q.", "query."),
780
+ ("k.", "key."),
781
+ ("v.", "value."),
782
+ ("proj_out.", "proj_attn."),
783
+ ]
784
+
785
+ mapping = {k: k for k in vae_state_dict.keys()}
786
+ for k, v in mapping.items():
787
+ for sd_part, hf_part in vae_conversion_map:
788
+ v = v.replace(hf_part, sd_part)
789
+ mapping[k] = v
790
+ for k, v in mapping.items():
791
+ if "attentions" in k:
792
+ for sd_part, hf_part in vae_conversion_map_attn:
793
+ v = v.replace(hf_part, sd_part)
794
+ mapping[k] = v
795
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
796
+ weights_to_convert = ["q", "k", "v", "proj_out"]
797
+ for k, v in new_state_dict.items():
798
+ for weight_name in weights_to_convert:
799
+ if f"mid.attn_1.{weight_name}.weight" in k:
800
+ # print(f"Reshaping {k} for SD format")
801
+ new_state_dict[k] = reshape_weight_for_sd(v)
802
+
803
+ return new_state_dict
804
 
805
 
806
  # endregion
807
 
808
  # region 自作のモデル読み書きなど
809
 
810
+
811
  def is_safetensors(path):
812
+ return os.path.splitext(path)[1].lower() == ".safetensors"
813
+
814
+
815
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
816
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
817
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
818
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
819
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
820
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
821
+ ]
822
+
823
+ if is_safetensors(ckpt_path):
824
+ checkpoint = None
825
+ state_dict = load_file(ckpt_path) # , device) # may causes error
 
 
 
 
826
  else:
827
+ checkpoint = torch.load(ckpt_path, map_location=device)
828
+ if "state_dict" in checkpoint:
829
+ state_dict = checkpoint["state_dict"]
830
+ else:
831
+ state_dict = checkpoint
832
+ checkpoint = None
833
 
834
+ key_reps = []
835
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
836
+ for key in state_dict.keys():
837
+ if key.startswith(rep_from):
838
+ new_key = rep_to + key[len(rep_from) :]
839
+ key_reps.append((key, new_key))
840
 
841
+ for key, new_key in key_reps:
842
+ state_dict[new_key] = state_dict[key]
843
+ del state_dict[key]
844
 
845
+ return checkpoint, state_dict
846
 
847
 
848
  # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
849
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
850
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
 
852
+ # Convert the UNet2DConditionModel model.
853
+ unet_config = create_unet_diffusers_config(v2)
854
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
855
 
856
+ unet = UNet2DConditionModel(**unet_config).to(device)
857
+ info = unet.load_state_dict(converted_unet_checkpoint)
858
+ print("loading u-net:", info)
859
 
860
+ # Convert the VAE model.
861
+ vae_config = create_vae_diffusers_config()
862
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
+ vae = AutoencoderKL(**vae_config).to(device)
865
+ info = vae.load_state_dict(converted_vae_checkpoint)
866
+ print("loading vae:", info)
867
 
868
+ # convert text_model
869
+ if v2:
870
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
871
+ cfg = CLIPTextConfig(
872
+ vocab_size=49408,
873
+ hidden_size=1024,
874
+ intermediate_size=4096,
875
+ num_hidden_layers=23,
876
+ num_attention_heads=16,
877
+ max_position_embeddings=77,
878
+ hidden_act="gelu",
879
+ layer_norm_eps=1e-05,
880
+ dropout=0.0,
881
+ attention_dropout=0.0,
882
+ initializer_range=0.02,
883
+ initializer_factor=1.0,
884
+ pad_token_id=1,
885
+ bos_token_id=0,
886
+ eos_token_id=2,
887
+ model_type="clip_text_model",
888
+ projection_dim=512,
889
+ torch_dtype="float32",
890
+ transformers_version="4.25.0.dev0",
891
+ )
892
+ text_model = CLIPTextModel._from_config(cfg)
893
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
894
+ else:
895
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
896
 
897
+ logging.set_verbosity_error() # don't show annoying warning
898
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
899
+ logging.set_verbosity_warning()
900
 
901
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
902
+ print("loading text encoder:", info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
 
904
+ return text_model, vae, unet
905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
 
907
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
908
+ def convert_key(key):
909
+ # position_idsの除去
910
+ if ".position_ids" in key:
911
+ return None
912
+
913
+ # common
914
+ key = key.replace("text_model.encoder.", "transformer.")
915
+ key = key.replace("text_model.", "")
916
+ if "layers" in key:
917
+ # resblocks conversion
918
+ key = key.replace(".layers.", ".resblocks.")
919
+ if ".layer_norm" in key:
920
+ key = key.replace(".layer_norm", ".ln_")
921
+ elif ".mlp." in key:
922
+ key = key.replace(".fc1.", ".c_fc.")
923
+ key = key.replace(".fc2.", ".c_proj.")
924
+ elif ".self_attn.out_proj" in key:
925
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
926
+ elif ".self_attn." in key:
927
+ key = None # 特殊なので後で処理する
928
+ else:
929
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
930
+ elif ".position_embedding" in key:
931
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
932
+ elif ".token_embedding" in key:
933
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
934
+ elif "final_layer_norm" in key:
935
+ key = key.replace("final_layer_norm", "ln_final")
936
+ return key
937
+
938
+ keys = list(checkpoint.keys())
939
+ new_sd = {}
940
+ for key in keys:
941
+ new_key = convert_key(key)
942
+ if new_key is None:
943
+ continue
944
+ new_sd[new_key] = checkpoint[key]
945
 
946
+ # attnの変換
947
+ for key in keys:
948
+ if "layers" in key and "q_proj" in key:
949
+ # 三つを結合
950
+ key_q = key
951
+ key_k = key.replace("q_proj", "k_proj")
952
+ key_v = key.replace("q_proj", "v_proj")
953
 
954
+ value_q = checkpoint[key_q]
955
+ value_k = checkpoint[key_k]
956
+ value_v = checkpoint[key_v]
957
+ value = torch.cat([value_q, value_k, value_v])
958
 
959
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
960
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
961
+ new_sd[new_key] = value
 
 
 
 
 
 
 
 
962
 
963
+ # 最後の層などを捏造するか
964
+ if make_dummy_weights:
965
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
966
+ keys = list(new_sd.keys())
967
+ for key in keys:
968
+ if key.startswith("transformer.resblocks.22."):
969
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
 
971
+ # Diffusersに含まれない重みを作っておく
972
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
973
+ new_sd["logit_scale"] = torch.tensor(1)
974
 
975
+ return new_sd
 
 
976
 
 
977
 
978
+ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
979
+ if ckpt_path is not None:
980
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
981
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
982
+ if checkpoint is None: # safetensors または state_dictのckpt
983
+ checkpoint = {}
984
+ strict = False
985
+ else:
986
+ strict = True
987
+ if "state_dict" in state_dict:
988
+ del state_dict["state_dict"]
989
+ else:
990
+ # 新しく作る
991
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
992
+ checkpoint = {}
993
+ state_dict = {}
994
+ strict = False
995
+
996
+ def update_sd(prefix, sd):
997
+ for k, v in sd.items():
998
+ key = prefix + k
999
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1000
+ if save_dtype is not None:
1001
+ v = v.detach().clone().to("cpu").to(save_dtype)
1002
+ state_dict[key] = v
1003
+
1004
+ # Convert the UNet model
1005
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1006
+ update_sd("model.diffusion_model.", unet_state_dict)
1007
+
1008
+ # Convert the text encoder model
1009
+ if v2:
1010
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1011
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1012
+ update_sd("cond_stage_model.model.", text_enc_dict)
1013
+ else:
1014
+ text_enc_dict = text_encoder.state_dict()
1015
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1016
 
1017
+ # Convert the VAE
1018
+ if vae is not None:
1019
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1020
+ update_sd("first_stage_model.", vae_dict)
1021
 
1022
+ # Put together new checkpoint
1023
+ key_count = len(state_dict.keys())
1024
+ new_ckpt = {"state_dict": state_dict}
1025
 
1026
+ # epoch and global_step are sometimes not int
1027
+ try:
1028
+ if "epoch" in checkpoint:
1029
+ epochs += checkpoint["epoch"]
1030
+ if "global_step" in checkpoint:
1031
+ steps += checkpoint["global_step"]
1032
+ except:
1033
+ pass
1034
+
1035
+ new_ckpt["epoch"] = epochs
1036
+ new_ckpt["global_step"] = steps
1037
+
1038
+ if is_safetensors(output_file):
1039
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1040
+ save_file(state_dict, output_file)
1041
+ else:
1042
+ torch.save(new_ckpt, output_file)
1043
 
1044
+ return key_count
 
 
 
 
 
1045
 
 
 
 
 
 
 
 
1046
 
1047
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1048
+ if pretrained_model_name_or_path is None:
1049
+ # load default settings for v1/v2
1050
+ if v2:
1051
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1052
+ else:
1053
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1054
+
1055
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1056
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1057
+ if vae is None:
1058
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1059
+
1060
+ pipeline = StableDiffusionPipeline(
1061
+ unet=unet,
1062
+ text_encoder=text_encoder,
1063
+ vae=vae,
1064
+ scheduler=scheduler,
1065
+ tokenizer=tokenizer,
1066
+ safety_checker=None,
1067
+ feature_extractor=None,
1068
+ requires_safety_checker=None,
1069
+ )
1070
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1071
 
 
 
1072
 
1073
+ VAE_PREFIX = "first_stage_model."
 
1074
 
1075
 
1076
+ def load_vae(vae_id, dtype):
1077
+ print(f"load VAE: {vae_id}")
1078
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1079
+ # Diffusers local/remote
1080
+ try:
1081
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1082
+ except EnvironmentError as e:
1083
+ print(f"exception occurs in loading vae: {e}")
1084
+ print("retry with subfolder='vae'")
1085
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1086
+ return vae
1087
+
1088
+ # local
1089
+ vae_config = create_vae_diffusers_config()
1090
+
1091
+ if vae_id.endswith(".bin"):
1092
+ # SD 1.5 VAE on Huggingface
1093
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1094
+ else:
1095
+ # StableDiffusion
1096
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1097
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1098
+
1099
+ # vae only or full model
1100
+ full_model = False
1101
+ for vae_key in vae_sd:
1102
+ if vae_key.startswith(VAE_PREFIX):
1103
+ full_model = True
1104
+ break
1105
+ if not full_model:
1106
+ sd = {}
1107
+ for key, value in vae_sd.items():
1108
+ sd[VAE_PREFIX + key] = value
1109
+ vae_sd = sd
1110
+ del sd
1111
+
1112
+ # Convert the VAE model.
1113
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1114
+
1115
+ vae = AutoencoderKL(**vae_config)
1116
+ vae.load_state_dict(converted_vae_checkpoint)
1117
+ return vae
1118
 
1119
+
1120
+ # endregion
1121
+
1122
+
1123
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1124
+ max_width, max_height = max_reso
1125
+ max_area = (max_width // divisible) * (max_height // divisible)
1126
+
1127
+ resos = set()
1128
+
1129
+ size = int(math.sqrt(max_area)) * divisible
1130
+ resos.add((size, size))
1131
+
1132
+ size = min_size
1133
+ while size <= max_size:
1134
+ width = size
1135
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1136
+ resos.add((width, height))
1137
+ resos.add((height, width))
1138
+
1139
+ # # make additional resos
1140
+ # if width >= height and width - divisible >= min_size:
1141
+ # resos.add((width - divisible, height))
1142
+ # resos.add((height, width - divisible))
1143
+ # if height >= width and height - divisible >= min_size:
1144
+ # resos.add((width, height - divisible))
1145
+ # resos.add((height - divisible, width))
1146
+
1147
+ size += divisible
1148
+
1149
+ resos = list(resos)
1150
+ resos.sort()
1151
+ return resos
1152
+
1153
+
1154
+ if __name__ == "__main__":
1155
+ resos = make_bucket_resolutions((512, 768))
1156
+ print(len(resos))
1157
+ print(resos)
1158
+ aspect_ratios = [w / h for w, h in resos]
1159
+ print(aspect_ratios)
1160
+
1161
+ ars = set()
1162
+ for ar in aspect_ratios:
1163
+ if ar in ars:
1164
+ print("error! duplicate ar:", ar)
1165
+ ars.add(ar)
lycoris/locon.py CHANGED
@@ -16,7 +16,8 @@ class LoConModule(nn.Module):
16
  multiplier=1.0,
17
  lora_dim=4, alpha=1,
18
  dropout=0.,
19
- use_cp=True,
 
20
  ):
21
  """ if alpha == 0 or None, alpha is rank (no scaling). """
22
  super().__init__()
 
16
  multiplier=1.0,
17
  lora_dim=4, alpha=1,
18
  dropout=0.,
19
+ use_cp=False,
20
+ **kwargs,
21
  ):
22
  """ if alpha == 0 or None, alpha is rank (no scaling). """
23
  super().__init__()
lycoris/loha.py CHANGED
@@ -92,7 +92,8 @@ class LohaModule(nn.Module):
92
  lora_name,
93
  org_module: nn.Module,
94
  multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
95
- use_cp=True,
 
96
  ):
97
  """ if alpha == 0 or None, alpha is rank (no scaling). """
98
  super().__init__()
 
92
  lora_name,
93
  org_module: nn.Module,
94
  multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
95
+ use_cp=False,
96
+ **kwargs,
97
  ):
98
  """ if alpha == 0 or None, alpha is rank (no scaling). """
99
  super().__init__()
lycoris/lokr.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ # 4, build custom backward function
8
+ # -
9
+
10
+
11
+ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
12
+ '''
13
+ return a tuple of two value of input dimension decomposed by the number closest to factor
14
+ second value is higher or equal than first value.
15
+
16
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
17
+ secon value is a value for weight.
18
+
19
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
20
+
21
+ examples)
22
+ factor
23
+ -1 2 4 8 16 ...
24
+ 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1
25
+ 128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8
26
+ 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2
27
+ 360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8
28
+ 512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
29
+ 1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
30
+ '''
31
+
32
+ if factor > 0 and (dimension % factor) == 0:
33
+ m = factor
34
+ n = dimension // factor
35
+ return m, n
36
+ if factor == -1:
37
+ factor = dimension
38
+ m, n = 1, dimension
39
+ length = m + n
40
+ while m<n:
41
+ new_m = m + 1
42
+ while dimension%new_m != 0:
43
+ new_m += 1
44
+ new_n = dimension // new_m
45
+ if new_m + new_n > length or new_m>factor:
46
+ break
47
+ else:
48
+ m, n = new_m, new_n
49
+ if m > n:
50
+ n, m = m, n
51
+ return m, n
52
+
53
+
54
+ def make_weight_cp(t, wa, wb):
55
+ rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
56
+ return rebuild2
57
+
58
+
59
+ def make_kron(orig_weight, w1, w2, scale):
60
+ if len(w2.shape) == 4:
61
+ w1 = w1.unsqueeze(2).unsqueeze(2)
62
+ w2 = w2.contiguous()
63
+ return orig_weight + torch.kron(w1, w2).reshape(orig_weight.shape)*scale
64
+
65
+
66
+ class LokrModule(nn.Module):
67
+ """
68
+ modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
69
+ and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
70
+ and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ lora_name, org_module: nn.Module,
76
+ multiplier=1.0,
77
+ lora_dim=4, alpha=1,
78
+ dropout=0.,
79
+ use_cp=False,
80
+ decompose_both = False,
81
+ factor:int=-1, # factorization factor
82
+ **kwargs,
83
+ ):
84
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
85
+ super().__init__()
86
+ factor = int(factor)
87
+ self.lora_name = lora_name
88
+ self.lora_dim = lora_dim
89
+ self.cp = False
90
+ self.use_w1 = False
91
+ self.use_w2 = False
92
+
93
+ self.shape = org_module.weight.shape
94
+ if org_module.__class__.__name__ == 'Conv2d':
95
+ in_dim = org_module.in_channels
96
+ k_size = org_module.kernel_size
97
+ out_dim = org_module.out_channels
98
+
99
+ in_m, in_n = factorization(in_dim, factor)
100
+ out_l, out_k = factorization(out_dim, factor)
101
+ shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
102
+
103
+ self.cp = use_cp and k_size!=(1, 1)
104
+ if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
105
+ self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
106
+ self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
107
+ else:
108
+ self.use_w1 = True
109
+ self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
110
+
111
+ if lora_dim >= max(shape[0][1], shape[1][1])/2:
112
+ self.use_w2 = True
113
+ self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
114
+ elif self.cp:
115
+ self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
116
+ self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
117
+ self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
118
+ else: # Conv2d not cp
119
+ # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
120
+ self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
121
+ self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
122
+ # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
123
+
124
+ self.op = F.conv2d
125
+ self.extra_args = {
126
+ "stride": org_module.stride,
127
+ "padding": org_module.padding,
128
+ "dilation": org_module.dilation,
129
+ "groups": org_module.groups
130
+ }
131
+
132
+ else: # Linear
133
+ in_dim = org_module.in_features
134
+ out_dim = org_module.out_features
135
+
136
+ in_m, in_n = factorization(in_dim, factor)
137
+ out_l, out_k = factorization(out_dim, factor)
138
+ shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
139
+
140
+ # smaller part. weight scale
141
+ if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
142
+ self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
143
+ self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
144
+ else:
145
+ self.use_w1 = True
146
+ self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
147
+
148
+ if lora_dim < max(shape[0][1], shape[1][1])/2:
149
+ # bigger part. weight and LoRA. [b, dim] x [dim, d]
150
+ self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
151
+ self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
152
+ # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
153
+ else:
154
+ self.use_w2 = True
155
+ self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
156
+
157
+ self.op = F.linear
158
+ self.extra_args = {}
159
+
160
+ if dropout:
161
+ self.dropout = nn.Dropout(dropout)
162
+ else:
163
+ self.dropout = nn.Identity()
164
+
165
+ if isinstance(alpha, torch.Tensor):
166
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
167
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
168
+ if self.use_w2 and self.use_w1:
169
+ #use scale = 1
170
+ alpha = lora_dim
171
+ self.scale = alpha / self.lora_dim
172
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
173
+
174
+ if self.use_w2:
175
+ torch.nn.init.constant_(self.lokr_w2, 0)
176
+ else:
177
+ if self.cp:
178
+ torch.nn.init.normal_(self.lokr_t2, std=0.1)
179
+ torch.nn.init.normal_(self.lokr_w2_a, std=1)
180
+ torch.nn.init.constant_(self.lokr_w2_b, 0)
181
+
182
+ if self.use_w1:
183
+ torch.nn.init.normal_(self.lokr_w1, std=1)
184
+ else:
185
+ torch.nn.init.normal_(self.lokr_w1_a, std=1)
186
+ torch.nn.init.normal_(self.lokr_w1_b, std=0.1)
187
+
188
+ self.multiplier = multiplier
189
+ self.org_module = [org_module]
190
+ weight = make_kron(
191
+ self.org_module[0].weight.data,
192
+ self.lokr_w1 if self.use_w1 else [email protected]_w1_b,
193
+ (self.lokr_w2 if self.use_w2
194
+ else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
195
+ else [email protected]_w2_b),
196
+ torch.tensor(self.multiplier * self.scale)
197
+ )
198
+ assert torch.sum(torch.isnan(weight)) == 0, "weight is nan"
199
+
200
+ # Same as locon.py
201
+ def apply_to(self):
202
+ self.org_forward = self.org_module[0].forward
203
+ self.org_module[0].forward = self.forward
204
+
205
+ def forward(self, x):
206
+ weight = make_kron(
207
+ self.org_module[0].weight.data,
208
+ self.lokr_w1 if self.use_w1 else [email protected]_w1_b,
209
+ (self.lokr_w2 if self.use_w2
210
+ else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
211
+ else [email protected]_w2_b),
212
+ torch.tensor(self.multiplier * self.scale)
213
+ )
214
+ bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
215
+ return self.op(
216
+ x,
217
+ weight.view(self.shape),
218
+ bias,
219
+ **self.extra_args
220
+ )
lycoris/utils.py CHANGED
@@ -24,6 +24,7 @@ def extract_conv(
24
  mode = 'fixed',
25
  mode_param = 0,
26
  device = 'cpu',
 
27
  ) -> Tuple[nn.Parameter, nn.Parameter]:
28
  weight = weight.to(device)
29
  out_ch, in_ch, kernel_size, _ = weight.shape
@@ -48,6 +49,8 @@ def extract_conv(
48
  raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
49
  lora_rank = max(1, lora_rank)
50
  lora_rank = min(out_ch, in_ch, lora_rank)
 
 
51
 
52
  U = U[:, :lora_rank]
53
  S = S[:lora_rank]
@@ -58,29 +61,7 @@ def extract_conv(
58
  extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
59
  extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
60
  del U, S, Vh, weight
61
- return extract_weight_A, extract_weight_B, diff
62
-
63
-
64
- def merge_conv(
65
- weight_a: Union[torch.Tensor, nn.Parameter],
66
- weight_b: Union[torch.Tensor, nn.Parameter],
67
- device = 'cpu'
68
- ):
69
- rank, in_ch, kernel_size, k_ = weight_a.shape
70
- out_ch, rank_, _, _ = weight_b.shape
71
- assert rank == rank_ and kernel_size == k_
72
-
73
- wa = weight_a.to(device)
74
- wb = weight_b.to(device)
75
-
76
- if device == 'cpu':
77
- wa = wa.float()
78
- wb = wb.float()
79
-
80
- merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1)
81
- weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
82
- del wb, wa
83
- return weight
84
 
85
 
86
  def extract_linear(
@@ -112,6 +93,8 @@ def extract_linear(
112
  raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
113
  lora_rank = max(1, lora_rank)
114
  lora_rank = min(out_ch, in_ch, lora_rank)
 
 
115
 
116
  U = U[:, :lora_rank]
117
  S = S[:lora_rank]
@@ -122,28 +105,7 @@ def extract_linear(
122
  extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
123
  extract_weight_B = U.reshape(out_ch, lora_rank).detach()
124
  del U, S, Vh, weight
125
- return extract_weight_A, extract_weight_B, diff
126
-
127
-
128
- def merge_linear(
129
- weight_a: Union[torch.Tensor, nn.Parameter],
130
- weight_b: Union[torch.Tensor, nn.Parameter],
131
- device = 'cpu'
132
- ):
133
- rank, in_ch = weight_a.shape
134
- out_ch, rank_ = weight_b.shape
135
- assert rank == rank_
136
-
137
- wa = weight_a.to(device)
138
- wb = weight_b.to(device)
139
-
140
- if device == 'cpu':
141
- wa = wa.float()
142
- wb = wb.float()
143
-
144
- weight = wb @ wa
145
- del wb, wa
146
- return weight
147
 
148
 
149
  def extract_diff(
@@ -200,30 +162,38 @@ def extract_diff(
200
  for child_name, child_module in module.named_modules():
201
  lora_name = prefix + '.' + name + '.' + child_name
202
  lora_name = lora_name.replace('.', '_')
203
-
204
  layer = child_module.__class__.__name__
 
 
 
 
 
205
  if layer == 'Linear':
206
- extract_a, extract_b, diff = extract_linear(
207
  (child_module.weight - weights[child_name]),
208
  mode,
209
  linear_mode_param,
210
  device = extract_device,
211
  )
 
 
212
  elif layer == 'Conv2d':
213
  is_linear = (child_module.weight.shape[2] == 1
214
  and child_module.weight.shape[3] == 1)
215
- extract_a, extract_b, diff = extract_conv(
216
  (child_module.weight - weights[child_name]),
217
  mode,
218
  linear_mode_param if is_linear else conv_mode_param,
219
  device = extract_device,
220
  )
221
- if small_conv and not is_linear:
 
 
222
  dim = extract_a.size(0)
223
- extract_c, extract_a, _ = extract_conv(
224
  extract_a.transpose(0, 1),
225
  'fixed', dim,
226
- extract_device
227
  )
228
  extract_a = extract_a.transpose(0, 1)
229
  extract_c = extract_c.transpose(0, 1)
@@ -235,77 +205,92 @@ def extract_diff(
235
  del extract_c
236
  else:
237
  continue
238
- loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
239
- loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
240
- loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
241
-
242
- if use_bias:
243
- diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
244
- sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
245
-
246
- indices = sparse_diff.indices().to(torch.int16)
247
- values = sparse_diff.values().half()
248
- loras[f'{lora_name}.bias_indices'] = indices
249
- loras[f'{lora_name}.bias_values'] = values
250
- loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
251
- del extract_a, extract_b, diff
 
 
 
 
252
  elif name in temp_name:
253
- weight = temp_name[name]
254
  lora_name = prefix + '.' + name
255
  lora_name = lora_name.replace('.', '_')
 
256
 
257
- if weight.size(0)<32 or weight.size(1)<32:
258
- loras[f'{lora_name}.diff'] = module.weight - weight
259
- continue
 
260
 
261
- layer = module.__class__.__name__
262
  if layer == 'Linear':
263
- extract_a, extract_b, diff = extract_linear(
264
- (module.weight - weight),
265
  mode,
266
  linear_mode_param,
267
  device = extract_device,
268
  )
 
 
269
  elif layer == 'Conv2d':
270
- is_linear = (module.weight.shape[2] == 1
271
- and module.weight.shape[3] == 1)
272
- extract_a, extract_b, diff = extract_conv(
273
- (module.weight - weight),
 
 
274
  mode,
275
  linear_mode_param if is_linear else conv_mode_param,
276
  device = extract_device,
277
  )
278
- if small_conv and not is_linear:
 
 
279
  dim = extract_a.size(0)
280
- extract_c, extract_a, _ = extract_conv(
281
  extract_a.transpose(0, 1),
282
  'fixed', dim,
283
- extract_device
284
  )
285
  extract_a = extract_a.transpose(0, 1)
286
  extract_c = extract_c.transpose(0, 1)
287
  loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
288
- diff = module.weight - torch.einsum(
289
  'i j k l, j r, p i -> p r k l',
290
  extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
291
  ).detach().cpu().contiguous()
292
  del extract_c
293
  else:
294
  continue
295
- loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
296
- loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
297
- loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
298
-
299
- if use_bias:
300
- diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
301
- sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
302
-
303
- indices = sparse_diff.indices().to(torch.int16)
304
- values = sparse_diff.values().half()
305
- loras[f'{lora_name}.bias_indices'] = indices
306
- loras[f'{lora_name}.bias_values'] = values
307
- loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
308
- del extract_a, extract_b, diff
 
 
 
 
309
  return loras
310
 
311
  text_encoder_loras = make_state_dict(
@@ -324,70 +309,125 @@ def extract_diff(
324
  return text_encoder_loras|unet_loras
325
 
326
 
327
- def merge_locon(
328
- base_model,
329
- locon_state_dict: Dict[str, torch.TensorType],
330
- scale: float = 1.0,
331
- device = 'cpu'
332
  ):
333
- UNET_TARGET_REPLACE_MODULE = [
334
- "Transformer2DModel",
335
- "Attention",
336
- "ResnetBlock2D",
337
- "Downsample2D",
338
- "Upsample2D"
339
- ]
340
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
341
- LORA_PREFIX_UNET = 'lora_unet'
342
- LORA_PREFIX_TEXT_ENCODER = 'lora_te'
343
- def merge(
344
- prefix,
345
- root_module: torch.nn.Module,
346
- target_replace_modules
347
- ):
348
- temp = {}
349
-
350
- for name, module in tqdm(list(root_module.named_modules())):
351
- if module.__class__.__name__ in target_replace_modules:
352
- temp[name] = {}
353
- for child_name, child_module in module.named_modules():
354
- layer = child_module.__class__.__name__
355
- if layer not in {'Linear', 'Conv2d'}:
356
- continue
357
- lora_name = prefix + '.' + name + '.' + child_name
358
- lora_name = lora_name.replace('.', '_')
359
-
360
- down = locon_state_dict[f'{lora_name}.lora_down.weight'].float()
361
- up = locon_state_dict[f'{lora_name}.lora_up.weight'].float()
362
- alpha = locon_state_dict[f'{lora_name}.alpha'].float()
363
- rank = down.shape[0]
364
-
365
- if layer == 'Conv2d':
366
- delta = merge_conv(down, up, device)
367
- child_module.weight.requires_grad_(False)
368
- child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
369
- del delta
370
- elif layer == 'Linear':
371
- delta = merge_linear(down, up, device)
372
- child_module.weight.requires_grad_(False)
373
- child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
374
- del delta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- merge(
377
- LORA_PREFIX_TEXT_ENCODER,
378
- base_model[0],
379
- TEXT_ENCODER_TARGET_REPLACE_MODULE
380
- )
381
- merge(
382
- LORA_PREFIX_UNET,
383
- base_model[2],
384
- UNET_TARGET_REPLACE_MODULE
385
- )
386
 
387
 
388
- def merge_loha(
389
  base_model,
390
- loha_state_dict: Dict[str, torch.TensorType],
391
  scale: float = 1.0,
392
  device = 'cpu'
393
  ):
@@ -398,51 +438,67 @@ def merge_loha(
398
  "Downsample2D",
399
  "Upsample2D"
400
  ]
 
 
 
 
 
 
401
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
402
  LORA_PREFIX_UNET = 'lora_unet'
403
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
404
- def merge(
 
405
  prefix,
406
  root_module: torch.nn.Module,
407
- target_replace_modules
 
 
408
  ):
409
- temp = {}
410
-
411
- for name, module in tqdm(list(root_module.named_modules())):
412
  if module.__class__.__name__ in target_replace_modules:
413
- temp[name] = {}
414
  for child_name, child_module in module.named_modules():
415
- layer = child_module.__class__.__name__
416
- if layer not in {'Linear', 'Conv2d'}:
417
  continue
418
  lora_name = prefix + '.' + name + '.' + child_name
419
  lora_name = lora_name.replace('.', '_')
420
 
421
- w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device)
422
- w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device)
423
- w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device)
424
- w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device)
425
- alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device)
426
- dim = w1b.shape[0]
427
-
428
- delta = (w1a @ w1b) * (w2a @ w2b)
429
- delta = delta.reshape(child_module.weight.shape)
 
430
 
431
- if layer == 'Conv2d':
432
- child_module.weight.requires_grad_(False)
433
- child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
434
- elif layer == 'Linear':
435
- child_module.weight.requires_grad_(False)
436
- child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
437
- del delta
438
 
439
- merge(
440
- LORA_PREFIX_TEXT_ENCODER,
441
- base_model[0],
442
- TEXT_ENCODER_TARGET_REPLACE_MODULE
 
 
 
 
 
 
443
  )
444
- merge(
445
  LORA_PREFIX_UNET,
446
- base_model[2],
447
- UNET_TARGET_REPLACE_MODULE
448
- )
 
 
 
 
24
  mode = 'fixed',
25
  mode_param = 0,
26
  device = 'cpu',
27
+ is_cp = False,
28
  ) -> Tuple[nn.Parameter, nn.Parameter]:
29
  weight = weight.to(device)
30
  out_ch, in_ch, kernel_size, _ = weight.shape
 
49
  raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
50
  lora_rank = max(1, lora_rank)
51
  lora_rank = min(out_ch, in_ch, lora_rank)
52
+ if lora_rank>=out_ch/2 and not is_cp:
53
+ return weight, 'full'
54
 
55
  U = U[:, :lora_rank]
56
  S = S[:lora_rank]
 
61
  extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
62
  extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
63
  del U, S, Vh, weight
64
+ return (extract_weight_A, extract_weight_B, diff), 'low rank'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  def extract_linear(
 
93
  raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
94
  lora_rank = max(1, lora_rank)
95
  lora_rank = min(out_ch, in_ch, lora_rank)
96
+ if lora_rank>=out_ch/2:
97
+ return weight, 'full'
98
 
99
  U = U[:, :lora_rank]
100
  S = S[:lora_rank]
 
105
  extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
106
  extract_weight_B = U.reshape(out_ch, lora_rank).detach()
107
  del U, S, Vh, weight
108
+ return (extract_weight_A, extract_weight_B, diff), 'low rank'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  def extract_diff(
 
162
  for child_name, child_module in module.named_modules():
163
  lora_name = prefix + '.' + name + '.' + child_name
164
  lora_name = lora_name.replace('.', '_')
 
165
  layer = child_module.__class__.__name__
166
+ if layer in {'Linear', 'Conv2d'}:
167
+ root_weight = child_module.weight
168
+ if torch.allclose(root_weight, weights[child_name]):
169
+ continue
170
+
171
  if layer == 'Linear':
172
+ weight, decompose_mode = extract_linear(
173
  (child_module.weight - weights[child_name]),
174
  mode,
175
  linear_mode_param,
176
  device = extract_device,
177
  )
178
+ if decompose_mode == 'low rank':
179
+ extract_a, extract_b, diff = weight
180
  elif layer == 'Conv2d':
181
  is_linear = (child_module.weight.shape[2] == 1
182
  and child_module.weight.shape[3] == 1)
183
+ weight, decompose_mode = extract_conv(
184
  (child_module.weight - weights[child_name]),
185
  mode,
186
  linear_mode_param if is_linear else conv_mode_param,
187
  device = extract_device,
188
  )
189
+ if decompose_mode == 'low rank':
190
+ extract_a, extract_b, diff = weight
191
+ if small_conv and not is_linear and decompose_mode == 'low rank':
192
  dim = extract_a.size(0)
193
+ (extract_c, extract_a, _), _ = extract_conv(
194
  extract_a.transpose(0, 1),
195
  'fixed', dim,
196
+ extract_device, True
197
  )
198
  extract_a = extract_a.transpose(0, 1)
199
  extract_c = extract_c.transpose(0, 1)
 
205
  del extract_c
206
  else:
207
  continue
208
+ if decompose_mode == 'low rank':
209
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
210
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
211
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
212
+ if use_bias:
213
+ diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
214
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
215
+
216
+ indices = sparse_diff.indices().to(torch.int16)
217
+ values = sparse_diff.values().half()
218
+ loras[f'{lora_name}.bias_indices'] = indices
219
+ loras[f'{lora_name}.bias_values'] = values
220
+ loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
221
+ del extract_a, extract_b, diff
222
+ elif decompose_mode == 'full':
223
+ loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
224
+ else:
225
+ raise NotImplementedError
226
  elif name in temp_name:
227
+ weights = temp_name[name]
228
  lora_name = prefix + '.' + name
229
  lora_name = lora_name.replace('.', '_')
230
+ layer = module.__class__.__name__
231
 
232
+ if layer in {'Linear', 'Conv2d'}:
233
+ root_weight = module.weight
234
+ if torch.allclose(root_weight, weights):
235
+ continue
236
 
 
237
  if layer == 'Linear':
238
+ weight, decompose_mode = extract_linear(
239
+ (root_weight - weights),
240
  mode,
241
  linear_mode_param,
242
  device = extract_device,
243
  )
244
+ if decompose_mode == 'low rank':
245
+ extract_a, extract_b, diff = weight
246
  elif layer == 'Conv2d':
247
+ is_linear = (
248
+ root_weight.shape[2] == 1
249
+ and root_weight.shape[3] == 1
250
+ )
251
+ weight, decompose_mode = extract_conv(
252
+ (root_weight - weights),
253
  mode,
254
  linear_mode_param if is_linear else conv_mode_param,
255
  device = extract_device,
256
  )
257
+ if decompose_mode == 'low rank':
258
+ extract_a, extract_b, diff = weight
259
+ if small_conv and not is_linear and decompose_mode == 'low rank':
260
  dim = extract_a.size(0)
261
+ (extract_c, extract_a, _), _ = extract_conv(
262
  extract_a.transpose(0, 1),
263
  'fixed', dim,
264
+ extract_device, True
265
  )
266
  extract_a = extract_a.transpose(0, 1)
267
  extract_c = extract_c.transpose(0, 1)
268
  loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
269
+ diff = root_weight - torch.einsum(
270
  'i j k l, j r, p i -> p r k l',
271
  extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
272
  ).detach().cpu().contiguous()
273
  del extract_c
274
  else:
275
  continue
276
+ if decompose_mode == 'low rank':
277
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
278
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
279
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
280
+ if use_bias:
281
+ diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
282
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
283
+
284
+ indices = sparse_diff.indices().to(torch.int16)
285
+ values = sparse_diff.values().half()
286
+ loras[f'{lora_name}.bias_indices'] = indices
287
+ loras[f'{lora_name}.bias_values'] = values
288
+ loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
289
+ del extract_a, extract_b, diff
290
+ elif decompose_mode == 'full':
291
+ loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
292
+ else:
293
+ raise NotImplementedError
294
  return loras
295
 
296
  text_encoder_loras = make_state_dict(
 
309
  return text_encoder_loras|unet_loras
310
 
311
 
312
+ def get_module(
313
+ lyco_state_dict: Dict,
314
+ lora_name
 
 
315
  ):
316
+ if f'{lora_name}.lora_up.weight' in lyco_state_dict:
317
+ up = lyco_state_dict[f'{lora_name}.lora_up.weight']
318
+ down = lyco_state_dict[f'{lora_name}.lora_down.weight']
319
+ mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None)
320
+ alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
321
+ return 'locon', (up, down, mid, alpha)
322
+ elif f'{lora_name}.hada_w1_a' in lyco_state_dict:
323
+ w1a = lyco_state_dict[f'{lora_name}.hada_w1_a']
324
+ w1b = lyco_state_dict[f'{lora_name}.hada_w1_b']
325
+ w2a = lyco_state_dict[f'{lora_name}.hada_w2_a']
326
+ w2b = lyco_state_dict[f'{lora_name}.hada_w2_b']
327
+ t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None)
328
+ t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None)
329
+ alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
330
+ return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha)
331
+ elif f'{lora_name}.weight' in lyco_state_dict:
332
+ weight = lyco_state_dict[f'{lora_name}.weight']
333
+ on_input = lyco_state_dict.get(f'{lora_name}.on_input', False)
334
+ return 'ia3', (weight, on_input)
335
+ elif (f'{lora_name}.lokr_w1' in lyco_state_dict
336
+ or f'{lora_name}.lokr_w1_a' in lyco_state_dict):
337
+ w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None)
338
+ w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None)
339
+ w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None)
340
+ w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None)
341
+ w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None)
342
+ w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None)
343
+ t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None)
344
+ t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None)
345
+ alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
346
+ return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha)
347
+ elif f'{lora_name}.diff' in lyco_state_dict:
348
+ return 'full', lyco_state_dict[f'{lora_name}.diff']
349
+ else:
350
+ return 'None', ()
351
+
352
+
353
+ def cp_weight_from_conv(
354
+ up, down, mid
355
+ ):
356
+ up = up.reshape(up.size(0), up.size(1))
357
+ down = down.reshape(down.size(0), down.size(1))
358
+ return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down)
359
+
360
+ def cp_weight(
361
+ wa, wb, t
362
+ ):
363
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
364
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
365
+
366
+
367
+ @torch.no_grad()
368
+ def rebuild_weight(module_type, params, orig_weight, scale=1):
369
+ if orig_weight is None:
370
+ return orig_weight
371
+ merged = orig_weight
372
+ if module_type == 'locon':
373
+ up, down, mid, alpha = params
374
+ if alpha is not None:
375
+ scale *= alpha/up.size(1)
376
+ if mid is not None:
377
+ rebuild = cp_weight_from_conv(up, down, mid)
378
+ else:
379
+ rebuild = up.reshape(up.size(0),-1) @ down.reshape(down.size(0), -1)
380
+ merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale
381
+ del up, down, mid, alpha, params, rebuild
382
+ elif module_type == 'hada':
383
+ w1a, w1b, w2a, w2b, t1, t2, alpha = params
384
+ if alpha is not None:
385
+ scale *= alpha / w1b.size(0)
386
+ if t1 is not None:
387
+ rebuild1 = cp_weight(w1a, w1b, t1)
388
+ else:
389
+ rebuild1 = w1a @ w1b
390
+ if t2 is not None:
391
+ rebuild2 = cp_weight(w2a, w2b, t2)
392
+ else:
393
+ rebuild2 = w2a @ w2b
394
+ rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape)
395
+ merged = orig_weight + rebuild * scale
396
+ del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2
397
+ elif module_type == 'ia3':
398
+ weight, on_input = params
399
+ if not on_input:
400
+ weight = weight.reshape(-1, 1)
401
+ merged = orig_weight + weight * orig_weight * scale
402
+ del weight, on_input, params
403
+ elif module_type == 'kron':
404
+ w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params
405
+ if alpha is not None and (w1b is not None or w2b is not None):
406
+ scale *= alpha / (w1b.size(0) if w1b else w2b.size(0))
407
+ if w1a is not None and w1b is not None:
408
+ if t1:
409
+ w1 = cp_weight(w1a, w1b, t1)
410
+ else:
411
+ w1 = w1a @ w1b
412
+ if w2a is not None and w2b is not None:
413
+ if t2:
414
+ w2 = cp_weight(w2a, w2b, t2)
415
+ else:
416
+ w2 = w2a @ w2b
417
+ rebuild = torch.kron(w1, w2).reshape(orig_weight.shape)
418
+ merged = orig_weight + rebuild* scale
419
+ del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild
420
+ elif module_type == 'full':
421
+ rebuild = params.reshape(orig_weight.shape)
422
+ merged = orig_weight + rebuild * scale
423
+ del params, rebuild
424
 
425
+ return merged
 
 
 
 
 
 
 
 
 
426
 
427
 
428
+ def merge(
429
  base_model,
430
+ lyco_state_dict,
431
  scale: float = 1.0,
432
  device = 'cpu'
433
  ):
 
438
  "Downsample2D",
439
  "Upsample2D"
440
  ]
441
+ UNET_TARGET_REPLACE_NAME = [
442
+ "conv_in",
443
+ "conv_out",
444
+ "time_embedding.linear_1",
445
+ "time_embedding.linear_2",
446
+ ]
447
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
448
  LORA_PREFIX_UNET = 'lora_unet'
449
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
450
+ merged = 0
451
+ def merge_state_dict(
452
  prefix,
453
  root_module: torch.nn.Module,
454
+ lyco_state_dict: Dict[str,torch.Tensor],
455
+ target_replace_modules,
456
+ target_replace_names = []
457
  ):
458
+ nonlocal merged
459
+ for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
 
460
  if module.__class__.__name__ in target_replace_modules:
 
461
  for child_name, child_module in module.named_modules():
462
+ if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
 
463
  continue
464
  lora_name = prefix + '.' + name + '.' + child_name
465
  lora_name = lora_name.replace('.', '_')
466
 
467
+ result = rebuild_weight(*get_module(
468
+ lyco_state_dict, lora_name
469
+ ), getattr(child_module, 'weight'), scale)
470
+ if result is not None:
471
+ merged += 1
472
+ child_module.requires_grad_(False)
473
+ child_module.weight.copy_(result)
474
+ elif name in target_replace_names:
475
+ lora_name = prefix + '.' + name
476
+ lora_name = lora_name.replace('.', '_')
477
 
478
+ result = rebuild_weight(*get_module(
479
+ lyco_state_dict, lora_name
480
+ ), getattr(module, 'weight'), scale)
481
+ if result is not None:
482
+ merged += 1
483
+ module.requires_grad_(False)
484
+ module.weight.copy_(result)
485
 
486
+ if device == 'cpu':
487
+ for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'):
488
+ lyco_state_dict[k] = v.float()
489
+
490
+ merge_state_dict(
491
+ LORA_PREFIX_TEXT_ENCODER,
492
+ base_model[0],
493
+ lyco_state_dict,
494
+ TEXT_ENCODER_TARGET_REPLACE_MODULE,
495
+ UNET_TARGET_REPLACE_NAME
496
  )
497
+ merge_state_dict(
498
  LORA_PREFIX_UNET,
499
+ base_model[2],
500
+ lyco_state_dict,
501
+ UNET_TARGET_REPLACE_MODULE,
502
+ UNET_TARGET_REPLACE_NAME
503
+ )
504
+ print(f'{merged} Modules been merged')