raman07 commited on
Commit
a2abddc
·
verified ·
1 Parent(s): 8d56d36

using instructions changed

Browse files
Files changed (1) hide show
  1. README.md +68 -3
README.md CHANGED
@@ -32,9 +32,74 @@ import os
32
  from safetensors.torch import load_file
33
  from diffusers.pipelines import StableDiffusionPipeline
34
 
35
- pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16")
36
- exp_path = os.path.join('unet', 'diffusion_pytorch_model.safetensors')
37
- state_dict = load_file(exp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Load the adapted U-Net
40
  pipe.unet.load_state_dict(state_dict, strict=False)
 
32
  from safetensors.torch import load_file
33
  from diffusers.pipelines import StableDiffusionPipeline
34
 
35
+
36
+ #### Defining loading function
37
+
38
+ def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
39
+
40
+ print(pretrained_model_name_or_path)
41
+ config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
42
+ original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
43
+ state_dict = original_model.state_dict()
44
+ with accelerate.init_empty_weights():
45
+ model = UNet2DConditionModelForSVDiff.from_config(config)
46
+ # load pre-trained weights
47
+ param_device = "cpu"
48
+ torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
49
+ spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
50
+ state_dict.update(spectral_shifts_weights)
51
+ # move the params from meta device to cpu
52
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
53
+ if len(missing_keys) > 0:
54
+ raise ValueError(
55
+ f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
56
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
57
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
58
+ " those weights or else make sure your checkpoint file is correct."
59
+ )
60
+
61
+ for param_name, param in state_dict.items():
62
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
63
+ if accepts_dtype:
64
+ set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
65
+ else:
66
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
67
+
68
+ if spectral_shifts_ckpt:
69
+ if os.path.isdir(spectral_shifts_ckpt):
70
+ spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
71
+ elif not os.path.exists(spectral_shifts_ckpt):
72
+ # download from hub
73
+ hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
74
+ spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
75
+ assert os.path.exists(spectral_shifts_ckpt)
76
+
77
+ with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
78
+ for key in f.keys():
79
+ # spectral_shifts_weights[key] = f.get_tensor(key)
80
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
81
+ if accepts_dtype:
82
+ set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
83
+ else:
84
+ set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
85
+ print(f"Resumed from {spectral_shifts_ckpt}")
86
+ if "torch_dtype"in kwargs:
87
+ model = model.to(kwargs["torch_dtype"])
88
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
89
+ # Set model in evaluation mode to deactivate DropOut modules by default
90
+ model.eval()
91
+ del original_model
92
+ torch.cuda.empty_cache()
93
+ return model
94
+
95
+ pipe.unet = load_unet_for_svdiff(
96
+ "runwayml/stable-diffusion-v1-5",
97
+ spectral_shifts_ckpt=os.path.join('unet', "spectral_shifts.safetensors"),
98
+ subfolder="unet",
99
+ )
100
+ for module in pipe.unet.modules():
101
+ if hasattr(module, "perform_svd"):
102
+ module.perform_svd()
103
 
104
  # Load the adapted U-Net
105
  pipe.unet.load_state_dict(state_dict, strict=False)