atsushieee's picture
Upload folder using huggingface_hub
29e2c21 verified
"""Module for handling audio input through Gradio interface."""
from typing import Callable
import numpy as np
from scipy import signal
from improvisation_lab.infrastructure.audio.audio_processor import \
AudioProcessor
class WebAudioProcessor(AudioProcessor):
"""Handle audio input from Gradio interface."""
def __init__(
self,
sample_rate: int,
callback: Callable[[np.ndarray], None] | None = None,
buffer_duration: float = 0.3,
):
"""Initialize GradioAudioInput.
Args:
sample_rate: Audio sample rate in Hz
callback: Optional callback function to process audio data
buffer_duration: Duration of audio buffer in seconds
"""
super().__init__(sample_rate, callback, buffer_duration)
def _resample_audio(
self, audio_data: np.ndarray, original_sr: int, target_sr: int
) -> np.ndarray:
"""Resample audio data to target sample rate.
In the case of Gradio,
the sample rate of the audio data may not match the target sample rate.
Args:
audio_data: numpy array of audio samples
original_sr: Original sample rate in Hz
target_sr: Target sample rate in Hz
Returns:
Resampled audio data with target sample rate
"""
number_of_samples = round(len(audio_data) * float(target_sr) / original_sr)
resampled_data = signal.resample(audio_data, number_of_samples)
return resampled_data
def _normalize_audio(self, audio_data: np.ndarray) -> np.ndarray:
"""Normalize audio data to range [-1, 1] by dividing by maximum absolute value.
Args:
audio_data: numpy array of audio samples
Returns:
Normalized audio data with values between -1 and 1
"""
if len(audio_data) == 0:
return audio_data
max_abs = np.max(np.abs(audio_data))
return audio_data if max_abs == 0 else audio_data / max_abs
def _remove_low_amplitude_noise(self, audio_data: np.ndarray) -> np.ndarray:
"""Remove low amplitude noise from audio data.
Applies a threshold to remove low amplitude signals that are likely noise.
Args:
audio_data: Audio data as numpy array
Returns:
Audio data with low amplitude noise removed
"""
# [TODO] Set appropriate threshold
threshold = 20.0
audio_data[np.abs(audio_data) < threshold] = 0
return audio_data
def process_audio(self, audio_input: tuple[int, np.ndarray]) -> None:
"""Process incoming audio data from Gradio.
Args:
audio_input: Tuple of (sample_rate, audio_data)
where audio_data is a (samples, channels) array
"""
if not self.is_recording:
return
input_sample_rate, audio_data = audio_input
if input_sample_rate != self.sample_rate:
audio_data = self._resample_audio(
audio_data, input_sample_rate, self.sample_rate
)
audio_data = self._remove_low_amplitude_noise(audio_data)
audio_data = self._normalize_audio(audio_data)
self._append_to_buffer(audio_data)
self._process_buffer()
def start_recording(self):
"""Start accepting audio input from Gradio."""
if self.is_recording:
raise RuntimeError("Recording is already in progress")
self.is_recording = True
def stop_recording(self):
"""Stop accepting audio input from Gradio."""
if not self.is_recording:
raise RuntimeError("Recording is not in progress")
self.is_recording = False
self._buffer = np.array([], dtype=np.float32)