import os import torch import torchaudio import numpy as np from omegaconf import OmegaConf from huggingface_hub import hf_hub_download from torch.nn.functional import pad, normalize, softmax from manipulate_model.model import Model def get_config_and_model(model_root="manipulate_model/demo-model/audio"): config_path = os.path.join(model_root, "config.yaml") config = OmegaConf.load(config_path) if isinstance(config.model.encoder, str): config.model.encoder = OmegaConf.load(config.model.encoder) if isinstance(config.model.decoder, str): config.model.decoder = OmegaConf.load(config.model.decoder) model = Model(config=config) model_file = hf_hub_download("arnabdas8901/manipulation_detection_transformer", filename= "weights.pt") weights = torch.load(model_file, map_location=torch.device("cpu")) model.load_state_dict(weights["model_state_dict"]) print("### Model loaded from :", model_file) return config, model def load_audio(file_path, config): # Load audio # Parameters # ---------- # file_path : str # Path to audio file # Returns # ------- # torch.Tensor audio = None if file_path.endswith(".wav") or file_path.endswith(".flac"): audio, sample_rate = torchaudio.load(file_path) if sample_rate != config.data.sr: print("requires resampling") audio = torchaudio.functional.resample(audio, sample_rate, config.data.sr) elif file_path.endswith(".mp3"): pass elif file_path.endswith(".mp4"): #_, audio, _ = read_video(file_path) pass return preprocess_audio(audio, config) def preprocess_audio(audio, config, step_size=1): # Preprocess audio # Parameters # ---------- # audio : torch.Tensor # Audio signal # config : OmegaConf # Configuration object # Returns # ------- # torch.Tensor : Normalized audio signal window_size = config.data.window_size sr = config.data.sr fps = config.data.fps if audio.shape[0] > 1: print("Warning: multi channel audio") audio = audio[0].unsqueeze(0) audio_len = audio.shape[1] step_size = step_size * (sr // fps) window_size = window_size * (sr // fps) audio = pad(audio, (window_size, window_size), "constant", 0) sliced_audio = [] for i in range(0, audio_len + window_size, step_size): audio_slice = audio[:, i : i + window_size] if audio_slice.shape[1] < window_size: audio_slice = pad( audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0 ) audio_slice = normalize(audio_slice, dim=1) sliced_audio.append(audio_slice) sliced_audio = torch.stack(sliced_audio).squeeze() return sliced_audio def infere(model, x, config, device="cpu", bs=8): print(x) model.eval() x = load_audio(x, config) # Inference (x is a stack of windows) frame_predictions = [] with torch.no_grad(): n_iter = x.shape[0] for i in range(0, n_iter, bs): input_batch = x[i: i + bs] input_batch = input_batch.to(device) output = softmax(model(input_batch), dim=1) frame_predictions.append(output.cpu().numpy()) frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0] return frame_predictions def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size): # Convert frame predictions to timestamps # Parameters # ---------- # frame_predictions : np.ndarray # Frame predictions # fps : int # Frames per second # Returns # ------- # np.ndarray : Timestamps frame_predictions = ( frame_predictions[ int(window_size / 2) : -int(window_size / 2), 0 ] # removes the padding, does not consider step size as of now .round() .astype(int) ) timestamps = [] for i, frame_prediction in enumerate(frame_predictions): if frame_prediction == 1: timestamps.append(i / fps) return timestamps