|
import gradio as gr |
|
import torch |
|
from safetensors.torch import save_file |
|
import requests |
|
import os |
|
|
|
def convert_ckpt_to_safetensors(input_path, output_path): |
|
|
|
state_dict = torch.load(input_path, map_location='cpu') |
|
|
|
save_file(state_dict, output_path) |
|
|
|
def process(url, uploaded_file): |
|
if url: |
|
|
|
local_filename = 'model.ckpt' |
|
with requests.get(url, stream=True) as r: |
|
r.raise_for_status() |
|
with open(local_filename, 'wb') as f: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
elif uploaded_file is not None: |
|
|
|
local_filename = uploaded_file.name |
|
with open(local_filename, 'wb') as f: |
|
f.write(uploaded_file.read()) |
|
else: |
|
return "Please provide a URL or upload a .ckpt file." |
|
|
|
output_filename = local_filename.replace('.ckpt', '.safetensors') |
|
|
|
convert_ckpt_to_safetensors(local_filename, output_filename) |
|
|
|
|
|
os.remove(local_filename) |
|
|
|
return output_filename |
|
|
|
iface = gr.Interface( |
|
fn=process, |
|
inputs=[ |
|
gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"), |
|
gr.File(label="Or upload a .ckpt file", file_types=['.ckpt']) |
|
], |
|
outputs=gr.File(label="Converted .safetensors file"), |
|
title="CKPT to SafeTensors Converter", |
|
description="Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file." |
|
) |
|
|
|
iface.launch() |
|
|