PANH commited on
Commit
9115b9f
1 Parent(s): aeaef7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -23
app.py CHANGED
@@ -1,38 +1,21 @@
1
  import gradio as gr
2
  import torch
3
- from safetensors.torch import save_file
4
  import requests
5
  import os
6
 
7
- def resolve_shared_tensors(state_dict):
8
- storage_map = {}
9
- for k, v in state_dict.items():
10
- if not isinstance(v, torch.Tensor):
11
- continue
12
- storage_data_ptr = v.storage().data_ptr()
13
- if storage_data_ptr in storage_map:
14
- # Clone the tensor to create independent storage
15
- state_dict[k] = v.clone()
16
- else:
17
- storage_map[storage_data_ptr] = k
18
- return state_dict
19
-
20
  def convert_ckpt_to_safetensors(input_path, output_path):
21
  # Load the .ckpt file
22
  state_dict = torch.load(input_path, map_location='cpu')
23
 
24
- # Check for nested 'state_dict' key
25
  if 'state_dict' in state_dict:
26
  state_dict = state_dict['state_dict']
 
 
27
 
28
- # Filter out non-tensor entries
29
- tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
30
-
31
- # Resolve shared tensors
32
- tensor_state_dict = resolve_shared_tensors(tensor_state_dict)
33
-
34
- # Save as .safetensors
35
- save_file(tensor_state_dict, output_path)
36
 
37
  def process(url, uploaded_file):
38
  if url:
 
1
  import gradio as gr
2
  import torch
3
+ from safetensors.torch import save_model
4
  import requests
5
  import os
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def convert_ckpt_to_safetensors(input_path, output_path):
8
  # Load the .ckpt file
9
  state_dict = torch.load(input_path, map_location='cpu')
10
 
11
+ # If the checkpoint has a 'state_dict' key, extract it
12
  if 'state_dict' in state_dict:
13
  state_dict = state_dict['state_dict']
14
+ elif 'model' in state_dict:
15
+ state_dict = state_dict['model']
16
 
17
+ # Save the entire state dictionary, including non-tensor entries
18
+ save_model(state_dict, output_path)
 
 
 
 
 
 
19
 
20
  def process(url, uploaded_file):
21
  if url: