nyanko7 commited on
Commit
92ee4a3
1 Parent(s): 81222fb

fix: minor fix

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -35,8 +35,8 @@ alt_models = [
35
  ("Stable Diffusion V1.5", "runwayml/stable-diffusion-v1-5", 1),
36
  ("Anything V3.0", "Linaqruf/anything-v3.0", 2),
37
  ("Open Journey", "prompthero/openjourney", 1),
38
- ("Eimis AnimeDiffusion", "eimiss/EimisAnimeDiffusion_1.0v", 2)
39
- ("Dreamlike Photoreal 2.0", "dreamlike-art/dreamlike-photoreal-2.0", 1)
40
  ("Redshift Diffusion", "nitrosocke/redshift-diffusion", 1)
41
  ]
42
 
@@ -115,13 +115,14 @@ lora_cache = {
115
  }
116
 
117
  def get_model(name):
118
- keys = [k[0] for k in models]
 
119
  if name not in unet_cache:
120
  if name not in keys:
121
  raise ValueError(name)
122
  else:
123
  unet = UNet2DConditionModel.from_pretrained(
124
- models[keys.index(name)][1],
125
  subfolder="unet",
126
  torch_dtype=torch.float16,
127
  )
@@ -133,10 +134,11 @@ def get_model(name):
133
  g_lora = lora_cache[name]
134
  g_unet.set_attn_processor(CrossAttnProcessor())
135
  g_lora.reset()
 
136
  if torch.cuda.is_available():
137
  g_unet.to("cuda")
138
  g_lora.to("cuda")
139
- return g_unet, g_lora, models[keys.index(name)][2]
140
 
141
  # precache on huggingface
142
  for model in models:
 
35
  ("Stable Diffusion V1.5", "runwayml/stable-diffusion-v1-5", 1),
36
  ("Anything V3.0", "Linaqruf/anything-v3.0", 2),
37
  ("Open Journey", "prompthero/openjourney", 1),
38
+ ("Eimis AnimeDiffusion", "eimiss/EimisAnimeDiffusion_1.0v", 2),
39
+ ("Dreamlike Photoreal 2.0", "dreamlike-art/dreamlike-photoreal-2.0", 1),
40
  ("Redshift Diffusion", "nitrosocke/redshift-diffusion", 1)
41
  ]
42
 
 
115
  }
116
 
117
  def get_model(name):
118
+ local_models = models + alt_models
119
+ keys = [k[0] for k in local_models]
120
  if name not in unet_cache:
121
  if name not in keys:
122
  raise ValueError(name)
123
  else:
124
  unet = UNet2DConditionModel.from_pretrained(
125
+ local_models[keys.index(name)][1],
126
  subfolder="unet",
127
  torch_dtype=torch.float16,
128
  )
 
134
  g_lora = lora_cache[name]
135
  g_unet.set_attn_processor(CrossAttnProcessor())
136
  g_lora.reset()
137
+ clip_skip = local_models[keys.index(name)][2]
138
  if torch.cuda.is_available():
139
  g_unet.to("cuda")
140
  g_lora.to("cuda")
141
+ return g_unet, g_lora, clip_skip
142
 
143
  # precache on huggingface
144
  for model in models: