|
from internvideo2_stage2 import InternVideo2_Stage2 as IV2S2 |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
|
from config import InternVideo2Config as config |
|
import warnings |
|
import torch |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
|
|
|
class InternVideo2Stage2VideoEncoder(PreTrainedModel): |
|
config_class = config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.model = IV2S2(config).half().to(config.device) |
|
|
|
def forward(self, x: torch.tensor): |
|
"""forward pass |
|
Args: |
|
x (torch.tensor): Shape (B, N, C, H, W) or (N, C, H, W) |
|
Returns: |
|
torch.tensor: Shape (B*N, hidden_size) |
|
""" |
|
|
|
|
|
if len(x.shape) == 4: |
|
x = x.unsqueeze(0) |
|
B, N, C, H, W = x.shape |
|
x = x.permute(0, 2, 1, 3, 4) |
|
output = self.model.encode_vision(x) |
|
pooled_vision_embeds = output[1] |
|
return pooled_vision_embeds |
|
|
|
if __name__ == "__main__": |
|
model_config = config() |
|
model = InternVideo2Stage2VideoEncoder(model_config) |
|
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device) |
|
output = model(x) |