Spaces:
Runtime error
Runtime error
FoodDesert
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,22 @@
|
|
1 |
import torch
|
2 |
-
from safetensors.torch import save_file
|
3 |
import gradio as gr
|
4 |
import os
|
5 |
|
6 |
-
def convert_embedding(
|
7 |
output_path = "embedding.safetensors"
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
num_vectors = sd15_tensor.shape[0]
|
13 |
clip_g_shape = (num_vectors, 1280)
|
@@ -30,3 +39,4 @@ iface = gr.Interface(
|
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
iface.launch()
|
|
|
|
1 |
import torch
|
2 |
+
from safetensors.torch import save_file, load_file
|
3 |
import gradio as gr
|
4 |
import os
|
5 |
|
6 |
+
def convert_embedding(uploaded_file):
|
7 |
output_path = "embedding.safetensors"
|
8 |
+
file_extension = os.path.splitext(uploaded_file.name)[1]
|
9 |
|
10 |
+
#The sample files are probably structured in these ways because the pt files were probably all created with automatic1111, and the safetensors files were probably created with kohya_ss
|
11 |
+
#If we learn of other programs that structure the embedding file differently, we'll have to adjust the logic.
|
12 |
+
if file_extension == '.pt':
|
13 |
+
sd15_embedding = torch.load(uploaded_file.name, map_location=torch.device('cpu'), weights_only=True)
|
14 |
+
sd15_tensor = sd15_embedding['string_to_param']['*']
|
15 |
+
elif file_extension == '.safetensors':
|
16 |
+
loaded_tensors = load_file(uploaded_file.name)
|
17 |
+
sd15_tensor = loaded_tensors['emb_params']
|
18 |
+
else:
|
19 |
+
raise ValueError("Unsupported file format")
|
20 |
|
21 |
num_vectors = sd15_tensor.shape[0]
|
22 |
clip_g_shape = (num_vectors, 1280)
|
|
|
39 |
|
40 |
if __name__ == "__main__":
|
41 |
iface.launch()
|
42 |
+
|