Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from audiocraft.models import MAGNeT, MusicGen, AudioGen | |
# from gradio_components.image import generate_caption, improve_prompt | |
from gradio_components.image import generate_caption_gpt4 | |
from gradio_components.prediction import predict, transcribe | |
import re | |
import argparse | |
from gradio_components.model_cards import TEXT_TO_MIDI_MODELS, TEXT_TO_SOUND_MODELS, MELODY_CONTINUATION_MODELS, TEXT_TO_MUSIC_MODELS, MODEL_CARDS, MELODY_CONDITIONED_MODELS | |
import ast | |
import json | |
theme = gr.themes.Glass( | |
primary_hue="fuchsia", | |
secondary_hue="indigo", | |
neutral_hue="slate", | |
font=[ | |
gr.themes.GoogleFont("Source Sans Pro"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
], | |
).set( | |
body_background_fill_dark="*background_fill_primary", | |
embed_radius="*table_radius", | |
background_fill_primary="*neutral_50", | |
background_fill_primary_dark="*neutral_950", | |
background_fill_secondary_dark="*neutral_900", | |
border_color_accent="*neutral_600", | |
border_color_accent_subdued="*color_accent", | |
border_color_primary_dark="*neutral_700", | |
block_background_fill="*background_fill_primary", | |
block_background_fill_dark="*neutral_800", | |
block_border_width="1px", | |
block_label_background_fill="*background_fill_primary", | |
block_label_background_fill_dark="*background_fill_secondary", | |
block_label_text_color="*neutral_500", | |
block_label_text_size="*text_sm", | |
block_label_text_weight="400", | |
block_shadow="none", | |
block_shadow_dark="none", | |
block_title_text_color="*neutral_500", | |
block_title_text_weight="400", | |
panel_border_width="0", | |
panel_border_width_dark="0", | |
checkbox_background_color_dark="*neutral_800", | |
checkbox_border_width="*input_border_width", | |
checkbox_label_border_width="*input_border_width", | |
input_background_fill="*neutral_100", | |
input_background_fill_dark="*neutral_700", | |
input_border_color_focus_dark="*neutral_700", | |
input_border_width="0px", | |
input_border_width_dark="0px", | |
slider_color="#2563eb", | |
slider_color_dark="#2563eb", | |
table_even_background_fill_dark="*neutral_950", | |
table_odd_background_fill_dark="*neutral_900", | |
button_border_width="*input_border_width", | |
button_shadow_active="none", | |
button_primary_background_fill="*primary_200", | |
button_primary_background_fill_dark="*primary_700", | |
button_primary_background_fill_hover="*button_primary_background_fill", | |
button_primary_background_fill_hover_dark="*button_primary_background_fill", | |
button_secondary_background_fill="*neutral_200", | |
button_secondary_background_fill_dark="*neutral_600", | |
button_secondary_background_fill_hover="*button_secondary_background_fill", | |
button_secondary_background_fill_hover_dark="*button_secondary_background_fill", | |
button_cancel_background_fill="*button_secondary_background_fill", | |
button_cancel_background_fill_dark="*button_secondary_background_fill", | |
button_cancel_background_fill_hover="*button_cancel_background_fill", | |
button_cancel_background_fill_hover_dark="*button_cancel_background_fill", | |
) | |
def generate_prompt(prompt, style): | |
prompt = ','.join([prompt]+style) | |
return prompt | |
def UI(share=False): | |
with gr.Blocks() as demo: | |
with gr.Tab("Generate Music by text"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
model_path = gr.Dropdown( | |
choices=TEXT_TO_MUSIC_MODELS, | |
label="Select the model", | |
value="facebook/musicgen-large", | |
) | |
with gr.Row(): | |
text_prompt = gr.Textbox( | |
label="Let's make a song about ...", | |
value="First day learning music generation in Standford university", | |
interactive=True, | |
visible=True, | |
) | |
num_outputs = gr.Number( | |
label="Number of outputs", | |
value=1, | |
minimum=1, | |
maximum=10, | |
interactive=True, | |
) | |
with gr.Row(): | |
style = gr.CheckboxGroup( | |
["Jazz", "Classical Music", "Hip Hop", "Ragga Jungle", "Dark Jazz", "Soul", "Blues", "80s Rock N Roll"], | |
value=None, | |
label="music genre", | |
interactive=True, | |
) | |
def update_prompt(style): | |
return generate_prompt(text_prompt.value, style) | |
config_output_textbox = gr.Textbox(label="Model Configs", visible=False) | |
def show_config_options(model_path): | |
print(model_path) | |
with gr.Accordion("Model Generation Configs"): | |
if "magnet" in model_path: | |
with gr.Row(): | |
top_k = gr.Number(label="Top-k", value=300, interactive=True) | |
top_p = gr.Number(label="Top-p", value=0, interactive=True) | |
temperature = gr.Number( | |
label="Temperature", value=1.0, interactive=True | |
) | |
span_arrangement = gr.Radio(["nonoverlap", "stride1"], value='nonoverlap', label="span arrangment", info=" Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') ") | |
def return_model_configs(top_k, top_p, temperature, span_arrangement): | |
return {"top_k": top_k, "top_p": top_p, "temperature": temperature, "span_arrangement": span_arrangement} | |
else: | |
with gr.Row(): | |
duration = gr.Slider( | |
minimum=10, | |
maximum=30, | |
value=30, | |
label="Duration", | |
interactive=True, | |
) | |
use_sampling = gr.Checkbox(label="Use Sampling", interactive=True, value=True) | |
top_k = gr.Number(label="Top-k", value=300, interactive=True) | |
top_p = gr.Number(label="Top-p", value=0, interactive=True) | |
temperature = gr.Number( | |
label="Temperature", value=1.0, interactive=True | |
) | |
def return_model_configs(duration, use_sampling, top_k, top_p, temperature): | |
return {"duration": duration, "use_sampling": use_sampling, "top_k": top_k, "top_p": top_p, "temperature": temperature} | |
with gr.Column(): | |
with gr.Row(): | |
melody = gr.Audio(sources=["upload"], type="numpy", label="File", | |
interactive=True, elem_id="melody-input", visible=False) | |
submit = gr.Button("Generate Music") | |
result_text = gr.Textbox(label="Generated Music (text)", type="text", interactive=False) | |
print(result_text) | |
output_audios = [] | |
def show_output_audio(tmp_paths): | |
if tmp_paths: | |
tmp_paths = ast.literal_eval(tmp_paths) | |
print(tmp_paths) | |
for i in range(len(tmp_paths)): | |
tmp_path = tmp_paths[i] | |
_audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False, visible=True) | |
output_audios.append(_audio) | |
submit.click( | |
fn=predict, | |
inputs=[model_path, config_output_textbox, text_prompt, melody, num_outputs], | |
outputs=result_text, | |
queue=True | |
) | |
with gr.Tab("Generate Music by melody"): | |
with gr.Column(): | |
with gr.Row(): | |
radio_melody_condition = gr.Radio(["Muisc Continuation", "Music Conditioning"], value=None, label="Select the condition") | |
model_path2 = gr.Dropdown(label="model") | |
def model_selection(radio_melody_condition): | |
if radio_melody_condition == "Muisc Continuation": | |
model_path2 = gr.Dropdown( | |
choices=MELODY_CONTINUATION_MODELS, | |
label="Select the model", | |
value="facebook/musicgen-large", | |
interactive=True, | |
visible=True | |
) | |
elif radio_melody_condition == "Music Conditioning": | |
model_path2 = gr.Dropdown( | |
choices=MELODY_CONDITIONED_MODELS, | |
label="Select the model", | |
value="facebook/musicgen-melody-large", | |
interactive=True, | |
visible=True | |
) | |
else: | |
model_path2 = gr.Dropdown( | |
choices=TEXT_TO_SOUND_MODELS, | |
label="Select the model", | |
value="facebook/musicgen-large", | |
interactive=True, | |
visible=False | |
) | |
return model_path2 | |
upload_melody = gr.Audio(sources=["upload", "microphone"], type="filepath", label="File") | |
prompt_text2 = gr.Textbox( | |
label="Let's make a song about ...", | |
value=None, | |
interactive=True, | |
visible=True, | |
) | |
with gr.Row(): | |
config_output_textbox2 = gr.Textbox( | |
label="Model Configs", | |
visible=True) | |
with gr.Row(): | |
duration2 = gr.Number(10, label="Duration", interactive=True) | |
num_outputs2 = gr.Number(1, label="Number of outputs", interactive=True) | |
def return_model_configs2(duration): | |
return {"duration": duration, "use_sampling": True, "top_k": 300, "top_p": 0, "temperature": 1} | |
submit2 = gr.Button("Generate Music") | |
result_text2 = gr.Textbox(label="Generated Music (melody)", type="text", interactive=False, visible=True) | |
submit2.click( | |
fn=predict, | |
inputs=[model_path2, config_output_textbox2, prompt_text2, upload_melody, num_outputs2], | |
outputs=result_text2, | |
queue=True | |
) | |
def show_output_audio(tmp_paths): | |
if tmp_paths: | |
tmp_paths = ast.literal_eval(tmp_paths) | |
print(tmp_paths) | |
for i in range(len(tmp_paths)): | |
tmp_path = tmp_paths[i] | |
_audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False) | |
output_audios.append(_audio) | |
gr.Examples( | |
examples = [ | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "./data/audio/Suri's Improv.mp3" | |
), | |
30, | |
"facebook/musicgen-large", | |
"Muisc Continuation", | |
], | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "./data/audio/lie_no_tomorrow_20sec.wav" | |
), | |
40, | |
"facebook/musicgen-melody-large", | |
"Music Conditioning", | |
] | |
], | |
inputs=[upload_melody, duration2, model_path2, radio_melody_condition], | |
) | |
with gr.Tab("Generate Music by image"): | |
with gr.Column(): | |
with gr.Row(): | |
image_input = gr.Image("Upload an image", type="filepath") | |
with gr.Accordion("Image Captioning", open=False): | |
image_description = gr.Textbox(label='image description', visible=True, interactive=False) | |
image_caption = gr.Textbox(label='generated text prompt', visible=True, interactive=True) | |
def generate_image_text_prompt(image_input): | |
if image_input: | |
image_description, image_caption = generate_caption_gpt4(image_input, model_path) | |
# meesage_object, description, prompt = generate_caption_claude3(image_input, model_path) | |
return image_description, image_caption | |
return "", "" | |
with gr.Row(): | |
melody3 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="File", visible=True) | |
with gr.Column(): | |
model_path3 = gr.Dropdown( | |
choices=TEXT_TO_SOUND_MODELS + TEXT_TO_MUSIC_MODELS + MELODY_CONDITIONED_MODELS, | |
label="Select the model", | |
value="facebook/musicgen-large", | |
) | |
duration3 = gr.Number(30, visible=False, label="Duration") | |
submit3 = gr.Button("Generate Music") | |
result_text3 = gr.Textbox(label="Generated Music (image)", type="text", interactive=False, visible=True) | |
def predict_image_music(model_path3, image_caption, duration3, melody3): | |
model_configs = {"duration": duration3, "use_sampling": True, "top_k": 250, "top_p": 0, "temperature": 1} | |
return predict( | |
model_version = model_path3, | |
generation_configs = model_configs, | |
prompt_text = image_caption, | |
prompt_wav = melody3 | |
) | |
submit3.click( | |
fn=predict_image_music, | |
inputs=[model_path3, image_caption, duration3, melody3], | |
outputs=result_text3, | |
queue=True | |
) | |
def show_output_audio(tmp_paths): | |
if tmp_paths: | |
tmp_paths = ast.literal_eval(tmp_paths) | |
print(tmp_paths) | |
for i in range(len(tmp_paths)): | |
tmp_path = tmp_paths[i] | |
_audio = gr.Audio(value=tmp_path , label=f"Generated Music {i}", type='filepath', interactive=False) | |
output_audios.append(_audio) | |
def show_transcribt_audio(tmp_paths): | |
transcribe(tmp_paths) | |
gr.Examples( | |
examples = [ | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "./data/image/beach.jpeg" | |
), | |
"facebook/musicgen-large", | |
30, | |
None, | |
], | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "./data/image/beach.jpeg" | |
), | |
"facebook/audiogen-medium", | |
15, | |
None, | |
], | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "./data/image/beach.jpeg" | |
), | |
"facebook/musicgen-melody-large", | |
30, | |
os.path.join( | |
os.path.dirname(__file__), "./data/audio/Suri's Improv.mp3" | |
), | |
], | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "./data/image/cat.jpeg" | |
), | |
"facebook/musicgen-large", | |
30, | |
None, | |
], | |
], | |
inputs=[image_input, model_path3, duration3, melody3], | |
) | |
demo.queue().launch(share=share) | |
if __name__ == "__main__": | |
# Create the parser | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--share', action='store_true', help='Enable sharing.') | |
args = parser.parse_args() | |
UI(share=args.share) | |