# Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gradio as gr import spaces import argparse from model_zero import SALMONN class ff: def generate(self, wav_path, prompt, prompt_pattern, num_beams, temperature, top_p): print(f'wav_path: {wav_path}, prompt: {prompt}, temperature: {temperature}, num_beams: {num_beams}, top_p: {top_p}') return "I'm sorry, but I cannot answer that question as it is not clear what you are asking. Can you please provide more context or clarify your question?" parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--ckpt_path", type=str, default="./salmonn_7b_v0.pth") parser.add_argument("--whisper_path", type=str, default="./whisper_large_v2") parser.add_argument("--beats_path", type=str, default="./beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt") parser.add_argument("--vicuna_path", type=str, default="./vicuna-7b-v1.5") parser.add_argument("--low_resource", action='store_true', default=False) parser.add_argument("--port", default=9527) args = parser.parse_args() args.low_resource = True # for huggingface A10 7b demo # model = ff() model = SALMONN( ckpt=args.ckpt_path, whisper_path=args.whisper_path, beats_path=args.beats_path, vicuna_path=args.vicuna_path, low_resource=args.low_resource, lora_alpha=28, device='cpu' ) model.to(args.device) model.eval() @spaces.GPU(enable_queue=True) def gradio_answer(speech, text_input, num_beams, temperature, top_p): llm_message = model.generate( wav_path=speech, prompt=text_input, num_beams=num_beams, temperature=temperature, top_p=top_p, ) print(llm_message) return llm_message[0] title = """

SALMONN: Speech Audio Language Music Open Neural Network

""" image_src = """

SALMONN

""" description = """

This is a simplified gradio demo for SALMONN-7B.
To experience SALMONN-13B, you can go to https://bytedance.github.io/SALMONN.
Upload your audio and ask a question!

""" css = """ div#col-container { margin: 0 auto; max-width: 840px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML(title) #gr.Markdown(image_src) gr.HTML(description) with gr.Row(): with gr.Column(): speech = gr.Audio(label="Audio", type='filepath') with gr.Row(): text_input = gr.Textbox(label='User question', placeholder='Please upload your audio first', interactive=True) submit_btn = gr.Button("Submit", scale=0) answer = gr.Textbox(label="Salmonn answer") with gr.Accordion("Advanced Settings", open=False): num_beams = gr.Slider( minimum=1, maximum=10, value=4, step=1, interactive=True, label="beam search numbers", ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True, label="top p", ) temperature = gr.Slider( minimum=0.8, maximum=2.0, value=1.0, step=0.1, interactive=False, label="temperature", ) with gr.Row(): examples = gr.Examples( examples = [ ["resource/audio_demo/gunshots.wav", "Recognize the speech and give me the transcription."], ["resource/audio_demo/gunshots.wav", "Listen to the speech and translate it into German."], ["resource/audio_demo/gunshots.wav", "Provide the phonetic transcription for the speech."], ["resource/audio_demo/gunshots.wav", "Please describe the audio."], ["resource/audio_demo/gunshots.wav", "Recognize what the speaker says and describe the background audio at the same time."], ["resource/audio_demo/gunshots.wav", "Use your strong reasoning skills to answer the speaker's question in detail based on the background sound."], ["resource/audio_demo/duck.wav", "Please list each event in the audio in order."], ["resource/audio_demo/duck.wav", "Based on the audio, write a story in detail. Your story should be highly related to the audio."], ["resource/audio_demo/duck.wav", "How many speakers did you hear in this audio? Who are they?"], ["resource/audio_demo/excitement.wav", "Describe the emotion of the speaker."], ["resource/audio_demo/mountain.wav", "Please answer the question in detail."], ["resource/audio_demo/jobs.wav", "Give me only three keywords of the text. Explain your reason."], ["resource/audio_demo/2_30.wav", "What is the time mentioned in the speech?"], ["resource/audio_demo/music.wav", "Please describe the music in detail."], ["resource/audio_demo/music.wav", "What is the emotion of the music? Explain the reason in detail."], ["resource/audio_demo/music.wav", "Can you write some lyrics of the song?"], ["resource/audio_demo/music.wav", "Give me a title of the music based on its rhythm and emotion."] ], inputs=[speech, text_input] ) text_input.submit( gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer] ) submit_btn.click( gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer] ) # demo.launch(share=True, enable_queue=True, server_port=int(args.port)) demo.queue(max_size=20).launch(share=False, show_error=True)