SunderAli17 commited on
Commit
5d18928
1 Parent(s): e3a081c

Create util.py

Browse files
Files changed (1) hide show
  1. flux/util.py +156 -0
flux/util.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file as load_sft
8
+
9
+ from flux.model import Flux, FluxParams
10
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
11
+ from flux.modules.conditioner import HFEmbedder
12
+
13
+
14
+ @dataclass
15
+ class ModelSpec:
16
+ params: FluxParams
17
+ ae_params: AutoEncoderParams
18
+ ckpt_path: str
19
+ ae_path: str
20
+ repo_id: str
21
+ repo_flow: str
22
+ repo_ae: str
23
+
24
+
25
+ configs = {
26
+ "flux-dev": ModelSpec(
27
+ repo_id="black-forest-labs/FLUX.1-dev",
28
+ repo_flow="flux1-dev.safetensors",
29
+ repo_ae="ae.safetensors",
30
+ ckpt_path='models/flux1-dev.safetensors',
31
+ params=FluxParams(
32
+ in_channels=64,
33
+ vec_in_dim=768,
34
+ context_in_dim=4096,
35
+ hidden_size=3072,
36
+ mlp_ratio=4.0,
37
+ num_heads=24,
38
+ depth=19,
39
+ depth_single_blocks=38,
40
+ axes_dim=[16, 56, 56],
41
+ theta=10_000,
42
+ qkv_bias=True,
43
+ guidance_embed=True,
44
+ ),
45
+ ae_path='models/ae.safetensors',
46
+ ae_params=AutoEncoderParams(
47
+ resolution=256,
48
+ in_channels=3,
49
+ ch=128,
50
+ out_ch=3,
51
+ ch_mult=[1, 2, 4, 4],
52
+ num_res_blocks=2,
53
+ z_channels=16,
54
+ scale_factor=0.3611,
55
+ shift_factor=0.1159,
56
+ ),
57
+ ),
58
+ "flux-schnell": ModelSpec(
59
+ repo_id="black-forest-labs/FLUX.1-schnell",
60
+ repo_flow="flux1-schnell.safetensors",
61
+ repo_ae="ae.safetensors",
62
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
63
+ params=FluxParams(
64
+ in_channels=64,
65
+ vec_in_dim=768,
66
+ context_in_dim=4096,
67
+ hidden_size=3072,
68
+ mlp_ratio=4.0,
69
+ num_heads=24,
70
+ depth=19,
71
+ depth_single_blocks=38,
72
+ axes_dim=[16, 56, 56],
73
+ theta=10_000,
74
+ qkv_bias=True,
75
+ guidance_embed=False,
76
+ ),
77
+ ae_path=os.getenv("AE"),
78
+ ae_params=AutoEncoderParams(
79
+ resolution=256,
80
+ in_channels=3,
81
+ ch=128,
82
+ out_ch=3,
83
+ ch_mult=[1, 2, 4, 4],
84
+ num_res_blocks=2,
85
+ z_channels=16,
86
+ scale_factor=0.3611,
87
+ shift_factor=0.1159,
88
+ ),
89
+ ),
90
+ }
91
+
92
+
93
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
94
+ if len(missing) > 0 and len(unexpected) > 0:
95
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
96
+ print("\n" + "-" * 79 + "\n")
97
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
98
+ elif len(missing) > 0:
99
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
100
+ elif len(unexpected) > 0:
101
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
102
+
103
+
104
+ def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
105
+ # Loading Flux
106
+ print("Init model")
107
+ ckpt_path = configs[name].ckpt_path
108
+ if (
109
+ not os.path.exists(ckpt_path)
110
+ and configs[name].repo_id is not None
111
+ and configs[name].repo_flow is not None
112
+ and hf_download
113
+ ):
114
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
115
+
116
+ with torch.device(device):
117
+ model = Flux(configs[name].params).to(torch.bfloat16)
118
+
119
+ if ckpt_path is not None:
120
+ print("Loading checkpoint")
121
+ # load_sft doesn't support torch.device
122
+ sd = load_sft(ckpt_path, device=str(device))
123
+ missing, unexpected = model.load_state_dict(sd, strict=False)
124
+ print_load_warning(missing, unexpected)
125
+ return model
126
+
127
+
128
+ def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
129
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
130
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
131
+
132
+
133
+ def load_clip(device: str = "cuda") -> HFEmbedder:
134
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
135
+
136
+
137
+ def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
138
+ ckpt_path = configs[name].ae_path
139
+ if (
140
+ not os.path.exists(ckpt_path)
141
+ and configs[name].repo_id is not None
142
+ and configs[name].repo_ae is not None
143
+ and hf_download
144
+ ):
145
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
146
+
147
+ # Loading the autoencoder
148
+ print("Init AE")
149
+ with torch.device(device):
150
+ ae = AutoEncoder(configs[name].ae_params)
151
+
152
+ if ckpt_path is not None:
153
+ sd = load_sft(ckpt_path, device=str(device))
154
+ missing, unexpected = ae.load_state_dict(sd, strict=False)
155
+ print_load_warning(missing, unexpected)
156
+ return ae