using instructions changed
Browse files
README.md
CHANGED
@@ -32,9 +32,74 @@ import os
|
|
32 |
from safetensors.torch import load_file
|
33 |
from diffusers.pipelines import StableDiffusionPipeline
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Load the adapted U-Net
|
40 |
pipe.unet.load_state_dict(state_dict, strict=False)
|
|
|
32 |
from safetensors.torch import load_file
|
33 |
from diffusers.pipelines import StableDiffusionPipeline
|
34 |
|
35 |
+
|
36 |
+
#### Defining loading function
|
37 |
+
|
38 |
+
def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
|
39 |
+
|
40 |
+
print(pretrained_model_name_or_path)
|
41 |
+
config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
|
42 |
+
original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
43 |
+
state_dict = original_model.state_dict()
|
44 |
+
with accelerate.init_empty_weights():
|
45 |
+
model = UNet2DConditionModelForSVDiff.from_config(config)
|
46 |
+
# load pre-trained weights
|
47 |
+
param_device = "cpu"
|
48 |
+
torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
|
49 |
+
spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
|
50 |
+
state_dict.update(spectral_shifts_weights)
|
51 |
+
# move the params from meta device to cpu
|
52 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
53 |
+
if len(missing_keys) > 0:
|
54 |
+
raise ValueError(
|
55 |
+
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
|
56 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
57 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
|
58 |
+
" those weights or else make sure your checkpoint file is correct."
|
59 |
+
)
|
60 |
+
|
61 |
+
for param_name, param in state_dict.items():
|
62 |
+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
63 |
+
if accepts_dtype:
|
64 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
|
65 |
+
else:
|
66 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
67 |
+
|
68 |
+
if spectral_shifts_ckpt:
|
69 |
+
if os.path.isdir(spectral_shifts_ckpt):
|
70 |
+
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
|
71 |
+
elif not os.path.exists(spectral_shifts_ckpt):
|
72 |
+
# download from hub
|
73 |
+
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
|
74 |
+
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
|
75 |
+
assert os.path.exists(spectral_shifts_ckpt)
|
76 |
+
|
77 |
+
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
|
78 |
+
for key in f.keys():
|
79 |
+
# spectral_shifts_weights[key] = f.get_tensor(key)
|
80 |
+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
81 |
+
if accepts_dtype:
|
82 |
+
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
|
83 |
+
else:
|
84 |
+
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
|
85 |
+
print(f"Resumed from {spectral_shifts_ckpt}")
|
86 |
+
if "torch_dtype"in kwargs:
|
87 |
+
model = model.to(kwargs["torch_dtype"])
|
88 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
89 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
90 |
+
model.eval()
|
91 |
+
del original_model
|
92 |
+
torch.cuda.empty_cache()
|
93 |
+
return model
|
94 |
+
|
95 |
+
pipe.unet = load_unet_for_svdiff(
|
96 |
+
"runwayml/stable-diffusion-v1-5",
|
97 |
+
spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
|
98 |
+
subfolder="unet",
|
99 |
+
)
|
100 |
+
for module in pipe.unet.modules():
|
101 |
+
if hasattr(module, "perform_svd"):
|
102 |
+
module.perform_svd()
|
103 |
|
104 |
# Load the adapted U-Net
|
105 |
pipe.unet.load_state_dict(state_dict, strict=False)
|