from torch import Tensor from transformers import PreTrainedModel from .configuration_act_estimator import ActEstimatorConfig from .model import VideoActionEstimator class ActEstimator(PreTrainedModel): config_class = ActEstimatorConfig def __init__(self, config: ActEstimatorConfig): super().__init__(config) self.model = VideoActionEstimator(**config.to_dict()) def forward(self, frames: Tensor, timestamps: Tensor = None) -> dict[str, Tensor]: return self.model(frames, timestamps)