sadafwalliyani's picture
Update app.py
ae84c6e verified
raw
history blame contribute delete
No virus
3.9 kB
import streamlit as st
import torch
import os
import base64
import torchaudio
import numpy as np
from audiocraft.models import MusicGen
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
"Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]
@st.cache_resource()
def load_model(model_name):
model = MusicGen.get_pretrained(model_name)
return model
def generate_music_tensors(description, duration: int, batch_size=1, models=None):
outputs = {}
for model_name, model in models.items():
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
with st.spinner(f"Generating Music with {model_name}..."):
output = model.generate(
descriptions=description,
progress=True,
return_tokens=True
)
outputs[model_name] = output
st.success("Music Generation Complete!")
return outputs
def save_audio(samples: torch.Tensor, filename):
sample_rate = 30000
save_path = "audio_output"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
return audio_path
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
st.title("🎧 AI Composer 🎧")
st.subheader("Generate Music")
st.write("Craft your perfect melody! Fill in the blanks below to create your music masterpiece:")
bpm = st.number_input("Enter Speed in BPM", min_value=60)
text_area = st.text_area('Example: 80s rock song with guitar and drums')
selected_genre = st.selectbox("Select Genre", genres)
time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
mood = st.selectbox("Select Mood", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"])
instrument = st.selectbox("Select Instrument", ["Piano", "Guitar", "Flute", "Violin", "Drums"])
tempo = st.selectbox("Select Tempo", ["Slow", "Moderate", "Fast"])
melody = st.text_input("Enter Melody or Chord Progression", "e.g., C D:min G:7 C, Twinkle Twinkle Little Star")
models = {
'Medium': load_model('facebook/musicgen-medium'),
'Large': load_model('facebook/musicgen-large'),
'Large': load_model('facebook/musicgen-melody'),
'Large': load_model('facebook/musicgen-small'),
# Add more models here as needed
}
if st.button('Let\'s Generate 🎢'):
st.text('\n\n')
st.subheader("Generated Music")
description = f"{text_area} {selected_genre} {bpm} BPM {mood} {instrument} {tempo} {melody}"
music_outputs = generate_music_tensors(description, time_slider, batch_size=2, models=models)
for model_name, output in music_outputs.items():
idx = 0 # Assuming you want to access the first audio file for each model
audio_filepath = save_audio(output, f'audio_{model_name}_{idx}')
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
st.audio(audio_bytes, format='audio/wav', label=f'{model_name} Model')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{model_name}_{idx}'), unsafe_allow_html=True)
if __name__ == "__main__":
main()