sd-to-diffusers / app.py
anzorq's picture
Update app.py
de84b65
raw
history blame
3.22 kB
import subprocess
from huggingface_hub import HfApi, hf_hub_download
import gradio as gr
subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers.git", "diffs"])
def error_str(error, title="Error"):
return f"""#### {title}
{error}"""
def url_to_model_id(model_id_str):
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
def get_ckpt_names(model_id = "nitrosocke/mo-di-diffusion"):
if model_id == "":
return error_str("Please enter a model name.", title="Invalid input"), None, None
try:
api = HfApi()
ckpt_files = [f for f in api.list_repo_files(url_to_model_id(model_id)) if f.endswith(".ckpt")]
if len(ckpt_files) == 0:
return error_str("No checkpoint files found in the model repo."), None, None
return None, gr.update(choices=ckpt_files, visible=True), gr.update(visible=True), "\n".join([f"- {ckpt_file.rfilename}" for ckpt_file in ckpt_files])
except Exception as e:
return error_str(e), None, None
def convert(model_id, ckpt_name, token = "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"):
model_id = url_to_model_id(model_id)
# 1. Download the checkpoint file
ckpt_path = hf_hub_download(repo_id=model_id, filename=ckpt_name)
# 2. Run the conversion script
subprocess.run(
[
"python3",
"./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
"--checkpoint_path",
ckpt_path,
"--dump_path" ,
model_id,
]
)
# list files in current directory and return them as a list:
import os
return f"""files in current directory:
{[f for f in os.listdir(".") if os.path.isfile(f)]}"""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=11):
with gr.Group():
gr.Markdown("## 1. Load model info")
input_token = gr.Textbox(
max_lines=1,
label="Hugging Face token",
placeholder="hf_...",
)
gr.Markdown("Get your token [here](https://huggingface.co/settings/tokens).")
input_model = gr.Textbox(
max_lines=1,
label="Model name or URL",
placeholder="username/model_name",
)
btn_get_ckpts = gr.Button("Load")
with gr.Column(scale=10, visible=False) as col_convert:
gr.Markdown("## 2. Convert to Diffusers🧨")
radio_ckpts = gr.Radio(label="Choose a checkpoint to convert", visible=False)
btn_convert = gr.Button("Convert")
error_output = gr.Markdown(label="Output")
btn_get_ckpts.click(
fn=get_ckpt_names,
inputs=[input_model],
outputs=[error_output, radio_ckpts, col_convert],
scroll_to_output=True
)
btn_convert.click(
fn=convert,
inputs=[input_model, radio_ckpts, input_token],
outputs=error_output,
scroll_to_output=True
)
demo.launch()