Spaces:
Runtime error
Runtime error
File size: 4,550 Bytes
5a64744 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import argparse
from pathlib import Path
from typing import Dict
import safetensors.torch
import torch
import json
import shutil
def load_text_encoder(index_path: Path) -> Dict:
with open(index_path, "r") as f:
index: Dict = json.load(f)
loaded_tensors = {}
for part_file in set(index.get("weight_map", {}).values()):
tensors = safetensors.torch.load_file(
index_path.parent / part_file, device="cpu"
)
for tensor_name in tensors:
loaded_tensors[tensor_name] = tensors[tensor_name]
return loaded_tensors
def convert_unet(unet: Dict, add_prefix=True) -> Dict:
if add_prefix:
return {"model.diffusion_model." + key: value for key, value in unet.items()}
return unet
def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
stats_path = vae_path / "per_channel_statistics.json"
if stats_path.exists():
with open(stats_path, "r") as f:
data = json.load(f)
transposed_data = list(zip(*data["data"]))
data_dict = {
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
vals
)
for col, vals in zip(data["columns"], transposed_data)
}
else:
data_dict = {}
result = {
("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
}
result.update(data_dict)
return result
def convert_encoder(encoder: Dict) -> Dict:
return {
"text_encoders.t5xxl.transformer." + key: value
for key, value in encoder.items()
}
def save_config(config_src: str, config_dst: str):
shutil.copy(config_src, config_dst)
def load_vae_config(vae_path: Path) -> str:
config_path = vae_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"VAE config file {config_path} not found.")
return str(config_path)
def main(
unet_path: str,
vae_path: str,
out_path: str,
mode: str,
unet_config_path: str = None,
scheduler_config_path: str = None,
) -> None:
unet = convert_unet(
torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
)
# Load VAE from directory and config
vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
vae_config_path = load_vae_config(Path(vae_path))
if mode == "single":
result = {**unet, **vae}
safetensors.torch.save_file(result, out_path)
elif mode == "separate":
# Create directories for unet, vae, and scheduler
unet_dir = Path(out_path) / "unet"
vae_dir = Path(out_path) / "vae"
scheduler_dir = Path(out_path) / "scheduler"
unet_dir.mkdir(parents=True, exist_ok=True)
vae_dir.mkdir(parents=True, exist_ok=True)
scheduler_dir.mkdir(parents=True, exist_ok=True)
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
safetensors.torch.save_file(
unet, unet_dir / "unet_diffusion_pytorch_model.safetensors"
)
safetensors.torch.save_file(
vae, vae_dir / "vae_diffusion_pytorch_model.safetensors"
)
# Save config files for unet, vae, and scheduler
if unet_config_path:
save_config(unet_config_path, unet_dir / "config.json")
if vae_config_path:
save_config(vae_config_path, vae_dir / "config.json")
if scheduler_config_path:
save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
parser.add_argument("--vae_path", "-v", type=str, default="vae/")
parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
parser.add_argument(
"--mode",
"-m",
type=str,
choices=["single", "separate"],
default="single",
help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
)
parser.add_argument(
"--unet_config_path",
type=str,
help="Path to the UNet config file (for separate mode)",
)
parser.add_argument(
"--scheduler_config_path",
type=str,
help="Path to the Scheduler config file (for separate mode)",
)
args = parser.parse_args()
main(**args.__dict__)
|