""" Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import os import torch import torch.distributed as dist import torch.nn as nn from torch.cuda.amp import autocast as autocast from torch.nn import functional as F # from lavis.common.registry import registry # from lavis.models.base_model import all_gather_with_grad, concat_all_gather from lavis.models.blip2_models.blip2 import ( disabled_train, ) from lavis.models.blip_models.blip_outputs import BlipOutput from lavis.common.dist_utils import is_dist_avail_and_initialized from model.blip2 import Blip2Base from pytorch_lightning.utilities import distributed @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ # if use distributed training if not is_dist_avail_and_initialized(): return tensor tensors_gather = [ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) print('running here') return output @torch.no_grad() def pl_concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ # if use distributed training if not is_dist_avail_and_initialized(): return tensor tensors_gather = distributed.gather_all_tensors(tensor) output = torch.cat(tensors_gather, dim=0) return output # @registry.register_model("blip2") # @registry.register_model("blip2_feature_extractor") class Blip2Qformer(Blip2Base): """ BLIP2 first-stage model with Q-former and ViT. Supported model types: - pretrained: pretrained model with vit-g - pretrain_vitL: pretrained model with vit-large - coco: fintuned model on coco Usage: >>> from lavis.models import load_model >>> model = load_model("blip2", "pretrain") """ def __init__( self, gtm, lm, bert_name, temperature, gin_num_layers, gin_hidden_dim, gin_drop_ratio, tune_gnn=False, num_query_token=32, cross_attention_freq=2, embed_dim=256, ): super().__init__() self.gtm = gtm self.lm = lm self.tokenizer = self.init_tokenizer() self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio) self.tune_gnn = tune_gnn if not tune_gnn: for name, param in self.graph_encoder.named_parameters(): param.requires_grad = False self.graph_encoder = self.graph_encoder.eval() self.graph_encoder.train = disabled_train logging.info("freeze graph encoder") self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq) self.Qformer.resize_token_embeddings(len(self.tokenizer)) state_dict = self.Qformer.state_dict() for name, param in self.Qformer.named_parameters(): if "_query" in name: key_orig = name.replace("_query", "") param.data.copy_(state_dict[key_orig]) self.graph_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) self.gtm_head = nn.Linear(self.Qformer.config.hidden_size, 2) self.temperature = temperature def contrast(self, features_graph, features_text, return_sim=False): ''' features_graph: shape = [B, num_qs, D] features_text: shape = [B, D] ''' batch_size = features_graph.size(0) # normalized features features_graph = F.normalize(features_graph, dim=-1) features_text = F.normalize(features_text, dim=-1) # cosine similarity as logits sim_q2t = (features_graph.unsqueeze(1) @ features_text.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs] sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B] logits_per_graph = sim_g2t / self.temperature logits_per_text = logits_per_graph.t() labels = torch.arange(batch_size, dtype=torch.long, device=self.device) # 大小为B loss_graph = F.cross_entropy(logits_per_graph, labels) loss_text = F.cross_entropy(logits_per_text, labels) loss = (loss_graph + loss_text) / 2 if return_sim: return logits_per_graph, logits_per_text, loss else: return loss def contrast_global(self, features_graph, features_text, features_graph_all, features_text_all, return_sim=False): ''' features_graph: shape = [B, num_qs, D] features_text: shape = [B, D] features_text_all: shape = [B * num_gpus, D] features_graph_all: shape = [B * num_gpus, num_qs, D] ''' bs = features_graph.size(0) # cosine similarity as logits sim_q2t = (features_graph.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B * num_gpus, D, 1]; output shape = [B, B * num_gpus, num_qs] sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B * num_gpus] logits_per_graph = sim_g2t / self.temperature sim_t2q = (features_text.unsqueeze(1).unsqueeze(1) @ features_graph_all.permute(0, 2, 1)).squeeze() # shape = [B, 1, 1, D]; [B*num_gpus, D, num_qs]; output shape = [B, B*num_gpus, 1, num_qs] sim_t2g, _ = sim_t2q.max(-1) logits_per_text = sim_t2g / self.temperature # labels = torch.arange(bs, dtype=torch.long, device=self.device) rank = dist.get_rank() labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) loss_graph = F.cross_entropy(logits_per_graph, labels) loss_text = F.cross_entropy(logits_per_text, labels) loss = (loss_graph + loss_text) / 2 if return_sim: return logits_per_graph[:, rank*bs:rank*bs+bs], logits_per_text[:, rank*bs:rank*bs+bs], loss else: return loss def forward_old(self, batch): ## v1: not gather results from all gpus ###============== Image-text Contrastive ===================### graph, text, mask = batch batch_node, batch_mask = self.graph_encoder(graph) batch_node = batch_node.detach() batch_size = batch_node.shape[0] batch_node = self.ln_graph(batch_node, batch_mask) query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=batch_node, encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct use_cache=True, return_dict=True, ) graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D] text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D] text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) sim_g2t, sim_t2g, loss_gtc = self.contrast(graph_feats, text_feats, return_sim=True) ###============== Image-text Matching ===================### loss_gtm = 0 if self.gtm: g_emb = batch_node g_mask = batch_mask text_ids = text.clone() with torch.no_grad(): weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4 weights_t2g.fill_diagonal_(0) weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4 weights_g2t.fill_diagonal_(0) # select a negative graph for each text graph_embeds_neg = [] graph_mask_neg = [] for b in range(batch_size): neg_idx = torch.multinomial(weights_t2g[b], 1).item() graph_embeds_neg.append(g_emb[neg_idx]) graph_mask_neg.append(g_mask[neg_idx]) graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0) graph_mask_neg = torch.stack(graph_mask_neg, dim=0) # select a negative text for each image text_ids_neg = [] text_atts_neg = [] for b in range(batch_size): neg_idx = torch.multinomial(weights_g2t[b], 1).item() text_ids_neg.append(text_ids[neg_idx]) text_atts_neg.append(mask[neg_idx]) text_ids_neg = torch.stack(text_ids_neg, dim=0) text_atts_neg = torch.stack(text_atts_neg, dim=0) text_ids_all = torch.cat( [text_ids, text_ids, text_ids_neg], dim=0 ) # pos, pos, neg text_atts_all = torch.cat( [mask, mask, text_atts_neg], dim=0, ) query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device) attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) graph_embeds_all = torch.cat([g_emb, graph_embeds_neg, g_emb], dim=0) # pos, neg, pos graph_atts_all = torch.cat([g_mask, graph_mask_neg, g_mask], dim=0) output_itm = self.Qformer.bert( text_ids_all, query_embeds=query_tokens_itm, attention_mask=attention_mask_all, encoder_hidden_states=graph_embeds_all, encoder_attention_mask=graph_atts_all, return_dict=True, ) vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only vl_output = self.gtm_head(vl_embeddings) logits = vl_output.mean(dim=1) itm_labels = torch.cat( [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0, ).to(text.device) loss_gtm = F.cross_entropy(logits, itm_labels) ##================= Image Captioning ========================## loss_lm = 0 if self.lm: decoder_input_ids = text.clone() decoder_input_ids[:, 0] = self.tokenizer.bos_token_id labels = decoder_input_ids.masked_fill( decoder_input_ids == self.tokenizer.pad_token_id, -100 ) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device) attention_mask = torch.cat([query_atts, mask], dim=1) lm_output = self.Qformer( decoder_input_ids, attention_mask=attention_mask, past_key_values=query_output.past_key_values, return_dict=True, labels=labels, ) loss_lm = lm_output.loss return BlipOutput( loss=loss_gtc + loss_gtm + loss_lm, loss_itc=loss_gtc, loss_itm=loss_gtm, loss_lm=loss_lm, ) def forward(self, batch): ## v2: gather results from all gpus ###============== Image-text Contrastive ===================### graph, text, mask = batch batch_node, batch_mask = self.graph_encoder(graph) if not self.tune_gnn: batch_node = batch_node.detach() batch_size = batch_node.shape[0] batch_node = self.ln_graph(batch_node, batch_mask) query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=batch_node, encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct use_cache=True, return_dict=True, ) graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D] text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D] text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1) text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D] sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True) ###============== Image-text Matching ===================### loss_gtm = 0 if self.gtm: ## not aggregate global tensor because of their different shapes g_emb_world = batch_node g_mask_world = batch_mask text_ids_world = text text_mask_world = mask with torch.no_grad(): weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4 weights_t2g.fill_diagonal_(0) weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4 weights_g2t.fill_diagonal_(0) # select a negative graph for each text graph_embeds_neg = [] graph_mask_neg = [] for b in range(batch_size): neg_idx = torch.multinomial(weights_t2g[b], 1).item() graph_embeds_neg.append(g_emb_world[neg_idx]) graph_mask_neg.append(g_mask_world[neg_idx]) graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0) graph_mask_neg = torch.stack(graph_mask_neg, dim=0) # select a negative text for each image text_ids_neg = [] text_atts_neg = [] for b in range(batch_size): neg_idx = torch.multinomial(weights_g2t[b], 1).item() text_ids_neg.append(text_ids_world[neg_idx]) text_atts_neg.append(text_mask_world[neg_idx]) text_ids_neg = torch.stack(text_ids_neg, dim=0) text_atts_neg = torch.stack(text_atts_neg, dim=0) text_ids_all = torch.cat( [text, text, text_ids_neg], dim=0 ) # pos, pos, neg text_atts_all = torch.cat( [mask, mask, text_atts_neg], dim=0, ) query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device) attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) graph_embeds_all = torch.cat([batch_node, graph_embeds_neg, batch_node], dim=0) # pos, neg, pos graph_atts_all = torch.cat([batch_mask, graph_mask_neg, batch_mask], dim=0) output_itm = self.Qformer.bert( text_ids_all, query_embeds=query_tokens_itm, attention_mask=attention_mask_all, encoder_hidden_states=graph_embeds_all, encoder_attention_mask=graph_atts_all, return_dict=True, ) vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only vl_output = self.gtm_head(vl_embeddings) logits = vl_output.mean(dim=1) itm_labels = torch.cat( [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0, ).to(text.device) loss_gtm = F.cross_entropy(logits, itm_labels) ##================= Image Captioning ========================## loss_lm = 0 if self.lm: decoder_input_ids = text.clone() decoder_input_ids[:, 0] = self.tokenizer.bos_token_id labels = decoder_input_ids.masked_fill( decoder_input_ids == self.tokenizer.pad_token_id, -100 ) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device) attention_mask = torch.cat([query_atts, mask], dim=1) lm_output = self.Qformer( decoder_input_ids, attention_mask=attention_mask, past_key_values=query_output.past_key_values, return_dict=True, labels=labels, ) loss_lm = lm_output.loss return BlipOutput( loss=loss_gtc + loss_gtm + loss_lm, loss_itc=loss_gtc, loss_itm=loss_gtm, loss_lm=loss_lm, ) def forward_v3(self, batch): ## v3: use smiles instruction ###============== Image-text Contrastive ===================### graphs, text_tokens, prompt_tokens = batch graph_embeds, graph_masks = self.graph_encoder(graphs) if not self.tune_gnn: graph_embeds = graph_embeds.detach() graph_embeds = self.ln_graph(graph_embeds, graph_masks) device = text_tokens.input_ids.device batch_size = graph_embeds.shape[0] ## query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=device) attention_mask_gtc = torch.cat([query_atts, prompt_tokens.attention_mask], dim=1) query_output = self.Qformer.bert( input_ids=prompt_tokens, query_embeds=query_tokens, attention_mask=attention_mask_gtc, encoder_hidden_states=graph_embeds, encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct use_cache=True, return_dict=True, ) query_output = query_output.last_hidden_state[:, : query_tokens.size(1), :] # keep query tokens only graph_feats = self.graph_proj(query_output) # shape = [B, num_q, D] text_output = self.Qformer.bert(text_tokens.input_ids, attention_mask=text_tokens.attention_mask, return_dict=True) # shape = [B, n_max, D] text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1) text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D] sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True) ###============== Image-text Matching ===================### loss_gtm = 0 if self.gtm: ## not aggregate global tensor because of their different shapes g_emb_world = graph_embeds g_mask_world = graph_masks text_ids_world = text_tokens.input_ids text_mask_world = text_tokens.attention_mask with torch.no_grad(): weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4 weights_t2g.fill_diagonal_(0) weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4 weights_g2t.fill_diagonal_(0) # select a negative graph for each text graph_embeds_neg = [] graph_mask_neg = [] for b in range(batch_size): neg_idx = torch.multinomial(weights_t2g[b], 1).item() graph_embeds_neg.append(g_emb_world[neg_idx]) graph_mask_neg.append(g_mask_world[neg_idx]) graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0) graph_mask_neg = torch.stack(graph_mask_neg, dim=0) # select a negative text for each image text_ids_neg = [] text_atts_neg = [] for b in range(batch_size): neg_idx = torch.multinomial(weights_g2t[b], 1).item() text_ids_neg.append(text_ids_world[neg_idx]) text_atts_neg.append(text_mask_world[neg_idx]) text_ids_neg = torch.stack(text_ids_neg, dim=0) text_atts_neg = torch.stack(text_atts_neg, dim=0) text_ids_all = torch.cat( [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0 ) # pos, pos, neg text_atts_all = torch.cat( [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg], dim=0, ) query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device) attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) graph_embeds_all = torch.cat([graph_embeds, graph_embeds_neg, graph_embeds], dim=0) # pos, neg, pos graph_atts_all = torch.cat([graph_masks, graph_mask_neg, graph_masks], dim=0) output_itm = self.Qformer.bert( text_ids_all, query_embeds=query_tokens_itm, attention_mask=attention_mask_all, encoder_hidden_states=graph_embeds_all, encoder_attention_mask=graph_atts_all, return_dict=True, ) vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only vl_output = self.gtm_head(vl_embeddings) logits = vl_output.mean(dim=1) itm_labels = torch.cat( [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0, ).to(text_tokens.input_ids.device) loss_gtm = F.cross_entropy(logits, itm_labels) ##================= Image Captioning ========================## loss_lm = 0 if self.lm: decoder_input_ids = text_tokens.input_ids.clone() decoder_input_ids[:, 0] = self.tokenizer.bos_token_id labels = decoder_input_ids.masked_fill( decoder_input_ids == self.tokenizer.pad_token_id, -100 ) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device) attention_mask = torch.cat([query_atts, prompt_tokens.attention_mask, text_tokens.attention_mask], dim=1) lm_output = self.Qformer( decoder_input_ids, attention_mask=attention_mask, past_key_values=query_output.past_key_values, return_dict=True, labels=labels, ) loss_lm = lm_output.loss return BlipOutput( loss=loss_gtc + loss_gtm + loss_lm, loss_itc=loss_gtc, loss_itm=loss_gtm, loss_lm=loss_lm, ) def graph_forward(self, graph): batch_node, batch_mask = self.graph_encoder(graph) batch_node = self.ln_graph(batch_node, batch_mask) query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=batch_node, encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct use_cache=False, return_dict=True, ) graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D] graph_feats = F.normalize(graph_feats, p=2, dim=-1) return graph_feats, batch_node, batch_mask def text_forward(self, text, mask): text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D] text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :] ) text_feats = F.normalize(text_feats, dim=-1, p=2) return text_feats def compute_gtm(self, batch_node, batch_mask, text_ids, text_atts): ''' batch_node shape = [B, N, D] batch_mask shape = [B, N] text_ids shape = [B, N] text_atts shape = [B, N] ''' query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) # shape = [B, Nq, D] query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( batch_node.device ) # shape = [B, Nq] attention_mask = torch.cat([query_atts, text_atts], dim=1) # shape = [B, Nq + N] output_gtm = self.Qformer.bert( text_ids, query_embeds=query_tokens, attention_mask=attention_mask, encoder_hidden_states=batch_node, encoder_attention_mask=batch_mask, return_dict=True, ) gl_embeddings = output_gtm.last_hidden_state[:, : query_tokens.size(1), :] # shape = [B, Nq, D] gtm_logit = self.gtm_head(gl_embeddings).mean(dim=1) # shape = [B, Nq, 2] # gtm_logit = F.softmax(gtm_logit, dim=-1)[:, 1] # select the axis of the positive class gtm_logit = gtm_logit[:, 1] # select the axis of the positive class return gtm_logit