OneLLM / data /imu_utils.py
csuhan's picture
Upload folder using huggingface_hub
8b54513
import string
import numpy as np
import matplotlib.animation as animation
from matplotlib import pyplot as plt
import json
from collections import defaultdict
from bisect import bisect_left
import os
import torch
import torchaudio
torchaudio.set_audio_backend("sox_io")
def load_json(json_path: str):
"""
Load a json file
"""
with open(json_path, "r", encoding="utf-8") as f_name:
data = json.load(f_name)
return data
def check_window_signal(info_t, w_s, w_e):
length = w_e - w_s
frame_offset = int(w_s * info_t.sample_rate)
num_frames = int(length * info_t.sample_rate)
if frame_offset + num_frames > int(info_t.num_frames):
return False
else:
return True
def index_narrations(ann_path):
narration_raw = load_json(ann_path)
narration_dict = defaultdict(list)
summary_dict = defaultdict(list)
avg_len = []
for v_id, narr in narration_raw.items():
narr_list = []
summ_list = []
if "narration_pass_1" in narr:
narr_list += narr["narration_pass_1"]["narrations"]
summ_list += narr["narration_pass_1"]["summaries"]
if "narration_pass_2" in narr:
narr_list += narr["narration_pass_2"]["narrations"]
summ_list += narr["narration_pass_2"]["summaries"]
if len(narr_list) > 0:
narration_dict[v_id] = [
(
float(n_t["timestamp_sec"]),
n_t["narration_text"],
n_t["annotation_uid"],
n_t["timestamp_frame"],
)
for n_t in narr_list
]
avg_len.append(len(narration_dict[v_id]))
else:
narration_dict[v_id] = []
if len(summ_list) > 0:
summary_dict[v_id] = [
(
float(s_t["start_sec"]),
float(s_t["end_sec"]),
s_t["summary_text"],
)
for s_t in summ_list
]
else:
summary_dict[v_id] = []
# print(f"Number of Videos with narration {len(narration_dict)}")
# print(f"Avg. narration length {np.mean(avg_len)}")
# print(f"Number of Videos with summaries {len(summary_dict)}")
return narration_dict, summary_dict
def get_signal_info(signal_fn: str):
return torchaudio.info(signal_fn)
def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
"""
Given a signal track return the frames between video_start_sec and video_end_sec
"""
info_t = get_signal_info(signal_fn)
length = video_end_sec - video_start_sec
aframes, _ = torchaudio.load(
signal_fn,
normalize=True,
frame_offset=int(video_start_sec * info_t.sample_rate),
num_frames=int(length * info_t.sample_rate),
)
return {"signal": aframes, "meta": info_t}
def tosec(value):
return value / 1000
def toms(value):
return value * 1000
def delta(first_num: float, second_num: float):
"""Compute the absolute value of the difference of two numbers"""
return abs(first_num - second_num)
def padIMU(signal, duration_sec):
"""
Pad the signal if necessary
"""
expected_elements = round(duration_sec) * 200
if signal.shape[0] > expected_elements:
signal = signal[:expected_elements, :]
elif signal.shape[0] < expected_elements:
padding = expected_elements - signal.shape[0]
padded_zeros = np.zeros((padding, 6))
signal = np.concatenate([signal, padded_zeros], 0)
# signal = signal[:expected_elements, :]
return signal
def resample(
signals: np.ndarray,
timestamps: np.ndarray,
original_sample_rate: int,
resample_rate: int,
):
"""
Resamples data to new sample rate
"""
signals = torch.as_tensor(signals)
timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
signals = torchaudio.functional.resample(
waveform=signals.data.T,
orig_freq=original_sample_rate,
new_freq=resample_rate,
).T.numpy()
nsamples = len(signals)
period = 1 / resample_rate
# timestamps are expected to be shape (N, 1)
initital_seconds = timestamps[0] / 1e3
ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
timestamps = (ntimes * 1e3).squeeze().numpy()
return signals, timestamps
def resampleIMU(signal, timestamps):
sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
# resample all to 200hz
if sampling_rate != 200:
signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
return signal, timestamps
def get_imu_frames(
imu_path,
uid: str,
video_start_sec: float,
video_end_sec: float,
):
"""
Given a IMU signal return the frames between video_start_sec and video_end_sec
"""
signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
signal = signal.transpose()
timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
return None
start_id = bisect_left(timestamps, toms(video_start_sec))
end_id = bisect_left(timestamps, toms(video_end_sec))
# make sure the retrieved window interval are correct by a max of 1 sec margin
if (
delta(video_start_sec, tosec(timestamps[start_id])) > 4
or delta(video_end_sec, tosec(timestamps[end_id])) > 4
):
return None
# get the window
if start_id == end_id:
start_id -= 1
end_id += 1
signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
if len(signal) < 10 or len(timestamps) < 10:
return None
# resample the signal at 200hz if necessary
signal, timestamps = resampleIMU(signal, timestamps)
# pad the signal if necessary
signal = padIMU(signal, video_end_sec - video_start_sec)
sample_dict = {
"timestamp": timestamps,
"signal": torch.tensor(signal.T),
"sampling_rate": 200,
}
return sample_dict
def display_animation(frames, title, save_path_gif):
fig, ax = plt.subplots()
frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
plt.title(title)
ani = animation.ArtistAnimation(fig, frames)
ani.save(save_path_gif, writer="imagemagick")
plt.close()
def display_animation_imu(frames, imu, title, save_path_gif):
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.set_title(title)
ax2.set_title("Acc.")
ax3.set_title("Gyro.")
frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
ani = animation.ArtistAnimation(fig, frames)
ax2.plot(imu[0].cpu().numpy(), color="red")
ax2.plot(imu[1].cpu().numpy(), color="blue")
ax2.plot(imu[2].cpu().numpy(), color="green")
ax3.plot(imu[3].cpu().numpy(), color="red")
ax3.plot(imu[4].cpu().numpy(), color="blue")
ax3.plot(imu[5].cpu().numpy(), color="green")
plt.tight_layout()
ani.save(save_path_gif, writer="imagemagick")
plt.close()
def filter_narration(narration_text: str) -> bool:
if "#c" in narration_text.lower():
return True
return False
def clean_narration_text(narration_text: str) -> str:
return (
narration_text.replace("#C C ", "")
.replace("#C", "")
.replace("#unsure", "something")
.strip()
.strip(string.punctuation)
.lower()[:128]
)