File size: 4,278 Bytes
d248698
6c780ba
d248698
 
 
 
2319c67
 
 
 
d248698
 
 
 
2319c67
d248698
 
 
 
 
2319c67
d248698
 
 
 
 
 
 
 
 
 
 
a22a294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d248698
 
a22a294
 
 
d248698
 
 
 
 
 
a22a294
d248698
 
 
 
2319c67
d248698
 
 
a22a294
d248698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22a294
 
 
d248698
 
 
 
 
a22a294
 
 
 
d248698
 
 
a22a294
d248698
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import huggingface_hub
import numpy as np
import pandas as pd
import os
import torch
from audiocraft.data.audio import audio_write
import audiocraft.models


# download models
huggingface_hub.hf_hub_download(
    repo_id='Cyan0731/MusiConGen',
    filename='compression_state_dict.bin',
    local_dir='./ckpt/musicongen'
)

huggingface_hub.hf_hub_download(
    repo_id='Cyan0731/MusiConGen',
    filename='state_dict.bin',
    local_dir='./ckpt/musicongen'
)

def print_directory_contents(path):
    for root, dirs, files in os.walk(path):
        level = root.replace(path, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{subindent}{f}")

def check_outputs_folder(folder_path):
    # Check if the folder exists
    if os.path.exists(folder_path) and os.path.isdir(folder_path):
        # Delete all contents inside the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)  # Remove file or link
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove directory
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
    else:
        print(f'The folder {folder_path} does not exist.')

def check_for_wav_in_outputs():
    # Define the path to the outputs folder
    outputs_folder = './example_1'
    
    # Check if the outputs folder exists
    if not os.path.exists(outputs_folder):
        return None
    
    # Check if there is a .mp4 file in the outputs folder
    mp4_files = [f for f in os.listdir(outputs_folder) if f.endswith('.wav')]
    
    # Return the path to the mp4 file if it exists
    if mp4_files:
        return os.path.join(outputs_folder, mp4_files[0])
    else:
        return None

def infer(text):

    # check if 'outputs' dir exists and empty it if necessary
    check_outputs_folder('./example_1')
    
    # set hparams
    output_dir = 'example_1' ### change this output directory
    
    
    duration = 30
    num_samples = 1
    bs = 1
    
    
    # load your model
    musicgen = audiocraft.models.MusicGen.get_pretrained('./ckpt/musicongen') ### change this path
    musicgen.set_generation_params(duration=duration, extend_stride=duration//2, top_k = 250)
    
    
    chords = ['C G A:min F']
    
    descriptions = ["A laid-back blues shuffle with a relaxed tempo, warm guitar tones, and a comfortable groove, perfect for a slow dance or a night in. Instruments: electric guitar, bass, drums."] * num_samples
    
    bpms = [120] * num_samples
    
    meters = [4] * num_samples
    
    wav = []
    for i in range(num_samples//bs):
      print(f"starting {i} batch...")
      temp = musicgen.generate_with_chords_and_beats(descriptions[i*bs:(i+1)*bs], 
                                                      chords[i*bs:(i+1)*bs],
                                                      bpms[i*bs:(i+1)*bs], 
                                                      meters[i*bs:(i+1)*bs]
                                                      )
      wav.extend(temp.cpu())
    
    # save and display generated audio
    for idx, one_wav in enumerate(wav):
      
      sav_path = os.path.join('./output_samples', output_dir, chords[idx] + "|" + descriptions[idx]).replace(" ", "_")
      audio_write(sav_path, one_wav.cpu(), musicgen.sample_rate, strategy='loudness', loudness_compressor=True)

    # Print the outputs directory contents
    print_directory_contents('./output_samples')
    wav_file_path = check_for_wav_in_outputs()
    print(wav_file_path)
    return wav_file_path

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("#MusiConGen")
        with gr.Row():
            with gr.Column():
                text_in = gr.Textbox()
                submit_btn = gr.Button("Submit")
            wav_out = gr.Audio(label="Wav Result")
    submit_btn.click(
        fn = infer,
        inputs = [text_in],
        outputs = [wav_out]
    )
demo.launch()