|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from typing import Optional, Tuple, Union, List, Dict, Any |
|
from synchformer_config import SynchformerConfig |
|
import os |
|
import json |
|
from pathlib import Path |
|
import tempfile |
|
import shutil |
|
import subprocess |
|
from omegaconf import OmegaConf |
|
import importlib.util |
|
|
|
|
|
class SynchformerModel(PreTrainedModel): |
|
config_class = SynchformerConfig |
|
base_model_prefix = "synchformer" |
|
|
|
def __init__(self, config: SynchformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
|
|
self.model = None |
|
self._is_initialized = False |
|
|
|
def _init_model(self, cfg): |
|
"""Initialize the model from the original config""" |
|
from scripts.train_utils import get_model |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
_, model = get_model(cfg, device) |
|
self.model = model |
|
self._is_initialized = True |
|
|
|
def forward( |
|
self, |
|
video: torch.Tensor, |
|
audio: torch.Tensor, |
|
return_dict: Optional[bool] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
video: Video tensor of shape (batch_size, num_frames, channels, height, width) |
|
audio: Audio tensor of shape (batch_size, num_channels, num_frames) |
|
return_dict: Whether to return a dictionary or tuple |
|
|
|
Returns: |
|
Tuple of (features, logits) |
|
""" |
|
if not self._is_initialized: |
|
raise RuntimeError("Model has not been properly initialized. Please use from_pretrained.") |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
with torch.set_grad_enabled(False): |
|
with torch.autocast('cuda', enabled=self.config.use_half_precision): |
|
features, logits = self.model(video, audio) |
|
|
|
if return_dict: |
|
return { |
|
"features": features, |
|
"logits": logits |
|
} |
|
else: |
|
return features, logits |
|
|
|
def prepare_inputs(self, video_path, offset_sec=0.0, v_start_i_sec=0.0): |
|
""" |
|
Prepare inputs for the model from a video file |
|
|
|
Args: |
|
video_path: Path to the video file |
|
offset_sec: Offset in seconds for audio |
|
v_start_i_sec: Start time in seconds for video |
|
|
|
Returns: |
|
Tuple of (audio, video, targets) |
|
""" |
|
import torchvision |
|
from dataset.dataset_utils import get_video_and_audio |
|
from dataset.transforms import make_class_grid |
|
from scripts.train_utils import get_transforms, prepare_inputs |
|
from utils.utils import check_if_file_exists_else_download, which_ffmpeg |
|
|
|
|
|
vfps = self.config.vfps |
|
afps = self.config.afps |
|
in_size = self.config.in_size |
|
|
|
v, _, info = torchvision.io.read_video(video_path, pts_unit='sec') |
|
_, H, W, _ = v.shape |
|
|
|
if info['video_fps'] != vfps or info['audio_fps'] != afps or min(H, W) != in_size: |
|
print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=' ') |
|
print(f'afps: {info["audio_fps"]} -> {afps};', end=' ') |
|
print(f'{(H, W)} -> min(H, W)={in_size}') |
|
video_path = self._reencode_video(video_path, vfps, afps, in_size) |
|
|
|
|
|
rgb, audio, meta = get_video_and_audio(video_path, get_meta=True) |
|
|
|
|
|
item = dict( |
|
video=rgb, audio=audio, meta=meta, path=video_path, split='test', |
|
targets={'v_start_i_sec': v_start_i_sec, 'offset_sec': offset_sec}, |
|
) |
|
|
|
|
|
cfg = self._get_original_config() |
|
transforms = get_transforms(cfg, ['test'])['test'] |
|
|
|
|
|
item = transforms(item) |
|
|
|
|
|
batch = torch.utils.data.default_collate([item]) |
|
aud, vid, targets = prepare_inputs(batch, self.device) |
|
|
|
return aud, vid, targets, batch |
|
|
|
def predict_offset(self, video_path, offset_sec=0.0, v_start_i_sec=0.0): |
|
""" |
|
Predict the audio-visual offset for a video |
|
|
|
Args: |
|
video_path: Path to the video file |
|
offset_sec: Ground truth offset in seconds (for evaluation) |
|
v_start_i_sec: Start time in seconds for video |
|
|
|
Returns: |
|
Dictionary with prediction results |
|
""" |
|
from dataset.transforms import make_class_grid |
|
|
|
|
|
aud, vid, targets, batch = self.prepare_inputs(video_path, offset_sec, v_start_i_sec) |
|
|
|
|
|
features, logits = self.forward(vid, aud, return_dict=False) |
|
|
|
|
|
cfg = self._get_original_config() |
|
max_off_sec = cfg.data.max_off_sec |
|
num_cls = cfg.model.params.transformer.params.off_head_cfg.params.out_features |
|
grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) |
|
|
|
|
|
off_probs = torch.softmax(logits, dim=-1) |
|
k = min(off_probs.shape[-1], 5) |
|
topk_logits, topk_preds = torch.topk(logits, k) |
|
|
|
|
|
topk_logits = topk_logits[0] |
|
topk_preds = topk_preds[0] |
|
off_logits = logits[0] |
|
off_probs = off_probs[0] |
|
|
|
|
|
results = { |
|
"grid": grid.cpu().numpy().tolist(), |
|
"predictions": [], |
|
"ground_truth": None |
|
} |
|
|
|
|
|
for i, target_hat in enumerate(topk_preds): |
|
idx = target_hat.item() |
|
results["predictions"].append({ |
|
"probability": off_probs[idx].item(), |
|
"logit": off_logits[idx].item(), |
|
"offset_sec": grid[idx].item(), |
|
"class_idx": idx, |
|
"rank": i |
|
}) |
|
|
|
|
|
if offset_sec != 0.0: |
|
from dataset.transforms import quantize_offset |
|
label = targets['offset_label'].item() |
|
results["ground_truth"] = { |
|
"offset_sec": label, |
|
"class_idx": quantize_offset(grid, label)[-1].item() |
|
} |
|
|
|
return results |
|
|
|
def _reencode_video(self, path, vfps=25, afps=16000, in_size=256): |
|
"""Reencode video to the required format""" |
|
from utils.utils import which_ffmpeg |
|
|
|
assert which_ffmpeg() != '', 'Is ffmpeg installed? Check if the conda environment is activated.' |
|
new_path = Path.cwd() / 'vis' / f'{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4' |
|
new_path.parent.mkdir(exist_ok=True) |
|
new_path = str(new_path) |
|
cmd = f'{which_ffmpeg()}' |
|
|
|
cmd += ' -hide_banner -loglevel panic' |
|
cmd += f' -y -i {path}' |
|
|
|
cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" |
|
cmd += f" -ar {afps}" |
|
cmd += f' {new_path}' |
|
subprocess.call(cmd.split()) |
|
cmd = f'{which_ffmpeg()}' |
|
cmd += ' -hide_banner -loglevel panic' |
|
cmd += f' -y -i {new_path}' |
|
cmd += f' -acodec pcm_s16le -ac 1' |
|
cmd += f' {new_path.replace(".mp4", ".wav")}' |
|
subprocess.call(cmd.split()) |
|
return new_path |
|
|
|
def _get_original_config(self): |
|
"""Get the original OmegaConf config from the model""" |
|
if not hasattr(self, "_original_config"): |
|
raise RuntimeError("Original config not found. Please use from_pretrained.") |
|
return self._original_config |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
*model_args, |
|
**kwargs |
|
): |
|
""" |
|
Load a Synchformer model from a pretrained model. |
|
|
|
Args: |
|
pretrained_model_name_or_path: Path to the pretrained model or its name |
|
|
|
Returns: |
|
SynchformerModel: The loaded model |
|
""" |
|
|
|
config = kwargs.pop("config", None) |
|
if config is None: |
|
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
|
|
model_path = pretrained_model_name_or_path |
|
else: |
|
|
|
model_path = model.get_file_from_repo( |
|
pretrained_model_name_or_path, |
|
filename="model_files.zip", |
|
cache_dir=kwargs.get("cache_dir", None), |
|
) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
import zipfile |
|
with zipfile.ZipFile(model_path, 'r') as zip_ref: |
|
zip_ref.extractall(temp_dir) |
|
|
|
|
|
model_path = temp_dir |
|
|
|
|
|
module_loader_path = os.path.join(temp_dir, "module_loader.py") |
|
spec = importlib.util.spec_from_file_location("module_loader", module_loader_path) |
|
module_loader = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(module_loader) |
|
|
|
|
|
module_loader.setup_modules() |
|
|
|
|
|
cfg_path = os.path.join(model_path, "cfg.yaml") |
|
ckpt_path = os.path.join(model_path, "model.pt") |
|
|
|
|
|
cfg = OmegaConf.load(cfg_path) |
|
|
|
|
|
cfg = model._patch_config(cfg) |
|
|
|
|
|
model._original_config = cfg |
|
|
|
|
|
model._init_model(cfg) |
|
|
|
|
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'), weights_only=False) |
|
model.model.load_state_dict(ckpt['model']) |
|
model.model.eval() |
|
|
|
return model |
|
|
|
def _patch_config(self, cfg): |
|
"""Patch the config as in the original code""" |
|
|
|
cfg.model.params.afeat_extractor.params.ckpt_path = None |
|
cfg.model.params.vfeat_extractor.params.ckpt_path = None |
|
|
|
cfg.model.params.transformer.target = cfg.model.params.transformer.target\ |
|
.replace('.modules.feature_selector.', '.sync_model.') |
|
return cfg |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: Union[str, os.PathLike], |
|
push_to_hub: bool = False, |
|
**kwargs |
|
): |
|
""" |
|
Save a model and its configuration file to a directory. |
|
|
|
Args: |
|
save_directory: Directory to save the model to |
|
push_to_hub: Whether to push the model to the hub |
|
|
|
Returns: |
|
List of files saved |
|
""" |
|
if os.path.isfile(save_directory): |
|
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
model_files_dir = os.path.join(save_directory, "model_files") |
|
os.makedirs(model_files_dir, exist_ok=True) |
|
|
|
|
|
cfg = self._get_original_config() |
|
cfg_path = os.path.join(model_files_dir, "cfg.yaml") |
|
OmegaConf.save(cfg, cfg_path) |
|
|
|
|
|
ckpt_path = os.path.join(model_files_dir, "model.pt") |
|
torch.save({"model": self.model.state_dict()}, ckpt_path) |
|
|
|
|
|
import zipfile |
|
zip_path = os.path.join(save_directory, "model_files.zip") |
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
for root, dirs, files in os.walk(model_files_dir): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
arcname = os.path.relpath(file_path, model_files_dir) |
|
zipf.write(file_path, arcname) |
|
|
|
|
|
shutil.rmtree(model_files_dir) |
|
|
|
|
|
if push_to_hub: |
|
self.push_to_hub(save_directory, **kwargs) |
|
|
|
return [ |
|
os.path.join(save_directory, "config.json"), |
|
zip_path |
|
] |