# Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
import spaces
import argparse
from model_zero import SALMONN
class ff:
def generate(self, wav_path, prompt, prompt_pattern, num_beams, temperature, top_p):
print(f'wav_path: {wav_path}, prompt: {prompt}, temperature: {temperature}, num_beams: {num_beams}, top_p: {top_p}')
return "I'm sorry, but I cannot answer that question as it is not clear what you are asking. Can you please provide more context or clarify your question?"
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--ckpt_path", type=str, default="./salmonn_7b_v0.pth")
parser.add_argument("--whisper_path", type=str, default="./whisper_large_v2")
parser.add_argument("--beats_path", type=str, default="./beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt")
parser.add_argument("--vicuna_path", type=str, default="./vicuna-7b-v1.5")
parser.add_argument("--low_resource", action='store_true', default=False)
parser.add_argument("--port", default=9527)
args = parser.parse_args()
args.low_resource = True # for huggingface A10 7b demo
# model = ff()
model = SALMONN(
ckpt=args.ckpt_path,
whisper_path=args.whisper_path,
beats_path=args.beats_path,
vicuna_path=args.vicuna_path,
low_resource=args.low_resource,
lora_alpha=28,
device='cpu'
)
model.to(args.device)
model.eval()
@spaces.GPU(enable_queue=True)
def gradio_answer(speech, text_input, num_beams, temperature, top_p):
llm_message = model.generate(
wav_path=speech,
prompt=text_input,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
)
print(llm_message)
return llm_message[0]
title = """
SALMONN: Speech Audio Language Music Open Neural Network
"""
image_src = """
"""
description = """This is a simplified gradio demo for SALMONN-7B.
To experience SALMONN-13B, you can go to https://bytedance.github.io/SALMONN.
Upload your audio and ask a question!
"""
css = """
div#col-container {
margin: 0 auto;
max-width: 840px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
#gr.Markdown(image_src)
gr.HTML(description)
with gr.Row():
with gr.Column():
speech = gr.Audio(label="Audio", type='filepath')
with gr.Row():
text_input = gr.Textbox(label='User question', placeholder='Please upload your audio first', interactive=True)
submit_btn = gr.Button("Submit", scale=0)
answer = gr.Textbox(label="Salmonn answer")
with gr.Accordion("Advanced Settings", open=False):
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=4,
step=1,
interactive=True,
label="beam search numbers",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
interactive=True,
label="top p",
)
temperature = gr.Slider(
minimum=0.8,
maximum=2.0,
value=1.0,
step=0.1,
interactive=False,
label="temperature",
)
with gr.Row():
examples = gr.Examples(
examples = [
["resource/audio_demo/gunshots.wav", "Recognize the speech and give me the transcription."],
["resource/audio_demo/gunshots.wav", "Listen to the speech and translate it into German."],
["resource/audio_demo/gunshots.wav", "Provide the phonetic transcription for the speech."],
["resource/audio_demo/gunshots.wav", "Please describe the audio."],
["resource/audio_demo/gunshots.wav", "Recognize what the speaker says and describe the background audio at the same time."],
["resource/audio_demo/gunshots.wav", "Use your strong reasoning skills to answer the speaker's question in detail based on the background sound."],
["resource/audio_demo/duck.wav", "Please list each event in the audio in order."],
["resource/audio_demo/duck.wav", "Based on the audio, write a story in detail. Your story should be highly related to the audio."],
["resource/audio_demo/duck.wav", "How many speakers did you hear in this audio? Who are they?"],
["resource/audio_demo/excitement.wav", "Describe the emotion of the speaker."],
["resource/audio_demo/mountain.wav", "Please answer the question in detail."],
["resource/audio_demo/jobs.wav", "Give me only three keywords of the text. Explain your reason."],
["resource/audio_demo/2_30.wav", "What is the time mentioned in the speech?"],
["resource/audio_demo/music.wav", "Please describe the music in detail."],
["resource/audio_demo/music.wav", "What is the emotion of the music? Explain the reason in detail."],
["resource/audio_demo/music.wav", "Can you write some lyrics of the song?"],
["resource/audio_demo/music.wav", "Give me a title of the music based on its rhythm and emotion."]
],
inputs=[speech, text_input]
)
text_input.submit(
gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer]
)
submit_btn.click(
gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer]
)
# demo.launch(share=True, enable_queue=True, server_port=int(args.port))
demo.queue(max_size=20).launch(share=False, show_error=True)