WishArdently's picture
Upload InternVideo2Stage2VideoEncoder
fd3be3e verified
raw
history blame
1.37 kB
from internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from config import InternVideo2Config as config
import warnings
import torch
warnings.filterwarnings("ignore")
# model_config = config()
# model = IV2S2(model_config)
# print(model)
class InternVideo2Stage2VideoEncoder(PreTrainedModel):
config_class = config
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = IV2S2(config).half().to(config.device)
def forward(self, x: torch.tensor):
"""forward pass
Args:
x (torch.tensor): Shape (B, N, C, H, W) or (N, C, H, W)
Returns:
torch.tensor: Shape (B*N, hidden_size)
"""
# x: Shape(B, C, N, H, W)
# output: Shape(B, N*98, hidden_size)
if len(x.shape) == 4:
x = x.unsqueeze(0)
B, N, C, H, W = x.shape
x = x.permute(0, 2, 1, 3, 4) # Shape(B, C, N, H, W)
output = self.model.encode_vision(x)
pooled_vision_embeds = output[1]
return pooled_vision_embeds
if __name__ == "__main__":
model_config = config()
model = InternVideo2Stage2VideoEncoder(model_config)
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
output = model(x)