Spaces:
Paused
Paused
File size: 3,298 Bytes
f1cc496 b35c416 f1cc496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from collections import OrderedDict
import torch
from .common import MODEL_FOLDER, load_sd_inpainting_model, download_file
model_dict = {
'sd15_inp': {
'sd_version': 1,
'diffusers_ckpt': True,
'model_path': OrderedDict([
('unet', 'sd-1-5-inpainting/unet.fp16.safetensors'),
('encoder', 'sd-1-5-inpainting/encoder.fp16.safetensors'),
('vae', 'sd-1-5-inpainting/vae.fp16.safetensors')
]),
'download_url': OrderedDict([
('unet', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'),
('encoder', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'),
('vae', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true')
])
},
'ds8_inp': {
'sd_version': 1,
'diffusers_ckpt': True,
'model_path': OrderedDict([
('unet', 'ds-8-inpainting/unet.fp16.safetensors'),
('encoder', 'ds-8-inpainting/encoder.fp16.safetensors'),
('vae', 'ds-8-inpainting/vae.fp16.safetensors')
]),
'download_url': OrderedDict([
('unet', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'),
('encoder', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'),
('vae', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true')
])
},
'sd2_inp': {
'sd_version': 2,
'diffusers_ckpt': False,
'model_path': 'sd-2-0-inpainting/512-inpainting-ema.safetensors',
'download_url': 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true'
}
}
model_cache = {}
def pre_download_inpainting_models():
for model_id, model_details in model_dict.items():
download_url = model_details['download_url']
model_path = model_details["model_path"]
if type(download_url) == str and type(model_path) == str:
download_file(download_url, f'{MODEL_FOLDER}/{model_path}')
elif type(download_url) == OrderedDict and type(model_path) == OrderedDict:
for key in download_url.keys():
download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}')
else:
raise Exception('download_url definition type is not supported')
def load_inpainting_model(model_id, dtype=torch.float16, device='cuda:0', cache=False):
if cache and model_id in model_cache:
return model_cache[model_id]
else:
if model_id not in model_dict:
raise Exception(f'Unsupported model-id. Choose one from {list(model_dict.keys())}.')
model = load_sd_inpainting_model(
**model_dict[model_id],
dtype=dtype,
device=device
)
if cache:
model_cache[model_id] = model
return model
|