nyanko7 commited on
Commit
81222fb
1 Parent(s): a7a1fa6

fix: device sync issue

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -125,9 +125,7 @@ def get_model(name):
125
  subfolder="unet",
126
  torch_dtype=torch.float16,
127
  )
128
- if torch.cuda.is_available():
129
- unet.to("cuda")
130
-
131
  unet_cache[name] = unet
132
  lora_cache[name] = LoRANetwork(lora_cache[base_name].text_encoder_loras, unet)
133
 
@@ -135,6 +133,9 @@ def get_model(name):
135
  g_lora = lora_cache[name]
136
  g_unet.set_attn_processor(CrossAttnProcessor())
137
  g_lora.reset()
 
 
 
138
  return g_unet, g_lora, models[keys.index(name)][2]
139
 
140
  # precache on huggingface
 
125
  subfolder="unet",
126
  torch_dtype=torch.float16,
127
  )
128
+ unet.to("cuda")
 
 
129
  unet_cache[name] = unet
130
  lora_cache[name] = LoRANetwork(lora_cache[base_name].text_encoder_loras, unet)
131
 
 
133
  g_lora = lora_cache[name]
134
  g_unet.set_attn_processor(CrossAttnProcessor())
135
  g_lora.reset()
136
+ if torch.cuda.is_available():
137
+ g_unet.to("cuda")
138
+ g_lora.to("cuda")
139
  return g_unet, g_lora, models[keys.index(name)][2]
140
 
141
  # precache on huggingface