from internvideo2_stage2 import InternVideo2_Stage2 as IV2S2 from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig from config import InternVideo2Config as config import warnings import torch warnings.filterwarnings("ignore") # model_config = config() # model = IV2S2(model_config) # print(model) class InternVideo2Stage2VideoEncoder(PreTrainedModel): config_class = config def __init__(self, config): super().__init__(config) self.config = config self.model = IV2S2(config).half().to(config.device) def forward(self, x: torch.tensor): """forward pass Args: x (torch.tensor): Shape (B, N, C, H, W) or (N, C, H, W) Returns: torch.tensor: Shape (B*N, hidden_size) """ # x: Shape(B, C, N, H, W) # output: Shape(B, N*98, hidden_size) if len(x.shape) == 4: x = x.unsqueeze(0) B, N, C, H, W = x.shape x = x.permute(0, 2, 1, 3, 4) # Shape(B, C, N, H, W) output = self.model.encode_vision(x) pooled_vision_embeds = output[1] return pooled_vision_embeds if __name__ == "__main__": model_config = config() model = InternVideo2Stage2VideoEncoder(model_config) x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) output = model(x)