Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torchaudio | |
import numpy as np | |
from omegaconf import OmegaConf | |
from huggingface_hub import hf_hub_download | |
from torch.nn.functional import pad, normalize, softmax | |
from manipulate_model.model import Model | |
def get_config_and_model(model_root="manipulate_model/demo-model/audio"): | |
config_path = os.path.join(model_root, "config.yaml") | |
config = OmegaConf.load(config_path) | |
if isinstance(config.model.encoder, str): | |
config.model.encoder = OmegaConf.load(config.model.encoder) | |
if isinstance(config.model.decoder, str): | |
config.model.decoder = OmegaConf.load(config.model.decoder) | |
model = Model(config=config) | |
model_file = hf_hub_download("arnabdas8901/manipulation_detection_transformer", filename= "weights.pt") | |
weights = torch.load(model_file, map_location=torch.device("cpu")) | |
model.load_state_dict(weights["model_state_dict"]) | |
print("### Model loaded from :", model_file) | |
return config, model | |
def load_audio(file_path, config): | |
# Load audio | |
# Parameters | |
# ---------- | |
# file_path : str | |
# Path to audio file | |
# Returns | |
# ------- | |
# torch.Tensor | |
audio = None | |
if file_path.endswith(".wav") or file_path.endswith(".flac"): | |
audio, sample_rate = torchaudio.load(file_path) | |
if sample_rate != config.data.sr: | |
print("requires resampling") | |
audio = torchaudio.functional.resample(audio, sample_rate, config.data.sr) | |
elif file_path.endswith(".mp3"): | |
pass | |
elif file_path.endswith(".mp4"): | |
#_, audio, _ = read_video(file_path) | |
pass | |
return preprocess_audio(audio, config) | |
def preprocess_audio(audio, config, step_size=1): | |
# Preprocess audio | |
# Parameters | |
# ---------- | |
# audio : torch.Tensor | |
# Audio signal | |
# config : OmegaConf | |
# Configuration object | |
# Returns | |
# ------- | |
# torch.Tensor : Normalized audio signal | |
window_size = config.data.window_size | |
sr = config.data.sr | |
fps = config.data.fps | |
if audio.shape[0] > 1: | |
print("Warning: multi channel audio") | |
audio = audio[0].unsqueeze(0) | |
audio_len = audio.shape[1] | |
step_size = step_size * (sr // fps) | |
window_size = window_size * (sr // fps) | |
audio = pad(audio, (window_size, window_size), "constant", 0) | |
sliced_audio = [] | |
for i in range(0, audio_len + window_size, step_size): | |
audio_slice = audio[:, i : i + window_size] | |
if audio_slice.shape[1] < window_size: | |
audio_slice = pad( | |
audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0 | |
) | |
audio_slice = normalize(audio_slice, dim=1) | |
sliced_audio.append(audio_slice) | |
sliced_audio = torch.stack(sliced_audio).squeeze() | |
return sliced_audio | |
def infere(model, x, config, device="cpu", bs=8): | |
print(x) | |
model.eval() | |
x = load_audio(x, config) | |
# Inference (x is a stack of windows) | |
frame_predictions = [] | |
with torch.no_grad(): | |
n_iter = x.shape[0] | |
for i in range(0, n_iter, bs): | |
input_batch = x[i: i + bs] | |
input_batch = input_batch.to(device) | |
output = softmax(model(input_batch), dim=1) | |
frame_predictions.append(output.cpu().numpy()) | |
frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0] | |
return frame_predictions | |
def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size): | |
# Convert frame predictions to timestamps | |
# Parameters | |
# ---------- | |
# frame_predictions : np.ndarray | |
# Frame predictions | |
# fps : int | |
# Frames per second | |
# Returns | |
# ------- | |
# np.ndarray : Timestamps | |
frame_predictions = ( | |
frame_predictions[ | |
int(window_size / 2) : -int(window_size / 2), 0 | |
] # removes the padding, does not consider step size as of now | |
.round() | |
.astype(int) | |
) | |
timestamps = [] | |
for i, frame_prediction in enumerate(frame_predictions): | |
if frame_prediction == 1: | |
timestamps.append(i / fps) | |
return timestamps | |