Spaces:
Paused
Paused
import importlib | |
import requests | |
from collections import OrderedDict | |
from pathlib import Path | |
from os.path import dirname | |
import torch | |
import safetensors | |
import safetensors.torch | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
from src.smplfusion import DDIM, share, scheduler | |
from src.utils.convert_diffusers_to_sd import ( | |
convert_vae_state_dict, | |
convert_unet_state_dict, | |
convert_text_enc_state_dict, | |
convert_text_enc_state_dict_v20 | |
) | |
PROJECT_DIR = dirname(dirname(dirname(__file__))) | |
CONFIG_FOLDER = f'{PROJECT_DIR}/config' | |
MODEL_FOLDER = f'{PROJECT_DIR}/checkpoints' | |
def download_file(url, save_path, chunk_size=1024): | |
try: | |
save_path = Path(save_path) | |
if save_path.exists(): | |
print(f'{save_path.name} exists') | |
return | |
save_path.parent.mkdir(exist_ok=True, parents=True) | |
resp = requests.get(url, stream=True) | |
total = int(resp.headers.get('content-length', 0)) | |
with open(save_path, 'wb') as file, tqdm( | |
desc=save_path.name, | |
total=total, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in resp.iter_content(chunk_size=chunk_size): | |
size = file.write(data) | |
bar.update(size) | |
print(f'{save_path.name} download finished') | |
except Exception as e: | |
raise Exception(f"Download failed: {e}") | |
def get_obj_from_str(string): | |
module, cls = string.rsplit(".", 1) | |
try: | |
return getattr(importlib.import_module(module, package=None), cls) | |
except: | |
return getattr(importlib.import_module('src.' + module, package=None), cls) | |
def load_obj(path): | |
objyaml = OmegaConf.load(path) | |
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) | |
def load_state_dict(model_path): | |
model_ext = Path(model_path).suffix | |
if model_ext == '.safetensors': | |
state_dict = safetensors.torch.load_file(model_path) | |
elif model_ext == '.ckpt': | |
state_dict = torch.load(model_path)['state_dict'] | |
elif model_ext == '.bin': | |
state_dict = torch.load(model_path) | |
else: | |
raise Exception(f'Unsupported model extension {model_ext}') | |
return state_dict | |
def load_sd_inpainting_model( | |
download_url, | |
model_path, | |
sd_version, | |
diffusers_ckpt=False, | |
dtype=torch.float16, | |
device='cuda:0' | |
): | |
if type(download_url) == str and type(model_path) == str: | |
model_path = f'{MODEL_FOLDER}/{model_path}' | |
download_file(download_url, model_path) | |
state_dict = load_state_dict(model_path) | |
if diffusers_ckpt: | |
raise Exception('Not implemented') | |
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} | |
unet_state = extract(state_dict, 'model.diffusion_model') | |
encoder_state = extract(state_dict, 'cond_stage_model') | |
vae_state = extract(state_dict, 'first_stage_model') | |
elif type(download_url) == OrderedDict and type(model_path) == OrderedDict: | |
for key in download_url.keys(): | |
download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}') | |
unet_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["unet"]}') | |
encoder_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["encoder"]}') | |
vae_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["vae"]}') | |
if diffusers_ckpt: | |
unet_state = convert_unet_state_dict(unet_state) | |
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in encoder_state | |
if is_v20_model: | |
encoder_state = {"transformer." + k: v for k, v in encoder_state .items()} | |
encoder_state = convert_text_enc_state_dict_v20(encoder_state) | |
encoder_state = {"model." + k: v for k, v in encoder_state .items()} | |
else: | |
encoder_state = convert_text_enc_state_dict(encoder_state) | |
encoder_state = {"transformer." + k: v for k, v in encoder_state .items()} | |
vae_state = convert_vae_state_dict(vae_state) | |
else: | |
raise Exception('download_url or model_path definition type is not supported') | |
# Load common config files | |
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml') | |
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda() | |
# Load version specific config files | |
if sd_version == 1: | |
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda() | |
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda() | |
elif sd_version == 2: | |
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda() | |
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda() | |
else: | |
raise Exception(f'Unsupported SD version {sd_version}.') | |
ddim = DDIM(config, vae, encoder, unet) | |
unet.load_state_dict(unet_state) | |
encoder.load_state_dict(encoder_state, strict=False) | |
vae.load_state_dict(vae_state) | |
if dtype == torch.float16: | |
unet.convert_to_fp16() | |
unet.to(device=device) | |
vae.to(dtype=dtype, device=device) | |
encoder.to(dtype=dtype, device=device) | |
encoder.device = device | |
unet = unet.requires_grad_(False) | |
encoder = encoder.requires_grad_(False) | |
vae = vae.requires_grad_(False) | |
ddim = DDIM(config, vae, encoder, unet) | |
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end) | |
return ddim | |