File size: 7,412 Bytes
0163a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import torch
import safetensors.torch
import threading
from modules import shared, sd_hijack, sd_models
from modules.sd_models import load_model
import json

try:
  from modules import sd_models_xl
  xl = True
except:
  xl = False

def prune_model(model, isxl=False):
    keys = list(model.keys())
    base_prefix = "conditioner." if isxl else "cond_stage_model."
    for k in keys:
        if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k:
            model.pop(k, None)
    return model

def to_half(sd):
    for key in sd.keys():
        if 'model' in key and sd[key].dtype in {torch.float32, torch.float64, torch.bfloat16}:
            sd[key] = sd[key].half()
    return sd

def savemodel(state_dict,currentmodel,fname,savesets,metadata={}):
    other_dict = {}
    if state_dict is None:
        if shared.sd_model and shared.sd_model.sd_checkpoint_info:
            metadata = shared.sd_model.sd_checkpoint_info.metadata.copy()
        else:
            return "Current model is not a valid merged model"

        checkpoint_info = shared.sd_model.sd_checkpoint_info
        # check if current merged model is a fake checkpoint_info
        if checkpoint_info is not None:
            filename = checkpoint_info.filename
            name = os.path.basename(filename)
            info = sd_models.get_closet_checkpoint_match(name)
            if info == checkpoint_info:
                # this is a valid checkpoint_info
                # no need to save
                return "Current model is not a merged model or you've already saved model"

        # prepare metadata
        save_metadata = "save metadata" in savesets
        if save_metadata:
            metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
        else:
            metadata = {"format": "pt"}

        if shared.sd_model is not None:
            print("load from shared.sd_model..")

            # restore textencoder
            sd_hijack.model_hijack.undo_hijack(shared.sd_model)

            for name,module in shared.sd_model.named_modules():
                if hasattr(module,"network_weights_backup"):
                    module = network_restore_weights_from_backup(module)

            state_dict = shared.sd_model.state_dict()
            for key in list(state_dict.keys()):
                if key in POPKEYS:
                    other_dict[key] = state_dict[key]
                    del state_dict[key]

            sd_hijack.model_hijack.hijack(shared.sd_model)
        else:
            return "No current loaded model found"

        # name_for_extra was set with the currentmodel
        currentmodel = checkpoint_info.name_for_extra

    if "fp16" in savesets:
        pre = ".fp16"
    else:pre = ""
    ext = ".safetensors" if "safetensors" in savesets else ".ckpt"

    # is it a inpainting or instruct-pix2pix2 model?
    if "model.diffusion_model.input_blocks.0.0.weight" in state_dict.keys():
        shape = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape
        if shape[1] == 9:
            pre += "-inpainting"
        if shape[1] == 8:
            pre += "-instruct-pix2pix"

    if not fname or fname == "":
        fname = currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_")+pre+ext
        if fname[0]=="_":fname = fname[1:]
    else:
        fname = fname if ext in fname else fname +pre+ext

    fname = os.path.join(sd_models.model_path, fname)
    fname = fname.replace("ProgramFiles_x86_","Program Files (x86)")

    if len(fname) > 255:
       fname.replace(ext,"")
       fname=fname[:240]+ext

    # check if output file already exists
    if os.path.isfile(fname) and not "overwrite" in savesets:
        _err_msg = f"Output file ({fname}) existed and was not saved]"
        print(_err_msg)
        return _err_msg

    print("Saving...")
    isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict
    if isxl:
        # prune share memory tensors, "cond_stage_model." prefixed base tensors are share memory with "conditioner." prefixed tensors
        for i, key in enumerate(state_dict.keys()):
            if "cond_stage_model." in key:
                del state_dict[key]

    if "fp16" in savesets:
        state_dict = to_half(state_dict)
    if "prune" in savesets:
        state_dict = prune_model(state_dict, isxl)

    # for safetensors contiguous error
    print("Check contiguous...")
    for key in state_dict.keys():
        v = state_dict[key]
        v = v.contiguous()
        state_dict[key] = v

    try:
      if ext == ".safetensors":
          safetensors.torch.save_file(state_dict, fname, metadata=metadata)
      else:
          torch.save(state_dict, fname)
    except Exception as e:
      print(f"ERROR: Couldn't saved:{fname},ERROR is {e}")
      return f"ERROR: Couldn't saved:{fname},ERROR is {e}"
    print("Done!")
    if other_dict:
        for key in other_dict.keys():
            state_dict[key] = other_dict[key]
        del other_dict
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
    return "Merged model saved in "+fname

def filenamecutter(name,model_a = False):
    if name =="" or name ==[]: return
    checkpoint_info = sd_models.get_closet_checkpoint_match(name)
    name= os.path.splitext(checkpoint_info.filename)[0]

    if not model_a:
        name = os.path.basename(name)
    return name

from typing import Union

def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
    weights_backup = getattr(self, "network_weights_backup", None)
    bias_backup = getattr(self, "network_bias_backup", None)

    if weights_backup is None and bias_backup is None:
        return self

    with torch.no_grad():
        if weights_backup is not None:
            if isinstance(self, torch.nn.MultiheadAttention):
                self.in_proj_weight = torch.nn.Parameter(weights_backup[0].detach().requires_grad_(self.in_proj_weight.requires_grad))
                self.out_proj.weight = torch.nn.Parameter(weights_backup[1].detach().requires_grad_(self.out_proj.weight.requires_grad))
            else:
                self.weight = torch.nn.Parameter(weights_backup.detach().requires_grad_(self.weight.requires_grad))

        if bias_backup is not None:
            if isinstance(self, torch.nn.MultiheadAttention):
                self.out_proj.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.out_proj.bias.requires_grad))
            else:
                self.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.bias.requires_grad))
        else:
            if isinstance(self, torch.nn.MultiheadAttention):
                self.out_proj.bias = None
            else:
                self.bias = None
    return self

def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
    self.network_current_names = ()
    self.network_weights_backup = None
    self.network_bias_backup = None

POPKEYS=[
"betas",
"alphas_cumprod",
"alphas_cumprod_prev",
"sqrt_alphas_cumprod",
"sqrt_one_minus_alphas_cumprod",
"log_one_minus_alphas_cumprod",
"sqrt_recip_alphas_cumprod",
"sqrt_recipm1_alphas_cumprod",
"posterior_variance",
"posterior_log_variance_clipped",
"posterior_mean_coef1",
"posterior_mean_coef2",
]