import streamlit as st import torchaudio import torch import matplotlib.pyplot as plt import soundfile as sf from audiosr import build_model, super_resolution, save_wave import tempfile import numpy as np import os # Set device (MPS for Mac, CUDA for other GPUs, otherwise CPU) device = "cuda" if torch.cuda.is_available() else "cpu" # Title and Description st.title("AudioSR: Versatile Audio Super-Resolution") st.write(""" Upload your low-resolution audio files, and AudioSR will enhance them to high fidelity! Supports all types of audio (music, speech, sound effects, etc.) with arbitrary sampling rates. Only the first 10 seconds of the audio will be processed. """) # Upload audio file uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"]) # Sidebar: Model Parameters st.sidebar.title("Model Parameters") model_name = st.sidebar.selectbox("Select Model", ["basic", "speech"], index=0) ddim_steps = st.sidebar.slider("DDIM Steps", min_value=10, max_value=100, value=50) guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10.0, value=3.5) random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1) latent_t_per_second = 12.8 # Helper function: Plot linear STFT spectrogram # Helper function: Plot linear STFT spectrogram def plot_spectrogram(waveform, sample_rate, title): # Ensure waveform is a 1D tensor if len(waveform.shape) > 1: waveform = waveform.squeeze() # Remove extra dimensions plt.figure(figsize=(10, 4)) spectrogram = torch.stft( torch.tensor(waveform), n_fft=2048, hop_length=512, win_length=2048, return_complex=True, ).abs().numpy() plt.imshow( np.log1p(spectrogram), aspect="auto", origin="lower", extent=[0, len(waveform) / sample_rate, 0, sample_rate / 2], cmap="viridis", ) plt.colorbar(format="%+2.0f dB") plt.title(title) plt.xlabel("Time (s)") plt.ylabel("Frequency (Hz)") plt.tight_layout() st.pyplot(plt) # Process Button if uploaded_file and st.button("Enhance Audio"): st.write("Processing audio...") with tempfile.TemporaryDirectory() as temp_dir: input_path = os.path.join(temp_dir, "input.wav") truncated_path = os.path.join(temp_dir, "truncated.wav") output_path = os.path.join(temp_dir, "output.wav") # Save uploaded file locally with open(input_path, "wb") as f: f.write(uploaded_file.read()) # Load and truncate the first 10 seconds waveform, sample_rate = torchaudio.load(input_path) max_samples = sample_rate * 10 # First 10 seconds if waveform.size(1) > max_samples: waveform = waveform[:, :max_samples] st.write("Truncated audio to the first 10 seconds.") sf.write(truncated_path, waveform[0].numpy(), sample_rate) # Plot truncated spectrogram st.write("Truncated Input Audio Spectrogram (First 10 seconds):") plot_spectrogram(waveform[0].numpy(), sample_rate, title="Truncated Input Audio Spectrogram") # Build and load the model audiosr = build_model(model_name=model_name, device=device) # Perform super-resolution waveform_sr = super_resolution( audiosr, truncated_path, seed=random_seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=latent_t_per_second, ) # Save enhanced audio output_waveform = waveform_sr save_wave(torch.tensor(output_waveform), inputpath=truncated_path, savepath=temp_dir, name="output", samplerate=48000) # Plot enhanced spectrogram st.write("Enhanced Audio Spectrogram:") plot_spectrogram(output_waveform, 48000, title="Enhanced Audio Spectrogram") # Display audio players and download link st.audio(truncated_path, format="audio/wav") st.write("Truncated Original Audio (First 10 seconds):") st.audio(output_path, format="audio/wav") st.write("Enhanced Audio:") st.download_button("Download Enhanced Audio", data=open(output_path, "rb").read(), file_name="enhanced_audio.wav") # Footer st.write("Built with [Streamlit](https://streamlit.io) and [AudioSR](https://audioldm.github.io/audiosr)")