AIMusicComposer / app.py
tspsram's picture
Rename main.py to app.py
d92a9f1 verified
raw
history blame
4.28 kB
import streamlit as st
import torch
import torchaudio
import os
import numpy as np
import base64
from audiocraft.models import MusicGen
# Before
batch_size = 64
# After
batch_size = 32
torch.cuda.empty_cache()
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
def generate_music_tensors(description, duration: int):
model = load_model()
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
with st.spinner("Generating Music..."):
output = model.generate(
descriptions=description,
progress=True,
return_tokens=True
)
st.success("Music Generation Complete!")
return output
def save_audio(samples: torch.Tensor):
print("Samples (inside function): ", samples)
sample_rate = 30000
save_path = "audio_output/"
sample= samples[0]
assert sample.dim() == 2 or sample.dim() == 3
sample = sample.detach().cpu()
if sample.dim() == 2:
sample = sample[None, ...]
for idx, audio in enumerate(sample):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
st.title("🎧 AI Composer Medium-Model 🎧")
st.subheader("Craft your perfect melody!")
bpm = st.number_input("Enter Speed in BPM", min_value=2)
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
# Dropdown for genres
selected_genre = st.selectbox("Select Genre", genres)
st.subheader("2. Select time duration (In Seconds)")
time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
melody = st.text_input("Enter Melody or Chord Progression (Optional) e.g: C D:min G:7 C, Twinkle Twinkle Little Star", " ")
if st.button('Let\'s Generate 🎢'):
st.text('\n\n')
st.subheader("Generated Music")
# Generate audio
description = text_area # Initialize description with text_area
if selected_genre:
description += f" {selected_genre}"
st.empty() # Hide the selected_genre selectbox after selecting one option
if bpm:
description += f" {bpm} BPM"
if mood:
description += f" {mood}"
st.empty() # Hide the mood selectbox after selecting one option
if instrument:
description += f" {instrument}"
st.empty() # Hide the instrument selectbox after selecting one option
if tempo:
description += f" {tempo}"
st.empty() # Hide the tempo selectbox after selecting one option
if melody:
description += f" {melody}"
# Clear CUDA memory cache before generating music
torch.cuda.empty_cache()
music_tensors = generate_music_tensors(description, time_slider)
# Only play the full audio for index 0
# idx = 0
# music_tensor = music_tensors[idx]
# music_tensor = 1
save_audio(music_tensors)
audio_filepath = f'audio_output/audio_0.wav'
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
# Play the full audio
st.audio(audio_bytes, format='audio/wav')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio'), unsafe_allow_html=True)
if __name__ == "__main__":
main()