multimodalart HF staff commited on
Commit
62ee77b
·
1 Parent(s): bc47650

Remove old conversion script

Browse files
Files changed (1) hide show
  1. convertosd_ld.py +0 -226
convertosd_ld.py DELETED
@@ -1,226 +0,0 @@
1
- # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
- # *Only* converts the UNet, VAE, and Text Encoder.
3
- # Does not convert optimizer state or any other thing.
4
- # Written by jachiam
5
-
6
- import argparse
7
- import os.path as osp
8
-
9
- import torch
10
- import gc
11
-
12
- # =================#
13
- # UNet Conversion #
14
- # =================#
15
-
16
- unet_conversion_map = [
17
- # (stable-diffusion, HF Diffusers)
18
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22
- ("input_blocks.0.0.weight", "conv_in.weight"),
23
- ("input_blocks.0.0.bias", "conv_in.bias"),
24
- ("out.0.weight", "conv_norm_out.weight"),
25
- ("out.0.bias", "conv_norm_out.bias"),
26
- ("out.2.weight", "conv_out.weight"),
27
- ("out.2.bias", "conv_out.bias"),
28
- ]
29
-
30
- unet_conversion_map_resnet = [
31
- # (stable-diffusion, HF Diffusers)
32
- ("in_layers.0", "norm1"),
33
- ("in_layers.2", "conv1"),
34
- ("out_layers.0", "norm2"),
35
- ("out_layers.3", "conv2"),
36
- ("emb_layers.1", "time_emb_proj"),
37
- ("skip_connection", "conv_shortcut"),
38
- ]
39
-
40
- unet_conversion_map_layer = []
41
- # hardcoded number of downblocks and resnets/attentions...
42
- # would need smarter logic for other networks.
43
- for i in range(4):
44
- # loop over downblocks/upblocks
45
-
46
- for j in range(2):
47
- # loop over resnets/attentions for downblocks
48
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
-
52
- if i < 3:
53
- # no attention layers in down_blocks.3
54
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57
-
58
- for j in range(3):
59
- # loop over resnets/attentions for upblocks
60
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63
-
64
- if i > 0:
65
- # no attention layers in up_blocks.0
66
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69
-
70
- if i < 3:
71
- # no downsample in down_blocks.3
72
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75
-
76
- # no upsample in up_blocks.3
77
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80
-
81
- hf_mid_atn_prefix = "mid_block.attentions.0."
82
- sd_mid_atn_prefix = "middle_block.1."
83
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84
-
85
- for j in range(2):
86
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
87
- sd_mid_res_prefix = f"middle_block.{2*j}."
88
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89
-
90
-
91
- def convert_unet_state_dict(unet_state_dict):
92
- # buyer beware: this is a *brittle* function,
93
- # and correct output requires that all of these pieces interact in
94
- # the exact order in which I have arranged them.
95
- mapping = {k: k for k in unet_state_dict.keys()}
96
- for sd_name, hf_name in unet_conversion_map:
97
- mapping[hf_name] = sd_name
98
- for k, v in mapping.items():
99
- if "resnets" in k:
100
- for sd_part, hf_part in unet_conversion_map_resnet:
101
- v = v.replace(hf_part, sd_part)
102
- mapping[k] = v
103
- for k, v in mapping.items():
104
- for sd_part, hf_part in unet_conversion_map_layer:
105
- v = v.replace(hf_part, sd_part)
106
- mapping[k] = v
107
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108
- return new_state_dict
109
-
110
-
111
- # ================#
112
- # VAE Conversion #
113
- # ================#
114
-
115
- vae_conversion_map = [
116
- # (stable-diffusion, HF Diffusers)
117
- ("nin_shortcut", "conv_shortcut"),
118
- ("norm_out", "conv_norm_out"),
119
- ("mid.attn_1.", "mid_block.attentions.0."),
120
- ]
121
-
122
- for i in range(4):
123
- # down_blocks have two resnets
124
- for j in range(2):
125
- hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126
- sd_down_prefix = f"encoder.down.{i}.block.{j}."
127
- vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128
-
129
- if i < 3:
130
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131
- sd_downsample_prefix = f"down.{i}.downsample."
132
- vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133
-
134
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135
- sd_upsample_prefix = f"up.{3-i}.upsample."
136
- vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137
-
138
- # up_blocks have three resnets
139
- # also, up blocks in hf are numbered in reverse from sd
140
- for j in range(3):
141
- hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143
- vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144
-
145
- # this part accounts for mid blocks in both the encoder and the decoder
146
- for i in range(2):
147
- hf_mid_res_prefix = f"mid_block.resnets.{i}."
148
- sd_mid_res_prefix = f"mid.block_{i+1}."
149
- vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150
-
151
-
152
- vae_conversion_map_attn = [
153
- # (stable-diffusion, HF Diffusers)
154
- ("norm.", "group_norm."),
155
- ("q.", "query."),
156
- ("k.", "key."),
157
- ("v.", "value."),
158
- ("proj_out.", "proj_attn."),
159
- ]
160
-
161
-
162
- def reshape_weight_for_sd(w):
163
- # convert HF linear weights to SD conv2d weights
164
- return w.reshape(*w.shape, 1, 1)
165
-
166
-
167
- def convert_vae_state_dict(vae_state_dict):
168
- mapping = {k: k for k in vae_state_dict.keys()}
169
- for k, v in mapping.items():
170
- for sd_part, hf_part in vae_conversion_map:
171
- v = v.replace(hf_part, sd_part)
172
- mapping[k] = v
173
- for k, v in mapping.items():
174
- if "attentions" in k:
175
- for sd_part, hf_part in vae_conversion_map_attn:
176
- v = v.replace(hf_part, sd_part)
177
- mapping[k] = v
178
- new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
- weights_to_convert = ["q", "k", "v", "proj_out"]
180
- print("Converting to CKPT ...")
181
- for k, v in new_state_dict.items():
182
- for weight_name in weights_to_convert:
183
- if f"mid.attn_1.{weight_name}.weight" in k:
184
- new_state_dict[k] = reshape_weight_for_sd(v)
185
- return new_state_dict
186
-
187
-
188
- # =========================#
189
- # Text Encoder Conversion #
190
- # =========================#
191
- # pretty much a no-op
192
-
193
-
194
- def convert_text_enc_state_dict(text_enc_dict):
195
- return text_enc_dict
196
-
197
-
198
- def convert(model_path, checkpoint_path):
199
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
200
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
201
- text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
202
-
203
- # Convert the UNet model
204
- unet_state_dict = torch.load(unet_path, map_location='cpu')
205
- unet_state_dict = convert_unet_state_dict(unet_state_dict)
206
- unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
207
-
208
- # Convert the VAE model
209
- vae_state_dict = torch.load(vae_path, map_location='cpu')
210
- vae_state_dict = convert_vae_state_dict(vae_state_dict)
211
- vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
212
-
213
- # Convert the text encoder model
214
- text_enc_dict = torch.load(text_enc_path, map_location='cpu')
215
- text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
216
- text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
217
-
218
- # Put together new checkpoint
219
- state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
220
-
221
- state_dict = {k:v.half() for k,v in state_dict.items()}
222
- state_dict = {"state_dict": state_dict}
223
- torch.save(state_dict, checkpoint_path)
224
- del state_dict, text_enc_dict, vae_state_dict, unet_state_dict
225
- torch.cuda.empty_cache()
226
- gc.collect()