Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2023, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import logging | |
import os | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
from torch.cuda.amp import autocast as autocast | |
from torch.nn import functional as F | |
# from lavis.common.registry import registry | |
# from lavis.models.base_model import all_gather_with_grad, concat_all_gather | |
from lavis.models.blip2_models.blip2 import ( | |
disabled_train, | |
) | |
from lavis.models.blip_models.blip_outputs import BlipOutput | |
from lavis.common.dist_utils import is_dist_avail_and_initialized | |
from model.blip2 import Blip2Base | |
from pytorch_lightning.utilities import distributed | |
def concat_all_gather(tensor): | |
""" | |
Performs all_gather operation on the provided tensors. | |
*** Warning ***: torch.distributed.all_gather has no gradient. | |
""" | |
# if use distributed training | |
if not is_dist_avail_and_initialized(): | |
return tensor | |
tensors_gather = [ | |
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) | |
] | |
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) | |
output = torch.cat(tensors_gather, dim=0) | |
print('running here') | |
return output | |
def pl_concat_all_gather(tensor): | |
""" | |
Performs all_gather operation on the provided tensors. | |
*** Warning ***: torch.distributed.all_gather has no gradient. | |
""" | |
# if use distributed training | |
if not is_dist_avail_and_initialized(): | |
return tensor | |
tensors_gather = distributed.gather_all_tensors(tensor) | |
output = torch.cat(tensors_gather, dim=0) | |
return output | |
# @registry.register_model("blip2") | |
# @registry.register_model("blip2_feature_extractor") | |
class Blip2Qformer(Blip2Base): | |
""" | |
BLIP2 first-stage model with Q-former and ViT. | |
Supported model types: | |
- pretrained: pretrained model with vit-g | |
- pretrain_vitL: pretrained model with vit-large | |
- coco: fintuned model on coco | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("blip2", "pretrain") | |
""" | |
def __init__( | |
self, | |
gtm, | |
lm, | |
bert_name, | |
temperature, | |
gin_num_layers, | |
gin_hidden_dim, | |
gin_drop_ratio, | |
tune_gnn=False, | |
num_query_token=32, | |
cross_attention_freq=2, | |
embed_dim=256, | |
): | |
super().__init__() | |
self.gtm = gtm | |
self.lm = lm | |
self.tokenizer = self.init_tokenizer() | |
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio) | |
self.tune_gnn = tune_gnn | |
if not tune_gnn: | |
for name, param in self.graph_encoder.named_parameters(): | |
param.requires_grad = False | |
self.graph_encoder = self.graph_encoder.eval() | |
self.graph_encoder.train = disabled_train | |
logging.info("freeze graph encoder") | |
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq) | |
self.Qformer.resize_token_embeddings(len(self.tokenizer)) | |
state_dict = self.Qformer.state_dict() | |
for name, param in self.Qformer.named_parameters(): | |
if "_query" in name: | |
key_orig = name.replace("_query", "") | |
param.data.copy_(state_dict[key_orig]) | |
self.graph_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) | |
self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) | |
self.gtm_head = nn.Linear(self.Qformer.config.hidden_size, 2) | |
self.temperature = temperature | |
def contrast(self, features_graph, features_text, return_sim=False): | |
''' | |
features_graph: shape = [B, num_qs, D] | |
features_text: shape = [B, D] | |
''' | |
batch_size = features_graph.size(0) | |
# normalized features | |
features_graph = F.normalize(features_graph, dim=-1) | |
features_text = F.normalize(features_text, dim=-1) | |
# cosine similarity as logits | |
sim_q2t = (features_graph.unsqueeze(1) @ features_text.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs] | |
sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B] | |
logits_per_graph = sim_g2t / self.temperature | |
logits_per_text = logits_per_graph.t() | |
labels = torch.arange(batch_size, dtype=torch.long, device=self.device) # 大小为B | |
loss_graph = F.cross_entropy(logits_per_graph, labels) | |
loss_text = F.cross_entropy(logits_per_text, labels) | |
loss = (loss_graph + loss_text) / 2 | |
if return_sim: | |
return logits_per_graph, logits_per_text, loss | |
else: | |
return loss | |
def contrast_global(self, features_graph, features_text, features_graph_all, features_text_all, return_sim=False): | |
''' | |
features_graph: shape = [B, num_qs, D] | |
features_text: shape = [B, D] | |
features_text_all: shape = [B * num_gpus, D] | |
features_graph_all: shape = [B * num_gpus, num_qs, D] | |
''' | |
bs = features_graph.size(0) | |
# cosine similarity as logits | |
sim_q2t = (features_graph.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B * num_gpus, D, 1]; output shape = [B, B * num_gpus, num_qs] | |
sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B * num_gpus] | |
logits_per_graph = sim_g2t / self.temperature | |
sim_t2q = (features_text.unsqueeze(1).unsqueeze(1) @ features_graph_all.permute(0, 2, 1)).squeeze() # shape = [B, 1, 1, D]; [B*num_gpus, D, num_qs]; output shape = [B, B*num_gpus, 1, num_qs] | |
sim_t2g, _ = sim_t2q.max(-1) | |
logits_per_text = sim_t2g / self.temperature | |
# labels = torch.arange(bs, dtype=torch.long, device=self.device) | |
rank = dist.get_rank() | |
labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) | |
loss_graph = F.cross_entropy(logits_per_graph, labels) | |
loss_text = F.cross_entropy(logits_per_text, labels) | |
loss = (loss_graph + loss_text) / 2 | |
if return_sim: | |
return logits_per_graph[:, rank*bs:rank*bs+bs], logits_per_text[:, rank*bs:rank*bs+bs], loss | |
else: | |
return loss | |
def forward_old(self, batch): | |
## v1: not gather results from all gpus | |
###============== Image-text Contrastive ===================### | |
graph, text, mask = batch | |
batch_node, batch_mask = self.graph_encoder(graph) | |
batch_node = batch_node.detach() | |
batch_size = batch_node.shape[0] | |
batch_node = self.ln_graph(batch_node, batch_mask) | |
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=batch_node, | |
encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct | |
use_cache=True, | |
return_dict=True, | |
) | |
graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D] | |
text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D] | |
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) | |
sim_g2t, sim_t2g, loss_gtc = self.contrast(graph_feats, text_feats, return_sim=True) | |
###============== Image-text Matching ===================### | |
loss_gtm = 0 | |
if self.gtm: | |
g_emb = batch_node | |
g_mask = batch_mask | |
text_ids = text.clone() | |
with torch.no_grad(): | |
weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4 | |
weights_t2g.fill_diagonal_(0) | |
weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4 | |
weights_g2t.fill_diagonal_(0) | |
# select a negative graph for each text | |
graph_embeds_neg = [] | |
graph_mask_neg = [] | |
for b in range(batch_size): | |
neg_idx = torch.multinomial(weights_t2g[b], 1).item() | |
graph_embeds_neg.append(g_emb[neg_idx]) | |
graph_mask_neg.append(g_mask[neg_idx]) | |
graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0) | |
graph_mask_neg = torch.stack(graph_mask_neg, dim=0) | |
# select a negative text for each image | |
text_ids_neg = [] | |
text_atts_neg = [] | |
for b in range(batch_size): | |
neg_idx = torch.multinomial(weights_g2t[b], 1).item() | |
text_ids_neg.append(text_ids[neg_idx]) | |
text_atts_neg.append(mask[neg_idx]) | |
text_ids_neg = torch.stack(text_ids_neg, dim=0) | |
text_atts_neg = torch.stack(text_atts_neg, dim=0) | |
text_ids_all = torch.cat( | |
[text_ids, text_ids, text_ids_neg], dim=0 | |
) # pos, pos, neg | |
text_atts_all = torch.cat( | |
[mask, mask, text_atts_neg], | |
dim=0, | |
) | |
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) | |
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device) | |
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) | |
graph_embeds_all = torch.cat([g_emb, graph_embeds_neg, g_emb], dim=0) # pos, neg, pos | |
graph_atts_all = torch.cat([g_mask, graph_mask_neg, g_mask], dim=0) | |
output_itm = self.Qformer.bert( | |
text_ids_all, | |
query_embeds=query_tokens_itm, | |
attention_mask=attention_mask_all, | |
encoder_hidden_states=graph_embeds_all, | |
encoder_attention_mask=graph_atts_all, | |
return_dict=True, | |
) | |
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only | |
vl_output = self.gtm_head(vl_embeddings) | |
logits = vl_output.mean(dim=1) | |
itm_labels = torch.cat( | |
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], | |
dim=0, | |
).to(text.device) | |
loss_gtm = F.cross_entropy(logits, itm_labels) | |
##================= Image Captioning ========================## | |
loss_lm = 0 | |
if self.lm: | |
decoder_input_ids = text.clone() | |
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id | |
labels = decoder_input_ids.masked_fill( | |
decoder_input_ids == self.tokenizer.pad_token_id, -100 | |
) | |
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device) | |
attention_mask = torch.cat([query_atts, mask], dim=1) | |
lm_output = self.Qformer( | |
decoder_input_ids, | |
attention_mask=attention_mask, | |
past_key_values=query_output.past_key_values, | |
return_dict=True, | |
labels=labels, | |
) | |
loss_lm = lm_output.loss | |
return BlipOutput( | |
loss=loss_gtc + loss_gtm + loss_lm, | |
loss_itc=loss_gtc, | |
loss_itm=loss_gtm, | |
loss_lm=loss_lm, | |
) | |
def forward(self, batch): | |
## v2: gather results from all gpus | |
###============== Image-text Contrastive ===================### | |
graph, text, mask = batch | |
batch_node, batch_mask = self.graph_encoder(graph) | |
if not self.tune_gnn: | |
batch_node = batch_node.detach() | |
batch_size = batch_node.shape[0] | |
batch_node = self.ln_graph(batch_node, batch_mask) | |
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=batch_node, | |
encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct | |
use_cache=True, | |
return_dict=True, | |
) | |
graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D] | |
text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D] | |
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) | |
text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1) | |
text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D] | |
sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True) | |
###============== Image-text Matching ===================### | |
loss_gtm = 0 | |
if self.gtm: | |
## not aggregate global tensor because of their different shapes | |
g_emb_world = batch_node | |
g_mask_world = batch_mask | |
text_ids_world = text | |
text_mask_world = mask | |
with torch.no_grad(): | |
weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4 | |
weights_t2g.fill_diagonal_(0) | |
weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4 | |
weights_g2t.fill_diagonal_(0) | |
# select a negative graph for each text | |
graph_embeds_neg = [] | |
graph_mask_neg = [] | |
for b in range(batch_size): | |
neg_idx = torch.multinomial(weights_t2g[b], 1).item() | |
graph_embeds_neg.append(g_emb_world[neg_idx]) | |
graph_mask_neg.append(g_mask_world[neg_idx]) | |
graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0) | |
graph_mask_neg = torch.stack(graph_mask_neg, dim=0) | |
# select a negative text for each image | |
text_ids_neg = [] | |
text_atts_neg = [] | |
for b in range(batch_size): | |
neg_idx = torch.multinomial(weights_g2t[b], 1).item() | |
text_ids_neg.append(text_ids_world[neg_idx]) | |
text_atts_neg.append(text_mask_world[neg_idx]) | |
text_ids_neg = torch.stack(text_ids_neg, dim=0) | |
text_atts_neg = torch.stack(text_atts_neg, dim=0) | |
text_ids_all = torch.cat( | |
[text, text, text_ids_neg], dim=0 | |
) # pos, pos, neg | |
text_atts_all = torch.cat( | |
[mask, mask, text_atts_neg], | |
dim=0, | |
) | |
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) | |
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text.device) | |
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) | |
graph_embeds_all = torch.cat([batch_node, graph_embeds_neg, batch_node], dim=0) # pos, neg, pos | |
graph_atts_all = torch.cat([batch_mask, graph_mask_neg, batch_mask], dim=0) | |
output_itm = self.Qformer.bert( | |
text_ids_all, | |
query_embeds=query_tokens_itm, | |
attention_mask=attention_mask_all, | |
encoder_hidden_states=graph_embeds_all, | |
encoder_attention_mask=graph_atts_all, | |
return_dict=True, | |
) | |
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only | |
vl_output = self.gtm_head(vl_embeddings) | |
logits = vl_output.mean(dim=1) | |
itm_labels = torch.cat( | |
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], | |
dim=0, | |
).to(text.device) | |
loss_gtm = F.cross_entropy(logits, itm_labels) | |
##================= Image Captioning ========================## | |
loss_lm = 0 | |
if self.lm: | |
decoder_input_ids = text.clone() | |
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id | |
labels = decoder_input_ids.masked_fill( | |
decoder_input_ids == self.tokenizer.pad_token_id, -100 | |
) | |
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text.device) | |
attention_mask = torch.cat([query_atts, mask], dim=1) | |
lm_output = self.Qformer( | |
decoder_input_ids, | |
attention_mask=attention_mask, | |
past_key_values=query_output.past_key_values, | |
return_dict=True, | |
labels=labels, | |
) | |
loss_lm = lm_output.loss | |
return BlipOutput( | |
loss=loss_gtc + loss_gtm + loss_lm, | |
loss_itc=loss_gtc, | |
loss_itm=loss_gtm, | |
loss_lm=loss_lm, | |
) | |
def forward_v3(self, batch): | |
## v3: use smiles instruction | |
###============== Image-text Contrastive ===================### | |
graphs, text_tokens, prompt_tokens = batch | |
graph_embeds, graph_masks = self.graph_encoder(graphs) | |
if not self.tune_gnn: | |
graph_embeds = graph_embeds.detach() | |
graph_embeds = self.ln_graph(graph_embeds, graph_masks) | |
device = text_tokens.input_ids.device | |
batch_size = graph_embeds.shape[0] | |
## | |
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1) | |
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=device) | |
attention_mask_gtc = torch.cat([query_atts, prompt_tokens.attention_mask], dim=1) | |
query_output = self.Qformer.bert( | |
input_ids=prompt_tokens, | |
query_embeds=query_tokens, | |
attention_mask=attention_mask_gtc, | |
encoder_hidden_states=graph_embeds, | |
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct | |
use_cache=True, | |
return_dict=True, | |
) | |
query_output = query_output.last_hidden_state[:, : query_tokens.size(1), :] # keep query tokens only | |
graph_feats = self.graph_proj(query_output) # shape = [B, num_q, D] | |
text_output = self.Qformer.bert(text_tokens.input_ids, attention_mask=text_tokens.attention_mask, return_dict=True) # shape = [B, n_max, D] | |
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) | |
text_feats, graph_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(graph_feats, p=2, dim=-1) | |
text_feats_all, graph_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(graph_feats) # shape = [B * num_gpus, D] | |
sim_g2t, sim_t2g, loss_gtc = self.contrast_global(graph_feats, text_feats, graph_feats_all, text_feats_all, return_sim=True) | |
###============== Image-text Matching ===================### | |
loss_gtm = 0 | |
if self.gtm: | |
## not aggregate global tensor because of their different shapes | |
g_emb_world = graph_embeds | |
g_mask_world = graph_masks | |
text_ids_world = text_tokens.input_ids | |
text_mask_world = text_tokens.attention_mask | |
with torch.no_grad(): | |
weights_t2g = F.softmax(sim_t2g, dim=1) + 1e-4 | |
weights_t2g.fill_diagonal_(0) | |
weights_g2t = F.softmax(sim_g2t, dim=1) + 1e-4 | |
weights_g2t.fill_diagonal_(0) | |
# select a negative graph for each text | |
graph_embeds_neg = [] | |
graph_mask_neg = [] | |
for b in range(batch_size): | |
neg_idx = torch.multinomial(weights_t2g[b], 1).item() | |
graph_embeds_neg.append(g_emb_world[neg_idx]) | |
graph_mask_neg.append(g_mask_world[neg_idx]) | |
graph_embeds_neg = torch.stack(graph_embeds_neg, dim=0) | |
graph_mask_neg = torch.stack(graph_mask_neg, dim=0) | |
# select a negative text for each image | |
text_ids_neg = [] | |
text_atts_neg = [] | |
for b in range(batch_size): | |
neg_idx = torch.multinomial(weights_g2t[b], 1).item() | |
text_ids_neg.append(text_ids_world[neg_idx]) | |
text_atts_neg.append(text_mask_world[neg_idx]) | |
text_ids_neg = torch.stack(text_ids_neg, dim=0) | |
text_atts_neg = torch.stack(text_atts_neg, dim=0) | |
text_ids_all = torch.cat( | |
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0 | |
) # pos, pos, neg | |
text_atts_all = torch.cat( | |
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg], | |
dim=0, | |
) | |
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) | |
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device) | |
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) | |
graph_embeds_all = torch.cat([graph_embeds, graph_embeds_neg, graph_embeds], dim=0) # pos, neg, pos | |
graph_atts_all = torch.cat([graph_masks, graph_mask_neg, graph_masks], dim=0) | |
output_itm = self.Qformer.bert( | |
text_ids_all, | |
query_embeds=query_tokens_itm, | |
attention_mask=attention_mask_all, | |
encoder_hidden_states=graph_embeds_all, | |
encoder_attention_mask=graph_atts_all, | |
return_dict=True, | |
) | |
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] # keep query tokens only | |
vl_output = self.gtm_head(vl_embeddings) | |
logits = vl_output.mean(dim=1) | |
itm_labels = torch.cat( | |
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], | |
dim=0, | |
).to(text_tokens.input_ids.device) | |
loss_gtm = F.cross_entropy(logits, itm_labels) | |
##================= Image Captioning ========================## | |
loss_lm = 0 | |
if self.lm: | |
decoder_input_ids = text_tokens.input_ids.clone() | |
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id | |
labels = decoder_input_ids.masked_fill( | |
decoder_input_ids == self.tokenizer.pad_token_id, -100 | |
) | |
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=text_tokens.input_ids.device) | |
attention_mask = torch.cat([query_atts, prompt_tokens.attention_mask, text_tokens.attention_mask], dim=1) | |
lm_output = self.Qformer( | |
decoder_input_ids, | |
attention_mask=attention_mask, | |
past_key_values=query_output.past_key_values, | |
return_dict=True, | |
labels=labels, | |
) | |
loss_lm = lm_output.loss | |
return BlipOutput( | |
loss=loss_gtc + loss_gtm + loss_lm, | |
loss_itc=loss_gtc, | |
loss_itm=loss_gtm, | |
loss_lm=loss_lm, | |
) | |
def graph_forward(self, graph): | |
batch_node, batch_mask = self.graph_encoder(graph) | |
batch_node = self.ln_graph(batch_node, batch_mask) | |
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=batch_node, | |
encoder_attention_mask=batch_mask, # fixme: check whether this mask is correct | |
use_cache=False, | |
return_dict=True, | |
) | |
graph_feats = self.graph_proj(query_output.last_hidden_state) # shape = [B, num_q, D] | |
graph_feats = F.normalize(graph_feats, p=2, dim=-1) | |
return graph_feats, batch_node, batch_mask | |
def text_forward(self, text, mask): | |
text_output = self.Qformer.bert(text, attention_mask=mask, return_dict=True) # shape = [B, n_max, D] | |
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :] ) | |
text_feats = F.normalize(text_feats, dim=-1, p=2) | |
return text_feats | |
def compute_gtm(self, batch_node, batch_mask, text_ids, text_atts): | |
''' | |
batch_node shape = [B, N, D] | |
batch_mask shape = [B, N] | |
text_ids shape = [B, N] | |
text_atts shape = [B, N] | |
''' | |
query_tokens = self.query_tokens.expand(batch_node.shape[0], -1, -1) # shape = [B, Nq, D] | |
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( | |
batch_node.device | |
) # shape = [B, Nq] | |
attention_mask = torch.cat([query_atts, text_atts], dim=1) # shape = [B, Nq + N] | |
output_gtm = self.Qformer.bert( | |
text_ids, | |
query_embeds=query_tokens, | |
attention_mask=attention_mask, | |
encoder_hidden_states=batch_node, | |
encoder_attention_mask=batch_mask, | |
return_dict=True, | |
) | |
gl_embeddings = output_gtm.last_hidden_state[:, : query_tokens.size(1), :] # shape = [B, Nq, D] | |
gtm_logit = self.gtm_head(gl_embeddings).mean(dim=1) # shape = [B, Nq, 2] | |
# gtm_logit = F.softmax(gtm_logit, dim=-1)[:, 1] # select the axis of the positive class | |
gtm_logit = gtm_logit[:, 1] # select the axis of the positive class | |
return gtm_logit | |