sayakpaul HF staff commited on
Commit
66dc63e
1 Parent(s): ce07462

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +103 -0
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ language:
4
+ - en
5
+ library_name: diffusers
6
+ ---
7
+ # FLUX.1-merged
8
+
9
+ This repository provides the merged params for [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev)
10
+ and [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell). Please be aware of the licenses of both
11
+ the models before using the params commercially.
12
+
13
+ <table>
14
+ <thead>
15
+ <tr>
16
+ <th>Dev (50 steps)</th>
17
+ <th>Dev (4 steps)</th>
18
+ <th>Dev + Schnell (4 steps)</th>
19
+ </tr>
20
+ </thead>
21
+ <tbody>
22
+ <tr>
23
+ <td>
24
+ <img src="./assets/flux.png" alt="Dev 50 Steps">
25
+ </td>
26
+ <td>
27
+ <img src="./assets/flux_4.png" alt="Dev 4 Steps">
28
+ </td>
29
+ <td>
30
+ <img src="./assets/merged_flux.png" alt="Dev + Schnell 4 Steps">
31
+ </td>
32
+ </tr>
33
+ </tbody>
34
+ </table>
35
+
36
+ ## Sub-memory-efficient merging code
37
+
38
+ ```python
39
+ from diffusers import FluxTransformer2DModel
40
+ from huggingface_hub import snapshot_download
41
+ from accelerate import init_empty_weights
42
+ from diffusers.models.model_loading_utils import load_model_dict_into_meta
43
+ import safetensors.torch
44
+ import glob
45
+ import torch
46
+
47
+
48
+ with init_empty_weights():
49
+ config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
50
+ model = FluxTransformer2DModel.from_config(config)
51
+
52
+ dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
53
+ schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
54
+
55
+ dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
56
+ schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
57
+
58
+ merged_state_dict = {}
59
+ guidance_state_dict = {}
60
+
61
+ for i in range(len((dev_shards))):
62
+ state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
63
+ state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
64
+
65
+ keys = list(state_dict_dev_temp.keys())
66
+ for k in keys:
67
+ if "guidance" not in k:
68
+ merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
69
+ else:
70
+ guidance_state_dict[k] = state_dict_dev_temp.pop(k)
71
+
72
+ if len(state_dict_dev_temp) > 0:
73
+ raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
74
+ if len(state_dict_schnell_temp) > 0:
75
+ raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
76
+
77
+ merged_state_dict.update(guidance_state_dict)
78
+ load_model_dict_into_meta(model, merged_state_dict)
79
+
80
+ model.to(torch.bfloat16).save_pretrained("/raid/.cache/huggingface/merged-flux")
81
+ ```
82
+
83
+ ## Inference code
84
+
85
+ ```python
86
+ from diffusers import FluxPipeline, FluxTransformer2DModel
87
+ import torch
88
+
89
+ transformer = FluxTransformer2DModel.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16)
90
+ pipeline = FluxPipeline.from_pretrained(
91
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
92
+ ).to("cuda")
93
+ image = pipeline(
94
+ prompt="a tiny astronaut hatching from an egg on the moon",
95
+ guidance_scale=3.5,
96
+ num_inference_steps=50,
97
+ height=880,
98
+ width=1184,
99
+ max_sequence_length=512,
100
+ generator=torch.manual_seed(0),
101
+ ).images[0]
102
+ image.save("merged_flux.png")
103
+ ```