Spaces:
Running
Running
# Import necessary libraries | |
from transformers import MusicgenForConditionalGeneration | |
from transformers import AutoProcessor | |
import scipy | |
import streamlit as st | |
# Set a random seed for reproducibility | |
import torch | |
torch.manual_seed(2912) | |
# Configure the Streamlit app with a custom title and icon | |
st.set_page_config( | |
page_title="Plant Orchestra", | |
page_icon="🎵" | |
) | |
# Initialize the model for generating music | |
def initialise_model(): | |
""" | |
Initialize the model for generating music using Hugging Face Transformers. | |
This function loads a pre-trained processor and model from the "facebook/musicgen-small" | |
checkpoint. It is wrapped in a Streamlit cache, ensuring efficient resource management | |
for repeated usage of the model within the application. | |
Returns: | |
Tuple[Optional[AutoProcessor], Optional[MusicgenForConditionalGeneration]]: A tuple containing | |
the processor and model if initialization is successful. If an error occurs during | |
initialization, the function returns None for both elements of the tuple. | |
Example: | |
processor, model = initialise_model() | |
if processor is not None and model is not None: | |
# Model is successfully initialized, and you can use it for music generation. | |
pass | |
else: | |
# Handle initialization error. | |
pass | |
""" | |
try: | |
# Load the processor and model from the pretrained model checkpoint | |
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
return processor, model | |
except Exception as e: | |
# Handle any errors that may occur during model initialization | |
st.error(f"Error initializing the model: {str(e)}") | |
return None, None | |
# Call the 'initialise_model' function to set up the processor and model | |
processor, model = initialise_model() | |
# Generate audio with given prompt | |
def generate_audio(processor, model, prompt): | |
""" | |
Generate audio based on a given prompt using a pre-trained model. | |
This function takes a processor and model, which are responsible for processing | |
the input and generating audio, and a user-provided prompt to create a musical | |
composition. | |
Args: | |
processor (AutoProcessor): A pre-trained processor for text-to-sequence tasks. | |
model (MusicgenForConditionalGeneration): A pre-trained model for generating music. | |
prompt (str): The user-provided text prompt to guide music generation. | |
Returns: | |
Union[None, torch.Tensor]: If the audio generation is successful, it returns a | |
tensor containing the generated audio. If an error occurs during generation, | |
the function returns None. | |
Example: | |
processor, model = initialise_model() | |
if processor is not None and model is not None: | |
audio_data = generate_audio(processor, model, "Sunflower, temperature: 32.5 C, UV light intensity: 50%, Soil water level: 3cm/h") | |
if audio_data is not None: | |
# Use the generated audio for further processing or display. | |
pass | |
else: | |
# Handle audio generation error. | |
pass | |
else: | |
# Handle model initialization error. | |
pass | |
""" | |
if processor is not None and model is not None: | |
try: | |
# Prepare the input for the model by tokenizing and converting the text to tensors. | |
inputs = processor( | |
text=[prompt], | |
padding=True, | |
return_tensors="pt", | |
) | |
# Generate audio based on the processed input using the pre-trained model. | |
audio_values = model.generate( | |
**inputs.to("cpu"), # Ensure computation on the CPU | |
do_sample=True, # Enable sampling for creative output | |
guidance_scale=3, # Adjust the guidance scale (you can customize) | |
max_new_tokens=256, # Limit the length of generated audio (you can customize) | |
) | |
return audio_values | |
except Exception as e: | |
# Handle any exceptions that may occur during audio generation. | |
st.error(f"Error generating audio: {str(e)}") | |
return None | |
# Save audio file with scipy | |
def save_file(model, audio_values, filename): | |
""" | |
Save audio data as a WAV file using the SciPy library. | |
Args: | |
model: The pre-trained model used for audio generation. | |
audio_values (torch.Tensor): The tensor containing the generated audio data. | |
filename (str): The desired filename for the saved WAV file. | |
Returns: | |
None | |
Example: | |
save_file(model, audio_data, "generated_audio.wav") | |
""" | |
# Get the sampling rate from the model's configuration | |
sampling_rate = model.config.audio_encoder.sampling_rate | |
# Write the audio data to a WAV file with the specified filename | |
scipy.io.wavfile.write(filename, rate=sampling_rate, data=audio_values[0, 0].cpu().numpy()) | |
# Main Code | |
st.title("Plant Orchestra 🌿") | |
st.markdown("Generate music based on your own terrarium plants.") | |
prompt = st.text_input(label='User input:', value='baby tears') | |
if st.button("Generate Music"): | |
if processor is not None and model is not None: | |
with st.spinner("Generating audio..."): | |
results = generate_audio(processor, model, prompt) | |
if results is not None: | |
sampling_rate = model.config.audio_encoder.sampling_rate | |
st.write("Listen to the generated music:") | |
st.audio(sample_rate=sampling_rate, data=results[0, 0].cpu().numpy(), format="audio/wav") | |
# Sidebar: How-to-use and Samples | |
st.sidebar.header("How to Use:") | |
st.sidebar.write("1. Enter a plant and condition (optional) in the text input. E.g. moss, 30C") | |
st.sidebar.write("2. Click the 'Generate Music' button to create music based on the provided input.") | |
st.sidebar.write("3. You can listen to the generated music and download it.") | |
st.sidebar.write() | |
st.sidebar.header('Samples 🎵') | |
st.sidebar.write('Holland moss') | |
st.sidebar.audio('sound/holland_moss.wav') | |
st.sidebar.write('Nerve plant') | |
st.sidebar.audio('sound/nerve_plant.wav') | |
st.sidebar.write('Artillery plant') | |
st.sidebar.audio('sound/artillery_plant.wav') | |
st.sidebar.write('Malayan moss') | |
st.sidebar.audio('sound/malayan_moss.wav') | |
st.sidebar.write('Pilea') | |
st.sidebar.audio('sound/pilea.wav') | |
st.sidebar.write('Hydrocotyle tripartita') | |
st.sidebar.audio('sound/hydrocotyle_tripartita.wav') | |
st.sidebar.write('Oak leaf fig') | |
st.sidebar.audio('sound/oak_leaf_fig.wav') | |
# Footer | |
st.markdown('##') | |
st.markdown("---") | |
st.markdown("Created with ❤️ by HS2912 W4 Group 2") | |