import streamlit as st from piano_transcription_inference import PianoTranscription, sample_rate, load_audio_from_memory from io import BytesIO import base64 st.title('Transcribe piano') audiofile = st.file_uploader('Upload audio file', type=['.wav'], accept_multiple_files=False) my_bar = None def print_progress(current, total): my_bar.progress(current / total, text=f'Transcribing ({current + 1} / {total + 1} segments)...') if audiofile is not None: audio_bytes = audiofile.read() st.text('Uploaded file') st.audio(audio_bytes, format='audio/wav') # (audio, _) = load_audio(audio_path, sr=sample_rate, mono=True) with st.spinner('Resampling...'): (audio, _) = load_audio_from_memory(audio_bytes, sr=sample_rate, mono=True) st.success('Resampling complete.') my_bar = st.progress(0, text='Transcribing...') transcriptor = PianoTranscription(device='cpu', checkpoint_path='CRNN_note_F1=0.9677_pedal_F1=0.9186.pth') # device: 'cuda' | 'cpu' buf = BytesIO() transcribed_dict = transcriptor.transcribe(audio, None, print_progress, buf) # st.download_button('Download MIDI', buf.getvalue(), f'transcribed_{audiofile.name}.mid', 'audio/midi') filename = f'transcribed_{audiofile.name}.mid' b64 = base64.b64encode(buf.getvalue()).decode() st.markdown(f'Download MIDI', unsafe_allow_html=True) st.balloons()