musicgen-small / app.py
echons's picture
Updated layout and added code comments
3312835
raw
history blame
6.69 kB
# Import necessary libraries
from transformers import MusicgenForConditionalGeneration
from transformers import AutoProcessor
import scipy
import streamlit as st
# Set a random seed for reproducibility
import torch
torch.manual_seed(2912)
# Configure the Streamlit app with a custom title and icon
st.set_page_config(
page_title="Plant Orchestra",
page_icon="🎵"
)
# Initialize the model for generating music
@st.cache_resource
def initialise_model():
"""
Initialize the model for generating music using Hugging Face Transformers.
This function loads a pre-trained processor and model from the "facebook/musicgen-small"
checkpoint. It is wrapped in a Streamlit cache, ensuring efficient resource management
for repeated usage of the model within the application.
Returns:
Tuple[Optional[AutoProcessor], Optional[MusicgenForConditionalGeneration]]: A tuple containing
the processor and model if initialization is successful. If an error occurs during
initialization, the function returns None for both elements of the tuple.
Example:
processor, model = initialise_model()
if processor is not None and model is not None:
# Model is successfully initialized, and you can use it for music generation.
pass
else:
# Handle initialization error.
pass
"""
try:
# Load the processor and model from the pretrained model checkpoint
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
return processor, model
except Exception as e:
# Handle any errors that may occur during model initialization
st.error(f"Error initializing the model: {str(e)}")
return None, None
# Call the 'initialise_model' function to set up the processor and model
processor, model = initialise_model()
# Generate audio with given prompt
def generate_audio(processor, model, prompt):
"""
Generate audio based on a given prompt using a pre-trained model.
This function takes a processor and model, which are responsible for processing
the input and generating audio, and a user-provided prompt to create a musical
composition.
Args:
processor (AutoProcessor): A pre-trained processor for text-to-sequence tasks.
model (MusicgenForConditionalGeneration): A pre-trained model for generating music.
prompt (str): The user-provided text prompt to guide music generation.
Returns:
Union[None, torch.Tensor]: If the audio generation is successful, it returns a
tensor containing the generated audio. If an error occurs during generation,
the function returns None.
Example:
processor, model = initialise_model()
if processor is not None and model is not None:
audio_data = generate_audio(processor, model, "Sunflower, temperature: 32.5 C, UV light intensity: 50%, Soil water level: 3cm/h")
if audio_data is not None:
# Use the generated audio for further processing or display.
pass
else:
# Handle audio generation error.
pass
else:
# Handle model initialization error.
pass
"""
if processor is not None and model is not None:
try:
# Prepare the input for the model by tokenizing and converting the text to tensors.
inputs = processor(
text=[prompt],
padding=True,
return_tensors="pt",
)
# Generate audio based on the processed input using the pre-trained model.
audio_values = model.generate(
**inputs.to("cpu"), # Ensure computation on the CPU
do_sample=True, # Enable sampling for creative output
guidance_scale=3, # Adjust the guidance scale (you can customize)
max_new_tokens=256, # Limit the length of generated audio (you can customize)
)
return audio_values
except Exception as e:
# Handle any exceptions that may occur during audio generation.
st.error(f"Error generating audio: {str(e)}")
return None
# Save audio file with scipy
def save_file(model, audio_values, filename):
"""
Save audio data as a WAV file using the SciPy library.
Args:
model: The pre-trained model used for audio generation.
audio_values (torch.Tensor): The tensor containing the generated audio data.
filename (str): The desired filename for the saved WAV file.
Returns:
None
Example:
save_file(model, audio_data, "generated_audio.wav")
"""
# Get the sampling rate from the model's configuration
sampling_rate = model.config.audio_encoder.sampling_rate
# Write the audio data to a WAV file with the specified filename
scipy.io.wavfile.write(filename, rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
# Main Code
st.title("Plant Orchestra 🌿")
st.markdown("Generate music based on your own terrarium plants.")
prompt = st.text_input(label='User input:', value='baby tears')
if st.button("Generate Music"):
if processor is not None and model is not None:
with st.spinner("Generating audio..."):
results = generate_audio(processor, model, prompt)
if results is not None:
sampling_rate = model.config.audio_encoder.sampling_rate
st.write("Listen to the generated music:")
st.audio(sample_rate=sampling_rate, data=results[0, 0].cpu().numpy(), format="audio/wav")
# Sidebar: How-to-use and Samples
st.sidebar.header("How to Use:")
st.sidebar.write("1. Enter a plant and condition (optional) in the text input. E.g. moss, 30C")
st.sidebar.write("2. Click the 'Generate Music' button to create music based on the provided input.")
st.sidebar.write("3. You can listen to the generated music and download it.")
st.sidebar.write()
st.sidebar.header('Samples 🎵')
st.sidebar.write('Holland moss')
st.sidebar.audio('sound/holland_moss.wav')
st.sidebar.write('Nerve plant')
st.sidebar.audio('sound/nerve_plant.wav')
st.sidebar.write('Artillery plant')
st.sidebar.audio('sound/artillery_plant.wav')
st.sidebar.write('Malayan moss')
st.sidebar.audio('sound/malayan_moss.wav')
st.sidebar.write('Pilea')
st.sidebar.audio('sound/pilea.wav')
st.sidebar.write('Hydrocotyle tripartita')
st.sidebar.audio('sound/hydrocotyle_tripartita.wav')
st.sidebar.write('Oak leaf fig')
st.sidebar.audio('sound/oak_leaf_fig.wav')
# Footer
st.markdown('##')
st.markdown("---")
st.markdown("Created with ❤️ by HS2912 W4 Group 2")