File size: 3,382 Bytes
5638c25
 
9115b9f
5638c25
 
 
9372c94
 
9b45214
 
 
 
4c2ed44
9b45214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c2ed44
9b45214
5638c25
9372c94
 
 
 
9b45214
 
 
 
 
 
 
 
9372c94
 
9b45214
 
 
 
 
 
9372c94
9b45214
5638c25
9372c94
9b45214
9372c94
aeaef7e
9b45214
 
 
 
 
aeaef7e
 
 
9b45214
5638c25
9372c94
 
5638c25
9b45214
 
5638c25
9372c94
 
 
 
 
 
 
 
9b45214
 
 
 
9372c94
5638c25
4c2ed44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from safetensors.torch import save_model
import requests
import os

def convert_ckpt_to_safetensors(input_path, output_path):
    # Load the .ckpt file
    # ⚠️ SECURITY WARNING:
    # Loading untrusted .ckpt files with torch.load() can execute arbitrary code.
    # Only load files from trusted sources.
    obj = torch.load(input_path, map_location='cpu')

    # Determine if obj is a state dict or a model object
    if isinstance(obj, dict):
        # Check for nested 'state_dict' or 'model' keys
        if 'state_dict' in obj:
            state_dict = obj['state_dict']
        elif 'model' in obj:
            state_dict = obj['model']
        else:
            # Assume obj is the state dict
            state_dict = obj
    elif hasattr(obj, 'state_dict'):
        # If obj is a model object
        state_dict = obj.state_dict()
    else:
        return "Unsupported checkpoint format."

    # Save the state dictionary, including shared tensors and LM head
    try:
        save_model(state_dict, output_path)
    except Exception as e:
        return f"An error occurred during saving: {e}"

    return "Success"

def process(url, uploaded_file):
    if url:
        # Download the .ckpt file
        local_filename = 'model.ckpt'
        try:
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                with open(local_filename, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
        except Exception as e:
            return f"<p style='color:red;'>Failed to download file: {e}</p>"
    elif uploaded_file is not None:
        # Save uploaded file
        local_filename = 'uploaded_model.ckpt'
        try:
            with open(local_filename, 'wb') as f:
                f.write(uploaded_file.read())
        except Exception as e:
            return f"<p style='color:red;'>Failed to save uploaded file: {e}</p>"
    else:
        return "<p style='color:red;'>Please provide a URL or upload a .ckpt file.</p>"

    output_filename = local_filename.replace('.ckpt', '.safetensors')

    # Convert the .ckpt to .safetensors
    try:
        result = convert_ckpt_to_safetensors(local_filename, output_filename)
        if result != "Success":
            # Clean up the input file
            os.remove(local_filename)
            return f"<p style='color:red;'>An error occurred during conversion: {result}</p>"
    except Exception as e:
        # Clean up the input file
        os.remove(local_filename)
        return f"<p style='color:red;'>An exception occurred: {e}</p>"

    # Clean up the input file
    os.remove(local_filename)

    # Provide a download link for the output file
    return gr.File.update(value=output_filename, visible=True)

iface = gr.Interface(
    fn=process,
    inputs=[
        gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"),
        gr.File(label="Or upload a .ckpt file", file_types=['.ckpt'])
    ],
    outputs=gr.File(label="Converted .safetensors file"),
    title="CKPT to SafeTensors Converter",
    description="""
    Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file.
    **Security Warning:** Loading .ckpt files can execute arbitrary code. Only use files from trusted sources.
    """
)

iface.launch()