from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2 from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig from .config import InternVideo2Config as config import warnings import torch # from transformers.utils import logging warnings.filterwarnings("ignore") # logging.set_verbosity_error() # model_config = config() # model = IV2S2(model_config) # print(model) class InternVideo2Stage2VideoEncoder(PreTrainedModel): config_class = config def __init__(self, config): super().__init__(config) self.config = config # print(self.config.model.vision_encoder.num_frames) self.model = IV2S2(self.config).to('cpu').to(torch.float16) def forward(self, x: torch.tensor): """forward pass Args: x (torch.tensor): Shape (B, N, C, H, W) or (B, C, H, W) Returns: torch.tensor: Shape (B*N, hidden_size) or (B, hidden_size) """ if len(x.shape) == 5 and x.shape[1] > 8: ## There is no way, the weight limits the number of input frames to be less than or equal to 8. ## Forgive me for dealing with input frames greater than 8 in such a stupid way. T^T T = x.shape[1] embs = torch.cat([self.forward(x[:, i:i+8, :, :, :])for i in range(0, T, 8)], dim=1) return embs image = False if len(x.shape) == 4: x = x.unsqueeze(1) image = True B, N, C, H, W = x.shape # x = x.permute(0, 2, 1, 3, 4) # Shape(B, N, C, H, W) output = self.model.encode_vision(x) pooled_vision_embeds = output[1] # Shape(B, N*256 + 1, Hidden_size) output = pooled_vision_embeds[:, :256*N, :] # Shape(B, N*256, Hidden_size) output = output.reshape(B, N, 256, -1) # Shape(B, N, 256, Hidden_size) output = output.mean(dim=2) # Shape(B, N, Hidden_size) if image: output = output.squeeze(1) return output if __name__ == "__main__": model_config = config() model = InternVideo2Stage2VideoEncoder(model_config) x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) output = model(x)