primerz commited on
Commit
0db314e
Β·
verified Β·
1 Parent(s): b3c8e03

Update cog_sdxl_dataset_and_utils.py

Browse files
Files changed (1) hide show
  1. cog_sdxl_dataset_and_utils.py +325 -86
cog_sdxl_dataset_and_utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # dataset_and_utils.py - Optimized and Improved Version
2
  import os
3
  from typing import Dict, List, Optional, Tuple
4
 
@@ -15,29 +15,28 @@ from torch.utils.data import Dataset
15
  from transformers import AutoTokenizer, PretrainedConfig
16
 
17
 
18
- def prepare_image(image: PIL.Image.Image, width: int = 512, height: int = 512) -> torch.Tensor:
19
- """
20
- Prepares an image for model input by resizing and normalizing it.
21
- """
22
- image = image.resize((width, height), resample=Image.BICUBIC, reducing_gap=1)
23
- arr = np.array(image.convert("RGB"), dtype=np.float32) / 127.5 - 1
24
- return torch.from_numpy(np.transpose(arr, (2, 0, 1))).unsqueeze(0)
 
 
25
 
26
 
27
- def prepare_mask(mask: PIL.Image.Image, width: int = 512, height: int = 512) -> torch.Tensor:
28
- """
29
- Prepares a mask image for model input by resizing and normalizing it.
30
- """
31
- mask = mask.resize((width, height), resample=Image.BICUBIC, reducing_gap=1)
32
- arr = np.array(mask.convert("L"), dtype=np.float32) / 255.0
33
- return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
 
 
34
 
35
 
36
- class TokenEmbeddingsHandler:
37
- def __init__(self, text_encoders, tokenizers):
38
- self.text_encoders = text_encoders
39
- self.tokenizers = tokenizers
40
-
41
  class PreprocessedDataset(Dataset):
42
  def __init__(
43
  self,
@@ -51,133 +50,373 @@ class PreprocessedDataset(Dataset):
51
  size: int = 512,
52
  text_dropout: float = 0.0,
53
  scale_vae_latents: bool = True,
54
- substitute_caption_map: Dict[str, str] = None,
55
  ):
56
- """
57
- Dataset class that pre-processes images, masks, and text data for training.
58
- """
59
  super().__init__()
 
60
  self.data = pd.read_csv(csv_path)
61
- self.size = size
62
- self.scale_vae_latents = scale_vae_latents
63
- self.text_dropout = text_dropout
64
  self.csv_path = csv_path
65
- self.tokenizer_1 = tokenizer_1
66
- self.tokenizer_2 = tokenizer_2
67
- self.vae_encoder = vae_encoder
68
- self.do_cache = do_cache
69
 
70
- self.caption = self.data["caption"].str.lower()
71
-
72
- if substitute_caption_map:
73
- for key, value in substitute_caption_map.items():
74
- self.caption = self.caption.str.replace(key.lower(), value)
75
 
76
  self.image_path = self.data["image_path"]
77
- self.mask_path = self.data["mask_path"] if "mask_path" in self.data.columns else None
78
 
79
- if text_encoder_1:
 
 
 
 
 
 
 
80
  self.text_encoder_1 = text_encoder_1
81
  self.text_encoder_2 = text_encoder_2
82
  self.return_text_embeddings = True
83
- raise NotImplementedError("Preprocessing for text encoder is not implemented yet.")
84
- else:
85
- self.return_text_embeddings = False
86
 
87
- if self.do_cache:
 
 
 
 
 
 
 
 
 
88
  self.vae_latents = []
89
  self.tokens_tuple = []
90
  self.masks = []
91
- print("Caching dataset...")
 
 
 
92
  for idx in range(len(self.data)):
93
  token, vae_latent, mask = self._process(idx)
94
- self.tokens_tuple.append(token)
95
  self.vae_latents.append(vae_latent)
 
96
  self.masks.append(mask)
