Singularity666 commited on
Commit
eae59cc
1 Parent(s): 8724709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -50
app.py CHANGED
@@ -1,4 +1,3 @@
1
- #mgie_llava.py:
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
@@ -19,11 +18,9 @@ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
  DEFAULT_IM_START_TOKEN = "<im_start>"
20
  DEFAULT_IM_END_TOKEN = "<im_end>"
21
 
22
-
23
  class LlavaConfig(LlamaConfig):
24
  model_type = "llava"
25
 
26
-
27
  class LlavaLlamaModel(LlamaModel):
28
  config_class = LlavaConfig
29
 
@@ -31,9 +28,7 @@ class LlavaLlamaModel(LlamaModel):
31
  super(LlavaLlamaModel, self).__init__(config)
32
 
33
  if hasattr(config, "mm_vision_tower"):
34
- # HACK: for FSDP
35
  self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
36
- # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
37
 
38
  if hasattr(config, "use_mm_proj"):
39
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
@@ -94,22 +89,15 @@ class LlavaLlamaModel(LlamaModel):
94
  return_dict: Optional[bool] = None,
95
  ) -> Union[Tuple, BaseModelOutputWithPast]:
96
 
97
- # HACK: replace back original embeddings for LLaVA pretraining
98
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
99
- # if orig_embeds_params is not None:
100
- # orig_embeds_params = orig_embeds_params[0]
101
- # with torch.no_grad():
102
- # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
103
 
104
  if inputs_embeds is None:
105
  inputs_embeds = self.embed_tokens(input_ids)
106
 
107
  vision_tower = self.get_vision_tower()
108
  if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
109
- # TODO: this is a modified multimodal LLM -- Haotian Liu
110
  with torch.no_grad():
111
  if type(images) is list:
112
- # variable length images
113
  image_features = []
114
  for image in images:
115
  image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
@@ -133,7 +121,6 @@ class LlavaLlamaModel(LlamaModel):
133
  cur_image_idx = 0
134
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
135
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
136
- # multimodal LLM, but the current sample is not multimodal
137
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
138
  new_input_embeds.append(cur_input_embeds)
139
  cur_image_idx += 1
@@ -191,7 +178,7 @@ class EditMapper(nn.Module):
191
  self.hid2feat = nn.Linear(512, 768)
192
 
193
  def forward(self, llm, emb):
194
- hid = self.llm2hid(llm+emb)
195
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
196
  feat = self.hid2feat(hid)
197
 
@@ -208,18 +195,19 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
208
 
209
  self.edit_head = EditMapper()
210
 
211
- '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
212
- diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
213
- diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')]
 
 
214
  self.vae.requires_grad_(False)
215
  self.unet.register_to_config(in_channels=8)
216
  with torch.no_grad():
217
  conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
218
  conv.weight.zero_()
