File size: 4,467 Bytes
1f34ab8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
import torchaudio
import torch
import librosa
import librosa.display
import matplotlib.pyplot as plt
from semanticodec import SemantiCodec
import numpy as np
import tempfile
import os
# Set default parameters
DEFAULT_TOKEN_RATE = 100
DEFAULT_SEMANTIC_VOCAB_SIZE = 16384
DEFAULT_SAMPLE_RATE = 16000
device = "cuda" if torch.cuda.is_available() else "cpu"
# Title and Description
st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec")
st.write("""
Upload your audio file, adjust the codec parameters, and compare the original and reconstructed audio.
SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates!
""")
# Sidebar: Parameters
st.sidebar.title("Codec Parameters")
token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2)
semantic_vocab_size = st.sidebar.selectbox(
"Semantic Vocabulary Size",
[4096, 8192, 16384, 32768],
index=2,
)
ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5)
guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1)
# Upload Audio File
uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])
# Helper function: Plot spectrogram
def plot_spectrogram(waveform, sample_rate, title):
plt.figure(figsize=(10, 4))
S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title(title)
plt.tight_layout()
st.pyplot(plt)
# Process Audio
if uploaded_file and st.button("Run SemantiCodec"):
with tempfile.TemporaryDirectory() as temp_dir:
# Save uploaded file
input_path = os.path.join(temp_dir, "input.wav")
with open(input_path, "wb") as f:
f.write(uploaded_file.read())
# Load audio
waveform, sample_rate = torchaudio.load(input_path)
# Check if resampling is needed
if sample_rate != DEFAULT_SAMPLE_RATE:
st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...")
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE)
waveform = resampler(waveform)
sample_rate = DEFAULT_SAMPLE_RATE # Update sample rate to 16kHz
# Convert to numpy for librosa compatibility
waveform = waveform[0].numpy()
# Plot Original Spectrogram (16kHz resampled)
st.write("Original Audio Spectrogram (Resampled to 16kHz):")
plot_spectrogram(waveform, sample_rate, "Original Audio Spectrogram (Resampled to 16kHz)")
# Initialize SemantiCodec
st.write("Initializing SemantiCodec...")
semanticodec = SemantiCodec(
token_rate=token_rate,
semantic_vocab_size=semantic_vocab_size,
ddim_sample_step=ddim_steps,
cfg_scale=guidance_scale,
)
semanticodec.device = device
semanticodec.encoder = semanticodec.encoder.to(device)
semanticodec.decoder = semanticodec.decoder.to(device)
# Encode and Decode
st.write("Encoding and Decoding Audio...")
tokens = semanticodec.encode(input_path)
reconstructed_waveform = semanticodec.decode(tokens)[0, 0]
# Save reconstructed audio
reconstructed_path = os.path.join(temp_dir, "reconstructed.wav")
torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate)
# Plot Reconstructed Spectrogram
st.write("Reconstructed Audio Spectrogram:")
plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram")
# Display latent code shape
st.write(f"Shape of Latent Code: {tokens.shape}")
# Audio Players
st.audio(input_path, format="audio/wav")
st.write("Original Audio")
st.audio(reconstructed_path, format="audio/wav")
st.write("Reconstructed Audio")
# Download Button for Reconstructed Audio
st.download_button(
"Download Reconstructed Audio",
data=open(reconstructed_path, "rb").read(),
file_name="reconstructed_audio.wav",
)
# Footer
st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")
|