File size: 5,553 Bytes
f1cc496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import importlib
import requests
from collections import OrderedDict
from pathlib import Path
from os.path import dirname

import torch
import safetensors
import safetensors.torch
from omegaconf import OmegaConf
from tqdm import tqdm

from src.smplfusion import DDIM, share, scheduler
from src.utils.convert_diffusers_to_sd import (
    convert_vae_state_dict,
    convert_unet_state_dict,
    convert_text_enc_state_dict,
    convert_text_enc_state_dict_v20
)


PROJECT_DIR = dirname(dirname(dirname(__file__)))
CONFIG_FOLDER =  f'{PROJECT_DIR}/config'
MODEL_FOLDER =  f'{PROJECT_DIR}/checkpoints'


def download_file(url, save_path, chunk_size=1024):
    try:
        save_path = Path(save_path)
        if save_path.exists():
            print(f'{save_path.name} exists')
            return
        save_path.parent.mkdir(exist_ok=True, parents=True)
        resp = requests.get(url, stream=True)
        total = int(resp.headers.get('content-length', 0))
        with open(save_path, 'wb') as file, tqdm(
            desc=save_path.name,
            total=total,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for data in resp.iter_content(chunk_size=chunk_size):
                size = file.write(data)
                bar.update(size)
        print(f'{save_path.name} download finished')
    except Exception as e:
        raise Exception(f"Download failed: {e}")


def get_obj_from_str(string):
    module, cls = string.rsplit(".", 1)
    try:
        return getattr(importlib.import_module(module, package=None), cls)
    except:
        return getattr(importlib.import_module('src.' + module, package=None), cls)


def load_obj(path):
    objyaml = OmegaConf.load(path)
    return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))


def load_state_dict(model_path):
    model_ext = Path(model_path).suffix
    if model_ext == '.safetensors':
        state_dict = safetensors.torch.load_file(model_path)
    elif model_ext == '.ckpt':
        state_dict = torch.load(model_path)['state_dict']
    elif model_ext == '.bin':
        state_dict = torch.load(model_path)
    else:
        raise Exception(f'Unsupported model extension {model_ext}')
    return state_dict


def load_sd_inpainting_model(
    download_url,
    model_path,
    sd_version,
    diffusers_ckpt=False,
    dtype=torch.float16,
    device='cuda:0'
):
    if type(download_url) == str and type(model_path) == str:
        model_path = f'{MODEL_FOLDER}/{model_path}'
        download_file(download_url, model_path)
        state_dict = load_state_dict(model_path)
        if diffusers_ckpt:
            raise Exception('Not implemented')
        extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
        unet_state = extract(state_dict, 'model.diffusion_model')
        encoder_state = extract(state_dict, 'cond_stage_model')
        vae_state = extract(state_dict, 'first_stage_model')
    elif type(download_url) == OrderedDict and type(model_path) == OrderedDict:
        for key in download_url.keys():
            download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}')
        unet_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["unet"]}')
        encoder_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["encoder"]}')
        vae_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["vae"]}')
        if diffusers_ckpt:
            unet_state = convert_unet_state_dict(unet_state)
            is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in encoder_state
            if is_v20_model:
                encoder_state  = {"transformer." + k: v for k, v in encoder_state .items()}
                encoder_state  = convert_text_enc_state_dict_v20(encoder_state)
                encoder_state  = {"model." + k: v for k, v in encoder_state .items()}
            else:
                encoder_state  = convert_text_enc_state_dict(encoder_state)
                encoder_state  = {"transformer." + k: v for k, v in encoder_state .items()}
            vae_state = convert_vae_state_dict(vae_state)
    else:
        raise Exception('download_url or model_path definition type is not supported')

    # Load common config files
    config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
    vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()

    # Load version specific config files
    if sd_version == 1:
        encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
        unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
    elif sd_version == 2:
        encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
        unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
    else:
        raise Exception(f'Unsupported SD version {sd_version}.')
    
    ddim = DDIM(config, vae, encoder, unet)

    unet.load_state_dict(unet_state)
    encoder.load_state_dict(encoder_state, strict=False)
    vae.load_state_dict(vae_state)

    if dtype == torch.float16:
        unet.convert_to_fp16()
    unet.to(device=device)
    vae.to(dtype=dtype, device=device)
    encoder.to(dtype=dtype, device=device)
    encoder.device = device

    unet = unet.requires_grad_(False)
    encoder = encoder.requires_grad_(False)
    vae = vae.requires_grad_(False)

    ddim = DDIM(config, vae, encoder, unet)
    share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)

    return ddim