File size: 3,901 Bytes
b356ee6
18b406b
 
 
3d2ca89
 
2dcfa57
b356ee6
d0d0a63
2dcfa57
18b406b
 
3d2ca89
 
18b406b
 
3d2ca89
 
 
 
 
 
 
 
 
 
 
 
2dcfa57
 
 
3d2ca89
18b406b
 
3d2ca89
18b406b
d0d0a63
18b406b
3d2ca89
18b406b
 
 
 
 
 
 
d0d0a63
18b406b
3d2ca89
18b406b
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2ca89
d0d0a63
 
 
 
 
 
 
 
2dcfa57
d0d0a63
 
 
 
18b406b
3d2ca89
 
 
ae84c6e
 
 
3d2ca89
 
d0d0a63
 
 
 
 
3d2ca89
d0d0a63
3d2ca89
 
 
 
 
d0d0a63
3d2ca89
 
18b406b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import torch
import os
import base64
import torchaudio
import numpy as np
from audiocraft.models import MusicGen

genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
          "Lofi", "Chillpop","Country","R&G", "Folk","EDM", "Disco", "House", "Techno",]

@st.cache_resource()
def load_model(model_name):
    model = MusicGen.get_pretrained(model_name)
    return model

def generate_music_tensors(description, duration: int, batch_size=1, models=None):
    outputs = {}
    for model_name, model in models.items():
        model.set_generation_params(
            use_sampling=True,
            top_k=250,
            duration=duration
        )

        with st.spinner(f"Generating Music with {model_name}..."):
            output = model.generate(
                descriptions=description,
                progress=True,
                return_tokens=True
            )
            outputs[model_name] = output

    st.success("Music Generation Complete!")
    return outputs

def save_audio(samples: torch.Tensor, filename):
    sample_rate = 30000
    save_path = "audio_output"
    assert samples.dim() == 2 or samples.dim() == 3

    samples = samples.detach().cpu()
    if samples.dim() == 2:
        samples = samples[None, ...]

    for idx, audio in enumerate(samples):
        audio_path = os.path.join(save_path, f"{filename}_{idx}.wav")
        torchaudio.save(audio_path, audio, sample_rate)
    return audio_path

def get_binary_file_downloader_html(bin_file, file_label='File'):
    with open(bin_file, 'rb') as f:
        data = f.read()
    bin_str = base64.b64encode(data).decode()
    href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
    return href

st.set_page_config(
    page_icon= "musical_note",
    page_title= "Music Gen"
)

def main():
    st.title("🎧 AI Composer 🎧")

    st.subheader("Generate Music")
    st.write("Craft your perfect melody! Fill in the blanks below to create your music masterpiece:")

    bpm = st.number_input("Enter Speed in BPM", min_value=60)
    text_area = st.text_area('Example: 80s rock song with guitar and drums')
    selected_genre = st.selectbox("Select Genre", genres)
    time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
    
    mood = st.selectbox("Select Mood", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"])
    instrument = st.selectbox("Select Instrument", ["Piano", "Guitar", "Flute", "Violin", "Drums"])
    tempo = st.selectbox("Select Tempo", ["Slow", "Moderate", "Fast"])
    melody = st.text_input("Enter Melody or Chord Progression", "e.g., C D:min G:7 C, Twinkle Twinkle Little Star")

    models = {
        'Medium': load_model('facebook/musicgen-medium'),
        'Large': load_model('facebook/musicgen-large'),
        'Large': load_model('facebook/musicgen-melody'),
        'Large': load_model('facebook/musicgen-small'),
                # Add more models here as needed
    }

    if st.button('Let\'s Generate 🎶'):
        st.text('\n\n')
        st.subheader("Generated Music")

        description = f"{text_area} {selected_genre} {bpm} BPM {mood} {instrument} {tempo} {melody}"
        music_outputs = generate_music_tensors(description, time_slider, batch_size=2, models=models)

        for model_name, output in music_outputs.items():
            idx = 0  # Assuming you want to access the first audio file for each model
            audio_filepath = save_audio(output, f'audio_{model_name}_{idx}')
            audio_file = open(audio_filepath, 'rb')
            audio_bytes = audio_file.read()

            st.audio(audio_bytes, format='audio/wav', label=f'{model_name} Model')
            st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{model_name}_{idx}'), unsafe_allow_html=True)

if __name__ == "__main__":
    main()