Upload sd3_diffusers_transformer_to_ckpt.py
Browse files
sd3_diffusers_transformer_to_ckpt.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from safetensors.torch import load_file, save_file
|
3 |
+
|
4 |
+
def swap_scale_shift(weight, dim):
|
5 |
+
shift, scale = weight.chunk(2, dim=0)
|
6 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
7 |
+
return new_weight
|
8 |
+
|
9 |
+
file_path = "./sd3_medium.safetensors"
|
10 |
+
file_path_diff = "./diffusion_pytorch_model.safetensors"
|
11 |
+
loaded = load_file(file_path)
|
12 |
+
diffload = load_file(file_path_diff)
|
13 |
+
|
14 |
+
loaded["model.diffusion_model.context_embedder.bias"] = diffload["context_embedder.bias"]
|
15 |
+
loaded["model.diffusion_model.context_embedder.weight"] = diffload["context_embedder.weight"]
|
16 |
+
loaded["model.diffusion_model.pos_embed"] = diffload["pos_embed.pos_embed"]
|
17 |
+
loaded["model.diffusion_model.x_embedder.proj.bias"] = diffload["pos_embed.proj.bias"]
|
18 |
+
loaded["model.diffusion_model.x_embedder.proj.weight"] = diffload["pos_embed.proj.weight"]
|
19 |
+
loaded["model.diffusion_model.t_embedder.mlp.0.bias"] = diffload["time_text_embed.timestep_embedder.linear_1.bias"]
|
20 |
+
loaded["model.diffusion_model.t_embedder.mlp.0.weight"] = diffload["time_text_embed.timestep_embedder.linear_1.weight"]
|
21 |
+
loaded["model.diffusion_model.t_embedder.mlp.2.bias"] = diffload["time_text_embed.timestep_embedder.linear_2.bias"]
|
22 |
+
loaded["model.diffusion_model.t_embedder.mlp.2.weight"] = diffload["time_text_embed.timestep_embedder.linear_2.weight"]
|
23 |
+
loaded["model.diffusion_model.y_embedder.mlp.0.bias"] = diffload["time_text_embed.text_embedder.linear_1.bias"]
|
24 |
+
loaded["model.diffusion_model.y_embedder.mlp.0.weight"] = diffload["time_text_embed.text_embedder.linear_1.weight"]
|
25 |
+
loaded["model.diffusion_model.y_embedder.mlp.2.bias"] = diffload["time_text_embed.text_embedder.linear_2.bias"]
|
26 |
+
loaded["model.diffusion_model.y_embedder.mlp.2.weight"] = diffload["time_text_embed.text_embedder.linear_2.weight"]
|
27 |
+
loaded["model.diffusion_model.final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(diffload["norm_out.linear.bias"], dim = 1536)
|
28 |
+
loaded["model.diffusion_model.final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(diffload["norm_out.linear.weight"], dim = 1536)
|
29 |
+
loaded["model.diffusion_model.final_layer.linear.bias"] = diffload["proj_out.bias"]
|
30 |
+
loaded["model.diffusion_model.final_layer.linear.weight"] = diffload["proj_out.weight"]
|
31 |
+
for iii in range(0, 23) :
|
32 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.adaLN_modulation.1.bias"] = diffload["transformer_blocks."+str(iii)+".norm1_context.linear.bias"]
|
33 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.adaLN_modulation.1.weight"] = diffload["transformer_blocks."+str(iii)+".norm1_context.linear.weight"]
|
34 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.attn.proj.bias"] = diffload["transformer_blocks."+str(iii)+".attn.to_add_out.bias"]
|
35 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.attn.proj.weight"] = diffload["transformer_blocks."+str(iii)+".attn.to_add_out.weight"]
|
36 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.attn.qkv.bias"] = torch.cat((diffload["transformer_blocks."+str(iii)+".attn.add_q_proj.bias"], diffload["transformer_blocks."+str(iii)+".attn.add_k_proj.bias"], diffload["transformer_blocks."+str(iii)+".attn.add_v_proj.bias"]), dim=0)
|
37 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.attn.qkv.weight"] = torch.cat((diffload["transformer_blocks."+str(iii)+".attn.add_q_proj.weight"], diffload["transformer_blocks."+str(iii)+".attn.add_k_proj.weight"], diffload["transformer_blocks."+str(iii)+".attn.add_v_proj.weight"]), dim=0)
|
38 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".x_block.adaLN_modulation.1.bias"] = diffload["transformer_blocks."+str(iii)+".norm1.linear.bias"]
|
39 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".x_block.adaLN_modulation.1.weight"] = diffload["transformer_blocks."+str(iii)+".norm1.linear.weight"]
|
40 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".x_block.attn.proj.bias"] = diffload["transformer_blocks."+str(iii)+".attn.to_out.0.bias"]
|
41 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".x_block.attn.proj.weight"] = diffload["transformer_blocks."+str(iii)+".attn.to_out.0.weight"]
|
42 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".x_block.attn.qkv.bias"] = torch.cat((diffload["transformer_blocks."+str(iii)+".attn.to_q.bias"], diffload["transformer_blocks."+str(iii)+".attn.to_k.bias"], diffload["transformer_blocks."+str(iii)+".attn.to_v.bias"]), dim=0)
|
43 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".x_block.attn.qkv.weight"] = torch.cat((diffload["transformer_blocks."+str(iii)+".attn.to_q.weight"], diffload["transformer_blocks."+str(iii)+".attn.to_k.weight"], diffload["transformer_blocks."+str(iii)+".attn.to_v.weight"]), dim=0)
|
44 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.mlp.fc1.bias"] = diffload["transformer_blocks."+str(iii)+".ff_context.net.0.proj.bias"]
|
45 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.mlp.fc1.weight"] = diffload["transformer_blocks."+str(iii)+".ff_context.net.0.proj.weight"]
|
46 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.mlp.fc2.bias"] = diffload["transformer_blocks."+str(iii)+".ff_context.net.2.bias"]
|
47 |
+
loaded["model.diffusion_model.joint_blocks."+str(iii)+".context_block.mlp.fc2.weight"] = diffload["transformer_blocks."+str(iii)+".ff_context.net.2.weight"]
|
48 |
+
loaded["model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias"] = swap_scale_shift(diffload["transformer_blocks.23.norm1_context.linear.bias"], dim = 1536)
|
49 |
+
loaded["model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight"] = swap_scale_shift(diffload["transformer_blocks.23.norm1_context.linear.weight"], dim = 1536)
|
50 |
+
loaded["model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias"] = torch.cat((diffload["transformer_blocks.23.attn.add_q_proj.bias"], diffload["transformer_blocks.23.attn.add_k_proj.bias"], diffload["transformer_blocks.23.attn.add_v_proj.bias"]), dim=0)
|
51 |
+
loaded["model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight"] = torch.cat((diffload["transformer_blocks.23.attn.add_q_proj.weight"], diffload["transformer_blocks.23.attn.add_k_proj.weight"], diffload["transformer_blocks.23.attn.add_v_proj.weight"]), dim=0)
|
52 |
+
loaded["model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias"] = diffload["transformer_blocks.23.norm1.linear.bias"]
|
53 |
+
loaded["model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight"] = diffload["transformer_blocks.23.norm1.linear.weight"]
|
54 |
+
loaded["model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias"] = diffload["transformer_blocks.23.attn.to_out.0.bias"]
|
55 |
+
loaded["model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight"] = diffload["transformer_blocks.23.attn.to_out.0.weight"]
|
56 |
+
loaded["model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias"] = torch.cat((diffload["transformer_blocks.23.attn.to_q.bias"], diffload["transformer_blocks.23.attn.to_k.bias"], diffload["transformer_blocks.23.attn.to_v.bias"]), dim=0)
|
57 |
+
loaded["model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight"] = torch.cat((diffload["transformer_blocks.23.attn.to_q.weight"], diffload["transformer_blocks.23.attn.to_k.weight"], diffload["transformer_blocks.23.attn.to_v.weight"]), dim=0)
|
58 |
+
|
59 |
+
save_file(loaded, "sd3-reality-mix.safetensors")
|
60 |
+
# manual surgery
|