SEED-Story / src /models /discrete_models.py
xinlai's picture
seedx
674d663
raw
history blame
20.3 kB
import torch
import torch.nn as nn
import pyrootutils
import torch.distributed as dist
import torch.nn.functional as F
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
from src.train.dist_utils import concat_all_gather
def cosine_loss(rec, target):
target = target / target.norm(dim=-1, keepdim=True)
rec = rec / rec.norm(dim=-1, keepdim=True)
rec_loss = (1 - (target * rec).sum(-1)).mean()
return rec_loss
def contrastive_loss(image_feats, text_feats, logit_scale):
image_feats = image_feats.unsqueeze(1).contiguous()
image_feats_all = concat_all_gather(image_feats) # [batch_size*num_gpu, num_query_tokens, embed_dim]
text_feats_all = concat_all_gather(text_feats) # [batch_size*num_gpu, embed_dim]
sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feats_all.unsqueeze(-1)).squeeze()
# [batch_size, batch_size*num_gpu, num_query_tokens]
# image-text similarity: aggregate across all query tokens
# sim_i2t, _ = sim_q2t.max(-1)
# sim_i2t = sim_q2t.mean(-1)
sim_i2t = sim_q2t
sim_i2t = sim_i2t / logit_scale
# text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
sim_t2q = torch.matmul(text_feats.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze()
# print(image_feats_all.shape, text_feat_all.shape, sim_q2t.shape, sim_t2q.shape)
# text-image similarity: aggregate across all query tokens
# sim_t2i, _ = sim_t2q.max(-1)
# sim_t2i = sim_t2q.mean(-1)
sim_t2i = sim_t2q
sim_t2i = sim_t2i / logit_scale # [batch_size, batch_size*num_gpu]
rank = dist.get_rank()
bs = image_feats.size(0)
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_feats.device)
loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) +
F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2
i2t_acc = (sim_i2t.argmax(-1) == targets).sum() / len(sim_i2t)
t2i_acc = (sim_t2i.argmax(-1) == targets).sum() / len(sim_t2i)
return loss_itc, i2t_acc, t2i_acc
class DiscreteModleOnlyDistill(nn.Module):
def __init__(self,
qformer,
quantizer,
distiller=None,
loss_type='cosine',
scale_commit_loss=1.0,
freeze_qformer=False) -> None:
super().__init__()
self.qformer = qformer
self.quantizer = quantizer
self.distiller = distiller
self.loss_type = loss_type
self.scale_commit_loss = scale_commit_loss
self.freeze_qformer = freeze_qformer
if freeze_qformer:
self.qformer.requires_grad_(False)
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
if self.freeze_qformer:
with torch.no_grad():
qforemr_embeds = self.qformer(image_embeds=image_embeds)
else:
qforemr_embeds = self.qformer(image_embeds=image_embeds)
quantizer_output = self.quantizer(qforemr_embeds)
recon_embeds = self.distiller(quantizer_output['quant_embeds'])
if self.loss_type == 'cosine':
distill_loss = cosine_loss(recon_embeds, image_embeds)
else:
raise NotImplementedError
total_loss = distill_loss + self.scale_commit_loss * \
quantizer_output['commit_loss']
return {
'total_loss': total_loss,
'distill_loss': distill_loss,
'commit_loss': quantizer_output['commit_loss'],
'indices': quantizer_output['indices']
}
def encode_image_embeds(self, image_embeds):
qforemr_embeds = self.qformer(image_embeds=image_embeds)
quantizer_output = self.quantizer(qforemr_embeds)
output_embeds = quantizer_output['quant_embeds']
if self.distiller is not None:
output_embeds = self.distiller(output_embeds)
return output_embeds
@classmethod
def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs):
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
class DiscreteModleIdentity(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = nn.Identity()
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
return
def encode_image_embeds(self, image_embeds):
return self.model(image_embeds)
class DiscreteModleStageOneContrastive(nn.Module):
def __init__(self, qformer, quantizer=None, distiller=None, projection_dim=1024,
image_cls_token_type='last') -> None:
super().__init__()
self.qformer = qformer
self.quantizer = quantizer
self.distiller = distiller
self.image_cls_token_type = image_cls_token_type
self.logit_scale = nn.Parameter(0.07 * torch.ones([]))
self.image_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
self.text_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
image_embeds = self.qformer(image_embeds=image_embeds)
if self.image_cls_token_type == 'last':
image_embeds = image_embeds[:, -1, :]
else:
raise NotImplementedError
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
text_embeds = text_embeds[:, 0, :]
image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1)
text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1)
contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds,
text_feats=text_embeds,
logit_scale=self.logit_scale)
return {
'total_loss': contrast_loss,
'i2t_acc': i2t_acc,
't2i_acc': t2i_acc,
}
def encode_image_embeds(self, image_embeds):
image_embeds = self.qformer(image_embeds=image_embeds)
return image_embeds
@classmethod
def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs):
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
class DiscreteModleStageTwoContrastiveDistill(nn.Module):
def __init__(self,
qformer,
quantizer=None,
distiller=None,
contrast_head=None,
projection_dim=1024,
distill_loss_type='cosine',
freeze_qformer=True,
image_cls_token_type='last',
scale_commit_loss=1.0,
scale_contrast_loss=1.0,
scale_distill_loss=1.0) -> None:
super().__init__()
self.qformer = qformer
self.quantizer = quantizer
self.distiller = distiller
self.contrast_head = contrast_head
self.distill_loss_type = distill_loss_type
self.image_cls_token_type = image_cls_token_type
if self.contrast_head is not None:
self.logit_scale = nn.Parameter(0.07 * torch.ones([]))
self.image_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
self.text_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
self.freeze_qformer = freeze_qformer
if freeze_qformer:
self.qformer.requires_grad_(False)
self.scale_commit_loss = scale_commit_loss
self.scale_contrast_loss = scale_contrast_loss
self.scale_distill_loss = scale_distill_loss
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
if self.freeze_qformer:
with torch.no_grad():
qforemr_embeds = self.qformer(image_embeds=image_embeds)
else:
qforemr_embeds = self.qformer(image_embeds=image_embeds)
quantizer_output = self.quantizer(qforemr_embeds)
output_state = {}
output_state['indices'] = quantizer_output['indices']
output_state['commit_loss'] = quantizer_output['commit_loss']
output_state['total_loss'] = self.scale_commit_loss * quantizer_output['commit_loss']
if self.distiller is not None:
recon_embeds = self.distiller(quantizer_output['quant_embeds'])
if self.distill_loss_type == 'cosine':
distill_loss = cosine_loss(recon_embeds, image_embeds)
else:
raise NotImplementedError
output_state['distill_loss'] = distill_loss
output_state['total_loss'] += self.scale_distill_loss * distill_loss
if self.contrast_head is not None:
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
text_embeds = text_embeds[:, 0, :]
image_embeds = self.contrast_head(quantizer_output['quant_embeds'])
if self.image_cls_token_type == 'last':
image_embeds = image_embeds[:, -1, :]
else:
raise NotImplementedError
image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1)
text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1)
contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds,
text_feats=text_embeds,
logit_scale=self.logit_scale)
output_state['contrast_loss'] = contrast_loss
output_state['total_loss'] += self.scale_contrast_loss * contrast_loss
output_state['i2t_acc'] = i2t_acc
output_state['t2i_acc'] = t2i_acc
return output_state
def encode_image_embeds(self, image_embeds):
pass
@classmethod
def from_pretrained(cls, qformer, quantizer, distiller=None, contrast_head=None, pretrained_model_path=None,
**kwargs):
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
class DiscreteModleDistillWithDoubleContrastive(nn.Module):
def __init__(
self,
qformer,
quantizer=None,
distiller=None,
contrast_head=None,
projection_dim=1024,
distill_loss_type='cosine',
share_contrast_head=True, # share contrastive head with distiller
quantize_cls_token=False,
rec_qformer=False,
has_contrast=False,
freeze_qformer=False,
scale_commit_loss=1.0,
scale_contrast_loss=1.0,
scale_distill_loss=1.0) -> None:
super().__init__()
self.qformer = qformer
self.quantizer = quantizer
self.distiller = distiller
self.contrast_head = contrast_head
self.distill_loss_type = distill_loss_type
self.quantize_cls_token = quantize_cls_token
self.rec_qformer = rec_qformer
self.has_contrast = has_contrast
if freeze_qformer:
self.qformer.requires_grad_(False)
else:
self.logit_scale_qformer = nn.Parameter(0.07 * torch.ones([]))
self.image_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
self.text_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
self.cls_norm_qformer = nn.LayerNorm(qformer.perceiver.config.projection_dim)
if self.contrast_head is not None:
self.logit_scale_head = nn.Parameter(0.07 * torch.ones([]))
self.image_proj_head = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
self.cls_norm_head = nn.LayerNorm(contrast_head.perceiver.config.projection_dim)
if share_contrast_head and distiller is not None:
self.logit_scale_head = nn.Parameter(0.07 * torch.ones([]))
self.image_proj_head = nn.Linear(distiller.perceiver.config.projection_dim, projection_dim, bias=False)
self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
self.cls_norm_head = nn.LayerNorm(distiller.perceiver.config.projection_dim)
self.scale_commit_loss = scale_commit_loss
self.scale_contrast_loss = scale_contrast_loss
self.scale_distill_loss = scale_distill_loss
self.share_contrast_head = share_contrast_head
self.freeze_qformer = freeze_qformer
assert int(self.share_contrast_head) + int(contrast_head is not None) <= 1
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
if self.freeze_qformer:
with torch.no_grad():
qforemr_embeds = self.qformer(image_embeds=image_embeds)
else:
qforemr_embeds = self.qformer(image_embeds=image_embeds)
qforemr_cls_embeds = qforemr_embeds[:, -1, :]
if not self.quantize_cls_token:
qforemr_embeds = qforemr_embeds[:, :-1, :]
if self.has_contrast:
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
text_cls_embeds = text_embeds[:, 0, :]
output_state = {}
output_state['total_loss'] = 0.0
if not self.freeze_qformer and self.has_contrast:
qforemr_cls_embeds = self.cls_norm_qformer(qforemr_cls_embeds)
qformer_image_embeds = F.normalize(self.image_proj_qformer(qforemr_cls_embeds), dim=-1)
qformer_text_embeds = F.normalize(self.text_proj_qformer(text_cls_embeds), dim=-1)
qformer_contrast_loss, \
qformer_i2t_acc, \
qformer_t2i_acc = contrastive_loss(image_feats=qformer_image_embeds,
text_feats=qformer_text_embeds,
logit_scale=self.logit_scale_qformer)
output_state['qformer_contrast_loss'] = qformer_contrast_loss
output_state['total_loss'] += self.scale_contrast_loss * qformer_contrast_loss
output_state['qformer_i2t_acc'] = qformer_i2t_acc
output_state['qformer_t2i_acc'] = qformer_t2i_acc
if self.quantizer is not None and self.distiller is not None:
quantizer_output = self.quantizer(qforemr_embeds)
recon_embeds = self.distiller(quantizer_output['quant_embeds'])
if self.share_contrast_head:
contrast_head_cls_embeds = recon_embeds[:, -1, :]
contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds)
recon_embeds = recon_embeds[:, :-1, :]
if self.contrast_head is not None:
contrast_head_embeds = self.contrast_head(quantizer_output['quant_embeds'])
contrast_head_cls_embeds = contrast_head_embeds[:, -1, :]
contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds)
output_state['indices'] = quantizer_output['indices']
output_state['commit_loss'] = quantizer_output['commit_loss']
output_state['total_loss'] += self.scale_commit_loss * quantizer_output['commit_loss']
if self.rec_qformer:
target_embeds = qforemr_embeds
else:
target_embeds = image_embeds
if self.distill_loss_type == 'cosine':
distill_loss = cosine_loss(recon_embeds, target_embeds)
else:
raise NotImplementedError
output_state['distill_loss'] = distill_loss
output_state['total_loss'] += self.scale_distill_loss * distill_loss
if self.contrast_head is not None or self.share_contrast_head:
head_image_embeds = F.normalize(self.image_proj_head(contrast_head_cls_embeds), dim=-1)
head_text_embeds = F.normalize(self.text_proj_head(text_cls_embeds), dim=-1)
head_contrast_loss, head_i2t_acc, head_t2i_acc = contrastive_loss(image_feats=head_image_embeds,
text_feats=head_text_embeds,
logit_scale=self.logit_scale_head)
output_state['head_contrast_loss'] = head_contrast_loss
output_state['total_loss'] += self.scale_contrast_loss * head_contrast_loss
output_state['head_i2t_acc'] = head_i2t_acc
output_state['head_t2i_acc'] = head_t2i_acc
return output_state
def encode_image_embeds(self, image_embeds):
qforemr_embeds = self.qformer(image_embeds=image_embeds)
return qforemr_embeds
@classmethod
def from_pretrained(cls, qformer, quantizer=None, distiller=None, contrast_head=None, pretrained_model_path=None,
**kwargs):
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
@classmethod
def from_pretrained_stage1_yuying(cls,
qformer,
quantizer=None,
distiller=None,
contrast_head=None,
pretrained_model_path=None,
**kwargs):
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
ckpt = ckpt['model']
new_ckpt = {}
new_ckpt['qformer.embed_module.query'] = ckpt['query_tokens'].squeeze(0)
new_ckpt['qformer.norm.weight'] = ckpt['ln_vision.weight']
new_ckpt['qformer.norm.bias'] = ckpt['ln_vision.bias']
for key in ckpt.keys():
if key.startswith('Qformer'):
new_key = key.replace('Qformer', 'qformer.perceiver')
new_ckpt[new_key] = ckpt[key]
del ckpt
missing, unexpected = model.load_state_dict(new_ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
print(missing)
print(unexpected)
return model