multimodalart HF staff commited on
Commit
722e096
·
1 Parent(s): 57558ed

Update lora.py

Browse files
Files changed (1) hide show
  1. lora.py +27 -84
lora.py CHANGED
@@ -5,16 +5,12 @@
5
 
6
  import math
7
  import os
8
- from typing import Dict, List, Optional, Tuple, Type, Union
9
- from diffusers import AutoencoderKL
10
- from transformers import CLIPTextModel
11
  import numpy as np
12
  import torch
13
  import re
14
 
15
 
16
- RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
17
-
18
  RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
19
 
20
 
@@ -219,13 +215,7 @@ class LoRAInfModule(LoRAModule):
219
 
220
  def default_forward(self, x):
221
  # print("default_forward", self.lora_name, x.size())
222
- org_forward = self.org_forward(x)
223
- lora_down = self.lora_down(x)
224
- lora_up_down = self.lora_up(lora_down)
225
- print(org_forward)
226
- print(lora_up_down)
227
- print(self.multiplier)
228
- return org_forward + lora_up_down * self.multiplier #* self.scale
229
 
230
  def forward(self, x):
231
  if not self.enabled:
@@ -410,16 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs):
410
  return down_lr_weight, mid_lr_weight, up_lr_weight
411
 
412
 
413
- def create_network(
414
- multiplier: float,
415
- network_dim: Optional[int],
416
- network_alpha: Optional[float],
417
- vae: AutoencoderKL,
418
- text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
419
- unet,
420
- neuron_dropout: Optional[float] = None,
421
- **kwargs,
422
- ):
423
  if network_dim is None:
424
  network_dim = 4 # default
425
  if network_alpha is None:
@@ -738,36 +719,33 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
738
  class LoRANetwork(torch.nn.Module):
739
  NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
740
 
741
- UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
 
742
  UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
743
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
744
  LORA_PREFIX_UNET = "lora_unet"
745
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
746
 
747
- # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
748
- LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
749
- LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
750
-
751
  def __init__(
752
  self,
753
- text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
754
  unet,
755
- multiplier: float = 1.0,
756
- lora_dim: int = 4,
757
- alpha: float = 1,
758
- dropout: Optional[float] = None,
759
- rank_dropout: Optional[float] = None,
760
- module_dropout: Optional[float] = None,
761
- conv_lora_dim: Optional[int] = None,
762
- conv_alpha: Optional[float] = None,
763
- block_dims: Optional[List[int]] = None,
764
- block_alphas: Optional[List[float]] = None,
765
- conv_block_dims: Optional[List[int]] = None,
766
- conv_block_alphas: Optional[List[float]] = None,
767
- modules_dim: Optional[Dict[str, int]] = None,
768
- modules_alpha: Optional[Dict[str, int]] = None,
769
- module_class: Type[object] = LoRAModule,
770
- varbose: Optional[bool] = False,
771
  ) -> None:
