|
import logging |
|
import json |
|
import torch |
|
from torch import nn |
|
from .config import InternVideo2Config, EasyDict |
|
from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224 |
|
from transformers.utils import logging |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
class InternVideo2_Stage2(nn.Module): |
|
"""docstring for InternVideo2_Stage2""" |
|
|
|
def __init__(self, config, is_pretrain=True): |
|
super(InternVideo2_Stage2, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.config = config |
|
|
|
self.is_pretrain = is_pretrain |
|
self.vision_width = config.model.vision_encoder.clip_embed_dim |
|
|
|
self.embed_dim = config.model.embed_dim |
|
|
|
|
|
self.vision_encoder = self.build_vision_encoder() |
|
if config.model.get("freeze_vision", False): |
|
self.freeze_vision() |
|
|
|
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) |
|
|
|
self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) |
|
self.uta_image_only = config.criterion.get('uta_image_only', False) |
|
|
|
|
|
|
|
def freeze_vision(self): |
|
"""freeze vision encoder""" |
|
for p in self.vision_encoder.parameters(): |
|
p.requires_grad = False |
|
|
|
def no_weight_decay(self): |
|
ret = {"temp"} |
|
ret.update( |
|
{"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} |
|
) |
|
|
|
|
|
|
|
|
|
return ret |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_encoder.patch_embed.proj.weight.dtype |
|
|
|
def encode_vision(self, image): |
|
"""encode image / videos as features. |
|
|
|
Args: |
|
image (torch.Tensor): The input images. Shape(B, N, C, H, W) |
|
test (bool): Whether testing. |
|
|
|
Returns: tuple. |
|
- vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. |
|
- pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. |
|
- student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. |
|
- clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. |
|
|
|
""" |
|
T = image.shape[1] |
|
use_image = True if T == 1 else False |
|
image = image.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( |
|
image, None, use_image) |
|
return vision_embeds, pooled_vision_embeds |
|
|
|
def build_vision_encoder(self): |
|
"""build vision encoder |
|
Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. |
|
|
|
""" |
|
encoder_name = self.config.model.vision_encoder.name |
|
|
|
if encoder_name == 'pretrain_internvideo2_1b_patch14_224': |
|
vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model) |
|
elif encoder_name == 'pretrain_internvideo2_6b_patch14_224': |
|
vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model) |
|
else: |
|
raise ValueError(f"Not implemented: {encoder_name}") |
|
return vision_encoder |
|
|
|
|