97
- del self.vae_encoder # Free up memory
 
 
 
 
98
 
99
  @torch.no_grad()
100
- def _process(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
101
- """
102
- Internal function to process images, text, and masks for a given index.
103
- """
104
- image_path = os.path.join(os.path.dirname(self.csv_path), self.image_path[idx])
105
- image = prepare_image(Image.open(image_path).convert("RGB"), self.size, self.size).to(
 
 
106
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
107
  )
108
 
109
  caption = self.caption[idx]
110
- ti1 = self.tokenizer_1(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
111
- ti2 = self.tokenizer_2(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
 
114
  if self.scale_vae_latents:
115
- vae_latent *= self.vae_encoder.config.scaling_factor
116
 
117
  if self.mask_path is None:
118
- mask = torch.ones_like(vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device)
 
 
 
119
  else:
120
- mask_path = os.path.join(os.path.dirname(self.csv_path), self.mask_path[idx])
121
- mask = prepare_mask(Image.open(mask_path), self.size, self.size).to(
 
 
 
122
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
123
  )
124
- mask = torch.nn.functional.interpolate(mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest")
 
 
 
125
  mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
126
 
127
- assert mask.shape == vae_latent.shape, "Mask and latent dimensions must match."
128
 
129
  return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
130
 
131
  def __len__(self) -> int:
132
  return len(self.data)
133
 
134
- def __getitem__(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
135
- return self.atidx(idx)
 
 
 
 
 
136
 
137
- def atidx(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
138
- return self._process(idx) if not self.do_cache else (self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx])
 
 
 
139
 
140
 
141
- def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"):
142
- """
143
- Dynamically imports a model class based on configuration.
144
- """
145
- config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, revision=revision)
146
- model_class = config.architectures[0]
 
147
 
148
  if model_class == "CLIPTextModel":
149
  from transformers import CLIPTextModel
 
150
  return CLIPTextModel
151
  elif model_class == "CLIPTextModelWithProjection":
152
  from transformers import CLIPTextModelWithProjection
 
153
  return CLIPTextModelWithProjection
154
  else:
155
- raise ValueError(f"Unsupported model class: {model_class}")
156
 
157
 
158
  def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
- Loads required models from a given pretrained path.
 
161
  """
162
- tokenizer_1 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", revision=revision, use_fast=False)
163
- tokenizer_2 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2", revision=revision, use_fast=False)
 
 
 
 
 
 
 
164
 
165
- noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
166
 
167
- text_encoder_cls_one = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision)
168
- text_encoder_cls_two = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision, subfolder="text_encoder_2")
169
 
170
- text_encoder_1 = text_encoder_cls_one.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", revision=revision)
171
- text_encoder_2 = text_encoder_cls_two.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision)
 
 
172
 
173
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
174
- unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", revision=revision)
 
175
 
176
- for model in [vae, text_encoder_1, text_encoder_2]:
177
- model.requires_grad_(False)
178
- model.to(device, dtype=weight_dtype)
 
 
 
 
 
 
179
 
180
- unet.to(device, dtype=weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
183
-
 
1
+ # dataset_and_utils.py file taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
2
  import os
3
  from typing import Dict, List, Optional, Tuple
4
 
 
15
  from transformers import AutoTokenizer, PretrainedConfig
16
 
17
 
18
+ def prepare_image(
19
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
20
+ ) -> torch.Tensor:
21
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
22
+ arr = np.array(pil_image.convert("RGB"))
23
+ arr = arr.astype(np.float32) / 127.5 - 1
24
+ arr = np.transpose(arr, [2, 0, 1])
25
+ image = torch.from_numpy(arr).unsqueeze(0)
26
+ return image
27
 
28
 
29
+ def prepare_mask(
30
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
31
+ ) -> torch.Tensor:
32
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
33
+ arr = np.array(pil_image.convert("L"))
34
+ arr = arr.astype(np.float32) / 255.0
35
+ arr = np.expand_dims(arr, 0)
36
+ image = torch.from_numpy(arr).unsqueeze(0)
37
+ return image
38
 
39
 
 
 
 
 
 
40
  class PreprocessedDataset(Dataset):
41
  def __init__(
42
  self,
 
50
  size: int = 512,
51
  text_dropout: float = 0.0,
52
  scale_vae_latents: bool = True,
53
+ substitute_caption_map: Dict[str, str] = {},
54
  ):
 
 
 
55
  super().__init__()
56
+
57
  self.data = pd.read_csv(csv_path)
 
 
 
58
  self.csv_path = csv_path
 
 
 
 
59
 
60
+ self.caption = self.data["caption"]
61
+ # make it lowercase
62
+ self.caption = self.caption.str.lower()
63
+ for key, value in substitute_caption_map.items():
64
+ self.caption = self.caption.str.replace(key.lower(), value)
65
 
66
  self.image_path = self.data["image_path"]
 
67
 
68
+ if "mask_path" not in self.data.columns:
69
+ self.mask_path = None
70
+ else:
71
+ self.mask_path = self.data["mask_path"]
72
+
73
+ if text_encoder_1 is None:
74
+ self.return_text_embeddings = False
75
+ else:
76
  self.text_encoder_1 = text_encoder_1
77
  self.text_encoder_2 = text_encoder_2
78
  self.return_text_embeddings = True
79
+ assert (
80
+ NotImplementedError
81
+ ), "Preprocessing Text Encoder is not implemented yet"
82
 
83
+ self.tokenizer_1 = tokenizer_1
84
+ self.tokenizer_2 = tokenizer_2
85
+
86
+ self.vae_encoder = vae_encoder
87
+ self.scale_vae_latents = scale_vae_latents
88
+ self.text_dropout = text_dropout
89
+
90
+ self.size = size
91
+
92
+ if do_cache:
93
  self.vae_latents = []
94
  self.tokens_tuple = []
95
  self.masks = []
96
+
97
+ self.do_cache = True
98
+
99
+ print("Captions to train on: ")
100
  for idx in range(len(self.data)):
101
  token, vae_latent, mask = self._process(idx)
 
102
  self.vae_latents.append(vae_latent)
103
+ self.tokens_tuple.append(token)
104
  self.masks.append(mask)
105
+
106
+ del self.vae_encoder
107
+
108
+ else:
109
+ self.do_cache = False
110
 
111
  @torch.no_grad()
112
+ def _process(
113
+ self, idx: int
114
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
115
+ image_path = self.image_path[idx]
116
+ image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
117
+
118
+ image = PIL.Image.open(image_path).convert("RGB")
119
+ image = prepare_image(image, self.size, self.size).to(
120
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
121
  )
122
 
123
  caption = self.caption[idx]
124
+
125
+ print(caption)
126
+
127
+ # tokenizer_1
128
+ ti1 = self.tokenizer_1(
129
+ caption,
130
+ padding="max_length",
131
+ max_length=77,
132
+ truncation=True,
133
+ add_special_tokens=True,
134
+ return_tensors="pt",
135
+ ).input_ids
136
+
137
+ ti2 = self.tokenizer_2(
138
+ caption,
139
+ padding="max_length",
140
+ max_length=77,
141
+ truncation=True,
142
+ add_special_tokens=True,
143
+ return_tensors="pt",
144
+ ).input_ids
145
 
146
  vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
147
+
148
  if self.scale_vae_latents:
149
+ vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
150
 
151
  if self.mask_path is None:
152
+ mask = torch.ones_like(
153
+ vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
154
+ )
155
+
156
  else:
157
+ mask_path = self.mask_path[idx]
158
+ mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
159
+
160
+ mask = PIL.Image.open(mask_path)
161
+ mask = prepare_mask(mask, self.size, self.size).to(
162
  dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
163
  )
164
+
165
+ mask = torch.nn.functional.interpolate(
166
+ mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
167
+ )
168
  mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
169
 
170
+ assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
171
 
172
  return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
173
 
174
  def __len__(self) -> int:
175
  return len(self.data)
176
 
177
+ def atidx(
178
+ self, idx: int
179
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
180
+ if self.do_cache:
181
+ return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
182
+ else:
183
+ return self._process(idx)
184
 
185
+ def __getitem__(
186
+ self, idx: int
187
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
188
+ token, vae_latent, mask = self.atidx(idx)
189
+ return token, vae_latent, mask
190
 
191
 
192
+ def import_model_class_from_model_name_or_path(
193
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
194
+ ):
195
+ text_encoder_config = PretrainedConfig.from_pretrained(
196
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
197
+ )
198
+ model_class = text_encoder_config.architectures[0]
199
 
200
  if model_class == "CLIPTextModel":
201
  from transformers import CLIPTextModel
202
+
203
  return CLIPTextModel
204
  elif model_class == "CLIPTextModelWithProjection":
205
  from transformers import CLIPTextModelWithProjection
206
+
207
  return CLIPTextModelWithProjection
208
  else:
209
+ raise ValueError(f"{model_class} is not supported.")
210
 
211
 
212
  def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
213
+ tokenizer_one = AutoTokenizer.from_pretrained(
214
+ pretrained_model_name_or_path,
215
+ subfolder="tokenizer",
216
+ revision=revision,
217
+ use_fast=False,
218
+ )
219
+ tokenizer_two = AutoTokenizer.from_pretrained(
220
+ pretrained_model_name_or_path,
221
+ subfolder="tokenizer_2",
222
+ revision=revision,
223
+ use_fast=False,
224
+ )
225
+
226
+ # Load scheduler and models
227
+ noise_scheduler = DDPMScheduler.from_pretrained(
228
+ pretrained_model_name_or_path, subfolder="scheduler"
229
+ )
230
+ # import correct text encoder classes
231
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
232
+ pretrained_model_name_or_path, revision
233
+ )
234
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
235
+ pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
236
+ )
237
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
238
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
239
+ )
240
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
241
+ pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
242
+ )
243
+
244
+ vae = AutoencoderKL.from_pretrained(
245
+ pretrained_model_name_or_path, subfolder="vae", revision=revision
246
+ )
247
+ unet = UNet2DConditionModel.from_pretrained(
248
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
249
+ )
250
+
251
+ vae.requires_grad_(False)
252
+ text_encoder_one.requires_grad_(False)
253
+ text_encoder_two.requires_grad_(False)
254
+
255
+ unet.to(device, dtype=weight_dtype)
256
+ vae.to(device, dtype=torch.float32)
257
+ text_encoder_one.to(device, dtype=weight_dtype)
258
+ text_encoder_two.to(device, dtype=weight_dtype)
259
+
260
+ return (
261
+ tokenizer_one,
262
+ tokenizer_two,
263
+ noise_scheduler,
264
+ text_encoder_one,
265
+ text_encoder_two,
266
+ vae,
267
+ unet,
268
+ )
269
+
270
+
271
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
272
  """
273
+ Returns:
274
+ a state dict containing just the attention processor parameters.
275
  """
276
+ attn_processors = unet.attn_processors
277
+
278
+ attn_processors_state_dict = {}
279
+
280
+ for attn_processor_key, attn_processor in attn_processors.items():
281
+ for parameter_key, parameter in attn_processor.state_dict().items():
282
+ attn_processors_state_dict[
283
+ f"{attn_processor_key}.{parameter_key}"
284
+ ] = parameter
285
 
286
+ return attn_processors_state_dict
287
 
 
 
288
 
289
+ class TokenEmbeddingsHandler:
290
+ def __init__(self, text_encoders, tokenizers):
291
+ self.text_encoders = text_encoders
292
+ self.tokenizers = tokenizers
293
 
294
+ self.train_ids: Optional[torch.Tensor] = None
295
+ self.inserting_toks: Optional[List[str]] = None
296
+ self.embeddings_settings = {}
297
 
298
+ def initialize_new_tokens(self, inserting_toks: List[str]):
299
+ idx = 0
300
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
301
+ assert isinstance(
302
+ inserting_toks, list
303
+ ), "inserting_toks should be a list of strings."
304
+ assert all(
305
+ isinstance(tok, str) for tok in inserting_toks
306
+ ), "All elements in inserting_toks should be strings."
307
 
308
+ self.inserting_toks = inserting_toks
309
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
310
+ tokenizer.add_special_tokens(special_tokens_dict)
311
+ text_encoder.resize_token_embeddings(len(tokenizer))
312
+
313
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
314
+
315
+ # random initialization of new tokens
316
+
317
+ std_token_embedding = (
318
+ text_encoder.text_model.embeddings.token_embedding.weight.data.std()
319
+ )
320
+
321
+ print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
322
+
323
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
324
+ self.train_ids
325
+ ] = (
326
+ torch.randn(
327
+ len(self.train_ids), text_encoder.text_model.config.hidden_size
328
+ )
329
+ .to(device=self.device)
330
+ .to(dtype=self.dtype)
331
+ * std_token_embedding
332
+ )
333
+ self.embeddings_settings[
334
+ f"original_embeddings_{idx}"
335
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
336
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
337
+
338
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
339
+ inu[self.train_ids] = False
340
+
341
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
342
+
343
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
344
+
345
+ idx += 1
346
+
347
+ def save_embeddings(self, file_path: str):
348
+ assert (
349
+ self.train_ids is not None
350
+ ), "Initialize new tokens before saving embeddings."
351
+ tensors = {}
352
+ for idx, text_encoder in enumerate(self.text_encoders):
353
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
354
+ 0
355
+ ] == len(self.tokenizers[0]), "Tokenizers should be the same."
356
+ new_token_embeddings = (
357
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
358
+ self.train_ids
359
+ ]
360
+ )
361
+ tensors[f"text_encoders_{idx}"] = new_token_embeddings
362
+
363
+ save_file(tensors, file_path)
364
+
365
+ @property
366
+ def dtype(self):
367
+ return self.text_encoders[0].dtype
368
+
369
+ @property
370
+ def device(self):
371
+ return self.text_encoders[0].device
372
+
373
+ def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
374
+ # Assuming new tokens are of the format <s_i>
375
+ self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
376
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
377
+ tokenizer.add_special_tokens(special_tokens_dict)
378
+ text_encoder.resize_token_embeddings(len(tokenizer))
379
+
380
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
381
+ assert self.train_ids is not None, "New tokens could not be converted to IDs."
382
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
383
+ self.train_ids
384
+ ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
385
+
386
+ @torch.no_grad()
387
+ def retract_embeddings(self):
388
+ for idx, text_encoder in enumerate(self.text_encoders):
389
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
390
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
391
+ index_no_updates
392
+ ] = (
393
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
394
+ .to(device=text_encoder.device)
395
+ .to(dtype=text_encoder.dtype)
396
+ )
397
+
398
+ # for the parts that were updated, we need to normalize them
399
+ # to have the same std as before
400
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
401
+
402
+ index_updates = ~index_no_updates
403
+ new_embeddings = (
404
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
405
+ index_updates
406
+ ]
407
+ )
408
+ off_ratio = std_token_embedding / new_embeddings.std()
409
+
410
+ new_embeddings = new_embeddings * (off_ratio**0.1)
411
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
412
+ index_updates
413
+ ] = new_embeddings
414
+
415
+ def load_embeddings(self, file_path: str):
416
+ with safe_open(file_path, framework="pt", device=self.device.type) as f:
417
+ for idx in range(len(self.text_encoders)):
418
+ text_encoder = self.text_encoders[idx]
419
+ tokenizer = self.tokenizers[idx]
420
 
421
+ loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
422
+ self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)