Arnab Das
bug fix
b65af55
import os
import torch
import torchaudio
import numpy as np
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from torch.nn.functional import pad, normalize, softmax
from manipulate_model.model import Model
def get_config_and_model(model_root="manipulate_model/demo-model/audio"):
config_path = os.path.join(model_root, "config.yaml")
config = OmegaConf.load(config_path)
if isinstance(config.model.encoder, str):
config.model.encoder = OmegaConf.load(config.model.encoder)
if isinstance(config.model.decoder, str):
config.model.decoder = OmegaConf.load(config.model.decoder)
model = Model(config=config)
model_file = hf_hub_download("arnabdas8901/manipulation_detection_transformer", filename= "weights.pt")
weights = torch.load(model_file, map_location=torch.device("cpu"))
model.load_state_dict(weights["model_state_dict"])
print("### Model loaded from :", model_file)
return config, model
def load_audio(file_path, config):
# Load audio
# Parameters
# ----------
# file_path : str
# Path to audio file
# Returns
# -------
# torch.Tensor
audio = None
if file_path.endswith(".wav") or file_path.endswith(".flac"):
audio, sample_rate = torchaudio.load(file_path)
if sample_rate != config.data.sr:
print("requires resampling")
audio = torchaudio.functional.resample(audio, sample_rate, config.data.sr)
elif file_path.endswith(".mp3"):
pass
elif file_path.endswith(".mp4"):
#_, audio, _ = read_video(file_path)
pass
return preprocess_audio(audio, config)
def preprocess_audio(audio, config, step_size=1):
# Preprocess audio
# Parameters
# ----------
# audio : torch.Tensor
# Audio signal
# config : OmegaConf
# Configuration object
# Returns
# -------
# torch.Tensor : Normalized audio signal
window_size = config.data.window_size
sr = config.data.sr
fps = config.data.fps
if audio.shape[0] > 1:
print("Warning: multi channel audio")
audio = audio[0].unsqueeze(0)
audio_len = audio.shape[1]
step_size = step_size * (sr // fps)
window_size = window_size * (sr // fps)
audio = pad(audio, (window_size, window_size), "constant", 0)
sliced_audio = []
for i in range(0, audio_len + window_size, step_size):
audio_slice = audio[:, i : i + window_size]
if audio_slice.shape[1] < window_size:
audio_slice = pad(
audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0
)
audio_slice = normalize(audio_slice, dim=1)
sliced_audio.append(audio_slice)
sliced_audio = torch.stack(sliced_audio).squeeze()
return sliced_audio
def infere(model, x, config, device="cpu", bs=8):
print(x)
model.eval()
x = load_audio(x, config)
# Inference (x is a stack of windows)
frame_predictions = []
with torch.no_grad():
n_iter = x.shape[0]
for i in range(0, n_iter, bs):
input_batch = x[i: i + bs]
input_batch = input_batch.to(device)
output = softmax(model(input_batch), dim=1)
frame_predictions.append(output.cpu().numpy())
frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0]
return frame_predictions
def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size):
# Convert frame predictions to timestamps
# Parameters
# ----------
# frame_predictions : np.ndarray
# Frame predictions
# fps : int
# Frames per second
# Returns
# -------
# np.ndarray : Timestamps
frame_predictions = (
frame_predictions[
int(window_size / 2) : -int(window_size / 2), 0
] # removes the padding, does not consider step size as of now
.round()
.astype(int)
)
timestamps = []
for i, frame_prediction in enumerate(frame_predictions):
if frame_prediction == 1:
timestamps.append(i / fps)
return timestamps