|
import gradio as gr |
|
import torch |
|
from safetensors.torch import save_model |
|
import requests |
|
import os |
|
|
|
def convert_ckpt_to_safetensors(input_path, output_path): |
|
|
|
|
|
|
|
|
|
obj = torch.load(input_path, map_location='cpu') |
|
|
|
|
|
if isinstance(obj, dict): |
|
|
|
if 'state_dict' in obj: |
|
state_dict = obj['state_dict'] |
|
elif 'model' in obj: |
|
state_dict = obj['model'] |
|
else: |
|
|
|
state_dict = obj |
|
elif hasattr(obj, 'state_dict'): |
|
|
|
state_dict = obj.state_dict() |
|
else: |
|
return "Unsupported checkpoint format." |
|
|
|
|
|
try: |
|
save_model(state_dict, output_path) |
|
except Exception as e: |
|
return f"An error occurred during saving: {e}" |
|
|
|
return "Success" |
|
|
|
def process(url, uploaded_file): |
|
if url: |
|
|
|
local_filename = 'model.ckpt' |
|
try: |
|
with requests.get(url, stream=True) as r: |
|
r.raise_for_status() |
|
with open(local_filename, 'wb') as f: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
except Exception as e: |
|
return f"<p style='color:red;'>Failed to download file: {e}</p>" |
|
elif uploaded_file is not None: |
|
|
|
local_filename = 'uploaded_model.ckpt' |
|
try: |
|
with open(local_filename, 'wb') as f: |
|
f.write(uploaded_file.read()) |
|
except Exception as e: |
|
return f"<p style='color:red;'>Failed to save uploaded file: {e}</p>" |
|
else: |
|
return "<p style='color:red;'>Please provide a URL or upload a .ckpt file.</p>" |
|
|
|
output_filename = local_filename.replace('.ckpt', '.safetensors') |
|
|
|
|
|
try: |
|
result = convert_ckpt_to_safetensors(local_filename, output_filename) |
|
if result != "Success": |
|
|
|
os.remove(local_filename) |
|
return f"<p style='color:red;'>An error occurred during conversion: {result}</p>" |
|
except Exception as e: |
|
|
|
os.remove(local_filename) |
|
return f"<p style='color:red;'>An exception occurred: {e}</p>" |
|
|
|
|
|
os.remove(local_filename) |
|
|
|
|
|
return gr.File.update(value=output_filename, visible=True) |
|
|
|
iface = gr.Interface( |
|
fn=process, |
|
inputs=[ |
|
gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"), |
|
gr.File(label="Or upload a .ckpt file", file_types=['.ckpt']) |
|
], |
|
outputs=gr.File(label="Converted .safetensors file"), |
|
title="CKPT to SafeTensors Converter", |
|
description=""" |
|
Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file. |
|
**Security Warning:** Loading .ckpt files can execute arbitrary code. Only use files from trusted sources. |
|
""" |
|
) |
|
|
|
iface.launch() |