haoheliu's picture
Upload 2 files
1f34ab8 verified
raw
history blame
4.47 kB
import streamlit as st
import torchaudio
import torch
import librosa
import librosa.display
import matplotlib.pyplot as plt
from semanticodec import SemantiCodec
import numpy as np
import tempfile
import os
# Set default parameters
DEFAULT_TOKEN_RATE = 100
DEFAULT_SEMANTIC_VOCAB_SIZE = 16384
DEFAULT_SAMPLE_RATE = 16000
device = "cuda" if torch.cuda.is_available() else "cpu"
# Title and Description
st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec")
st.write("""
Upload your audio file, adjust the codec parameters, and compare the original and reconstructed audio.
SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates!
""")
# Sidebar: Parameters
st.sidebar.title("Codec Parameters")
token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2)
semantic_vocab_size = st.sidebar.selectbox(
"Semantic Vocabulary Size",
[4096, 8192, 16384, 32768],
index=2,
)
ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5)
guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1)
# Upload Audio File
uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])
# Helper function: Plot spectrogram
def plot_spectrogram(waveform, sample_rate, title):
plt.figure(figsize=(10, 4))
S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title(title)
plt.tight_layout()
st.pyplot(plt)
# Process Audio
if uploaded_file and st.button("Run SemantiCodec"):
with tempfile.TemporaryDirectory() as temp_dir:
# Save uploaded file
input_path = os.path.join(temp_dir, "input.wav")
with open(input_path, "wb") as f:
f.write(uploaded_file.read())
# Load audio
waveform, sample_rate = torchaudio.load(input_path)
# Check if resampling is needed
if sample_rate != DEFAULT_SAMPLE_RATE:
st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...")
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE)
waveform = resampler(waveform)
sample_rate = DEFAULT_SAMPLE_RATE # Update sample rate to 16kHz
# Convert to numpy for librosa compatibility
waveform = waveform[0].numpy()
# Plot Original Spectrogram (16kHz resampled)
st.write("Original Audio Spectrogram (Resampled to 16kHz):")
plot_spectrogram(waveform, sample_rate, "Original Audio Spectrogram (Resampled to 16kHz)")
# Initialize SemantiCodec
st.write("Initializing SemantiCodec...")
semanticodec = SemantiCodec(
token_rate=token_rate,
semantic_vocab_size=semantic_vocab_size,
ddim_sample_step=ddim_steps,
cfg_scale=guidance_scale,
)
semanticodec.device = device
semanticodec.encoder = semanticodec.encoder.to(device)
semanticodec.decoder = semanticodec.decoder.to(device)
# Encode and Decode
st.write("Encoding and Decoding Audio...")
tokens = semanticodec.encode(input_path)
reconstructed_waveform = semanticodec.decode(tokens)[0, 0]
# Save reconstructed audio
reconstructed_path = os.path.join(temp_dir, "reconstructed.wav")
torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate)
# Plot Reconstructed Spectrogram
st.write("Reconstructed Audio Spectrogram:")
plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram")
# Display latent code shape
st.write(f"Shape of Latent Code: {tokens.shape}")
# Audio Players
st.audio(input_path, format="audio/wav")
st.write("Original Audio")
st.audio(reconstructed_path, format="audio/wav")
st.write("Reconstructed Audio")
# Download Button for Reconstructed Audio
st.download_button(
"Download Reconstructed Audio",
data=open(reconstructed_path, "rb").read(),
file_name="reconstructed_audio.wav",
)
# Footer
st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")