abc commited on
Commit
0380cfd
·
1 Parent(s): 72f0b64

Upload 7 files

Browse files
Files changed (3) hide show
  1. lycoris/kohya.py +37 -1
  2. lycoris/locon.py +8 -19
  3. lycoris/utils.py +72 -4
lycoris/kohya.py CHANGED
@@ -70,6 +70,12 @@ class LycorisNetwork(torch.nn.Module):
70
  "Downsample2D",
71
  "Upsample2D"
72
  ]
 
 
 
 
 
 
73
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
74
  LORA_PREFIX_UNET = 'lora_unet'
75
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
@@ -102,7 +108,12 @@ class LycorisNetwork(torch.nn.Module):
102
  self.dropout = dropout
103
 
104
  # create module instances
105
- def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[network_module]:
 
 
 
 
 
106
  print('Create LyCORIS Module')
107
  loras = []
108
  for name, module in root_module.named_modules():
@@ -132,6 +143,31 @@ class LycorisNetwork(torch.nn.Module):
132
  else:
133
  continue
134
  loras.append(lora)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return loras
136
 
137
  self.text_encoder_loras = create_modules(
 
70
  "Downsample2D",
71
  "Upsample2D"
72
  ]
73
+ UNET_TARGET_REPLACE_NAME = [
74
+ "conv_in",
75
+ "conv_out",
76
+ "time_embedding.linear_1",
77
+ "time_embedding.linear_2",
78
+ ]
79
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
80
  LORA_PREFIX_UNET = 'lora_unet'
81
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
 
108
  self.dropout = dropout
109
 
110
  # create module instances
111
+ def create_modules(
112
+ prefix,
113
+ root_module: torch.nn.Module,
114
+ target_replace_modules,
115
+ target_replace_names = []
116
+ ) -> List[network_module]:
117
  print('Create LyCORIS Module')
118
  loras = []
119
  for name, module in root_module.named_modules():
 
143
  else:
144
  continue
145
  loras.append(lora)
146
+ elif name in target_replace_names:
147
+ lora_name = prefix + '.' + name
148
+ lora_name = lora_name.replace('.', '_')
149
+ if module.__class__.__name__ == 'Linear' and lora_dim>0:
150
+ lora = network_module(
151
+ lora_name, module, self.multiplier,
152
+ self.lora_dim, self.alpha, self.dropout, use_cp
153
+ )
154
+ elif module.__class__.__name__ == 'Conv2d':
155
+ k_size, *_ = module.kernel_size
156
+ if k_size==1 and lora_dim>0:
157
+ lora = network_module(
158
+ lora_name, module, self.multiplier,
159
+ self.lora_dim, self.alpha, self.dropout, use_cp
160
+ )
161
+ elif conv_lora_dim>0:
162
+ lora = network_module(
163
+ lora_name, module, self.multiplier,
164
+ self.conv_lora_dim, self.conv_alpha, self.dropout, use_cp
165
+ )
166
+ else:
167
+ continue
168
+ else:
169
+ continue
170
+ loras.append(lora)
171
  return loras
172
 