772
  """
773
  LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -805,21 +783,8 @@ class LoRANetwork(torch.nn.Module):
805
  print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
806
 
807
  # create module instances
808
- def create_modules(
809
- is_unet: bool,
810
- text_encoder_idx: Optional[int], # None, 1, 2
811
- root_module: torch.nn.Module,
812
- target_replace_modules: List[torch.nn.Module],
813
- ) -> List[LoRAModule]:
814
- prefix = (
815
- self.LORA_PREFIX_UNET
816
- if is_unet
817
- else (
818
- self.LORA_PREFIX_TEXT_ENCODER
819
- if text_encoder_idx is None
820
- else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
821
- )
822
- )
823
  loras = []
824
  skipped = []
825
  for name, module in root_module.named_modules():
@@ -835,14 +800,11 @@ class LoRANetwork(torch.nn.Module):
835
 
836
  dim = None
837
  alpha = None
838
-
839
  if modules_dim is not None:
840
- # モジュール指定あり
841
  if lora_name in modules_dim:
842
  dim = modules_dim[lora_name]
843
  alpha = modules_alpha[lora_name]
844
  elif is_unet and block_dims is not None:
845
- # U-Netでblock_dims指定あり
846
  block_idx = get_block_index(lora_name)
847
  if is_linear or is_conv2d_1x1:
848
  dim = block_dims[block_idx]
@@ -851,7 +813,6 @@ class LoRANetwork(torch.nn.Module):
851
  dim = conv_block_dims[block_idx]
852
  alpha = conv_block_alphas[block_idx]
853
  else:
854
- # 通常、すべて対象とする
855
  if is_linear or is_conv2d_1x1:
856
  dim = self.lora_dim
857
  alpha = self.alpha
@@ -860,7 +821,6 @@ class LoRANetwork(torch.nn.Module):
860
  alpha = self.conv_alpha
861
 
862
  if dim is None or dim == 0:
863
- # skipした情報を出力
864
  if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
865
  skipped.append(lora_name)
866
  continue
@@ -878,24 +838,7 @@ class LoRANetwork(torch.nn.Module):
878
  loras.append(lora)
879
  return loras, skipped
880
 
881
- text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
882
- print(text_encoders)
883
- # create LoRA for text encoder
884
- # 毎回すべてのモジュールを作るのは無駄なので要検討
885
- self.text_encoder_loras = []
886
- skipped_te = []
887
- for i, text_encoder in enumerate(text_encoders):
888
- if len(text_encoders) > 1:
889
- index = i + 1
890
- print(f"create LoRA for Text Encoder {index}:")
891
- else:
892
- index = None
893
- print(f"create LoRA for Text Encoder:")
894
-
895
- print(text_encoder)
896
- text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
897
- self.text_encoder_loras.extend(text_encoder_loras)
898
- skipped_te += skipped
899
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
900
 
901
  # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
@@ -903,7 +846,7 @@ class LoRANetwork(torch.nn.Module):
903
  if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
904
  target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
905
 
906
- self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
907
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
908
 
909
  skipped = skipped_te + skipped_un
@@ -937,6 +880,7 @@ class LoRANetwork(torch.nn.Module):
937
  weights_sd = load_file(file)
938
  else:
939
  weights_sd = torch.load(file, map_location="cpu")
 
940
  info = self.load_state_dict(weights_sd, False)
941
  return info
942
 
@@ -1017,7 +961,6 @@ class LoRANetwork(torch.nn.Module):
1017
 
1018
  return lr_weight
1019
 
1020
- # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1021
  def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1022
  self.requires_grad_(True)
1023
  all_params = []
 
5
 
6
  import math
7
  import os
8
+ from typing import List, Tuple, Union
 
 
9
  import numpy as np
10
  import torch
11
  import re
12
 
13
 
 
 
14
  RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
15
 
16
 
 
215
 
216
  def default_forward(self, x):
217
  # print("default_forward", self.lora_name, x.size())
218
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
 
 
 
 
 
 
219
 
220
  def forward(self, x):
221
  if not self.enabled:
 
400
  return down_lr_weight, mid_lr_weight, up_lr_weight
401
 
402
 
403
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
 
 
 
 
 
 
 
 
 
404
  if network_dim is None:
405
  network_dim = 4 # default
406
  if network_alpha is None:
 
719
  class LoRANetwork(torch.nn.Module):
720
  NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
721
 
722
+ # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
723
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
724
  UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
725
  TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
726
  LORA_PREFIX_UNET = "lora_unet"
727
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
728
 
 
 
 
 
729
  def __init__(
730
  self,
731
+ text_encoder,
732
  unet,
733
+ multiplier=1.0,
734
+ lora_dim=4,
735
+ alpha=1,
736
+ dropout=None,
737
+ rank_dropout=None,
738
+ module_dropout=None,
739
+ conv_lora_dim=None,
740
+ conv_alpha=None,
741
+ block_dims=None,
742
+ block_alphas=None,
743
+ conv_block_dims=None,
744
+ conv_block_alphas=None,
745
+ modules_dim=None,
746
+ modules_alpha=None,
747
+ module_class=LoRAModule,
748
+ varbose=False,
749
  ) -> None:
750
  """
751
  LoRA network: すごく引数が多いが、パターンは以下の通り
 
783
  print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
784
 
785
  # create module instances
786
+ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
787
+ prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  loras = []
789
  skipped = []
790
  for name, module in root_module.named_modules():
 
800
 
801
  dim = None
802
  alpha = None
 
803
  if modules_dim is not None:
 
804
  if lora_name in modules_dim:
805
  dim = modules_dim[lora_name]
806
  alpha = modules_alpha[lora_name]
807
  elif is_unet and block_dims is not None:
 
808
  block_idx = get_block_index(lora_name)
809
  if is_linear or is_conv2d_1x1:
810
  dim = block_dims[block_idx]
 
813
  dim = conv_block_dims[block_idx]
814
  alpha = conv_block_alphas[block_idx]
815
  else:
 
816
  if is_linear or is_conv2d_1x1:
817
  dim = self.lora_dim
818
  alpha = self.alpha
 
821
  alpha = self.conv_alpha
822
 
823
  if dim is None or dim == 0:
 
824
  if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
825
  skipped.append(lora_name)
826
  continue
 
838
  loras.append(lora)
839
  return loras, skipped
840
 
841
+ self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
843
 
844
  # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
 
846
  if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
847
  target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
848
 
849
+ self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
850
  print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
851
 
852
  skipped = skipped_te + skipped_un
 
880
  weights_sd = load_file(file)
881
  else:
882
  weights_sd = torch.load(file, map_location="cpu")
883
+
884
  info = self.load_state_dict(weights_sd, False)
885
  return info
886
 
 
961
 
962
  return lr_weight
963
 
 
964
  def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
965
  self.requires_grad_(True)
966
  all_params = []