Singularity666 commited on
Commit
267daea
1 Parent(s): afa2fb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -35
app.py CHANGED
@@ -1,11 +1,19 @@
 
1
  from typing import List, Optional, Tuple, Union
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from torch.nn import CrossEntropyLoss
6
- from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM, CLIPVisionModel, CLIPImageProcessor
 
 
 
 
7
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
8
- import os, diffusers
 
 
9
 
10
  DEFAULT_IMAGE_TOKEN = "<image>"
11
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -20,40 +28,50 @@ class LlavaLlamaModel(LlamaModel):
20
 
21
  def __init__(self, config: LlamaConfig):
22
  super(LlavaLlamaModel, self).__init__(config)
 
23
  if hasattr(config, "mm_vision_tower"):
24
- self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
 
25
  if hasattr(config, "use_mm_proj"):
26
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
27
 
28
  def get_vision_tower(self):
29
  vision_tower = getattr(self, 'vision_tower', None)
30
- if type(vision_tower) is list:
31
- vision_tower = vision_tower[0]
32
  return vision_tower
33
 
34
- def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, pretrain_mm_mlp_adapter=None, fsdp=None):
 
35
  self.config.mm_vision_tower = vision_tower
 
36
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
 
37
  if not hasattr(self, 'vision_tower'):
38
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
39
  else:
40
- vision_tower = self.vision_tower[0]
41
  vision_tower.requires_grad_(False)
42
- if fsdp is not None and len(fsdp) > 0:
43
- self.vision_tower = [vision_tower]
44
- else:
45
- self.vision_tower = vision_tower
46
  vision_config = vision_tower.config
47
  num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
 
48
  self.config.use_mm_proj = True
49
  self.config.mm_hidden_size = vision_config.hidden_size
50
  self.config.mm_vision_select_layer = mm_vision_select_layer
 
51
  if not hasattr(self, 'mm_projector'):
52
  self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
 
53
  if pretrain_mm_mlp_adapter is not None:
54
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
55
  self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
56
- return dict(image_processor=image_processor, image_token_len=num_patches, vision_config=vision_config)
 
 
 
 
 
57
 
58
  def forward(
59
  self,
@@ -67,9 +85,16 @@ class LlavaLlamaModel(LlamaModel):
67
  images: Optional[torch.FloatTensor] = None,
68
  return_dict: Optional[bool] = None,
69
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
70
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
 
 
 
 
 
71
  if inputs_embeds is None:
72
  inputs_embeds = self.embed_tokens(input_ids)
 
73
  vision_tower = self.get_vision_tower()
74
  if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
75
  with torch.no_grad():
@@ -92,6 +117,7 @@ class LlavaLlamaModel(LlamaModel):
92
  image_features = self.mm_projector(image_features)
93
  dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
94
  dummy_image_features = self.mm_projector(dummy_image_features)
 
95
  new_input_embeds = []
96
  cur_image_idx = 0
97
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
@@ -133,6 +159,7 @@ class LlavaLlamaModel(LlamaModel):
133
  new_input_embeds.append(cur_new_input_embeds)
134
  cur_image_idx += 1
135
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
 
136
  return super(LlavaLlamaModel, self).forward(
137
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
138
  inputs_embeds=inputs_embeds, use_cache=use_cache,
@@ -143,6 +170,7 @@ class LlavaLlamaModel(LlamaModel):
143
  class EditMapper(nn.Module):
144
  def __init__(self):
145
  super().__init__()
 
146
  self.llm2hid = nn.Linear(4096, 512)
147
  self.query = nn.Parameter(torch.randn(1, 77, 512))
148
  self.mapper = nn.Transformer(batch_first=True, norm_first=True,
@@ -154,6 +182,7 @@ class EditMapper(nn.Module):
154
  hid = self.llm2hid(llm+emb)
155
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
156
  feat = self.hid2feat(hid)
 
157
  return feat
158
 
159
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
@@ -162,13 +191,15 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
162
  def __init__(self, config):
163
  super(LlamaForCausalLM, self).__init__(config)
164
  self.model = LlavaLlamaModel(config)
 
165
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
166
  self.edit_head = EditMapper()
167
- self.scheduler, self.vae, self.unet = [
168
- diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
169
- diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
170
- diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')
171
- ]
172
  self.vae.requires_grad_(False)
173
  self.unet.register_to_config(in_channels=8)
174
  with torch.no_grad():
@@ -176,6 +207,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
176
  conv.weight.zero_()
177
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
178
  self.unet.conv_in = conv
 
 
179
  self.post_init()
180
 
181
  def get_model(self):
@@ -184,8 +217,6 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
184
  def get_vision_tower(self):
185
  model = self.get_model()
186
  vision_tower = model.vision_tower
187
- if type(vision_tower) is list:
188
- vision_tower = vision_tower[0]
189
  return vision_tower
190
 
191
  def forward(
@@ -207,6 +238,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
207
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
208
  )
209
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
210
  outputs = self.model(
211
  input_ids=input_ids,
212
  attention_mask=attention_mask,
@@ -218,58 +251,82 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
218
  return_dict=return_dict,
219
  images=images
220
  )
 
221
  hidden_states = outputs[0]
222
  logits = self.lm_head(hidden_states)
 
223
  loss = None
224
  if labels is not None:
 
225
  shift_logits = logits[..., :-1, :].contiguous()
226
  shift_labels = labels[..., 1:].contiguous()
 
227
  loss_fct = CrossEntropyLoss()
228
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
229
  shift_labels = shift_labels.view(-1)
 
230
  shift_labels = shift_labels.to(shift_logits.device)
231
  loss = loss_fct(shift_logits, shift_labels)
 
232
  if labels is not None:
233
  llm = []
234
  for i in range(labels.shape[0]):
235
- try: p = labels[i].data.cpu().tolist().index(32003)-1
236
- except: p = len(labels[i])-9
 
 
237
  p = min(len(hidden_states[i])-9, p)
238
  llm.append(hidden_states[i][p:p+8].unsqueeze(0))
239
  llm = torch.cat(llm, dim=0)
240
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
 
241
  B, DROP = labels.shape[0], 0.05
242
- hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device), self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
 
 
 
243
  with torch.no_grad():
244
  lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
245
- lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device), torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
 
 
246
  noise = torch.randn_like(lat_ans)
247
- ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device).long()
248
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
 
