multimodalart HF staff commited on
Commit
be2828d
·
1 Parent(s): 09baf03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -58
app.py CHANGED
@@ -93,49 +93,54 @@ def check_selected(selected_state):
93
  if not selected_state:
94
  raise gr.Error("You must select a LoRA")
95
 
96
- def get_cross_attention_kwargs(scale, repo_name, is_compatible):
97
- if repo_name != last_lora and is_compatible:
98
- return {"scale": scale}
99
- return None
 
 
 
100
 
101
- def load_lora_model(pipe, repo_name, full_path_lora, lora_scale, selected_state):
102
- if repo_name == last_lora:
103
- return
104
-
105
- if last_merged:
106
- pipe = copy.deepcopy(original_pipe)
107
- pipe.to(device)
108
- else:
109
- pipe.unload_lora_weights()
 
 
110
 
111
- is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
112
- if is_compatible:
113
- pipe.load_lora_weights(full_path_lora)
114
- else:
115
- load_incompatible_lora(pipe, full_path_lora, lora_scale)
116
 
117
- def load_incompatible_lora(pipe, full_path_lora, lora_scale):
118
- for weights_file in [full_path_lora]:
119
- if ";" in weights_file:
120
- weights_file, multiplier = weights_file.split(";")
121
- multiplier = float(multiplier)
122
- else:
123
- multiplier = lora_scale
124
 
125
- lora_model, weights_sd = lora.create_network_from_weights(
126
- multiplier,
127
- full_path_lora,
128
- pipe.vae,
129
- pipe.text_encoder,
130
- pipe.unet,
131
- for_inference=True,
132
- )
133
- lora_model.merge_to(
134
- pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
135
- )
 
 
 
 
 
 
 
 
 
136
 
137
- def generate_image(pipe, prompt, negative, cross_attention_kwargs):
138
- return pipe(
139
  prompt=prompt,
140
  negative_prompt=negative,
141
  width=768,
@@ -144,26 +149,6 @@ def generate_image(pipe, prompt, negative, cross_attention_kwargs):
144
  guidance_scale=7.5,
145
  cross_attention_kwargs=cross_attention_kwargs,
146
  ).images[0]
147
-
148
- def run_lora(prompt, negative, lora_scale, selected_state):
149
- global last_lora, last_merged, pipe
150
-
151
- if not selected_state:
152
- raise gr.Error("You must select a LoRA")
153
-
154
- if negative == "":
155
- negative = None
156
-
157
- repo_name = sdxl_loras[selected_state.index]["repo"]
158
- full_path_lora = saved_names[selected_state.index]
159
-
160
- cross_attention_kwargs = get_cross_attention_kwargs(
161
- lora_scale, repo_name, sdxl_loras[selected_state.index]["is_compatible"])
162
-
163
- load_lora_model(pipe, repo_name, full_path_lora, lora_scale, selected_state)
164
-
165
- image = generate_image(pipe, prompt, negative, cross_attention_kwargs)
166
-
167
  last_lora = repo_name
168
  return image, gr.update(visible=True)
169
 
 
93
  if not selected_state:
94
  raise gr.Error("You must select a LoRA")
95
 
96
+ def merge_incompatible_lora(full_path_lora, lora_scale):
97
+ for weights_file in [full_path_lora]:
98
+ if ";" in weights_file:
99
+ weights_file, multiplier = weights_file.split(";")
100
+ multiplier = float(multiplier)
101
+ else:
102
+ multiplier = lora_scale
103
 
104
+ lora_model, weights_sd = lora.create_network_from_weights(
105
+ multiplier,
106
+ full_path_lora,
107
+ pipe.vae,
108
+ pipe.text_encoder,
109
+ pipe.unet,
110
+ for_inference=True,
111
+ )
112
+ lora_model.merge_to(
113
+ pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
114
+ )
115
 
116
+ def run_lora(prompt, negative, lora_scale, selected_state):
117
+ global last_lora, last_merged, pipe
 
 
 
118
 
119
+ if negative == "":
120
+ negative = None
 
 
 
 
 
121
 
122
+ if not selected_state:
123
+ raise gr.Error("You must select a LoRA")
124
+ repo_name = sdxl_loras[selected_state.index]["repo"]
125
+ weight_name = sdxl_loras[selected_state.index]["weights"]
126
+ full_path_lora = saved_names[selected_state.index]
127
+ cross_attention_kwargs = None
128
+ if last_lora != repo_name:
129
+ if last_merged:
130
+ pipe = copy.deepcopy(original_pipe)
131
+ pipe.to(device)
132
+ else:
133
+ pipe.unload_lora_weights()
134
+ is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
135
+
136
+ if is_compatible:
137
+ pipe.load_lora_weights(full_path_lora)
138
+ cross_attention_kwargs = {"scale": lora_scale}
139
+ else:
140
+ merge_incompatible_lora(full_path_lora, lora_scale)
141
+ last_merged = True
142
 
143
+ image = pipe(
 
144
  prompt=prompt,
145
  negative_prompt=negative,
146
  width=768,
 
149
  guidance_scale=7.5,
150
  cross_attention_kwargs=cross_attention_kwargs,
151
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  last_lora = repo_name
153
  return image, gr.update(visible=True)
154