synchformer-hf / synchformer_model.py
AmrMKayid's picture
Upload folder using huggingface_hub
1a6ac97 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from typing import Optional, Tuple, Union, List, Dict, Any
from synchformer_config import SynchformerConfig
import os
import json
from pathlib import Path
import tempfile
import shutil
import subprocess
from omegaconf import OmegaConf
import importlib.util
class SynchformerModel(PreTrainedModel):
config_class = SynchformerConfig
base_model_prefix = "synchformer"
def __init__(self, config: SynchformerConfig):
super().__init__(config)
self.config = config
# We'll initialize the actual model in the from_pretrained method
# since we need to load the original model architecture from the config
self.model = None
self._is_initialized = False
def _init_model(self, cfg):
"""Initialize the model from the original config"""
from scripts.train_utils import get_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, model = get_model(cfg, device)
self.model = model
self._is_initialized = True
def forward(
self,
video: torch.Tensor,
audio: torch.Tensor,
return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
video: Video tensor of shape (batch_size, num_frames, channels, height, width)
audio: Audio tensor of shape (batch_size, num_channels, num_frames)
return_dict: Whether to return a dictionary or tuple
Returns:
Tuple of (features, logits)
"""
if not self._is_initialized:
raise RuntimeError("Model has not been properly initialized. Please use from_pretrained.")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
with torch.set_grad_enabled(False):
with torch.autocast('cuda', enabled=self.config.use_half_precision):
features, logits = self.model(video, audio)
if return_dict:
return {
"features": features,
"logits": logits
}
else:
return features, logits
def prepare_inputs(self, video_path, offset_sec=0.0, v_start_i_sec=0.0):
"""
Prepare inputs for the model from a video file
Args:
video_path: Path to the video file
offset_sec: Offset in seconds for audio
v_start_i_sec: Start time in seconds for video
Returns:
Tuple of (audio, video, targets)
"""
import torchvision
from dataset.dataset_utils import get_video_and_audio
from dataset.transforms import make_class_grid
from scripts.train_utils import get_transforms, prepare_inputs
from utils.utils import check_if_file_exists_else_download, which_ffmpeg
# Check if video needs reencoding
vfps = self.config.vfps
afps = self.config.afps
in_size = self.config.in_size
v, _, info = torchvision.io.read_video(video_path, pts_unit='sec')
_, H, W, _ = v.shape
if info['video_fps'] != vfps or info['audio_fps'] != afps or min(H, W) != in_size:
print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=' ')
print(f'afps: {info["audio_fps"]} -> {afps};', end=' ')
print(f'{(H, W)} -> min(H, W)={in_size}')
video_path = self._reencode_video(video_path, vfps, afps, in_size)
# Load visual and audio streams
rgb, audio, meta = get_video_and_audio(video_path, get_meta=True)
# Create item for transformations
item = dict(
video=rgb, audio=audio, meta=meta, path=video_path, split='test',
targets={'v_start_i_sec': v_start_i_sec, 'offset_sec': offset_sec},
)
# Get transforms from config
cfg = self._get_original_config()
transforms = get_transforms(cfg, ['test'])['test']
# Apply transforms
item = transforms(item)
# Prepare inputs for inference
batch = torch.utils.data.default_collate([item])
aud, vid, targets = prepare_inputs(batch, self.device)
return aud, vid, targets, batch
def predict_offset(self, video_path, offset_sec=0.0, v_start_i_sec=0.0):
"""
Predict the audio-visual offset for a video
Args:
video_path: Path to the video file
offset_sec: Ground truth offset in seconds (for evaluation)
v_start_i_sec: Start time in seconds for video
Returns:
Dictionary with prediction results
"""
from dataset.transforms import make_class_grid
# Prepare inputs
aud, vid, targets, batch = self.prepare_inputs(video_path, offset_sec, v_start_i_sec)
# Forward pass
features, logits = self.forward(vid, aud, return_dict=False)
# Get grid for interpretation
cfg = self._get_original_config()
max_off_sec = cfg.data.max_off_sec
num_cls = cfg.model.params.transformer.params.off_head_cfg.params.out_features
grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
# Process results
off_probs = torch.softmax(logits, dim=-1)
k = min(off_probs.shape[-1], 5)
topk_logits, topk_preds = torch.topk(logits, k)
# Remove batch dimension
topk_logits = topk_logits[0]
topk_preds = topk_preds[0]
off_logits = logits[0]
off_probs = off_probs[0]
# Prepare results
results = {
"grid": grid.cpu().numpy().tolist(),
"predictions": [],
"ground_truth": None
}
# Add top-k predictions
for i, target_hat in enumerate(topk_preds):
idx = target_hat.item()
results["predictions"].append({
"probability": off_probs[idx].item(),
"logit": off_logits[idx].item(),
"offset_sec": grid[idx].item(),
"class_idx": idx,
"rank": i
})
# Add ground truth if provided
if offset_sec != 0.0:
from dataset.transforms import quantize_offset
label = targets['offset_label'].item()
results["ground_truth"] = {
"offset_sec": label,
"class_idx": quantize_offset(grid, label)[-1].item()
}
return results
def _reencode_video(self, path, vfps=25, afps=16000, in_size=256):
"""Reencode video to the required format"""
from utils.utils import which_ffmpeg
assert which_ffmpeg() != '', 'Is ffmpeg installed? Check if the conda environment is activated.'
new_path = Path.cwd() / 'vis' / f'{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4'
new_path.parent.mkdir(exist_ok=True)
new_path = str(new_path)
cmd = f'{which_ffmpeg()}'
# no info/error printing
cmd += ' -hide_banner -loglevel panic'
cmd += f' -y -i {path}'
# 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
cmd += f" -ar {afps}"
cmd += f' {new_path}'
subprocess.call(cmd.split())
cmd = f'{which_ffmpeg()}'
cmd += ' -hide_banner -loglevel panic'
cmd += f' -y -i {new_path}'
cmd += f' -acodec pcm_s16le -ac 1'
cmd += f' {new_path.replace(".mp4", ".wav")}'
subprocess.call(cmd.split())
return new_path
def _get_original_config(self):
"""Get the original OmegaConf config from the model"""
if not hasattr(self, "_original_config"):
raise RuntimeError("Original config not found. Please use from_pretrained.")
return self._original_config
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
**kwargs
):
"""
Load a Synchformer model from a pretrained model.
Args:
pretrained_model_name_or_path: Path to the pretrained model or its name
Returns:
SynchformerModel: The loaded model
"""
# First load the config
config = kwargs.pop("config", None)
if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Initialize the model with the config
model = cls(config)
# Check if we have local files or need to download from hub
if os.path.isdir(pretrained_model_name_or_path):
# Local directory
model_path = pretrained_model_name_or_path
else:
# Download from hub
model_path = model.get_file_from_repo(
pretrained_model_name_or_path,
filename="model_files.zip",
cache_dir=kwargs.get("cache_dir", None),
)
# Extract the zip file
with tempfile.TemporaryDirectory() as temp_dir:
import zipfile
with zipfile.ZipFile(model_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# Set model_path to the extracted directory
model_path = temp_dir
# Load the module loader
module_loader_path = os.path.join(temp_dir, "module_loader.py")
spec = importlib.util.spec_from_file_location("module_loader", module_loader_path)
module_loader = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module_loader)
# Setup the modules
module_loader.setup_modules()
# Load the original config and checkpoint
cfg_path = os.path.join(model_path, "cfg.yaml")
ckpt_path = os.path.join(model_path, "model.pt")
# Load the original config
cfg = OmegaConf.load(cfg_path)
# Patch config (as in the original code)
cfg = model._patch_config(cfg)
# Store the original config
model._original_config = cfg
# Initialize the model
model._init_model(cfg)
# Load the checkpoint
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'), weights_only=False)
model.model.load_state_dict(ckpt['model'])
model.model.eval()
return model
def _patch_config(self, cfg):
"""Patch the config as in the original code"""
# the FE ckpts are already in the model ckpt
cfg.model.params.afeat_extractor.params.ckpt_path = None
cfg.model.params.vfeat_extractor.params.ckpt_path = None
# old checkpoints have different names
cfg.model.params.transformer.target = cfg.model.params.transformer.target\
.replace('.modules.feature_selector.', '.sync_model.')
return cfg
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
push_to_hub: bool = False,
**kwargs
):
"""
Save a model and its configuration file to a directory.
Args:
save_directory: Directory to save the model to
push_to_hub: Whether to push the model to the hub
Returns:
List of files saved
"""
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
# Save the config
self.config.save_pretrained(save_directory)
# Create a directory for model files
model_files_dir = os.path.join(save_directory, "model_files")
os.makedirs(model_files_dir, exist_ok=True)
# Save the original config
cfg = self._get_original_config()
cfg_path = os.path.join(model_files_dir, "cfg.yaml")
OmegaConf.save(cfg, cfg_path)
# Save the model checkpoint
ckpt_path = os.path.join(model_files_dir, "model.pt")
torch.save({"model": self.model.state_dict()}, ckpt_path)
# Create a zip file of the model files
import zipfile
zip_path = os.path.join(save_directory, "model_files.zip")
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(model_files_dir):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, model_files_dir)
zipf.write(file_path, arcname)
# Remove the model_files directory
shutil.rmtree(model_files_dir)
# Push to hub if requested
if push_to_hub:
self.push_to_hub(save_directory, **kwargs)
return [
os.path.join(save_directory, "config.json"),
zip_path
]