Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import pyrootutils | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) | |
from src.train.dist_utils import concat_all_gather | |
def cosine_loss(rec, target): | |
target = target / target.norm(dim=-1, keepdim=True) | |
rec = rec / rec.norm(dim=-1, keepdim=True) | |
rec_loss = (1 - (target * rec).sum(-1)).mean() | |
return rec_loss | |
def contrastive_loss(image_feats, text_feats, logit_scale): | |
image_feats = image_feats.unsqueeze(1).contiguous() | |
image_feats_all = concat_all_gather(image_feats) # [batch_size*num_gpu, num_query_tokens, embed_dim] | |
text_feats_all = concat_all_gather(text_feats) # [batch_size*num_gpu, embed_dim] | |
sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feats_all.unsqueeze(-1)).squeeze() | |
# [batch_size, batch_size*num_gpu, num_query_tokens] | |
# image-text similarity: aggregate across all query tokens | |
# sim_i2t, _ = sim_q2t.max(-1) | |
# sim_i2t = sim_q2t.mean(-1) | |
sim_i2t = sim_q2t | |
sim_i2t = sim_i2t / logit_scale | |
# text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] | |
sim_t2q = torch.matmul(text_feats.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze() | |
# print(image_feats_all.shape, text_feat_all.shape, sim_q2t.shape, sim_t2q.shape) | |
# text-image similarity: aggregate across all query tokens | |
# sim_t2i, _ = sim_t2q.max(-1) | |
# sim_t2i = sim_t2q.mean(-1) | |
sim_t2i = sim_t2q | |
sim_t2i = sim_t2i / logit_scale # [batch_size, batch_size*num_gpu] | |
rank = dist.get_rank() | |
bs = image_feats.size(0) | |
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_feats.device) | |
loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + | |
F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2 | |
i2t_acc = (sim_i2t.argmax(-1) == targets).sum() / len(sim_i2t) | |
t2i_acc = (sim_t2i.argmax(-1) == targets).sum() / len(sim_t2i) | |
return loss_itc, i2t_acc, t2i_acc | |
class DiscreteModleOnlyDistill(nn.Module): | |
def __init__(self, | |
qformer, | |
quantizer, | |
distiller=None, | |
loss_type='cosine', | |
scale_commit_loss=1.0, | |
freeze_qformer=False) -> None: | |
super().__init__() | |
self.qformer = qformer | |
self.quantizer = quantizer | |
self.distiller = distiller | |
self.loss_type = loss_type | |
self.scale_commit_loss = scale_commit_loss | |
self.freeze_qformer = freeze_qformer | |
if freeze_qformer: | |
self.qformer.requires_grad_(False) | |
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
if self.freeze_qformer: | |
with torch.no_grad(): | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
else: | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
quantizer_output = self.quantizer(qforemr_embeds) | |
recon_embeds = self.distiller(quantizer_output['quant_embeds']) | |
if self.loss_type == 'cosine': | |
distill_loss = cosine_loss(recon_embeds, image_embeds) | |
else: | |
raise NotImplementedError | |
total_loss = distill_loss + self.scale_commit_loss * \ | |
quantizer_output['commit_loss'] | |
return { | |
'total_loss': total_loss, | |
'distill_loss': distill_loss, | |
'commit_loss': quantizer_output['commit_loss'], | |
'indices': quantizer_output['indices'] | |
} | |
def encode_image_embeds(self, image_embeds): | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
quantizer_output = self.quantizer(qforemr_embeds) | |
output_embeds = quantizer_output['quant_embeds'] | |
if self.distiller is not None: | |
output_embeds = self.distiller(output_embeds) | |
return output_embeds | |
def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs): | |
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs) | |
if pretrained_model_path is not None: | |
ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
return model | |
class DiscreteModleIdentity(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.model = nn.Identity() | |
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
return | |
def encode_image_embeds(self, image_embeds): | |
return self.model(image_embeds) | |
class DiscreteModleStageOneContrastive(nn.Module): | |
def __init__(self, qformer, quantizer=None, distiller=None, projection_dim=1024, | |
image_cls_token_type='last') -> None: | |
super().__init__() | |
self.qformer = qformer | |
self.quantizer = quantizer | |
self.distiller = distiller | |
self.image_cls_token_type = image_cls_token_type | |
self.logit_scale = nn.Parameter(0.07 * torch.ones([])) | |
self.image_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.text_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
image_embeds = self.qformer(image_embeds=image_embeds) | |
if self.image_cls_token_type == 'last': | |
image_embeds = image_embeds[:, -1, :] | |
else: | |
raise NotImplementedError | |
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask) | |
text_embeds = text_embeds[:, 0, :] | |
image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1) | |
text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1) | |
contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds, | |
text_feats=text_embeds, | |
logit_scale=self.logit_scale) | |
return { | |
'total_loss': contrast_loss, | |
'i2t_acc': i2t_acc, | |
't2i_acc': t2i_acc, | |
} | |
def encode_image_embeds(self, image_embeds): | |
image_embeds = self.qformer(image_embeds=image_embeds) | |
return image_embeds | |
def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs): | |
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs) | |
if pretrained_model_path is not None: | |
ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
return model | |
class DiscreteModleStageTwoContrastiveDistill(nn.Module): | |
def __init__(self, | |
qformer, | |
quantizer=None, | |
distiller=None, | |
contrast_head=None, | |
projection_dim=1024, | |
distill_loss_type='cosine', | |
freeze_qformer=True, | |
image_cls_token_type='last', | |
scale_commit_loss=1.0, | |
scale_contrast_loss=1.0, | |
scale_distill_loss=1.0) -> None: | |
super().__init__() | |
self.qformer = qformer | |
self.quantizer = quantizer | |
self.distiller = distiller | |
self.contrast_head = contrast_head | |
self.distill_loss_type = distill_loss_type | |
self.image_cls_token_type = image_cls_token_type | |
if self.contrast_head is not None: | |
self.logit_scale = nn.Parameter(0.07 * torch.ones([])) | |
self.image_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.text_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.freeze_qformer = freeze_qformer | |
if freeze_qformer: | |
self.qformer.requires_grad_(False) | |
self.scale_commit_loss = scale_commit_loss | |
self.scale_contrast_loss = scale_contrast_loss | |
self.scale_distill_loss = scale_distill_loss | |
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
if self.freeze_qformer: | |
with torch.no_grad(): | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
else: | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
quantizer_output = self.quantizer(qforemr_embeds) | |
output_state = {} | |
output_state['indices'] = quantizer_output['indices'] | |
output_state['commit_loss'] = quantizer_output['commit_loss'] | |
output_state['total_loss'] = self.scale_commit_loss * quantizer_output['commit_loss'] | |
if self.distiller is not None: | |
recon_embeds = self.distiller(quantizer_output['quant_embeds']) | |
if self.distill_loss_type == 'cosine': | |
distill_loss = cosine_loss(recon_embeds, image_embeds) | |
else: | |
raise NotImplementedError | |
output_state['distill_loss'] = distill_loss | |
output_state['total_loss'] += self.scale_distill_loss * distill_loss | |
if self.contrast_head is not None: | |
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask) | |
text_embeds = text_embeds[:, 0, :] | |
image_embeds = self.contrast_head(quantizer_output['quant_embeds']) | |
if self.image_cls_token_type == 'last': | |
image_embeds = image_embeds[:, -1, :] | |
else: | |
raise NotImplementedError | |
image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1) | |
text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1) | |
contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds, | |
text_feats=text_embeds, | |
logit_scale=self.logit_scale) | |
output_state['contrast_loss'] = contrast_loss | |
output_state['total_loss'] += self.scale_contrast_loss * contrast_loss | |
output_state['i2t_acc'] = i2t_acc | |
output_state['t2i_acc'] = t2i_acc | |
return output_state | |
def encode_image_embeds(self, image_embeds): | |
pass | |
def from_pretrained(cls, qformer, quantizer, distiller=None, contrast_head=None, pretrained_model_path=None, | |
**kwargs): | |
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs) | |
if pretrained_model_path is not None: | |
ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
return model | |
class DiscreteModleDistillWithDoubleContrastive(nn.Module): | |
def __init__( | |
self, | |
qformer, | |
quantizer=None, | |
distiller=None, | |
contrast_head=None, | |
projection_dim=1024, | |
distill_loss_type='cosine', | |
share_contrast_head=True, # share contrastive head with distiller | |
quantize_cls_token=False, | |
rec_qformer=False, | |
has_contrast=False, | |
freeze_qformer=False, | |
scale_commit_loss=1.0, | |
scale_contrast_loss=1.0, | |
scale_distill_loss=1.0) -> None: | |
super().__init__() | |
self.qformer = qformer | |
self.quantizer = quantizer | |
self.distiller = distiller | |
self.contrast_head = contrast_head | |
self.distill_loss_type = distill_loss_type | |
self.quantize_cls_token = quantize_cls_token | |
self.rec_qformer = rec_qformer | |
self.has_contrast = has_contrast | |
if freeze_qformer: | |
self.qformer.requires_grad_(False) | |
else: | |
self.logit_scale_qformer = nn.Parameter(0.07 * torch.ones([])) | |
self.image_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.text_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.cls_norm_qformer = nn.LayerNorm(qformer.perceiver.config.projection_dim) | |
if self.contrast_head is not None: | |
self.logit_scale_head = nn.Parameter(0.07 * torch.ones([])) | |
self.image_proj_head = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.cls_norm_head = nn.LayerNorm(contrast_head.perceiver.config.projection_dim) | |
if share_contrast_head and distiller is not None: | |
self.logit_scale_head = nn.Parameter(0.07 * torch.ones([])) | |
self.image_proj_head = nn.Linear(distiller.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False) | |
self.cls_norm_head = nn.LayerNorm(distiller.perceiver.config.projection_dim) | |
self.scale_commit_loss = scale_commit_loss | |
self.scale_contrast_loss = scale_contrast_loss | |
self.scale_distill_loss = scale_distill_loss | |
self.share_contrast_head = share_contrast_head | |
self.freeze_qformer = freeze_qformer | |
assert int(self.share_contrast_head) + int(contrast_head is not None) <= 1 | |
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None): | |
if self.freeze_qformer: | |
with torch.no_grad(): | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
else: | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
qforemr_cls_embeds = qforemr_embeds[:, -1, :] | |
if not self.quantize_cls_token: | |
qforemr_embeds = qforemr_embeds[:, :-1, :] | |
if self.has_contrast: | |
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask) | |
text_cls_embeds = text_embeds[:, 0, :] | |
output_state = {} | |
output_state['total_loss'] = 0.0 | |
if not self.freeze_qformer and self.has_contrast: | |
qforemr_cls_embeds = self.cls_norm_qformer(qforemr_cls_embeds) | |
qformer_image_embeds = F.normalize(self.image_proj_qformer(qforemr_cls_embeds), dim=-1) | |
qformer_text_embeds = F.normalize(self.text_proj_qformer(text_cls_embeds), dim=-1) | |
qformer_contrast_loss, \ | |
qformer_i2t_acc, \ | |
qformer_t2i_acc = contrastive_loss(image_feats=qformer_image_embeds, | |
text_feats=qformer_text_embeds, | |
logit_scale=self.logit_scale_qformer) | |
output_state['qformer_contrast_loss'] = qformer_contrast_loss | |
output_state['total_loss'] += self.scale_contrast_loss * qformer_contrast_loss | |
output_state['qformer_i2t_acc'] = qformer_i2t_acc | |
output_state['qformer_t2i_acc'] = qformer_t2i_acc | |
if self.quantizer is not None and self.distiller is not None: | |
quantizer_output = self.quantizer(qforemr_embeds) | |
recon_embeds = self.distiller(quantizer_output['quant_embeds']) | |
if self.share_contrast_head: | |
contrast_head_cls_embeds = recon_embeds[:, -1, :] | |
contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds) | |
recon_embeds = recon_embeds[:, :-1, :] | |
if self.contrast_head is not None: | |
contrast_head_embeds = self.contrast_head(quantizer_output['quant_embeds']) | |
contrast_head_cls_embeds = contrast_head_embeds[:, -1, :] | |
contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds) | |
output_state['indices'] = quantizer_output['indices'] | |
output_state['commit_loss'] = quantizer_output['commit_loss'] | |
output_state['total_loss'] += self.scale_commit_loss * quantizer_output['commit_loss'] | |
if self.rec_qformer: | |
target_embeds = qforemr_embeds | |
else: | |
target_embeds = image_embeds | |
if self.distill_loss_type == 'cosine': | |
distill_loss = cosine_loss(recon_embeds, target_embeds) | |
else: | |
raise NotImplementedError | |
output_state['distill_loss'] = distill_loss | |
output_state['total_loss'] += self.scale_distill_loss * distill_loss | |
if self.contrast_head is not None or self.share_contrast_head: | |
head_image_embeds = F.normalize(self.image_proj_head(contrast_head_cls_embeds), dim=-1) | |
head_text_embeds = F.normalize(self.text_proj_head(text_cls_embeds), dim=-1) | |
head_contrast_loss, head_i2t_acc, head_t2i_acc = contrastive_loss(image_feats=head_image_embeds, | |
text_feats=head_text_embeds, | |
logit_scale=self.logit_scale_head) | |
output_state['head_contrast_loss'] = head_contrast_loss | |
output_state['total_loss'] += self.scale_contrast_loss * head_contrast_loss | |
output_state['head_i2t_acc'] = head_i2t_acc | |
output_state['head_t2i_acc'] = head_t2i_acc | |
return output_state | |
def encode_image_embeds(self, image_embeds): | |
qforemr_embeds = self.qformer(image_embeds=image_embeds) | |
return qforemr_embeds | |
def from_pretrained(cls, qformer, quantizer=None, distiller=None, contrast_head=None, pretrained_model_path=None, | |
**kwargs): | |
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs) | |
if pretrained_model_path is not None: | |
ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
missing, unexpected = model.load_state_dict(ckpt, strict=False) | |
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
return model | |
def from_pretrained_stage1_yuying(cls, | |
qformer, | |
quantizer=None, | |
distiller=None, | |
contrast_head=None, | |
pretrained_model_path=None, | |
**kwargs): | |
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs) | |
if pretrained_model_path is not None: | |
ckpt = torch.load(pretrained_model_path, map_location='cpu') | |
ckpt = ckpt['model'] | |
new_ckpt = {} | |
new_ckpt['qformer.embed_module.query'] = ckpt['query_tokens'].squeeze(0) | |
new_ckpt['qformer.norm.weight'] = ckpt['ln_vision.weight'] | |
new_ckpt['qformer.norm.bias'] = ckpt['ln_vision.bias'] | |
for key in ckpt.keys(): | |
if key.startswith('Qformer'): | |
new_key = key.replace('Qformer', 'qformer.perceiver') | |
new_ckpt[new_key] = ckpt[key] | |
del ckpt | |
missing, unexpected = model.load_state_dict(new_ckpt, strict=False) | |
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) | |
print(missing) | |
print(unexpected) | |
return model | |