249
  prob = torch.rand(B, device=lat_ans.device)
250
- mask = (prob < (DROP*2)).reshape(B, 1, 1)
251
  hid_edit = torch.where(mask, hid_null, hid_edit)
252
- mask = (1.0 - ((prob >= DROP).to(lat_inp.dtype) * (prob < (DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
253
  lat_inp *= mask
 
254
  # Progressive Feature Blending
255
  beta_1, beta_2 = 0.7, 0.3
256
- visual_features = lat_inp
257
  B_1 = beta_1 * hid_edit + (1 - beta_1) * visual_features
258
  B_2 = beta_2 * hid_edit + (1 - beta_2) * visual_features
 
259
  # Cross-Attention Masking
260
  attention_scores = torch.matmul(hid_edit, hid_edit.transpose(-1, -2))
261
  mask = torch.zeros_like(hid_edit)
262
- mask[:, 3:5] = 1.0
263
  masked_attention_scores = attention_scores * mask
264
  hid_edit = torch.matmul(F.softmax(masked_attention_scores, dim=-1), hid_edit)
 
 
265
  hid_edit = B_1 + B_2
 
266
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
 
267
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
268
- if int(os.environ['LOCAL_RANK']) == 0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
269
- loss = loss_ce + loss_edit * 0.5
 
 
270
  if not return_dict:
271
  output = (logits,) + outputs[1:]
272
  return (loss,) + output if loss is not None else output
 
273
  return CausalLMOutputWithPast(
274
  loss=loss,
275
  logits=logits,
@@ -278,13 +335,18 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
278
  attentions=outputs.attentions,
279
  )
280
 
281
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
 
 
282
  if past_key_values:
283
  input_ids = input_ids[:, -1:]
 
 
284
  if inputs_embeds is not None and past_key_values is None:
285
  model_inputs = {"inputs_embeds": inputs_embeds}
286
  else:
287
  model_inputs = {"input_ids": input_ids}
 
288
  model_inputs.update(
289
  {
290
  "past_key_values": past_key_values,
@@ -295,28 +357,37 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
295
  )
296
  return model_inputs
297
 
298
- def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
 
299
  vision_config = self.get_vision_tower().config
300
  vision_config.use_im_start_end = mm_use_im_start_end
301
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
302
  self.resize_token_embeddings(len(tokenizer))
 
303
  if mm_use_im_start_end:
304
  num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
305
  self.resize_token_embeddings(len(tokenizer))
306
  vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
307
  if num_new_tokens > 0:
308
  input_embeddings = self.get_input_embeddings().weight.data
309
  output_embeddings = self.get_output_embeddings().weight.data
310
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
311
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
 
 
 
 
312
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
313
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
 
314
  if tune_mm_mlp_adapter:
315
  self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
316
  for p in self.get_input_embeddings().parameters():
317
  p.requires_grad = True
318
  for p in self.get_output_embeddings().parameters():
319
  p.requires_grad = False
 
320
  if pretrain_mm_mlp_adapter:
321
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
322
  embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
@@ -326,7 +397,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
326
  elif embed_tokens_weight.shape[0] == num_new_tokens:
327
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
328
  else:
329
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
 
330
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
331
 
332
  AutoConfig.register("llava", LlavaConfig)
 
1
+ # mgie_llava.py
2
  from typing import List, Optional, Tuple, Union
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from torch.nn import CrossEntropyLoss
8
+
9
+ from transformers import AutoConfig, AutoModelForCausalLM, \
10
+ LlamaConfig, LlamaModel, LlamaForCausalLM, \
11
+ CLIPVisionModel, CLIPImageProcessor
12
+
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+
15
+ import os
16
+ import diffusers
17
 
18
  DEFAULT_IMAGE_TOKEN = "<image>"
19
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
28
 
29
  def __init__(self, config: LlamaConfig):
30
  super(LlavaLlamaModel, self).__init__(config)
31
+
32
  if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
34
+
35
  if hasattr(config, "use_mm_proj"):
36
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
37
 
38
  def get_vision_tower(self):
39
  vision_tower = getattr(self, 'vision_tower', None)
 
 
40
  return vision_tower
41
 
42
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
43
+ pretrain_mm_mlp_adapter=None):
44
  self.config.mm_vision_tower = vision_tower
45
+
46
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
47
+
48
  if not hasattr(self, 'vision_tower'):
49
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
50
  else:
51
+ vision_tower = self.vision_tower
52
  vision_tower.requires_grad_(False)
53
+
54
+ self.vision_tower = vision_tower
55
+
 
56
  vision_config = vision_tower.config
57
  num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
58
+
59
  self.config.use_mm_proj = True
60
  self.config.mm_hidden_size = vision_config.hidden_size
61
  self.config.mm_vision_select_layer = mm_vision_select_layer
62
+
63
  if not hasattr(self, 'mm_projector'):
64
  self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
65
+
66
  if pretrain_mm_mlp_adapter is not None:
67
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
68
  self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
69
+
70
+ return dict(
71
+ image_processor=image_processor,
72
+ image_token_len=num_patches,
73
+ vision_config=vision_config
74
+ )
75
 
76
  def forward(
77
  self,
 
85
  images: Optional[torch.FloatTensor] = None,
86
  return_dict: Optional[bool] = None,
87
  ) -> Union[Tuple, BaseModelOutputWithPast]:
88
+
89
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
90
+ if orig_embeds_params is not None:
91
+ orig_embeds_params = orig_embeds_params[0]
92
+ with torch.no_grad():
93
+ self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
94
+
95
  if inputs_embeds is None:
96
  inputs_embeds = self.embed_tokens(input_ids)
97
+
98
  vision_tower = self.get_vision_tower()
99
  if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
100
  with torch.no_grad():
 
117
  image_features = self.mm_projector(image_features)
118
  dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
119
  dummy_image_features = self.mm_projector(dummy_image_features)
120
+
121
  new_input_embeds = []
122
  cur_image_idx = 0
123
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
 
159
  new_input_embeds.append(cur_new_input_embeds)
160
  cur_image_idx += 1
161
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
162
+
163
  return super(LlavaLlamaModel, self).forward(
164
  input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
165
  inputs_embeds=inputs_embeds, use_cache=use_cache,
 
170
  class EditMapper(nn.Module):
171
  def __init__(self):
172
  super().__init__()
173
+
174
  self.llm2hid = nn.Linear(4096, 512)
175
  self.query = nn.Parameter(torch.randn(1, 77, 512))
176
  self.mapper = nn.Transformer(batch_first=True, norm_first=True,
 
182
  hid = self.llm2hid(llm+emb)
183
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
184
  feat = self.hid2feat(hid)
185
+
186
  return feat
187
 
188
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
 
191
  def __init__(self, config):
192
  super(LlamaForCausalLM, self).__init__(config)
193
  self.model = LlavaLlamaModel(config)
194
+
195
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
196
+
197
  self.edit_head = EditMapper()
198
+
199
+ self.scheduler = diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler')
200
+ self.vae = diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae')
201
+ self.unet = diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')
202
+
203
  self.vae.requires_grad_(False)
204
  self.unet.register_to_config(in_channels=8)
205
  with torch.no_grad():
 
207
  conv.weight.zero_()
208
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
209
  self.unet.conv_in = conv
210
+
211
+ # Initialize weights and apply final processing
212
  self.post_init()
213
 
214
  def get_model(self):
 
217
  def get_vision_tower(self):
218
  model = self.get_model()
219
  vision_tower = model.vision_tower
 
 
220
  return vision_tower
221
 
222
  def forward(
 
238
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
239
  )
240
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
241
+
242
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
243
  outputs = self.model(
244
  input_ids=input_ids,
245
  attention_mask=attention_mask,
 
251
  return_dict=return_dict,
252
  images=images
253
  )
254
+
255
  hidden_states = outputs[0]
256
  logits = self.lm_head(hidden_states)
257
+
258
  loss = None
259
  if labels is not None:
260
+ # Shift so that tokens < n predict n
261
  shift_logits = logits[..., :-1, :].contiguous()
262
  shift_labels = labels[..., 1:].contiguous()
263
+ # Flatten the tokens
264
  loss_fct = CrossEntropyLoss()
265
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
266
  shift_labels = shift_labels.view(-1)
267
+ # Enable model/pipeline parallelism
268
  shift_labels = shift_labels.to(shift_logits.device)
269
  loss = loss_fct(shift_logits, shift_labels)
270
+
271
  if labels is not None:
272
  llm = []
273
  for i in range(labels.shape[0]):
274
+ try:
275
+ p = labels[i].data.cpu().tolist().index(32003)-1
276
+ except:
277
+ p = len(labels[i])-9
278
  p = min(len(hidden_states[i])-9, p)
279
  llm.append(hidden_states[i][p:p+8].unsqueeze(0))
280
  llm = torch.cat(llm, dim=0)
281
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
282
+
283
  B, DROP = labels.shape[0], 0.05
284
+
285
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
286
+ self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
287
+
288
  with torch.no_grad():
289
  lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
290
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
291
+ torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
292
+
293
  noise = torch.randn_like(lat_ans)
294
+ ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
295
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
296
+
297
  prob = torch.rand(B, device=lat_ans.device)
298
+ mask = (prob<(DROP*2)).reshape(B, 1, 1)
299
  hid_edit = torch.where(mask, hid_null, hid_edit)
300
+ mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
301
  lat_inp *= mask
302
+
303
  # Progressive Feature Blending
304
  beta_1, beta_2 = 0.7, 0.3
305
+ visual_features = lat_inp # Assuming lat_inp represents the visual features
306
  B_1 = beta_1 * hid_edit + (1 - beta_1) * visual_features
307
  B_2 = beta_2 * hid_edit + (1 - beta_2) * visual_features
308
+
309
  # Cross-Attention Masking
310
  attention_scores = torch.matmul(hid_edit, hid_edit.transpose(-1, -2))
311
  mask = torch.zeros_like(hid_edit)
312
+ mask[:, 3:5] = 1.0 # Emphasize central elements (e.g., "hat", "blue")
313
  masked_attention_scores = attention_scores * mask
314
  hid_edit = torch.matmul(F.softmax(masked_attention_scores, dim=-1), hid_edit)
315
+
316
+ # Use blended features in subsequent processing
317
  hid_edit = B_1 + B_2
318
+
319
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
320
+
321
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
322
+ if int(os.environ['LOCAL_RANK'])==0:
323
+ print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
324
+ loss = loss_ce+loss_edit*0.5
325
+
326
  if not return_dict:
327
  output = (logits,) + outputs[1:]
328
  return (loss,) + output if loss is not None else output
329
+
330
  return CausalLMOutputWithPast(
331
  loss=loss,
332
  logits=logits,
 
335
  attentions=outputs.attentions,
336
  )
337
 
338
+ def prepare_inputs_for_generation(
339
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
340
+ ):
341
  if past_key_values:
342
  input_ids = input_ids[:, -1:]
343
+
344
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
345
  if inputs_embeds is not None and past_key_values is None:
346
  model_inputs = {"inputs_embeds": inputs_embeds}
347
  else:
348
  model_inputs = {"input_ids": input_ids}
349
+
350
  model_inputs.update(
351
  {
352
  "past_key_values": past_key_values,
 
357
  )
358
  return model_inputs
359
 
360
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
361
+ tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
362
  vision_config = self.get_vision_tower().config
363
  vision_config.use_im_start_end = mm_use_im_start_end
364
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
365
  self.resize_token_embeddings(len(tokenizer))
366
+
367
  if mm_use_im_start_end:
368
  num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
369
  self.resize_token_embeddings(len(tokenizer))
370
  vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
371
+
372
  if num_new_tokens > 0:
373
  input_embeddings = self.get_input_embeddings().weight.data
374
  output_embeddings = self.get_output_embeddings().weight.data
375
+
376
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
377
+ dim=0, keepdim=True)
378
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
379
+ dim=0, keepdim=True)
380
+
381
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
382
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
383
+
384
  if tune_mm_mlp_adapter:
385
  self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
386
  for p in self.get_input_embeddings().parameters():
387
  p.requires_grad = True
388
  for p in self.get_output_embeddings().parameters():
389
  p.requires_grad = False
390
+
391
  if pretrain_mm_mlp_adapter:
392
  mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
393
  embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
 
397
  elif embed_tokens_weight.shape[0] == num_new_tokens:
398
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
399
  else:
400
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: input_embeddings.shape. Number of new tokens: num_new_tokens.")
401
+
402
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
403
 
404
  AutoConfig.register("llava", LlavaConfig)