John6666 commited on
Commit
d0ab6a4
1 Parent(s): a6a09d9

Upload 9 files

Browse files
Files changed (9) hide show
  1. README.md +3 -2
  2. app.py +220 -109
  3. convert_url_to_diffusers_multi_gr.py +463 -0
  4. packages.txt +1 -0
  5. presets.py +134 -0
  6. requirements.txt +8 -5
  7. sdutils.py +170 -0
  8. stkey.py +122 -0
  9. utils.py +70 -42
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Download and Convert SDXL To Diffusers V2
3
  emoji: 🎨➡️🧨
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Download safetensors and convert to HF🤗 Diffusers format (SDXL / SD 1.5 / FLUX.1 / SD 3.5) Alpha
3
  emoji: 🎨➡️🧨
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Convert SDXL/1.5/3.5/FLUX.1 safetensors to HF🤗 Diffusers
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,54 +1,16 @@
1
  import gradio as gr
2
- from convert_url_to_diffusers_sdxl_gr import (
3
- convert_url_to_diffusers_repo,
4
- SCHEDULER_CONFIG_MAP,
5
- )
 
6
 
7
- vaes = [
8
- "",
9
- "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
10
- "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
11
- "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_e7.safetensors",
12
- "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_f1.safetensors",
13
- ]
14
- clips = [
15
- "",
16
- "openai/clip-vit-large-patch14",
17
- ]
18
- loras = [
19
- "",
20
- "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
21
- "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_2step_converted.safetensors",
22
- "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_4step_converted.safetensors",
23
- "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_8step_converted.safetensors",
24
- "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_8step_converted.safetensors",
25
- "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_16step_converted.safetensors",
26
- "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-1step-lora.safetensors",
27
- "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-2steps-lora.safetensors",
28
- "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-4steps-lora.safetensors",
29
- "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors",
30
- "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-12steps-CFG-lora.safetensors",
31
- "https://huggingface.co/latent-consistency/lcm-lora-sdxl/blob/main/pytorch_lora_weights.safetensors",
32
- ]
33
- schedulers = list(SCHEDULER_CONFIG_MAP.keys())
34
-
35
- preset_dict = {
36
- "Default": [True, "", "Euler a", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
37
- "Bake in standard VAE": [True, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
38
- "Euler a", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
39
- "Hyper-SDXL / SPO": [True, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
40
- "TCD", "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors", 1.0,
41
- "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
42
- 1.0, "", 1.0, "", 1.0, "", 1.0],
43
- }
44
-
45
-
46
- def set_presets(preset: str="Default"):
47
- p = []
48
- if preset in preset_dict.keys(): p = preset_dict[preset]
49
- else: p = preset_dict["Default"]
50
- return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12]
51
 
 
 
 
 
 
52
 
53
  css = """
54
  .title { font-size: 3em; align-items: center; text-align: center; }
@@ -57,82 +19,231 @@ css = """
57
  """
58
 
59
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
60
- gr.Markdown("# Download and convert any Stable Diffusion XL safetensors to Diffusers and create your repo", elem_classes="title")
61
- gr.Markdown(
62
- f"""
63
- - [A CLI version of this tool (without uploading-related function) is available here](https://huggingface.co/spaces/John6666/sdxl-to-diffusers-v2/tree/main/local).
64
-
65
- **⚠️IMPORTANT NOTICE⚠️**<br>
66
- From an information security standpoint, it is dangerous to expose your access token or key to others.
67
- If you do use it, I recommend that you duplicate this space on your own account before doing so.
68
- Keys and tokens could be set to SECRET (HF_TOKEN, CIVITAI_API_KEY) if it's placed in your own space.
69
  It saves you the trouble of typing them in.<br>
70
- <br>
71
- **The steps are the following**:
72
- - Paste a write-access token from [hf.co/settings/tokens](https://huggingface.co/settings/tokens).
73
- - Input a model download url from the Hub or Civitai or other sites.
74
- - If you want to download a model from Civitai, paste a Civitai API Key.
75
- - Input your HF user ID. e.g. 'yourid'.
76
- - Input your new repo name. If empty, auto-complete. e.g. 'newrepo'.
77
- - Set the parameters. If not sure, just use the defaults.
78
- - Click "Submit".
79
- - Patiently wait until the output changes. It takes approximately 2 to 3 minutes (downloading from HF).
80
- """
81
- )
82
  with gr.Column():
 
 
83
  with gr.Group():
84
- dl_url = gr.Textbox(label="URL to download", placeholder="https://huggingface.co/bluepen5805/blue_pencil-XL/blob/main/blue_pencil-XL-v7.0.0.safetensors", value="", max_lines=1)
85
  with gr.Row():
86
- hf_user = gr.Textbox(label="Your HF user ID", placeholder="username", value="", max_lines=1)
87
- hf_repo = gr.Textbox(label="New repo name", placeholder="reponame", info="If empty, auto-complete", value="", max_lines=1)
88
- with gr.Row():
89
- hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
90
- civitai_key = gr.Textbox(label="Your Civitai API Key (Optional)", info="If you download model from Civitai...", placeholder="", value="", max_lines=1)
 
 
 
 
91
  with gr.Row():
92
  is_upload_sf = gr.Checkbox(label="Upload single safetensors file into new repo", value=False)
93
  is_private = gr.Checkbox(label="Create private repo", value=True)
94
- is_overwrite = gr.Checkbox(label="Overwrite repo", value=False)
95
- presets = gr.Radio(label="Presets", choices=list(preset_dict.keys()), value="Default")
96
- with gr.Accordion("Advanced settings", open=False):
97
- dtype = gr.Radio(label="Output data type", choices=["fp16", "fp32", "bf16", "fp8", "default"], value="fp16")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  with gr.Row():
99
- vae = gr.Dropdown(label="VAE", choices=vaes, value="", allow_custom_value=True)
100
- clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
101
- scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler a")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  with gr.Row():
103
- with gr.Column():
104
- lora1 = gr.Dropdown(label="LoRA1", choices=loras, value="", allow_custom_value=True, min_width=320)
105
- lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
106
- with gr.Column():
107
- lora2 = gr.Dropdown(label="LoRA2", choices=loras, value="", allow_custom_value=True, min_width=320)
108
- lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
109
- with gr.Column():
110
- lora3 = gr.Dropdown(label="LoRA3", choices=loras, value="", allow_custom_value=True, min_width=320)
111
- lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
112
- with gr.Column():
113
- lora4 = gr.Dropdown(label="LoRA4", choices=loras, value="", allow_custom_value=True, min_width=320)
114
- lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
115
- with gr.Column():
116
- lora5 = gr.Dropdown(label="LoRA5", choices=loras, value="", allow_custom_value=True, min_width=320)
117
- lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
118
- run_button = gr.Button(value="Submit")
119
- repo_urls = gr.CheckboxGroup(visible=False, choices=[], value=None)
120
- output_md = gr.Markdown(label="Output", value="<br><br>", elem_classes="result")
121
- gr.DuplicateButton(value="Duplicate Space")
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  gr.on(
124
- triggers=[run_button.click],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  fn=convert_url_to_diffusers_repo,
126
- inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_upload_sf, repo_urls, dtype, vae, clip, scheduler,
127
- lora1, lora1s, lora2, lora2s, lora3, lora3s, lora4, lora4s, lora5, lora5s],
 
128
  outputs=[repo_urls, output_md],
129
  )
130
- presets.change(
131
- fn=set_presets,
132
- inputs=[presets],
133
- outputs=[dtype, vae, scheduler, lora1, lora1s, lora2, lora2s, lora3, lora3s, lora4, lora4s, lora5, lora5s],
 
134
  queue=False,
135
  )
 
136
 
137
  demo.queue()
138
- demo.launch()
 
1
  import gradio as gr
2
+ from convert_url_to_diffusers_multi_gr import convert_url_to_diffusers_repo, get_dtypes, FLUX_BASE_REPOS, SD35_BASE_REPOS
3
+ from presets import (DEFAULT_DTYPE, schedulers, clips, t5s, sdxl_vaes, sdxl_loras, sdxl_preset_dict, sdxl_set_presets,
4
+ sd15_vaes, sd15_loras, sd15_preset_dict, sd15_set_presets, flux_vaes, flux_loras, flux_preset_dict, flux_set_presets,
5
+ sd35_vaes, sd35_loras, sd35_preset_dict, sd35_set_presets)
6
+ import os
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ HF_USER = os.getenv("HF_USER", "")
10
+ HF_REPO = os.getenv("HF_REPO", "")
11
+ HF_URL = os.getenv("HF_URL", "")
12
+ HF_OW = os.getenv("HF_OW", False)
13
+ HF_PR = os.getenv("HF_PR", False)
14
 
15
  css = """
16
  .title { font-size: 3em; align-items: center; text-align: center; }
 
19
  """
20
 
21
  with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
22
+ gr.Markdown("# Download SDXL / SD 1.5 / SD 3.5 / FLUX.1 safetensors and convert to HF🤗 Diffusers format and create your repo", elem_classes="title")
23
+ gr.Markdown(f"""
24
+ ### ⚠️IMPORTANT NOTICE⚠️<br>
25
+ It's dangerous to expose your access token or key to others.
26
+ If you do use it, I recommend that you duplicate this space on your own HF account in advance.
27
+ Keys and tokens could be set to **Secrets** (`HF_TOKEN`, `CIVITAI_API_KEY`) if it's placed in your own space.
 
 
 
28
  It saves you the trouble of typing them in.<br>
29
+ It barely works in the CPU space, but larger files can be converted if duplicated on the more powerful **Zero GPU** space.
30
+ In particular, conversion of FLUX.1 or SD 3.5 is almost impossible in CPU space.
31
+ ### The steps are the following:
32
+ 1. Paste a write-access token from [hf.co/settings/tokens](https://huggingface.co/settings/tokens).
33
+ 1. Input a model download url of the Hugging Face or Civitai or other sites.
34
+ 1. If you want to download a model from Civitai, paste a Civitai API Key.
35
+ 1. Input your HF user ID. e.g. 'yourid'.
36
+ 1. Input your new repo name. If empty, auto-complete. e.g. 'newrepo'.
37
+ 1. Set the parameters. If not sure, just use the defaults.
38
+ 1. Click "Submit".
39
+ 1. Patiently wait until the output changes. It takes approximately 2 to 3 minutes (on SDXL models downloading from HF).
40
+ """)
41
  with gr.Column():
42
+ dl_url = gr.Textbox(label="URL to download", placeholder="https://huggingface.co/bluepen5805/blue_pencil-XL/blob/main/blue_pencil-XL-v7.0.0.safetensors",
43
+ value=HF_URL, max_lines=1)
44
  with gr.Group():
 
45
  with gr.Row():
46
+ hf_user = gr.Textbox(label="Your HF user ID", placeholder="username", value=HF_USER, max_lines=1)
47
+ hf_repo = gr.Textbox(label="New repo name", placeholder="reponame", info="If empty, auto-complete", value=HF_REPO, max_lines=1)
48
+ with gr.Row(equal_height=True):
49
+ with gr.Column():
50
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
51
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).", elem_classes="info")
52
+ with gr.Column():
53
+ civitai_key = gr.Textbox(label="Your Civitai API Key (Optional)", info="If you download model from Civitai...", placeholder="", value="", max_lines=1)
54
+ gr.Markdown("Your Civitai API key is available at [https://civitai.com/user/account](https://civitai.com/user/account).", elem_classes="info")
55
  with gr.Row():
56
  is_upload_sf = gr.Checkbox(label="Upload single safetensors file into new repo", value=False)
57
  is_private = gr.Checkbox(label="Create private repo", value=True)
58
+ is_overwrite = gr.Checkbox(label="Overwrite repo", value=HF_OW)
59
+ is_pr = gr.Checkbox(label="Create PR", value=HF_PR)
60
+ with gr.Tab("SDXL"):
61
+ with gr.Group():
62
+ sdxl_presets = gr.Radio(label="Presets", choices=list(sdxl_preset_dict.keys()), value=list(sdxl_preset_dict.keys())[0])
63
+ sdxl_mtype = gr.Textbox(value="SDXL", visible=False)
64
+ with gr.Row():
65
+ sdxl_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value=DEFAULT_DTYPE)
66
+ with gr.Accordion("Advanced settings", open=False):
67
+ with gr.Row():
68
+ sdxl_vae = gr.Dropdown(label="VAE", choices=sdxl_vaes, value="", allow_custom_value=True)
69
+ sdxl_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
70
+ sdxl_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler a")
71
+ with gr.Row():
72
+ with gr.Column():
73
+ sdxl_lora1 = gr.Dropdown(label="LoRA1", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
74
+ sdxl_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
75
+ with gr.Column():
76
+ sdxl_lora2 = gr.Dropdown(label="LoRA2", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
77
+ sdxl_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
78
+ with gr.Column():
79
+ sdxl_lora3 = gr.Dropdown(label="LoRA3", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
80
+ sdxl_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
81
+ with gr.Column():
82
+ sdxl_lora4 = gr.Dropdown(label="LoRA4", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
83
+ sdxl_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
84
+ with gr.Column():
85
+ sdxl_lora5 = gr.Dropdown(label="LoRA5", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
86
+ sdxl_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
87
+ sdxl_run_button = gr.Button(value="Submit", variant="primary")
88
+ with gr.Tab("SD 1.5"):
89
+ with gr.Group():
90
+ sd15_presets = gr.Radio(label="Presets", choices=list(sd15_preset_dict.keys()), value=list(sd15_preset_dict.keys())[0])
91
+ sd15_mtype = gr.Textbox(value="SD 1.5", visible=False)
92
+ with gr.Row():
93
+ sd15_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value=DEFAULT_DTYPE)
94
+ sd15_ema = gr.Checkbox(label="Extract EMA", value=True, visible=True)
95
+ sd15_isize = gr.Radio(label="Image size", choices=["768", "512"], value="768")
96
+ sd15_sc = gr.Checkbox(label="Safety checker", value=False)
97
+ with gr.Accordion("Advanced settings", open=False):
98
+ with gr.Row():
99
+ sd15_vae = gr.Dropdown(label="VAE", choices=sd15_vaes, value="", allow_custom_value=True)
100
+ sd15_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
101
+ sd15_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler")
102
+ with gr.Row():
103
+ with gr.Column():
104
+ sd15_lora1 = gr.Dropdown(label="LoRA1", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
105
+ sd15_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
106
+ with gr.Column():
107
+ sd15_lora2 = gr.Dropdown(label="LoRA2", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
108
+ sd15_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
109
+ with gr.Column():
110
+ sd15_lora3 = gr.Dropdown(label="LoRA3", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
111
+ sd15_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
112
+ with gr.Column():
113
+ sd15_lora4 = gr.Dropdown(label="LoRA4", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
114
+ sd15_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
115
+ with gr.Column():
116
+ sd15_lora5 = gr.Dropdown(label="LoRA5", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
117
+ sd15_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
118
+ sd15_run_button = gr.Button(value="Submit", variant="primary")
119
+ with gr.Tab("FLUX.1"):
120
+ with gr.Group():
121
+ flux_presets = gr.Radio(label="Presets", choices=list(flux_preset_dict.keys()), value=list(flux_preset_dict.keys())[0])
122
+ flux_mtype = gr.Textbox(value="FLUX", visible=False)
123
  with gr.Row():
124
+ flux_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value="bf16")
125
+ flux_base_repo = gr.Dropdown(label="Base repo ID", choices=FLUX_BASE_REPOS, value=FLUX_BASE_REPOS[0], allow_custom_value=True, visible=True)
126
+ with gr.Accordion("Advanced settings", open=False):
127
+ with gr.Row():
128
+ flux_vae = gr.Dropdown(label="VAE", choices=flux_vaes, value="", allow_custom_value=True)
129
+ flux_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
130
+ flux_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True)
131
+ flux_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=[""], value="", visible=False)
132
+ with gr.Row():
133
+ with gr.Column():
134
+ flux_lora1 = gr.Dropdown(label="LoRA1", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
135
+ flux_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
136
+ with gr.Column():
137
+ flux_lora2 = gr.Dropdown(label="LoRA2", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
138
+ flux_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
139
+ with gr.Column():
140
+ flux_lora3 = gr.Dropdown(label="LoRA3", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
141
+ flux_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
142
+ with gr.Column():
143
+ flux_lora4 = gr.Dropdown(label="LoRA4", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
144
+ flux_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
145
+ with gr.Column():
146
+ flux_lora5 = gr.Dropdown(label="LoRA5", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
147
+ flux_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
148
+ flux_run_button = gr.Button(value="Submit", variant="primary")
149
+ with gr.Tab("SD 3.5"):
150
+ with gr.Group():
151
+ sd35_presets = gr.Radio(label="Presets", choices=list(sd35_preset_dict.keys()), value=list(sd35_preset_dict.keys())[0])
152
+ sd35_mtype = gr.Textbox(value="SD 3.5", visible=False)
153
  with gr.Row():
154
+ sd35_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value="bf16")
155
+ sd35_base_repo = gr.Dropdown(label="Base repo ID", choices=SD35_BASE_REPOS, value=SD35_BASE_REPOS[0], allow_custom_value=True, visible=True)
156
+ with gr.Accordion("Advanced settings", open=False):
157
+ with gr.Row():
158
+ sd35_vae = gr.Dropdown(label="VAE", choices=sd35_vaes, value="", allow_custom_value=True)
159
+ sd35_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
160
+ sd35_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True)
161
+ sd35_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=[""], value="", visible=False)
162
+ with gr.Row():
163
+ with gr.Column():
164
+ sd35_lora1 = gr.Dropdown(label="LoRA1", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
165
+ sd35_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
166
+ with gr.Column():
167
+ sd35_lora2 = gr.Dropdown(label="LoRA2", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
168
+ sd35_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
169
+ with gr.Column():
170
+ sd35_lora3 = gr.Dropdown(label="LoRA3", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
171
+ sd35_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
172
+ with gr.Column():
173
+ sd35_lora4 = gr.Dropdown(label="LoRA4", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
174
+ sd35_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
175
+ with gr.Column():
176
+ sd35_lora5 = gr.Dropdown(label="LoRA5", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
177
+ sd35_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
178
+ sd35_run_button = gr.Button(value="Submit", variant="primary")
179
+ adv_args = gr.Textbox(label="Advanced arguments", value="", visible=False)
180
+ with gr.Group():
181
+ repo_urls = gr.CheckboxGroup(visible=False, choices=[], value=[])
182
+ output_md = gr.Markdown(label="Output", value="<br><br>", elem_classes="result")
183
+ clear_button = gr.Button(value="Clear Output", variant="secondary")
184
+ gr.DuplicateButton(value="Duplicate Space")
185
 
186
  gr.on(
187
+ triggers=[sdxl_run_button.click],
188
+ fn=convert_url_to_diffusers_repo,
189
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_pr, is_upload_sf, repo_urls,
190
+ sdxl_dtype, sdxl_vae, sdxl_clip, flux_t5, sdxl_scheduler, sd15_ema, sd15_isize, sd15_sc, flux_base_repo, sdxl_mtype,
191
+ sdxl_lora1, sdxl_lora1s, sdxl_lora2, sdxl_lora2s, sdxl_lora3, sdxl_lora3s, sdxl_lora4, sdxl_lora4s, sdxl_lora5, sdxl_lora5s, adv_args],
192
+ outputs=[repo_urls, output_md],
193
+ )
194
+ sdxl_presets.change(
195
+ fn=sdxl_set_presets,
196
+ inputs=[sdxl_presets],
197
+ outputs=[sdxl_dtype, sdxl_vae, sdxl_scheduler, sdxl_lora1, sdxl_lora1s, sdxl_lora2, sdxl_lora2s, sdxl_lora3, sdxl_lora3s,
198
+ sdxl_lora4, sdxl_lora4s, sdxl_lora5, sdxl_lora5s],
199
+ queue=False,
200
+ )
201
+ gr.on(
202
+ triggers=[sd15_run_button.click],
203
+ fn=convert_url_to_diffusers_repo,
204
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_pr, is_upload_sf, repo_urls,
205
+ sd15_dtype, sd15_vae, sd15_clip, flux_t5, sd15_scheduler, sd15_ema, sd15_isize, sd15_sc, flux_base_repo, sd15_mtype,
206
+ sd15_lora1, sd15_lora1s, sd15_lora2, sd15_lora2s, sd15_lora3, sd15_lora3s, sd15_lora4, sd15_lora4s, sd15_lora5, sd15_lora5s, adv_args],
207
+ outputs=[repo_urls, output_md],
208
+ )
209
+ sd15_presets.change(
210
+ fn=sd15_set_presets,
211
+ inputs=[sd15_presets],
212
+ outputs=[sd15_dtype, sd15_vae, sd15_scheduler, sd15_lora1, sd15_lora1s, sd15_lora2, sd15_lora2s, sd15_lora3, sd15_lora3s,
213
+ sd15_lora4, sd15_lora4s, sd15_lora5, sd15_lora5s, sd15_ema],
214
+ queue=False,
215
+ )
216
+ gr.on(
217
+ triggers=[flux_run_button.click],
218
+ fn=convert_url_to_diffusers_repo,
219
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_pr, is_upload_sf, repo_urls,
220
+ flux_dtype, flux_vae, flux_clip, flux_t5, flux_scheduler, sd15_ema, sd15_isize, sd15_sc, flux_base_repo, flux_mtype,
221
+ flux_lora1, flux_lora1s, flux_lora2, flux_lora2s, flux_lora3, flux_lora3s, flux_lora4, flux_lora4s, flux_lora5, flux_lora5s, adv_args],
222
+ outputs=[repo_urls, output_md],
223
+ )
224
+ flux_presets.change(
225
+ fn=flux_set_presets,
226
+ inputs=[flux_presets],
227
+ outputs=[flux_dtype, flux_vae, flux_scheduler, flux_lora1, flux_lora1s, flux_lora2, flux_lora2s, flux_lora3, flux_lora3s,
228
+ flux_lora4, flux_lora4s, flux_lora5, flux_lora5s, flux_base_repo],
229
+ queue=False,
230
+ )
231
+ gr.on(
232
+ triggers=[sd35_run_button.click],
233
  fn=convert_url_to_diffusers_repo,
234
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_pr, is_upload_sf, repo_urls,
235
+ sd35_dtype, sd35_vae, sd35_clip, sd35_t5, sd35_scheduler, sd15_ema, sd15_isize, sd15_sc, sd35_base_repo, sd35_mtype,
236
+ sd35_lora1, sd35_lora1s, sd35_lora2, sd35_lora2s, sd35_lora3, sd35_lora3s, sd35_lora4, sd35_lora4s, sd35_lora5, sd35_lora5s, adv_args],
237
  outputs=[repo_urls, output_md],
238
  )
239
+ sd35_presets.change(
240
+ fn=sd35_set_presets,
241
+ inputs=[sd35_presets],
242
+ outputs=[sd35_dtype, sd35_vae, sd35_scheduler, sd35_lora1, sd35_lora1s, sd35_lora2, sd35_lora2s, sd35_lora3, sd35_lora3s,
243
+ sd35_lora4, sd35_lora4s, sd35_lora5, sd35_lora5s, sd35_base_repo],
244
  queue=False,
245
  )
246
+ clear_button.click(lambda: ([], "<br><br>"), None, [repo_urls, output_md], queue=False, show_api=False)
247
 
248
  demo.queue()
249
+ demo.launch(ssr_mode=False)
convert_url_to_diffusers_multi_gr.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import argparse
4
+ from pathlib import Path
5
+ import os
6
+ import torch
7
+ from diffusers import (DiffusionPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline,
8
+ FluxPipeline, FluxTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline)
9
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPFeatureExtractor, AutoTokenizer, T5EncoderModel, BitsAndBytesConfig as TFBitsAndBytesConfig
10
+ from huggingface_hub import save_torch_state_dict, snapshot_download
11
+ from diffusers.loaders.single_file_utils import (convert_flux_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers,
12
+ convert_sd3_t5_checkpoint_to_diffusers)
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
14
+ import safetensors.torch
15
+ import gradio as gr
16
+ import shutil
17
+ import gc
18
+ import tempfile
19
+ # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
20
+ from utils import (get_token, set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo)
21
+ from sdutils import (SCHEDULER_CONFIG_MAP, get_scheduler_config, fuse_loras, DTYPE_DEFAULT, get_dtype, get_dtypes, get_model_type_from_key, get_process_dtype)
22
+
23
+
24
+ @spaces.GPU
25
+ def fake_gpu():
26
+ pass
27
+
28
+
29
+ try:
30
+ from diffusers import BitsAndBytesConfig
31
+ is_nf4 = True
32
+ except Exception:
33
+ is_nf4 = False
34
+
35
+
36
+ FLUX_BASE_REPOS = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell", "John6666/flux1-dev-fp8-flux", "John6666/flux1-schnell-fp8-flux"]
37
+ FLUX_T5_URL = "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors"
38
+ SD35_BASE_REPOS = ["adamo1139/stable-diffusion-3.5-large-ungated", "adamo1139/stable-diffusion-3.5-large-turbo-ungated"]
39
+ SD35_T5_URL = "https://huggingface.co/adamo1139/stable-diffusion-3.5-large-turbo-ungated/blob/main/text_encoders/t5xxl_fp8_e4m3fn.safetensors"
40
+ TEMP_DIR = tempfile.mkdtemp()
41
+ IS_ZERO = os.environ.get("SPACES_ZERO_GPU") is not None
42
+ IS_CUDA = torch.cuda.is_available()
43
+
44
+
45
+ def safe_clean(path: str):
46
+ try:
47
+ if Path(path).exists():
48
+ if Path(path).is_dir(): shutil.rmtree(str(Path(path)))
49
+ else: Path(path).unlink()
50
+ print(f"Deleted: {path}")
51
+ else: print(f"File not found: {path}")
52
+ except Exception as e:
53
+ print(f"Failed to delete: {path} {e}")
54
+
55
+
56
+ def save_readme_md(dir, url):
57
+ orig_url = ""
58
+ orig_name = ""
59
+ if is_repo_name(url):
60
+ orig_name = url
61
+ orig_url = f"https://huggingface.co/{url}/"
62
+ elif "http" in url:
63
+ orig_name = url
64
+ orig_url = url
65
+ if orig_name and orig_url:
66
+ md = f"""---
67
+ license: other
68
+ language:
69
+ - en
70
+ library_name: diffusers
71
+ pipeline_tag: text-to-image
72
+ tags:
73
+ - text-to-image
74
+ ---
75
+ Converted from [{orig_name}]({orig_url}).
76
+ """
77
+ else:
78
+ md = f"""---
79
+ license: other
80
+ language:
81
+ - en
82
+ library_name: diffusers
83
+ pipeline_tag: text-to-image
84
+ tags:
85
+ - text-to-image
86
+ ---
87
+ """
88
+ path = str(Path(dir, "README.md"))
89
+ with open(path, mode='w', encoding="utf-8") as f:
90
+ f.write(md)
91
+
92
+
93
+ def save_module(model, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): # doesn't work
94
+ if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
95
+ else: pattern = "model{suffix}.safetensors"
96
+ if name in ["transformer", "unet"]: size = "10GB"
97
+ else: size = "5GB"
98
+ path = str(Path(f"{dir.removesuffix('/')}/{name}"))
99
+ os.makedirs(path, exist_ok=True)
100
+ progress(0, desc=f"Saving {name} to {dir}...")
101
+ print(f"Saving {name} to {dir}...")
102
+ model.to("cpu")
103
+ sd = dict(model.state_dict())
104
+ new_sd = {}
105
+ for key in list(sd.keys()):
106
+ q = sd.pop(key)
107
+ if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
108
+ else: new_sd[key] = q
109
+ del sd
110
+ gc.collect()
111
+ save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
112
+ del new_sd
113
+ gc.collect()
114
+
115
+
116
+ def save_module_sd(sd: dict, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)):
117
+ if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
118
+ else: pattern = "model{suffix}.safetensors"
119
+ if name in ["transformer", "unet"]: size = "10GB"
120
+ else: size = "5GB"
121
+ path = str(Path(f"{dir.removesuffix('/')}/{name}"))
122
+ os.makedirs(path, exist_ok=True)
123
+ progress(0, desc=f"Saving state_dict of {name} to {dir}...")
124
+ print(f"Saving state_dict of {name} to {dir}...")
125
+ new_sd = {}
126
+ for key in list(sd.keys()):
127
+ q = sd.pop(key).to("cpu")
128
+ if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
129
+ else: new_sd[key] = q
130
+ save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
131
+ del new_sd
132
+ gc.collect()
133
+
134
+
135
+ def convert_flux_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
136
+ temp_dir = TEMP_DIR
137
+ down_dir = str(Path(f"{TEMP_DIR}/down"))
138
+ os.makedirs(down_dir, exist_ok=True)
139
+ hf_token = get_token()
140
+ progress(0.25, desc=f"Loading {new_file}...")
141
+ orig_sd = safetensors.torch.load_file(new_file)
142
+ progress(0.3, desc=f"Converting {new_file}...")
143
+ conv_sd = convert_flux_transformer_checkpoint_to_diffusers(orig_sd)
144
+ del orig_sd
145
+ gc.collect()
146
+ progress(0.35, desc=f"Saving {new_file}...")
147
+ save_module_sd(conv_sd, "transformer", new_dir, dtype)
148
+ del conv_sd
149
+ gc.collect()
150
+ progress(0.5, desc=f"Loading text_encoder_2 from {FLUX_T5_URL}...")
151
+ t5_file = get_download_file(temp_dir, FLUX_T5_URL, civitai_key)
152
+ if not t5_file: raise Exception(f"Safetensors file not found: {FLUX_T5_URL}")
153
+ t5_sd = safetensors.torch.load_file(t5_file)
154
+ safe_clean(t5_file)
155
+ save_module_sd(t5_sd, "text_encoder_2", new_dir, dtype)
156
+ del t5_sd
157
+ gc.collect()
158
+ progress(0.6, desc=f"Loading other components from {base_repo}...")
159
+ pipe = FluxPipeline.from_pretrained(base_repo, transformer=None, text_encoder_2=None, use_safetensors=True, **kwargs,
160
+ torch_dtype=torch.bfloat16, token=hf_token)
161
+ pipe.save_pretrained(new_dir)
162
+ progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
163
+ snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
164
+ ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
165
+ shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
166
+ safe_clean(down_dir)
167
+
168
+
169
+ def convert_sd35_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
170
+ temp_dir = TEMP_DIR
171
+ down_dir = str(Path(f"{TEMP_DIR}/down"))
172
+ os.makedirs(down_dir, exist_ok=True)
173
+ hf_token = get_token()
174
+ progress(0.25, desc=f"Loading {new_file}...")
175
+ orig_sd = safetensors.torch.load_file(new_file)
176
+ progress(0.3, desc=f"Converting {new_file}...")
177
+ conv_sd = convert_sd3_transformer_checkpoint_to_diffusers(orig_sd)
178
+ del orig_sd
179
+ gc.collect()
180
+ progress(0.35, desc=f"Saving {new_file}...")
181
+ save_module_sd(conv_sd, "transformer", new_dir, dtype)
182
+ del conv_sd
183
+ gc.collect()
184
+ progress(0.5, desc=f"Loading text_encoder_3 from {SD35_T5_URL}...")
185
+ t5_file = get_download_file(temp_dir, SD35_T5_URL, civitai_key)
186
+ if not t5_file: raise Exception(f"Safetensors file not found: {SD35_T5_URL}")
187
+ t5_sd = safetensors.torch.load_file(t5_file)
188
+ safe_clean(t5_file)
189
+ conv_t5_sd = convert_sd3_t5_checkpoint_to_diffusers(t5_sd)
190
+ del t5_sd
191
+ gc.collect()
192
+ save_module_sd(conv_t5_sd, "text_encoder_3", new_dir, dtype)
193
+ del conv_t5_sd
194
+ gc.collect()
195
+ progress(0.6, desc=f"Loading other components from {base_repo}...")
196
+ pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=None, text_encoder_3=None, use_safetensors=True, **kwargs,
197
+ torch_dtype=torch.bfloat16, token=hf_token)
198
+ pipe.save_pretrained(new_dir)
199
+ progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
200
+ snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
201
+ ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
202
+ shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
203
+ safe_clean(down_dir)
204
+
205
+
206
+ #@spaces.GPU(duration=60)
207
+ def load_and_save_pipeline(pipe, model_type: str, url: str, new_file: str, new_dir: str, dtype: str,
208
+ scheduler: str, ema: bool, image_size: str, is_safety_checker: bool, base_repo: str, civitai_key: str, lora_dict: dict,
209
+ my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder,
210
+ kwargs: dict, dkwargs: dict, progress=gr.Progress(track_tqdm=True)):
211
+ try:
212
+ hf_token = get_token()
213
+ temp_dir = TEMP_DIR
214
+ qkwargs = {}
215
+ tfqkwargs = {}
216
+ if is_nf4:
217
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
218
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
219
+ nf4_config_tf = TFBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
220
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
221
+ else:
222
+ nf4_config = None
223
+ nf4_config_tf = None
224
+ if dtype == "NF4" and nf4_config is not None and nf4_config_tf is not None:
225
+ qkwargs["quantization_config"] = nf4_config
226
+ tfqkwargs["quantization_config"] = nf4_config_tf
227
+
228
+ print(f"model_type:{model_type}, dtype:{dtype}, scheduler:{scheduler}, ema:{ema}, base_repo:{base_repo}")
229
+ print("lora_dict:", lora_dict, "kwargs:", kwargs, "dkwargs:", dkwargs)
230
+
231
+ #t5 = None
232
+
233
+ if model_type == "SDXL":
234
+ if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
235
+ else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
236
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
237
+ sconf = get_scheduler_config(scheduler)
238
+ pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
239
+ pipe.save_pretrained(new_dir)
240
+ elif model_type == "SD 1.5":
241
+ if is_safety_checker:
242
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
243
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
244
+ kwargs["requires_safety_checker"] = True
245
+ kwargs["safety_checker"] = safety_checker
246
+ kwargs["feature_extractor"] = feature_extractor
247
+ else: kwargs["requires_safety_checker"] = False
248
+ if is_repo_name(url): pipe = StableDiffusionPipeline.from_pretrained(url, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
249
+ else: pipe = StableDiffusionPipeline.from_single_file(new_file, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs)
250
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
251
+ sconf = get_scheduler_config(scheduler)
252
+ pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
253
+ if image_size != "512": pipe.vae = AutoencoderKL.from_config(pipe.vae.config, sample_size=int(image_size))
254
+ pipe.save_pretrained(new_dir)
255
+ elif model_type == "FLUX":
256
+ if dtype != "fp8":
257
+ if is_repo_name(url):
258
+ transformer = FluxTransformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
259
+ #if my_t5_encoder is None:
260
+ # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs)
261
+ # kwargs["text_encoder_2"] = t5
262
+ pipe = FluxPipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
263
+ else:
264
+ transformer = FluxTransformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
265
+ #if my_t5_encoder is None:
266
+ # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs)
267
+ # kwargs["text_encoder_2"] = t5
268
+ pipe = FluxPipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
269
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
270
+ pipe.save_pretrained(new_dir)
271
+ elif not is_repo_name(url): convert_flux_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
272
+ elif model_type == "SD 3.5":
273
+ if dtype != "fp8":
274
+ if is_repo_name(url):
275
+ transformer = SD3Transformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
276
+ #if my_t5_encoder is None:
277
+ # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs)
278
+ # kwargs["text_encoder_3"] = t5
279
+ pipe = StableDiffusion3Pipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
280
+ else:
281
+ transformer = SD3Transformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
282
+ #if my_t5_encoder is None:
283
+ # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs)
284
+ # kwargs["text_encoder_3"] = t5
285
+ pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
286
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
287
+ pipe.save_pretrained(new_dir)
288
+ elif not is_repo_name(url): convert_sd35_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
289
+ else: # unknown model type
290
+ if is_repo_name(url): pipe = DiffusionPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
291
+ else: pipe = DiffusionPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
292
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
293
+ pipe.save_pretrained(new_dir)
294
+ except Exception as e:
295
+ print(f"Failed to load pipeline. {e}")
296
+ raise Exception("Failed to load pipeline.") from e
297
+ finally:
298
+ return pipe
299
+
300
+
301
+ def convert_url_to_diffusers(url: str, civitai_key: str="", is_upload_sf: bool=False, dtype: str="fp16", vae: str="", clip: str="", t5: str="",
302
+ scheduler: str="Euler a", ema: bool=True, image_size: str="768", safety_checker: bool=False,
303
+ base_repo: str="", mtype: str="", lora_dict: dict={}, is_local: bool=True, progress=gr.Progress(track_tqdm=True)):
304
+ try:
305
+ hf_token = get_token()
306
+ progress(0, desc="Start converting...")
307
+ temp_dir = TEMP_DIR
308
+
309
+ if is_repo_name(url) and is_repo_exists(url):
310
+ new_file = url
311
+ model_type = mtype
312
+ else:
313
+ new_file = get_download_file(temp_dir, url, civitai_key)
314
+ if not new_file: raise Exception(f"Safetensors file not found: {url}")
315
+ model_type = get_model_type_from_key(new_file)
316
+ new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
317
+
318
+ kwargs = {}
319
+ dkwargs = {}
320
+ if dtype != DTYPE_DEFAULT: dkwargs["torch_dtype"] = get_process_dtype(dtype, model_type)
321
+ pipe = None
322
+
323
+ print(f"Model type: {model_type} / VAE: {vae} / CLIP: {clip} / T5: {t5} / Scheduler: {scheduler} / dtype: {dtype} / EMA: {ema} / Base repo: {base_repo} / LoRAs: {lora_dict}")
324
+
325
+ my_vae = None
326
+ if vae:
327
+ progress(0, desc=f"Loading VAE: {vae}...")
328
+ if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **dkwargs, token=hf_token)
329
+ else:
330
+ new_vae_file = get_download_file(temp_dir, vae, civitai_key)
331
+ my_vae = AutoencoderKL.from_single_file(new_vae_file, **dkwargs) if new_vae_file else None
332
+ safe_clean(new_vae_file)
333
+ if my_vae: kwargs["vae"] = my_vae
334
+
335
+ my_clip_tokenizer = None
336
+ my_clip_encoder = None
337
+ if clip:
338
+ progress(0, desc=f"Loading CLIP: {clip}...")
339
+ if is_repo_name(clip):
340
+ my_clip_tokenizer = CLIPTokenizer.from_pretrained(clip, token=hf_token)
341
+ if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_pretrained(clip, **dkwargs, token=hf_token)
342
+ else: my_clip_encoder = CLIPTextModel.from_pretrained(clip, **dkwargs, token=hf_token)
343
+ else:
344
+ new_clip_file = get_download_file(temp_dir, clip, civitai_key)
345
+ if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
346
+ else: my_clip_encoder = CLIPTextModel.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
347
+ safe_clean(new_clip_file)
348
+ if model_type == "SD 3.5":
349
+ if my_clip_tokenizer:
350
+ kwargs["tokenizer"] = my_clip_tokenizer
351
+ kwargs["tokenizer_2"] = my_clip_tokenizer
352
+ if my_clip_encoder:
353
+ kwargs["text_encoder"] = my_clip_encoder
354
+ kwargs["text_encoder_2"] = my_clip_encoder
355
+ else:
356
+ if my_clip_tokenizer: kwargs["tokenizer"] = my_clip_tokenizer
357
+ if my_clip_encoder: kwargs["text_encoder"] = my_clip_encoder
358
+
359
+ my_t5_tokenizer = None
360
+ my_t5_encoder = None
361
+ if t5:
362
+ progress(0, desc=f"Loading T5: {t5}...")
363
+ if is_repo_name(t5):
364
+ my_t5_tokenizer = AutoTokenizer.from_pretrained(t5, token=hf_token)
365
+ my_t5_encoder = T5EncoderModel.from_pretrained(t5, **dkwargs, token=hf_token)
366
+ else:
367
+ new_t5_file = get_download_file(temp_dir, t5, civitai_key)
368
+ my_t5_encoder = T5EncoderModel.from_single_file(new_t5_file, **dkwargs) if new_t5_file else None
369
+ safe_clean(new_t5_file)
370
+ if model_type == "SD 3.5":
371
+ if my_t5_tokenizer: kwargs["tokenizer_3"] = my_t5_tokenizer
372
+ if my_t5_encoder: kwargs["text_encoder_3"] = my_t5_encoder
373
+ else:
374
+ if my_t5_tokenizer: kwargs["tokenizer_2"] = my_t5_tokenizer
375
+ if my_t5_encoder: kwargs["text_encoder_2"] = my_t5_encoder
376
+
377
+ pipe = load_and_save_pipeline(pipe, model_type, url, new_file, new_dir, dtype, scheduler, ema, image_size, safety_checker, base_repo, civitai_key, lora_dict,
378
+ my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, kwargs, dkwargs)
379
+
380
+ if Path(new_dir).exists(): save_readme_md(new_dir, url)
381
+
382
+ if not is_local:
383
+ if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve()))
384
+ else: safe_clean(new_file)
385
+
386
+ progress(1, desc="Converted.")
387
+ return new_dir
388
+ except Exception as e:
389
+ print(f"Failed to convert. {e}")
390
+ raise Exception("Failed to convert.") from e
391
+ finally:
392
+ del pipe
393
+ torch.cuda.empty_cache()
394
+ gc.collect()
395
+
396
+
397
+ def convert_url_to_diffusers_repo(dl_url: str, hf_user: str, hf_repo: str, hf_token: str, civitai_key="", is_private: bool=True, is_overwrite: bool=False, is_pr: bool=False,
398
+ is_upload_sf: bool=False, urls: list=[], dtype: str="fp16", vae: str="", clip: str="", t5: str="", scheduler: str="Euler a",
399
+ ema: bool=True, image_size: str="768", safety_checker: bool=False,
400
+ base_repo: str="", mtype: str="", lora1: str="", lora1s=1.0, lora2: str="", lora2s=1.0, lora3: str="", lora3s=1.0,
401
+ lora4: str="", lora4s=1.0, lora5: str="", lora5s=1.0, args: str="", progress=gr.Progress(track_tqdm=True)):
402
+ try:
403
+ is_local = False
404
+ if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
405
+ if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token
406
+ if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}")
407
+ set_token(hf_token)
408
+ lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s}
409
+ new_path = convert_url_to_diffusers(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, t5, scheduler, ema, image_size, safety_checker, base_repo, mtype, lora_dict, is_local)
410
+ if not new_path: return ""
411
+ new_repo_id = f"{hf_user}/{Path(new_path).stem}"
412
+ if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
413
+ if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}")
414
+ if not is_overwrite and is_repo_exists(new_repo_id) and not is_pr: raise gr.Error(f"Repo already exists: {new_repo_id}")
415
+ repo_url = upload_repo(new_repo_id, new_path, is_private, is_pr)
416
+ safe_clean(new_path)
417
+ if not urls: urls = []
418
+ urls.append(repo_url)
419
+ md = "### Your new repo:\n"
420
+ for u in urls:
421
+ md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
422
+ return gr.update(value=urls, choices=urls), gr.update(value=md)
423
+ except Exception as e:
424
+ print(f"Error occured. {e}")
425
+ raise gr.Error(f"Error occured. {e}")
426
+
427
+
428
+ if __name__ == "__main__":
429
+ parser = argparse.ArgumentParser()
430
+
431
+ parser.add_argument("--url", type=str, required=True, help="URL of the model to convert.")
432
+ parser.add_argument("--dtype", default="fp16", type=str, choices=get_dtypes(), help='Output data type. (Default: "fp16")')
433
+ parser.add_argument("--scheduler", default="Euler a", type=str, choices=list(SCHEDULER_CONFIG_MAP.keys()), required=False, help="Scheduler name to use.")
434
+ parser.add_argument("--vae", default="", type=str, required=False, help="URL or Repo ID of the VAE to use.")
435
+ parser.add_argument("--clip", default="", type=str, required=False, help="URL or Repo ID of the CLIP to use.")
436
+ parser.add_argument("--t5", default="", type=str, required=False, help="URL or Repo ID of the T5 to use.")
437
+ parser.add_argument("--base", default="", type=str, required=False, help="Repo ID of the base repo.")
438
+ parser.add_argument("--nonema", action="store_true", default=False, help="Don't extract EMA (for SD 1.5).")
439
+ parser.add_argument("--civitai_key", default="", type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
440
+ parser.add_argument("--lora1", default="", type=str, required=False, help="URL of the LoRA to use.")
441
+ parser.add_argument("--lora1s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora1.")
442
+ parser.add_argument("--lora2", default="", type=str, required=False, help="URL of the LoRA to use.")
443
+ parser.add_argument("--lora2s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora2.")
444
+ parser.add_argument("--lora3", default="", type=str, required=False, help="URL of the LoRA to use.")
445
+ parser.add_argument("--lora3s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora3.")
446
+ parser.add_argument("--lora4", default="", type=str, required=False, help="URL of the LoRA to use.")
447
+ parser.add_argument("--lora4s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora4.")
448
+ parser.add_argument("--lora5", default="", type=str, required=False, help="URL of the LoRA to use.")
449
+ parser.add_argument("--lora5s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora5.")
450
+ parser.add_argument("--loras", default="", type=str, required=False, help="Folder of the LoRA to use.")
451
+
452
+ args = parser.parse_args()
453
+ assert args.url is not None, "Must provide a URL!"
454
+
455
+ is_local = True
456
+ lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s}
457
+ if args.loras and Path(args.loras).exists():
458
+ for p in Path(args.loras).glob('**/*.safetensors'):
459
+ lora_dict[str(p)] = 1.0
460
+ ema = not args.nonema
461
+ mtype = "SDXL"
462
+
463
+ convert_url_to_diffusers(args.url, args.civitai_key, args.dtype, args.vae, args.clip, args.t5, args.scheduler, ema, args.base, mtype, lora_dict, is_local)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2
presets.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sdutils import get_dtypes, SCHEDULER_CONFIG_MAP
2
+
3
+
4
+ DEFAULT_DTYPE = get_dtypes()[0]
5
+ schedulers = list(SCHEDULER_CONFIG_MAP.keys())
6
+
7
+
8
+ clips = [
9
+ "",
10
+ "openai/clip-vit-large-patch14",
11
+ ]
12
+
13
+
14
+ t5s = [
15
+ "",
16
+ "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors",
17
+ ]
18
+
19
+
20
+ sdxl_vaes = [
21
+ "",
22
+ "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
23
+ "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
24
+ "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_e7.safetensors",
25
+ "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_f1.safetensors",
26
+ ]
27
+
28
+
29
+ sdxl_loras = [
30
+ "",
31
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
32
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_2step_converted.safetensors",
33
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_4step_converted.safetensors",
34
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_8step_converted.safetensors",
35
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_8step_converted.safetensors",
36
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_16step_converted.safetensors",
37
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-1step-lora.safetensors",
38
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-2steps-lora.safetensors",
39
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-4steps-lora.safetensors",
40
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors",
41
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-12steps-CFG-lora.safetensors",
42
+ "https://huggingface.co/latent-consistency/lcm-lora-sdxl/blob/main/pytorch_lora_weights.safetensors",
43
+ ]
44
+
45
+
46
+ sdxl_preset_dict = {
47
+ "Default": [DEFAULT_DTYPE, "", "Euler a", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
48
+ "Bake in standard VAE": [DEFAULT_DTYPE, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
49
+ "Euler a", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
50
+ "Hyper-SDXL / SPO": [DEFAULT_DTYPE, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
51
+ "TCD", "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors", 1.0,
52
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
53
+ 1.0, "", 1.0, "", 1.0, "", 1.0],
54
+ }
55
+
56
+
57
+ def sdxl_set_presets(preset: str="Default"):
58
+ p = []
59
+ if preset in sdxl_preset_dict.keys(): p = sdxl_preset_dict[preset]
60
+ else: p = sdxl_preset_dict["Default"]
61
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12]
62
+
63
+
64
+ sd15_vaes = [
65
+ "",
66
+ "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
67
+ "https://huggingface.co/stabilityai/sd-vae-ft-ema-original/resolve/main/vae-ft-ema-560000-ema-pruned.ckpt",
68
+ ]
69
+
70
+
71
+ sd15_loras = [
72
+ "",
73
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep_LoRA/blob/main/spo-sd-v1-5_4k-p_10ep_lora_diffusers.safetensors",
74
+ ]
75
+
76
+
77
+ sd15_preset_dict = {
78
+ "Default": [DEFAULT_DTYPE, "", "Euler", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, True],
79
+ "Bake in standard VAE": [DEFAULT_DTYPE, "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
80
+ "Euler", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, True],
81
+ }
82
+
83
+
84
+ def sd15_set_presets(preset: str="Default"):
85
+ p = []
86
+ if preset in sd15_preset_dict.keys(): p = sd15_preset_dict[preset]
87
+ else: p = sd15_preset_dict["Default"]
88
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
89
+
90
+
91
+ flux_vaes = [
92
+ "",
93
+ ]
94
+
95
+
96
+ flux_loras = [
97
+ "",
98
+ ]
99
+
100
+
101
+ flux_preset_dict = {
102
+ "dev": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "camenduru/FLUX.1-dev-diffusers"],
103
+ "schnell": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "black-forest-labs/FLUX.1-schnell"],
104
+ }
105
+
106
+
107
+ def flux_set_presets(preset: str="dev"):
108
+ p = []
109
+ if preset in flux_preset_dict.keys(): p = flux_preset_dict[preset]
110
+ else: p = flux_preset_dict["dev"]
111
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
112
+
113
+
114
+
115
+ sd35_vaes = [
116
+ "",
117
+ ]
118
+
119
+
120
+ sd35_loras = [
121
+ "",
122
+ ]
123
+
124
+
125
+ sd35_preset_dict = {
126
+ "Default": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "adamo1139/stable-diffusion-3.5-large-ungated"],
127
+ }
128
+
129
+
130
+ def sd35_set_presets(preset: str="dev"):
131
+ p = []
132
+ if preset in sd35_preset_dict.keys(): p = sd35_preset_dict[preset]
133
+ else: p = sd35_preset_dict["Default"]
134
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
requirements.txt CHANGED
@@ -1,9 +1,12 @@
1
  huggingface_hub
2
  safetensors
3
- transformers==4.44.0
4
- accelerate
5
- diffusers==0.30.3
6
- pytorch_lightning
7
  peft
8
- aria2
 
 
9
  gdown
 
 
 
 
1
  huggingface_hub
2
  safetensors
3
+ transformers==4.46.3
4
+ diffusers==0.31.0
 
 
5
  peft
6
+ sentencepiece
7
+ torch==2.5.1
8
+ pytorch_lightning
9
  gdown
10
+ bitsandbytes
11
+ accelerate
12
+ numpy<2
sdutils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from utils import get_download_file
4
+ from stkey import read_safetensors_key
5
+ try:
6
+ from diffusers import BitsAndBytesConfig
7
+ is_nf4 = True
8
+ except Exception:
9
+ is_nf4 = False
10
+
11
+
12
+ DTYPE_DEFAULT = "default"
13
+ DTYPE_DICT = {
14
+ "fp16": torch.float16,
15
+ "bf16": torch.bfloat16,
16
+ "fp32": torch.float32,
17
+ "fp8": torch.float8_e4m3fn,
18
+ }
19
+ #QTYPES = ["NF4"] if is_nf4 else []
20
+ QTYPES = []
21
+
22
+ def get_dtypes():
23
+ return list(DTYPE_DICT.keys()) + [DTYPE_DEFAULT] + QTYPES
24
+
25
+
26
+ def get_dtype(dtype: str):
27
+ if dtype in set(QTYPES): return torch.bfloat16
28
+ return DTYPE_DICT.get(dtype, torch.float16)
29
+
30
+
31
+ from diffusers import (
32
+ DPMSolverMultistepScheduler,
33
+ DPMSolverSinglestepScheduler,
34
+ KDPM2DiscreteScheduler,
35
+ EulerDiscreteScheduler,
36
+ EulerAncestralDiscreteScheduler,
37
+ HeunDiscreteScheduler,
38
+ LMSDiscreteScheduler,
39
+ DDIMScheduler,
40
+ DEISMultistepScheduler,
41
+ UniPCMultistepScheduler,
42
+ LCMScheduler,
43
+ PNDMScheduler,
44
+ KDPM2AncestralDiscreteScheduler,
45
+ DPMSolverSDEScheduler,
46
+ EDMDPMSolverMultistepScheduler,
47
+ DDPMScheduler,
48
+ EDMEulerScheduler,
49
+ TCDScheduler,
50
+ )
51
+
52
+
53
+ SCHEDULER_CONFIG_MAP = {
54
+ "DPM++ 2M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": False}),
55
+ "DPM++ 2M Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": True}),
56
+ "DPM++ 2M SDE": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
57
+ "DPM++ 2M SDE Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
58
+ "DPM++ 2S": (DPMSolverSinglestepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": False}),
59
+ "DPM++ 2S Karras": (DPMSolverSinglestepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": True}),
60
+ "DPM++ 1S": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 1}),
61
+ "DPM++ 1S Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 1, "use_karras_sigmas": True}),
62
+ "DPM++ 3M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 3}),
63
+ "DPM++ 3M Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 3, "use_karras_sigmas": True}),
64
+ "DPM 3M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver", "final_sigmas_type": "sigma_min", "solver_order": 3}),
65
+ "DPM++ SDE": (DPMSolverSDEScheduler, {"use_karras_sigmas": False}),
66
+ "DPM++ SDE Karras": (DPMSolverSDEScheduler, {"use_karras_sigmas": True}),
67
+ "DPM2": (KDPM2DiscreteScheduler, {}),
68
+ "DPM2 Karras": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
69
+ "DPM2 a": (KDPM2AncestralDiscreteScheduler, {}),
70
+ "DPM2 a Karras": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
71
+ "Euler": (EulerDiscreteScheduler, {}),
72
+ "Euler a": (EulerAncestralDiscreteScheduler, {}),
73
+ "Euler trailing": (EulerDiscreteScheduler, {"timestep_spacing": "trailing", "prediction_type": "sample"}),
74
+ "Euler a trailing": (EulerAncestralDiscreteScheduler, {"timestep_spacing": "trailing"}),
75
+ "Heun": (HeunDiscreteScheduler, {}),
76
+ "Heun Karras": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
77
+ "LMS": (LMSDiscreteScheduler, {}),
78
+ "LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
79
+ "DDIM": (DDIMScheduler, {}),
80
+ "DDIM trailing": (DDIMScheduler, {"timestep_spacing": "trailing"}),
81
+ "DEIS": (DEISMultistepScheduler, {}),
82
+ "UniPC": (UniPCMultistepScheduler, {}),
83
+ "UniPC Karras": (UniPCMultistepScheduler, {"use_karras_sigmas": True}),
84
+ "PNDM": (PNDMScheduler, {}),
85
+ "Euler EDM": (EDMEulerScheduler, {}),
86
+ "Euler EDM Karras": (EDMEulerScheduler, {"use_karras_sigmas": True}),
87
+ "DPM++ 2M EDM": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
88
+ "DPM++ 2M EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
89
+ "DDPM": (DDPMScheduler, {}),
90
+
91
+ "DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_lu_lambdas": True}),
92
+ "DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "euler_at_final": True}),
93
+ "DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
94
+ "DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),
95
+
96
+ "LCM": (LCMScheduler, {}),
97
+ "TCD": (TCDScheduler, {}),
98
+ "LCM trailing": (LCMScheduler, {"timestep_spacing": "trailing"}),
99
+ "TCD trailing": (TCDScheduler, {"timestep_spacing": "trailing"}),
100
+ "LCM Auto-Loader": (LCMScheduler, {}),
101
+ "TCD Auto-Loader": (TCDScheduler, {}),
102
+
103
+ "EDM": (EDMDPMSolverMultistepScheduler, {}),
104
+ "EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
105
+
106
+ "Euler (V-Prediction)": (EulerDiscreteScheduler, {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}),
107
+ "Euler a (V-Prediction)": (EulerAncestralDiscreteScheduler, {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}),
108
+ "Euler EDM (V-Prediction)": (EDMEulerScheduler, {"prediction_type": "v_prediction"}),
109
+ "Euler EDM Karras (V-Prediction)": (EDMEulerScheduler, {"use_karras_sigmas": True, "prediction_type": "v_prediction"}),
110
+ "DPM++ 2M EDM (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++", "prediction_type": "v_prediction"}),
111
+ "DPM++ 2M EDM Karras (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++", "prediction_type": "v_prediction"}),
112
+ "EDM (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"prediction_type": "v_prediction"}),
113
+ "EDM Karras (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "prediction_type": "v_prediction"}),
114
+ }
115
+
116
+
117
+ def get_scheduler_config(name: str):
118
+ if not name in SCHEDULER_CONFIG_MAP.keys(): return SCHEDULER_CONFIG_MAP["Euler a"]
119
+ return SCHEDULER_CONFIG_MAP[name]
120
+
121
+
122
+ def fuse_loras(pipe, lora_dict: dict, temp_dir: str, civitai_key: str="", dkwargs: dict={}):
123
+ if not lora_dict or not isinstance(lora_dict, dict): return pipe
124
+ a_list = []
125
+ w_list = []
126
+ for k, v in lora_dict.items():
127
+ if not k: continue
128
+ new_lora_file = get_download_file(temp_dir, k, civitai_key)
129
+ if not new_lora_file or not Path(new_lora_file).exists():
130
+ print(f"LoRA file not found: {k}")
131
+ continue
132
+ w_name = Path(new_lora_file).name
133
+ a_name = Path(new_lora_file).stem
134
+ pipe.load_lora_weights(new_lora_file, weight_name=w_name, adapter_name=a_name, low_cpu_mem_usage=False, **dkwargs)
135
+ a_list.append(a_name)
136
+ w_list.append(v)
137
+ if Path(new_lora_file).exists(): Path(new_lora_file).unlink()
138
+ if len(a_list) == 0: return pipe
139
+ pipe.set_adapters(a_list, adapter_weights=w_list)
140
+ pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
141
+ pipe.unload_lora_weights()
142
+ return pipe
143
+
144
+
145
+ MODEL_TYPE_KEY = {
146
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL",
147
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5",
148
+ "double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
149
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
150
+ "model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5",
151
+ }
152
+
153
+
154
+ def get_model_type_from_key(path: str):
155
+ default = "SDXL"
156
+ try:
157
+ keys = read_safetensors_key(path)
158
+ for k, v in MODEL_TYPE_KEY.items():
159
+ if k in set(keys):
160
+ print(f"Model type is {v}.")
161
+ return v
162
+ print("Model type could not be identified.")
163
+ except Exception:
164
+ return default
165
+ return default
166
+
167
+
168
+ def get_process_dtype(dtype: str, model_type: str):
169
+ if dtype in set(["fp8"] + QTYPES): return torch.bfloat16 if model_type in ["FLUX", "SD 3.5"] else torch.float16
170
+ return DTYPE_DICT.get(dtype, torch.float16)
stkey.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import json
4
+ import re
5
+ import gc
6
+ from safetensors.torch import load_file, save_file
7
+ import torch
8
+
9
+
10
+ SDXL_KEYS_FILE = "keys/sdxl_keys.txt"
11
+
12
+
13
+ def list_uniq(l):
14
+ return sorted(set(l), key=l.index)
15
+
16
+
17
+ def read_safetensors_metadata(path: str):
18
+ with open(path, 'rb') as f:
19
+ header_size = int.from_bytes(f.read(8), 'little')
20
+ header_json = f.read(header_size).decode('utf-8')
21
+ header = json.loads(header_json)
22
+ metadata = header.get('__metadata__', {})
23
+ return metadata
24
+
25
+
26
+ def keys_from_file(path: str):
27
+ keys = []
28
+ try:
29
+ with open(str(Path(path)), encoding='utf-8', mode='r') as f:
30
+ lines = f.readlines()
31
+ for line in lines:
32
+ keys.append(line.strip())
33
+ except Exception as e:
34
+ print(e)
35
+ finally:
36
+ return keys
37
+
38
+
39
+ def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE):
40
+ missing = []
41
+ added = []
42
+ try:
43
+ rkeys = keys_from_file(rfile)
44
+ all_keys = list_uniq(keys + rkeys)
45
+ for key in all_keys:
46
+ if key in set(rkeys) and key not in set(keys): missing.append(key)
47
+ if key in set(keys) and key not in set(rkeys): added.append(key)
48
+ except Exception as e:
49
+ print(e)
50
+ finally:
51
+ return missing, added
52
+
53
+
54
+ def read_safetensors_key(path: str):
55
+ try:
56
+ keys = []
57
+ state_dict = load_file(str(Path(path)))
58
+ for k in list(state_dict.keys()):
59
+ keys.append(k)
60
+ state_dict.pop(k)
61
+ except Exception as e:
62
+ print(e)
63
+ finally:
64
+ del state_dict
65
+ torch.cuda.empty_cache()
66
+ gc.collect()
67
+ return keys
68
+
69
+
70
+ def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE):
71
+ if len(keys) == 0: return False
72
+ try:
73
+ with open(str(Path(path)), encoding='utf-8', mode='w') as f:
74
+ f.write("\n".join(keys))
75
+ if is_validate:
76
+ missing, added = validate_keys(keys, rpath)
77
+ with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f:
78
+ f.write("\n".join(missing))
79
+ with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f:
80
+ f.write("\n".join(added))
81
+ return True
82
+ except Exception as e:
83
+ print(e)
84
+ return False
85
+
86
+
87
+ def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE):
88
+ keys = read_safetensors_key(input)
89
+ if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile)
90
+ if len(keys) != 0:
91
+ print("Metadata:")
92
+ print(read_safetensors_metadata(input))
93
+ print("\nKeys:")
94
+ print("\n".join(keys))
95
+ if is_validate:
96
+ missing, added = validate_keys(keys, rfile)
97
+ print("\nMissing Keys:")
98
+ print("\n".join(missing))
99
+ print("\nAdded Keys:")
100
+ print("\n".join(added))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument("input", type=str, help="Input safetensors file.")
106
+ parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.")
107
+ parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.")
108
+ parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.")
109
+ parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.")
110
+
111
+ args = parser.parse_args()
112
+
113
+ if args.save: out_filename = Path(args.input).stem + ".txt"
114
+ out_filename = args.output if args.output else out_filename
115
+
116
+ stkey(args.input, out_filename, args.val, args.rfile)
117
+
118
+
119
+ # Usage:
120
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors
121
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -s
122
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -o key.txt
utils.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  import urllib.parse
9
  import subprocess
10
  import time
 
11
 
12
 
13
  def get_token():
@@ -25,6 +26,17 @@ def set_token(token):
25
  print(f"Error: Failed to save token.")
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  def get_user_agent():
29
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
30
 
@@ -113,26 +125,28 @@ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
113
 
114
 
115
  def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
116
- url = url.strip()
117
- if "drive.google.com" in url:
118
- original_dir = os.getcwd()
119
- os.chdir(directory)
120
- os.system(f"gdown --fuzzy {url}")
121
- os.chdir(original_dir)
122
- elif "huggingface.co" in url:
123
- url = url.replace("?download=true", "")
124
- if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
125
- download_hf_file(directory, url)
126
- elif "civitai.com" in url:
127
- if "?" in url:
128
- url = url.split("?")[0]
129
- if civitai_api_key:
130
- url = url + f"?token={civitai_api_key}"
131
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
 
 
132
  else:
133
- print("You need an API key to download Civitai models.")
134
- else:
135
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
136
 
137
 
138
  def get_local_file_list(dir_path):
@@ -145,30 +159,30 @@ def get_local_file_list(dir_path):
145
 
146
 
147
  def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
148
- if not "http" in url and is_repo_name(url) and not Path(url).exists():
149
- print(f"Use HF Repo: {url}")
150
- new_file = url
151
- elif not "http" in url and Path(url).exists():
152
- print(f"Use local file: {url}")
153
- new_file = url
154
- elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
155
- print(f"File to download alreday exists: {url}")
156
- new_file = f"{temp_dir}/{url.split('/')[-1]}"
157
- else:
158
- print(f"Start downloading: {url}")
159
- before = get_local_file_list(temp_dir)
160
- try:
161
  download_thing(temp_dir, url.strip(), civitai_key)
162
- except Exception:
 
 
163
  print(f"Download failed: {url}")
164
  return ""
165
- after = get_local_file_list(temp_dir)
166
- new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
167
- if not new_file:
168
- print(f"Download failed: {url}")
169
  return ""
170
- print(f"Download completed: {url}")
171
- return new_file
172
 
173
 
174
  def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
@@ -183,7 +197,21 @@ def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=T
183
  return False
184
 
185
 
186
- def upload_repo(repo_id: str, dir_path: str, is_private: bool, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  hf_token = get_token()
188
  api = HfApi(token=hf_token)
189
  try:
@@ -191,9 +219,9 @@ def upload_repo(repo_id: str, dir_path: str, is_private: bool, progress=gr.Progr
191
  api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
192
  for path in Path(dir_path).glob("*"):
193
  if path.is_dir():
194
- api.upload_folder(repo_id=repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
195
  elif path.is_file():
196
- api.upload_file(repo_id=repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
197
  progress(1, desc="Uploaded.")
198
  return get_hf_url(repo_id, "model")
199
  except Exception as e:
 
8
  import urllib.parse
9
  import subprocess
10
  import time
11
+ from typing import Any
12
 
13
 
14
  def get_token():
 
26
  print(f"Error: Failed to save token.")
27
 
28
 
29
+ def get_state(state: dict, key: str):
30
+ if key in state.keys(): return state[key]
31
+ else:
32
+ print(f"State '{key}' not found.")
33
+ return None
34
+
35
+
36
+ def set_state(state: dict, key: str, value: Any):
37
+ state[key] = value
38
+
39
+
40
  def get_user_agent():
41
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
42
 
 
125
 
126
 
127
  def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
128
+ try:
129
+ url = url.strip()
130
+ if "drive.google.com" in url:
131
+ original_dir = os.getcwd()
132
+ os.chdir(directory)
133
+ subprocess.run(f"gdown --fuzzy {url}", shell=True)
134
+ os.chdir(original_dir)
135
+ elif "huggingface.co" in url:
136
+ url = url.replace("?download=true", "")
137
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
138
+ download_hf_file(directory, url)
139
+ elif "civitai.com" in url:
140
+ if civitai_api_key:
141
+ url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}"
142
+ print(f"Downloading {url}")
143
+ subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True)
144
+ else:
145
+ print("You need an API key to download Civitai models.")
146
  else:
147
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
148
+ except Exception as e:
149
+ print(f"Failed to download: {e}")
150
 
151
 
152
  def get_local_file_list(dir_path):
 
159
 
160
 
161
  def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
162
+ try:
163
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
164
+ print(f"Use HF Repo: {url}")
165
+ new_file = url
166
+ elif not "http" in url and Path(url).exists():
167
+ print(f"Use local file: {url}")
168
+ new_file = url
169
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
170
+ print(f"File to download alreday exists: {url}")
171
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
172
+ else:
173
+ print(f"Start downloading: {url}")
174
+ before = get_local_file_list(temp_dir)
175
  download_thing(temp_dir, url.strip(), civitai_key)
176
+ after = get_local_file_list(temp_dir)
177
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
178
+ if not new_file:
179
  print(f"Download failed: {url}")
180
  return ""
181
+ print(f"Download completed: {url}")
182
+ return new_file
183
+ except Exception as e:
184
+ print(f"Download failed: {url} {e}")
185
  return ""
 
 
186
 
187
 
188
  def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
 
197
  return False
198
 
199
 
200
+ def upload_repo(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
201
+ hf_token = get_token()
202
+ api = HfApi(token=hf_token)
203
+ try:
204
+ progress(0, desc="Start uploading...")
205
+ api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
206
+ api.upload_folder(repo_id=repo_id, folder_path=dir_path, path_in_repo="", create_pr=is_pr, token=hf_token)
207
+ progress(1, desc="Uploaded.")
208
+ return get_hf_url(repo_id, "model")
209
+ except Exception as e:
210
+ print(f"Error: Failed to upload to {repo_id}. {e}")
211
+ return ""
212
+
213
+
214
+ def upload_repo_old(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
215
  hf_token = get_token()
216
  api = HfApi(token=hf_token)
217
  try:
 
219
  api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
220
  for path in Path(dir_path).glob("*"):
221
  if path.is_dir():
222
+ api.upload_folder(repo_id=repo_id, folder_path=str(path), path_in_repo=path.name, create_pr=is_pr, token=hf_token)
223
  elif path.is_file():
224
+ api.upload_file(repo_id=repo_id, path_or_fileobj=str(path), path_in_repo=path.name, create_pr=is_pr, token=hf_token)
225
  progress(1, desc="Uploaded.")
226
  return get_hf_url(repo_id, "model")
227
  except Exception as e: