FoodDesert commited on
Commit
72d9282
1 Parent(s): 26c6087

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -15
app.py CHANGED
@@ -4,17 +4,11 @@ import gradio as gr
4
  import os
5
 
6
  def convert_embedding(sd15_embedding):
7
- # Temporary file paths
8
- input_path = "temp_input.pt"
9
- output_path = "temp_output.safetensors"
10
-
11
- # Save uploaded file to disk to be processed
12
- with open(input_path, "wb") as f:
13
- f.write(sd15_embedding)
14
 
15
- # Your existing conversion logic
16
- sd15_embedding = torch.load(input_path, map_location=torch.device('cpu'))
17
  sd15_tensor = sd15_embedding['string_to_param']['*']
 
18
  num_vectors = sd15_tensor.shape[0]
19
  clip_g_shape = (num_vectors, 1280)
20
  clip_l_shape = (num_vectors, 768)
@@ -23,18 +17,15 @@ def convert_embedding(sd15_embedding):
23
  clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16)
24
  save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path)
25
 
26
- # Remove the temporary input file
27
- os.remove(input_path)
28
-
29
  # Return the path to the converted file for download
30
  return output_path
31
 
32
  iface = gr.Interface(
33
  fn=convert_embedding,
34
- inputs=gr.File(label="Upload SD1.5 Embedding"),
35
- outputs=gr.File(label="Download Converted SDXL Embedding"),
36
  title="SD1.5 to SDXL Embedding Converter",
37
- description="Upload an SD1.5 embedding file to convert it to SDXL format."
38
  )
39
 
40
  if __name__ == "__main__":
 
4
  import os
5
 
6
  def convert_embedding(sd15_embedding):
7
+ output_path = "embedding.safetensors"
 
 
 
 
 
 
8
 
9
+ sd15_embedding = torch.load(sd15_embedding.name, weights_only=True)
 
10
  sd15_tensor = sd15_embedding['string_to_param']['*']
11
+
12
  num_vectors = sd15_tensor.shape[0]
13
  clip_g_shape = (num_vectors, 1280)
14
  clip_l_shape = (num_vectors, 768)
 
17
  clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16)
18
  save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path)
19
 
 
 
 
20
  # Return the path to the converted file for download
21
  return output_path
22
 
23
  iface = gr.Interface(
24
  fn=convert_embedding,
25
+ inputs=gr.File(label="Upload SD1.5 pt Embedding"),
26
+ outputs=gr.File(label="Download Converted SDXL safetensors Embedding"),
27
  title="SD1.5 to SDXL Embedding Converter",
28
+ description="Upload an SD1.5 embedding file in pt format to convert it to SDXL."
29
  )
30
 
31
  if __name__ == "__main__":