InternVideo2Stage2-VisionEncoder / internvideo2_stage2.py
WishArdently's picture
Upload InternVideo2Stage2VideoEncoder
edf2ce7 verified
import logging
import json
import torch
from torch import nn
from .config import InternVideo2Config, EasyDict
from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224
from transformers.utils import logging
import warnings
warnings.filterwarnings("ignore")
class InternVideo2_Stage2(nn.Module):
"""docstring for InternVideo2_Stage2"""
def __init__(self, config, is_pretrain=True):
super(InternVideo2_Stage2, self).__init__()
# if isinstance(config, InternVideo2Config):
# config_str = str(config)
# config_str = config_str.replace('InternVideo2Config ', '')
# config_json = json.loads(config_str)
# config = EasyDict(config_json)
# self.config = config
self.config = config
self.is_pretrain = is_pretrain
self.vision_width = config.model.vision_encoder.clip_embed_dim
# self.text_width = config.model.text_encoder.d_model
self.embed_dim = config.model.embed_dim
# create modules.
self.vision_encoder = self.build_vision_encoder()
if config.model.get("freeze_vision", False):
self.freeze_vision()
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
self.uta_image_only = config.criterion.get('uta_image_only', False)
# logger.info(f"uta_image_only={self.uta_image_only}")
def freeze_vision(self):
"""freeze vision encoder"""
for p in self.vision_encoder.parameters():
p.requires_grad = False
def no_weight_decay(self):
ret = {"temp"}
ret.update(
{"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
)
# ret.update(
# {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
# )
return ret
@property
def dtype(self):
return self.vision_encoder.patch_embed.proj.weight.dtype
def encode_vision(self, image):
"""encode image / videos as features.
Args:
image (torch.Tensor): The input images. Shape(B, N, C, H, W)
test (bool): Whether testing.
Returns: tuple.
- vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
- pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
- student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C].
- clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C].
"""
T = image.shape[1]
use_image = True if T == 1 else False
image = image.permute(0, 2, 1, 3, 4) # [B,N,C,H,W] -> [B,C,N,H,W]
# whether save temporal dimension
# keep_temporal=self.config.model.vision_encoder.keep_temporal
vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder(
image, None, use_image)
return vision_embeds, pooled_vision_embeds
def build_vision_encoder(self):
"""build vision encoder
Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`.
"""
encoder_name = self.config.model.vision_encoder.name
# logger.info(f"Build vision_encoder: {encoder_name}")
if encoder_name == 'pretrain_internvideo2_1b_patch14_224':
vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model)
elif encoder_name == 'pretrain_internvideo2_6b_patch14_224':
vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model)
else:
raise ValueError(f"Not implemented: {encoder_name}")
return vision_encoder