219
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
220
- self.unet.conv_in = conv'''
221
 
222
- # Initialize weights and apply final processing
223
  self.post_init()
224
 
225
  def get_model(self):
@@ -228,13 +216,6 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
228
  def get_vision_tower(self):
229
  return self.get_model().get_vision_tower()
230
 
231
- def get_vision_tower(self):
232
- model = self.get_model()
233
- vision_tower = model.vision_tower
234
- if type(vision_tower) is list:
235
- vision_tower = vision_tower[0]
236
- return vision_tower
237
-
238
  def forward(
239
  self,
240
  input_ids: torch.LongTensor = None,
@@ -250,12 +231,9 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
250
  p2p_inp=None, p2p_ans=None
251
  ) -> Union[Tuple, CausalLMOutputWithPast]:
252
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
253
- output_hidden_states = (
254
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
255
- )
256
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
257
 
258
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
259
  outputs = self.model(
260
  input_ids=input_ids,
261
  attention_mask=attention_mask,
@@ -273,24 +251,23 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
273
 
274
  loss = None
275
  if labels is not None:
276
- # Shift so that tokens < n predict n
277
  shift_logits = logits[..., :-1, :].contiguous()
278
  shift_labels = labels[..., 1:].contiguous()
279
- # Flatten the tokens
280
  loss_fct = CrossEntropyLoss()
281
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
282
  shift_labels = shift_labels.view(-1)
283
- # Enable model/pipeline parallelism
284
  shift_labels = shift_labels.to(shift_logits.device)
285
  loss = loss_fct(shift_logits, shift_labels)
286
 
287
  if labels is not None:
288
  llm = []
289
  for i in range(labels.shape[0]):
290
- try: p = labels[i].data.cpu().tolist().index(32003)-1
291
- except: p = len(labels[i])-9
292
- p = min(len(hidden_states[i])-9, p)
293
- llm.append(hidden_states[i][p:p+8].unsqueeze(0))
 
 
294
  llm = torch.cat(llm, dim=0)
295
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
296
 
@@ -300,25 +277,41 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
300
  self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
301
 
302
  with torch.no_grad():
303
- lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
304
  lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
305
  torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
306
 
307
  noise = torch.randn_like(lat_ans)
308
- ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
309
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
310
 
311
  prob = torch.rand(B, device=lat_ans.device)
312
- mask = (prob<(DROP*2)).reshape(B, 1, 1)
313
  hid_edit = torch.where(mask, hid_null, hid_edit)
314
- mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
315
  lat_inp *= mask
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
318
 
319
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
320
- if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
321
- loss = loss_ce+loss_edit*0.5
322
 
323
  if not return_dict:
324
  output = (logits,) + outputs[1:]
@@ -338,7 +331,6 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
338
  if past_key_values:
339
  input_ids = input_ids[:, -1:]
340
 
341
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
342
  if inputs_embeds is not None and past_key_values is None:
343
  model_inputs = {"inputs_embeds": inputs_embeds}
344
  else:
@@ -370,10 +362,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
370
  input_embeddings = self.get_input_embeddings().weight.data
371
  output_embeddings = self.get_output_embeddings().weight.data
372
 
373
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
374
- dim=0, keepdim=True)
375
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
376
- dim=0, keepdim=True)
377
 
378
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
379
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
@@ -394,9 +384,9 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
394
  elif embed_tokens_weight.shape[0] == num_new_tokens:
395
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
396
  else:
397
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
398
 
399
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
400
 
401
  AutoConfig.register("llava", LlavaConfig)
402
- AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
 
 
1
  from typing import List, Optional, Tuple, Union
2
 
3
  import torch
 
18
  DEFAULT_IM_START_TOKEN = "<im_start>"
19
  DEFAULT_IM_END_TOKEN = "<im_end>"
20
 
 
21
  class LlavaConfig(LlamaConfig):
22
  model_type = "llava"
23
 
 
24
  class LlavaLlamaModel(LlamaModel):
25
  config_class = LlavaConfig
26
 
 
28
  super(LlavaLlamaModel, self).__init__(config)
29
 
30
  if hasattr(config, "mm_vision_tower"):
 
31
  self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
 
32
 
33
  if hasattr(config, "use_mm_proj"):
34
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
 
89
  return_dict: Optional[bool] = None,
90
  ) -> Union[Tuple, BaseModelOutputWithPast]:
91
 
 
92
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
 
 
 
 
93
 
94
  if inputs_embeds is None:
95
  inputs_embeds = self.embed_tokens(input_ids)
96
 
97
  vision_tower = self.get_vision_tower()
98
  if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
 
99
  with torch.no_grad():
100
  if type(images) is list:
 
101
  image_features = []
102
  for image in images:
103
  image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
 
121
  cur_image_idx = 0
122
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
123
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
 
124
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
125
  new_input_embeds.append(cur_input_embeds)
126
  cur_image_idx += 1
 
178
  self.hid2feat = nn.Linear(512, 768)
179
 
180
  def forward(self, llm, emb):
181
+ hid = self.llm2hid(llm + emb)
182
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
183
  feat = self.hid2feat(hid)
184
 
 
195
 
196
  self.edit_head = EditMapper()
197
 
198
+ self.scheduler, self.vae, self.unet = [
199
+ diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
200
+ diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
201
+ diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')
202
+ ]
203
  self.vae.requires_grad_(False)
204
  self.unet.register_to_config(in_channels=8)
205
  with torch.no_grad():
206
  conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
207
  conv.weight.zero_()
208
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
209
+ self.unet.conv_in = conv
210
 
 
211
  self.post_init()
212
 
213
  def get_model(self):
 
216
  def get_vision_tower(self):
217
  return self.get_model().get_vision_tower()
218
 
 
 
 
 
 
 
 
219
  def forward(
220
  self,
221
  input_ids: torch.LongTensor = None,
 
231
  p2p_inp=None, p2p_ans=None
232
  ) -> Union[Tuple, CausalLMOutputWithPast]:
233
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
234
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
235
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
236
 
 
237
  outputs = self.model(
238
  input_ids=input_ids,
239
  attention_mask=attention_mask,
 
251
 
252
  loss = None
253
  if labels is not None:
 
254
  shift_logits = logits[..., :-1, :].contiguous()
255
  shift_labels = labels[..., 1:].contiguous()
 
256
  loss_fct = CrossEntropyLoss()
257
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
258
  shift_labels = shift_labels.view(-1)
 
259
  shift_labels = shift_labels.to(shift_logits.device)
260
  loss = loss_fct(shift_logits, shift_labels)
261
 
262
  if labels is not None:
263
  llm = []
264
  for i in range(labels.shape[0]):
265
+ try:
266
+ p = labels[i].data.cpu().tolist().index(32003) - 1
267
+ except:
268
+ p = len(labels[i]) - 9
269
+ p = min(len(hidden_states[i]) - 9, p)
270
+ llm.append(hidden_states[i][p:p + 8].unsqueeze(0))
271
  llm = torch.cat(llm, dim=0)
272
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
273
 
 
277
  self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
278
 
279
  with torch.no_grad():
280
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample() * self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
281
  lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
282
  torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
283
 
284
  noise = torch.randn_like(lat_ans)
285
+ ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device).long()
286
  lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
287
 
288
  prob = torch.rand(B, device=lat_ans.device)
289
+ mask = (prob < (DROP * 2)).reshape(B, 1, 1)
290
  hid_edit = torch.where(mask, hid_null, hid_edit)
291
+ mask = (1.0 - ((prob >= DROP).to(lat_inp.dtype) * (prob < (DROP * 3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
292
  lat_inp *= mask
293
 
294
+ # Progressive Feature Blending
295
+ beta_1, beta_2 = 0.7, 0.3
296
+ visual_features = lat_inp # Assuming lat_inp represents the visual features
297
+ B_1 = beta_1 * hid_edit + (1 - beta_1) * visual_features
298
+ B_2 = beta_2 * hid_edit + (1 - beta_2) * visual_features
299
+
300
+ # Cross-Attention Masking
301
+ attention_scores = torch.matmul(hid_edit, hid_edit.transpose(-1, -2))
302
+ mask = torch.zeros_like(hid_edit)
303
+ mask[:, 3:5] = 1.0 # Emphasize central elements (e.g., "hat", "blue")
304
+ masked_attention_scores = attention_scores * mask
305
+ hid_edit = torch.matmul(F.softmax(masked_attention_scores, dim=-1), hid_edit)
306
+
307
+ # Use blended features in subsequent processing
308
+ hid_edit = B_1 + B_2
309
+
310
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
311
 
312
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
313
+ if int(os.environ.get('LOCAL_RANK', 0)) == 0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
314
+ loss = loss_ce + loss_edit * 0.5
315
 
316
  if not return_dict:
317
  output = (logits,) + outputs[1:]
 
331
  if past_key_values:
332
  input_ids = input_ids[:, -1:]
333
 
 
334
  if inputs_embeds is not None and past_key_values is None:
335
  model_inputs = {"inputs_embeds": inputs_embeds}
336
  else:
 
362
  input_embeddings = self.get_input_embeddings().weight.data
363
  output_embeddings = self.get_output_embeddings().weight.data
364
 
365
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
366
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
 
 
367
 
368
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
369
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
 
384
  elif embed_tokens_weight.shape[0] == num_new_tokens:
385
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
386
  else:
387
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Number of new tokens: {num_new_tokens}.")
388
 
389
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
390
 
391
  AutoConfig.register("llava", LlavaConfig)
392
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)