PANH commited on
Commit
9372c94
1 Parent(s): c42a6b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -77
app.py CHANGED
@@ -4,86 +4,47 @@ from safetensors.torch import save_file
4
  import requests
5
  import os
6
 
7
- def download_ckpt_file(ckpt_url):
8
- """
9
- Downloads the .ckpt file from the provided Hugging Face URL.
10
- """
11
- try:
12
- # Get the filename from the URL
13
- filename = ckpt_url.split("/")[-1]
14
- response = requests.get(ckpt_url)
15
-
16
- # Save the file locally
17
- with open(filename, 'wb') as f:
18
- f.write(response.content)
19
-
20
- return filename
21
- except Exception as e:
22
- return None, f"Error downloading the file: {str(e)}"
23
 
24
- def convert_ckpt_to_safetensors(ckpt_file):
25
- """
26
- Converts a .ckpt file to safetensors format.
27
- """
28
- try:
29
- # Load the checkpoint
30
- checkpoint = torch.load(ckpt_file, map_location='cpu')
31
-
32
- # Ensure the checkpoint contains a 'state_dict'
33
- if 'state_dict' in checkpoint:
34
- state_dict = checkpoint['state_dict']
35
- else:
36
- state_dict = checkpoint
 
 
 
37
 
38
- # Remove any prefixes if necessary (e.g., 'module.')
39
- new_state_dict = {}
40
- for key, value in state_dict.items():
41
- if key.startswith('module.'):
42
- new_key = key[len('module.'):]
43
- else:
44
- new_key = key
45
- new_state_dict[new_key] = value
46
 
47
- # Save to safetensors format
48
- output_file = ckpt_file.replace(".ckpt", ".safetensors")
49
- save_file(new_state_dict, output_file)
50
- return output_file
51
- except Exception as e:
52
- return f"Error converting to safetensors: {str(e)}"
53
 
54
- def handle_conversion(ckpt_url):
55
- """
56
- Handles the entire process of downloading the .ckpt file from the link,
57
- converting it to safetensors, and providing the user with the output file.
58
- """
59
- # Download the .ckpt file
60
- filename = download_ckpt_file(ckpt_url)
61
- if not filename:
62
- return None, "Failed to download the file."
63
-
64
- # Convert the .ckpt file to safetensors
65
- safetensors_file = convert_ckpt_to_safetensors(filename)
66
-
67
- # If the conversion is successful, return the safetensors file
68
- if safetensors_file.endswith(".safetensors"):
69
- return safetensors_file
70
- else:
71
- return None, safetensors_file
72
 
73
- # Gradio Interface
74
- def convert_and_download(ckpt_url):
75
- safetensors_file, message = handle_conversion(ckpt_url)
76
-
77
- if safetensors_file:
78
- return safetensors_file # Provide the converted file for download
79
- else:
80
- return message # Provide the error message
 
 
81
 
82
- # Define the Gradio interface
83
- gr.Interface(
84
- fn=convert_and_download,
85
- inputs=gr.Textbox(label="Hugging Face CKPT File URL", placeholder="Enter the link to a .ckpt file"),
86
- outputs=gr.File(label="Download .safetensors file"),
87
- title="CKPT to Safetensors Converter",
88
- description="Enter the Hugging Face URL for a .ckpt file and convert it to safetensors format."
89
- ).launch()
 
4
  import requests
5
  import os
6
 
7
+ def convert_ckpt_to_safetensors(input_path, output_path):
8
+ # Load the .ckpt file
9
+ state_dict = torch.load(input_path, map_location='cpu')
10
+ # Save as .safetensors
11
+ save_file(state_dict, output_path)
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def process(url, uploaded_file):
14
+ if url:
15
+ # Download the .ckpt file
16
+ local_filename = 'model.ckpt'
17
+ with requests.get(url, stream=True) as r:
18
+ r.raise_for_status()
19
+ with open(local_filename, 'wb') as f:
20
+ for chunk in r.iter_content(chunk_size=8192):
21
+ f.write(chunk)
22
+ elif uploaded_file is not None:
23
+ # Save uploaded file
24
+ local_filename = uploaded_file.name
25
+ with open(local_filename, 'wb') as f:
26
+ f.write(uploaded_file.read())
27
+ else:
28
+ return "Please provide a URL or upload a .ckpt file."
29
 
30
+ output_filename = local_filename.replace('.ckpt', '.safetensors')
31
+ # Convert the .ckpt to .safetensors
32
+ convert_ckpt_to_safetensors(local_filename, output_filename)
 
 
 
 
 
33
 
34
+ # Clean up the input file
35
+ os.remove(local_filename)
 
 
 
 
36
 
37
+ return output_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ iface = gr.Interface(
40
+ fn=process,
41
+ inputs=[
42
+ gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"),
43
+ gr.File(label="Or upload a .ckpt file", file_types=['.ckpt'])
44
+ ],
45
+ outputs=gr.File(label="Converted .safetensors file"),
46
+ title="CKPT to SafeTensors Converter",
47
+ description="Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file."
48
+ )
49
 
50
+ iface.launch()