multimodalart HF staff commited on
Commit
f424501
·
1 Parent(s): d41c21d

Attempt gc again for faster speeds

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -7,6 +7,7 @@ import lora
7
  from time import sleep
8
  import copy
9
  import json
 
10
 
11
  with open("sdxl_loras.json", "r") as file:
12
  data = json.load(file)
@@ -35,11 +36,14 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
35
  "stabilityai/stable-diffusion-xl-base-1.0",
36
  vae=vae,
37
  torch_dtype=torch.float16,
38
- ).to(device)
 
 
39
 
40
  last_lora = ""
41
  last_merged = False
42
 
 
43
  def update_selection(selected_state: gr.SelectData):
44
  lora_repo = sdxl_loras[selected_state.index]["repo"]
45
  instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
@@ -129,11 +133,10 @@ def run_lora(prompt, negative, lora_scale, selected_state):
129
  cross_attention_kwargs = None
130
  if last_lora != repo_name:
131
  if last_merged:
132
- pipe = StableDiffusionXLPipeline.from_pretrained(
133
- "stabilityai/stable-diffusion-xl-base-1.0",
134
- vae=vae,
135
- torch_dtype=torch.float16,
136
- ).to(device)
137
  else:
138
  pipe.unload_lora_weights()
139
  is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
@@ -260,4 +263,4 @@ with gr.Blocks(css="custom.css") as demo:
260
  share_button.click(None, [], [], _js=share_js)
261
 
262
  demo.queue(max_size=20)
263
- demo.launch()
 
7
  from time import sleep
8
  import copy
9
  import json
10
+ import gc
11
 
12
  with open("sdxl_loras.json", "r") as file:
13
  data = json.load(file)
 
36
  "stabilityai/stable-diffusion-xl-base-1.0",
37
  vae=vae,
38
  torch_dtype=torch.float16,
39
+ ).to("cpu")
40
+ original_pipe = copy.deepcopy(pipe)
41
+ pipe.to(device)
42
 
43
  last_lora = ""
44
  last_merged = False
45
 
46
+
47
  def update_selection(selected_state: gr.SelectData):
48
  lora_repo = sdxl_loras[selected_state.index]["repo"]
49
  instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
 
133
  cross_attention_kwargs = None
134
  if last_lora != repo_name:
135
  if last_merged:
136
+ del pipe
137
+ gc.collect()
138
+ pipe = copy.deepcopy(original_pipe)
139
+ pipe.to(device)
 
140
  else:
141
  pipe.unload_lora_weights()
142
  is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
 
263
  share_button.click(None, [], [], _js=share_js)
264
 
265
  demo.queue(max_size=20)
266
+ demo.launch()