File size: 3,382 Bytes
5638c25 9115b9f 5638c25 9372c94 9b45214 4c2ed44 9b45214 4c2ed44 9b45214 5638c25 9372c94 9b45214 9372c94 9b45214 9372c94 9b45214 5638c25 9372c94 9b45214 9372c94 aeaef7e 9b45214 aeaef7e 9b45214 5638c25 9372c94 5638c25 9b45214 5638c25 9372c94 9b45214 9372c94 5638c25 4c2ed44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import torch
from safetensors.torch import save_model
import requests
import os
def convert_ckpt_to_safetensors(input_path, output_path):
# Load the .ckpt file
# ⚠️ SECURITY WARNING:
# Loading untrusted .ckpt files with torch.load() can execute arbitrary code.
# Only load files from trusted sources.
obj = torch.load(input_path, map_location='cpu')
# Determine if obj is a state dict or a model object
if isinstance(obj, dict):
# Check for nested 'state_dict' or 'model' keys
if 'state_dict' in obj:
state_dict = obj['state_dict']
elif 'model' in obj:
state_dict = obj['model']
else:
# Assume obj is the state dict
state_dict = obj
elif hasattr(obj, 'state_dict'):
# If obj is a model object
state_dict = obj.state_dict()
else:
return "Unsupported checkpoint format."
# Save the state dictionary, including shared tensors and LM head
try:
save_model(state_dict, output_path)
except Exception as e:
return f"An error occurred during saving: {e}"
return "Success"
def process(url, uploaded_file):
if url:
# Download the .ckpt file
local_filename = 'model.ckpt'
try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e:
return f"<p style='color:red;'>Failed to download file: {e}</p>"
elif uploaded_file is not None:
# Save uploaded file
local_filename = 'uploaded_model.ckpt'
try:
with open(local_filename, 'wb') as f:
f.write(uploaded_file.read())
except Exception as e:
return f"<p style='color:red;'>Failed to save uploaded file: {e}</p>"
else:
return "<p style='color:red;'>Please provide a URL or upload a .ckpt file.</p>"
output_filename = local_filename.replace('.ckpt', '.safetensors')
# Convert the .ckpt to .safetensors
try:
result = convert_ckpt_to_safetensors(local_filename, output_filename)
if result != "Success":
# Clean up the input file
os.remove(local_filename)
return f"<p style='color:red;'>An error occurred during conversion: {result}</p>"
except Exception as e:
# Clean up the input file
os.remove(local_filename)
return f"<p style='color:red;'>An exception occurred: {e}</p>"
# Clean up the input file
os.remove(local_filename)
# Provide a download link for the output file
return gr.File.update(value=output_filename, visible=True)
iface = gr.Interface(
fn=process,
inputs=[
gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"),
gr.File(label="Or upload a .ckpt file", file_types=['.ckpt'])
],
outputs=gr.File(label="Converted .safetensors file"),
title="CKPT to SafeTensors Converter",
description="""
Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file.
**Security Warning:** Loading .ckpt files can execute arbitrary code. Only use files from trusted sources.
"""
)
iface.launch() |