173
  self.text_encoder_loras = create_modules(
lycoris/locon.py CHANGED
@@ -38,18 +38,11 @@ class LoConModule(nn.Module):
38
  else:
39
  self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
40
  self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
41
- self.op = F.conv2d
42
- self.extra_args = {
43
- 'stride': stride,
44
- 'padding': padding
45
- }
46
  else:
47
  in_dim = org_module.in_features
48
  out_dim = org_module.out_features
49
  self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
50
  self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
51
- self.op = F.linear
52
- self.extra_args = {}
53
  self.shape = org_module.weight.shape
54
 
55
  if dropout:
@@ -66,6 +59,8 @@ class LoConModule(nn.Module):
66
  # same as microsoft's
67
  torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
  torch.nn.init.zeros_(self.lora_up.weight)
 
 
69
 
70
  self.multiplier = multiplier
71
  self.org_module = [org_module]
@@ -81,16 +76,10 @@ class LoConModule(nn.Module):
81
 
82
  def forward(self, x):
83
  if self.cp:
84
- return self.dropout(
85
- self.org_forward(x)
86
- + self.lora_up(self.lora_mid(self.lora_down(x)))
87
- ) * self.multiplier * self.scale
88
  else:
89
- bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
90
- return self.op(
91
- x,
92
- (self.org_module[0].weight.data
93
- + self.dropout(self.make_weight()) * self.multiplier * self.scale),
94
- bias,
95
- **self.extra_args,
96
- )
 
38
  else:
39
  self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
40
  self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
 
 
 
 
 
41
  else:
42
  in_dim = org_module.in_features
43
  out_dim = org_module.out_features
44
  self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
45
  self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
 
 
46
  self.shape = org_module.weight.shape
47
 
48
  if dropout:
 
59
  # same as microsoft's
60
  torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
61
  torch.nn.init.zeros_(self.lora_up.weight)
62
+ if self.cp:
63
+ torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
64
 
65
  self.multiplier = multiplier
66
  self.org_module = [org_module]
 
76
 
77
  def forward(self, x):
78
  if self.cp:
79
+ return self.org_forward(x) + self.dropout(
80
+ self.lora_up(self.lora_mid(self.lora_down(x)))* self.multiplier * self.scale
81
+ )
 
82
  else:
83
+ return self.org_forward(x) + self.dropout(
84
+ self.lora_up(self.lora_down(x))* self.multiplier * self.scale
85
+ )
 
 
 
 
 
lycoris/utils.py CHANGED
@@ -164,6 +164,12 @@ def extract_diff(
164
  "Downsample2D",
165
  "Upsample2D"
166
  ]
 
 
 
 
 
 
167
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
168
  LORA_PREFIX_UNET = 'lora_unet'
169
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
@@ -171,10 +177,12 @@ def extract_diff(
171
  prefix,
172
  root_module: torch.nn.Module,
173
  target_module: torch.nn.Module,
174
- target_replace_modules
 
175
  ):
176
  loras = {}
177
  temp = {}
 
178
 
179
  for name, module in root_module.named_modules():
180
  if module.__class__.__name__ in target_replace_modules:
@@ -183,6 +191,8 @@ def extract_diff(
183
  if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
184
  continue
185
  temp[name][child_name] = child_module.weight
 
 
186
 
187
  for name, module in tqdm(list(target_module.named_modules())):
188
  if name in temp:
@@ -221,7 +231,7 @@ def extract_diff(
221
  diff = child_module.weight - torch.einsum(
222
  'i j k l, j r, p i -> p r k l',
223
  extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
224
- )
225
  del extract_c
226
  else:
227
  continue
@@ -231,7 +241,7 @@ def extract_diff(
231
 
232
  if use_bias:
233
  diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
234
- sparse_diff = make_sparse(diff, sparsity).to_sparse()
235
 
236
  indices = sparse_diff.indices().to(torch.int16)
237
  values = sparse_diff.values().half()
@@ -239,6 +249,63 @@ def extract_diff(
239
  loras[f'{lora_name}.bias_values'] = values
240
  loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
241
  del extract_a, extract_b, diff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  return loras
243
 
244
  text_encoder_loras = make_state_dict(
@@ -250,7 +317,8 @@ def extract_diff(
250
  unet_loras = make_state_dict(
251
  LORA_PREFIX_UNET,
252
  base_model[2], db_model[2],
253
- UNET_TARGET_REPLACE_MODULE
 
254
  )
255
  print(len(text_encoder_loras), len(unet_loras))
256
  return text_encoder_loras|unet_loras
 
164
  "Downsample2D",
165
  "Upsample2D"
166
  ]
167
+ UNET_TARGET_REPLACE_NAME = [
168
+ "conv_in",
169
+ "conv_out",
170
+ "time_embedding.linear_1",
171
+ "time_embedding.linear_2",
172
+ ]
173
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
174
  LORA_PREFIX_UNET = 'lora_unet'
175
  LORA_PREFIX_TEXT_ENCODER = 'lora_te'
 
177
  prefix,
178
  root_module: torch.nn.Module,
179
  target_module: torch.nn.Module,
180
+ target_replace_modules,
181
+ target_replace_names = []
182
  ):
183
  loras = {}
184
  temp = {}
185
+ temp_name = {}
186
 
187
  for name, module in root_module.named_modules():
188
  if module.__class__.__name__ in target_replace_modules:
 
191
  if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
192
  continue
193
  temp[name][child_name] = child_module.weight
194
+ elif name in target_replace_names:
195
+ temp_name[name] = module.weight
196
 
197
  for name, module in tqdm(list(target_module.named_modules())):
198
  if name in temp:
 
231
  diff = child_module.weight - torch.einsum(
232
  'i j k l, j r, p i -> p r k l',
233
  extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
234
+ ).detach().cpu().contiguous()
235
  del extract_c
236
  else:
237
  continue
 
241
 
242
  if use_bias:
243
  diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
244
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
245
 
246
  indices = sparse_diff.indices().to(torch.int16)
247
  values = sparse_diff.values().half()
 
249
  loras[f'{lora_name}.bias_values'] = values
250
  loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
251
  del extract_a, extract_b, diff
252
+ elif name in temp_name:
253
+ weight = temp_name[name]
254
+ lora_name = prefix + '.' + name
255
+ lora_name = lora_name.replace('.', '_')
256
+
257
+ if weight.size(0)<32 or weight.size(1)<32:
258
+ loras[f'{lora_name}.diff'] = module.weight - weight
259
+ continue
260
+
261
+ layer = module.__class__.__name__
262
+ if layer == 'Linear':
263
+ extract_a, extract_b, diff = extract_linear(
264
+ (module.weight - weight),
265
+ mode,
266
+ linear_mode_param,
267
+ device = extract_device,
268
+ )
269
+ elif layer == 'Conv2d':
270
+ is_linear = (module.weight.shape[2] == 1
271
+ and module.weight.shape[3] == 1)
272
+ extract_a, extract_b, diff = extract_conv(
273
+ (module.weight - weight),
274
+ mode,
275
+ linear_mode_param if is_linear else conv_mode_param,
276
+ device = extract_device,
277
+ )
278
+ if small_conv and not is_linear:
279
+ dim = extract_a.size(0)
280
+ extract_c, extract_a, _ = extract_conv(
281
+ extract_a.transpose(0, 1),
282
+ 'fixed', dim,
283
+ extract_device
284
+ )
285
+ extract_a = extract_a.transpose(0, 1)
286
+ extract_c = extract_c.transpose(0, 1)
287
+ loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
288
+ diff = module.weight - torch.einsum(
289
+ 'i j k l, j r, p i -> p r k l',
290
+ extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
291
+ ).detach().cpu().contiguous()
292
+ del extract_c
293
+ else:
294
+ continue
295
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
296
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
297
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
298
+
299
+ if use_bias:
300
+ diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
301
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
302
+
303
+ indices = sparse_diff.indices().to(torch.int16)
304
+ values = sparse_diff.values().half()
305
+ loras[f'{lora_name}.bias_indices'] = indices
306
+ loras[f'{lora_name}.bias_values'] = values
307
+ loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
308
+ del extract_a, extract_b, diff
309
  return loras
310
 
311
  text_encoder_loras = make_state_dict(
 
317
  unet_loras = make_state_dict(
318
  LORA_PREFIX_UNET,
319
  base_model[2], db_model[2],
320
+ UNET_TARGET_REPLACE_MODULE,
321
+ UNET_TARGET_REPLACE_NAME
322
  )
323
  print(len(text_encoder_loras), len(unet_loras))
324
  return text_encoder_loras|unet_loras