Huage001 commited on
Commit
644b57d
1 Parent(s): 7e86263

Create linfusion.py

Browse files
Files changed (1) hide show
  1. src/linfusion/linfusion.py +116 -0
src/linfusion/linfusion.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import Attention
2
+ from diffusers import ModelMixin, ConfigMixin
3
+ import functools
4
+
5
+ from .attention import GeneralizedLinearAttention
6
+
7
+
8
+ model_dict = {
9
+ "runwayml/stable-diffusion-v1-5": "Yuanshi/LinFusion-1-5",
10
+ "stablediffusionapi/realistic-vision-v51": "Yuanshi/LinFusion-1-5",
11
+ "Lykon/dreamshaper-8": "Yuanshi/LinFusion-1-5",
12
+ }
13
+
14
+
15
+ def replace_submodule(model, module_name, new_submodule):
16
+ path, attr = module_name.rsplit(".", 1)
17
+ parent_module = functools.reduce(getattr, path.split("."), model)
18
+ setattr(parent_module, attr, new_submodule)
19
+
20
+
21
+ class LinFusion(ModelMixin, ConfigMixin):
22
+ def __init__(self, modules_list, *args, **kwargs) -> None:
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.modules_dict = {}
26
+ self.register_to_config(modules_list=modules_list)
27
+
28
+ for i, attention_config in enumerate(modules_list):
29
+ dim_n = attention_config["dim_n"]
30
+ heads = attention_config["heads"]
31
+ projection_mid_dim = attention_config["projection_mid_dim"]
32
+ linear_attention = GeneralizedLinearAttention(
33
+ query_dim=dim_n,
34
+ out_dim=dim_n,
35
+ dim_head=dim_n // heads,
36
+ projection_mid_dim=projection_mid_dim,
37
+ )
38
+ self.add_module(f"{i}", linear_attention)
39
+ self.modules_dict[attention_config["module_name"]] = linear_attention
40
+
41
+ @classmethod
42
+ def get_default_config(
43
+ cls,
44
+ pipeline=None,
45
+ unet=None,
46
+ ):
47
+ """
48
+ Get the default configuration for the LinFusion model.
49
+ (The `projection_mid_dim` is same as the `query_dim` by default.)
50
+ """
51
+ assert unet is not None or pipeline.unet is not None
52
+ unet = unet or pipeline.unet
53
+ modules_list = []
54
+ for module_name, module in unet.named_modules():
55
+ if not isinstance(module, Attention):
56
+ continue
57
+ if "attn1" not in module_name:
58
+ continue
59
+ dim_n = module.to_q.weight.shape[0]
60
+ # modules_list.append((module_name, dim_n, module.heads))
61
+ modules_list.append(
62
+ {
63
+ "module_name": module_name,
64
+ "dim_n": dim_n,
65
+ "heads": module.heads,
66
+ "projection_mid_dim": None,
67
+ }
68
+ )
69
+ return {"modules_list": modules_list}
70
+
71
+ @classmethod
72
+ def construct_for(
73
+ cls,
74
+ pipeline=None,
75
+ unet=None,
76
+ load_pretrained=True,
77
+ pretrained_model_name_or_path=None,
78
+ ) -> "LinFusion":
79
+ """
80
+ Construct a LinFusion object for the given pipeline.
81
+ """
82
+ assert unet is not None or pipeline.unet is not None
83
+ unet = unet or pipeline.unet
84
+ if load_pretrained:
85
+ # Load from pretrained
86
+ pipe_name_path = pipeline._internal_dict._name_or_path
87
+ if not pretrained_model_name_or_path:
88
+ pretrained_model_name_or_path = model_dict.get(pipe_name_path, None)
89
+ if pretrained_model_name_or_path:
90
+ print(
91
+ f"Matching LinFusion '{pretrained_model_name_or_path}' for pipeline '{pipe_name_path}'."
92
+ )
93
+ else:
94
+ raise Exception(
95
+ f"LinFusion not found for pipeline [{pipe_name_path}], please provide the path."
96
+ )
97
+ linfusion = (
98
+ LinFusion.from_pretrained(pretrained_model_name_or_path)
99
+ .to(pipeline.device)
100
+ .to(pipeline.dtype)
101
+ )
102
+ else:
103
+ # Create from scratch without pretrained parameters
104
+ default_config = LinFusion.get_default_config(pipeline)
105
+ linfusion = (
106
+ LinFusion(**default_config).to(pipeline.device).to(pipeline.dtype)
107
+ )
108
+ linfusion.mount_to(unet)
109
+ return linfusion
110
+
111
+ def mount_to(self, unet) -> None:
112
+ """
113
+ Mounts the modules in the `modules_dict` to the given `pipeline`.
114
+ """
115
+ for module_name, module in self.modules_dict.items():
116
+ replace_submodule(unet, module_name, module)