File size: 4,467 Bytes
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import torchaudio
import torch
import librosa
import librosa.display
import matplotlib.pyplot as plt
from semanticodec import SemantiCodec
import numpy as np
import tempfile
import os

# Set default parameters
DEFAULT_TOKEN_RATE = 100
DEFAULT_SEMANTIC_VOCAB_SIZE = 16384
DEFAULT_SAMPLE_RATE = 16000
device = "cuda" if torch.cuda.is_available() else "cpu"

# Title and Description
st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec")
st.write("""
Upload your audio file, adjust the codec parameters, and compare the original and reconstructed audio. 
SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates!
""")

# Sidebar: Parameters
st.sidebar.title("Codec Parameters")
token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2)
semantic_vocab_size = st.sidebar.selectbox(
    "Semantic Vocabulary Size",
    [4096, 8192, 16384, 32768],
    index=2,
)
ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5)
guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1)

# Upload Audio File
uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])

# Helper function: Plot spectrogram
def plot_spectrogram(waveform, sample_rate, title):
    plt.figure(figsize=(10, 4))
    S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2)
    S_dB = librosa.power_to_db(S, ref=np.max)
    librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    st.pyplot(plt)

# Process Audio
if uploaded_file and st.button("Run SemantiCodec"):
    with tempfile.TemporaryDirectory() as temp_dir:
        # Save uploaded file
        input_path = os.path.join(temp_dir, "input.wav")
        with open(input_path, "wb") as f:
            f.write(uploaded_file.read())

        # Load audio
        waveform, sample_rate = torchaudio.load(input_path)
        
        # Check if resampling is needed
        if sample_rate != DEFAULT_SAMPLE_RATE:
            st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...")
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE)
            waveform = resampler(waveform)
            sample_rate = DEFAULT_SAMPLE_RATE  # Update sample rate to 16kHz
        
        # Convert to numpy for librosa compatibility
        waveform = waveform[0].numpy()

        # Plot Original Spectrogram (16kHz resampled)
        st.write("Original Audio Spectrogram (Resampled to 16kHz):")
        plot_spectrogram(waveform, sample_rate, "Original Audio Spectrogram (Resampled to 16kHz)")

        # Initialize SemantiCodec
        st.write("Initializing SemantiCodec...")
        semanticodec = SemantiCodec(
            token_rate=token_rate,
            semantic_vocab_size=semantic_vocab_size,
            ddim_sample_step=ddim_steps,
            cfg_scale=guidance_scale,
        )
        semanticodec.device = device
        semanticodec.encoder = semanticodec.encoder.to(device)
        semanticodec.decoder = semanticodec.decoder.to(device)

        # Encode and Decode
        st.write("Encoding and Decoding Audio...")
        tokens = semanticodec.encode(input_path)
        reconstructed_waveform = semanticodec.decode(tokens)[0, 0]

        # Save reconstructed audio
        reconstructed_path = os.path.join(temp_dir, "reconstructed.wav")
        torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate)

        # Plot Reconstructed Spectrogram
        st.write("Reconstructed Audio Spectrogram:")
        plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram")

        # Display latent code shape
        st.write(f"Shape of Latent Code: {tokens.shape}")

        # Audio Players
        st.audio(input_path, format="audio/wav")
        st.write("Original Audio")
        st.audio(reconstructed_path, format="audio/wav")
        st.write("Reconstructed Audio")
        
        # Download Button for Reconstructed Audio
        st.download_button(
            "Download Reconstructed Audio",
            data=open(reconstructed_path, "rb").read(),
            file_name="reconstructed_audio.wav",
        )


# Footer
st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")