FoodDesert commited on
Commit
fc06c9f
·
verified ·
1 Parent(s): 987ab0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -1,13 +1,22 @@
1
  import torch
2
- from safetensors.torch import save_file
3
  import gradio as gr
4
  import os
5
 
6
- def convert_embedding(sd15_embedding):
7
  output_path = "embedding.safetensors"
 
8
 
9
- sd15_embedding = torch.load(sd15_embedding.name, map_location=torch.device('cpu'), weights_only=True)
10
- sd15_tensor = sd15_embedding['string_to_param']['*']
 
 
 
 
 
 
 
 
11
 
12
  num_vectors = sd15_tensor.shape[0]
13
  clip_g_shape = (num_vectors, 1280)
@@ -30,3 +39,4 @@ iface = gr.Interface(
30
 
31
  if __name__ == "__main__":
32
  iface.launch()
 
 
1
  import torch
2
+ from safetensors.torch import save_file, load_file
3
  import gradio as gr
4
  import os
5
 
6
+ def convert_embedding(uploaded_file):
7
  output_path = "embedding.safetensors"
8
+ file_extension = os.path.splitext(uploaded_file.name)[1]
9
 
10
+ #The sample files are probably structured in these ways because the pt files were probably all created with automatic1111, and the safetensors files were probably created with kohya_ss
11
+ #If we learn of other programs that structure the embedding file differently, we'll have to adjust the logic.
12
+ if file_extension == '.pt':
13
+ sd15_embedding = torch.load(uploaded_file.name, map_location=torch.device('cpu'), weights_only=True)
14
+ sd15_tensor = sd15_embedding['string_to_param']['*']
15
+ elif file_extension == '.safetensors':
16
+ loaded_tensors = load_file(uploaded_file.name)
17
+ sd15_tensor = loaded_tensors['emb_params']
18
+ else:
19
+ raise ValueError("Unsupported file format")
20
 
21
  num_vectors = sd15_tensor.shape[0]
22
  clip_g_shape = (num_vectors, 1280)
 
39
 
40
  if __name__ == "__main__":
41
  iface.launch()
42
+