Spaces:
Paused
Paused
from collections import OrderedDict | |
import torch | |
from .common import MODEL_FOLDER, load_sd_inpainting_model, download_file | |
model_dict = { | |
'sd15_inp': { | |
'sd_version': 1, | |
'diffusers_ckpt': True, | |
'model_path': OrderedDict([ | |
('unet', 'sd-1-5-inpainting/unet.fp16.safetensors'), | |
('encoder', 'sd-1-5-inpainting/encoder.fp16.safetensors'), | |
('vae', 'sd-1-5-inpainting/vae.fp16.safetensors') | |
]), | |
'download_url': OrderedDict([ | |
('unet', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'), | |
('encoder', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'), | |
('vae', 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true') | |
]) | |
}, | |
'ds8_inp': { | |
'sd_version': 1, | |
'diffusers_ckpt': True, | |
'model_path': OrderedDict([ | |
('unet', 'ds-8-inpainting/unet.fp16.safetensors'), | |
('encoder', 'ds-8-inpainting/encoder.fp16.safetensors'), | |
('vae', 'ds-8-inpainting/vae.fp16.safetensors') | |
]), | |
'download_url': OrderedDict([ | |
('unet', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'), | |
('encoder', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'), | |
('vae', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true') | |
]) | |
}, | |
'sd2_inp': { | |
'sd_version': 2, | |
'diffusers_ckpt': False, | |
'model_path': 'sd-2-0-inpainting/512-inpainting-ema.safetensors', | |
'download_url': 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true' | |
} | |
} | |
model_cache = {} | |
def pre_download_inpainting_models(): | |
for model_id, model_details in model_dict.items(): | |
download_url = model_details['download_url'] | |
model_path = model_details["model_path"] | |
if type(download_url) == str and type(model_path) == str: | |
download_file(download_url, f'{MODEL_FOLDER}/{model_path}') | |
elif type(download_url) == OrderedDict and type(model_path) == OrderedDict: | |
for key in download_url.keys(): | |
download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}') | |
else: | |
raise Exception('download_url definition type is not supported') | |
def load_inpainting_model(model_id, dtype=torch.float16, device='cuda:0', cache=False): | |
if cache and model_id in model_cache: | |
return model_cache[model_id] | |
else: | |
if model_id not in model_dict: | |
raise Exception(f'Unsupported model-id. Choose one from {list(model_dict.keys())}.') | |
model = load_sd_inpainting_model( | |
**model_dict[model_id], | |
dtype=dtype, | |
device=device | |
) | |
if cache: | |
model_cache[model_id] = model | |
return model | |