fffiloni commited on
Commit
b0c18ca
·
verified ·
1 Parent(s): da5de54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -45
app.py CHANGED
@@ -44,45 +44,17 @@ model = SALMONN(
44
  model.to(args.device)
45
  model.eval()
46
 
47
- # gradio
48
- def gradio_reset(chat_state):
49
 
50
- chat_state = []
51
- return (None,
52
- gr.update(value=None, interactive=True),
53
- gr.update(placeholder='Please upload your wav first', interactive=False),
54
- gr.update(value="Upload & Start Chat", interactive=True),
55
- chat_state)
56
-
57
- def upload_speech(gr_speech, text_input, chat_state):
58
-
59
- if gr_speech is None:
60
- return None, None, gr.update(interactive=True), chat_state, None
61
- chat_state.append(gr_speech)
62
- return (gr.update(interactive=False),
63
- gr.update(interactive=True, placeholder='Type and press Enter'),
64
- gr.update(value="Start Chatting", interactive=False),
65
- chat_state)
66
-
67
- def gradio_ask(user_message, chatbot, chat_state):
68
-
69
- if len(user_message) == 0:
70
- return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
71
- chat_state.append(user_message)
72
- chatbot.append([user_message, None])
73
- #
74
- return gr.update(interactive=False, placeholder='Currently only single round conversations are supported.'), chatbot, chat_state
75
-
76
- def gradio_answer(chatbot, chat_state, num_beams, temperature, top_p):
77
  llm_message = model.generate(
78
- wav_path=chat_state[0],
79
- prompt=chat_state[1],
80
  num_beams=num_beams,
81
  temperature=temperature,
82
  top_p=top_p,
83
  )
84
- chatbot[-1][1] = llm_message[0]
85
- return chatbot, chat_state
86
 
87
  title = """<h1 align="center">SALMONN: Speech Audio Language Music Open Neural Network</h1>"""
88
  image_src = """<h1 align="center"><a href="https://github.com/bytedance/SALMONN"><img src="https://raw.githubusercontent.com/bytedance/SALMONN/main/resource/salmon.png", alt="SALMONN" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>"""
@@ -97,9 +69,7 @@ with gr.Blocks() as demo:
97
  with gr.Row():
98
  with gr.Column():
99
  speech = gr.Audio(label="Audio", type='filepath')
100
- upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
101
- clear = gr.Button("Restart")
102
-
103
  num_beams = gr.Slider(
104
  minimum=1,
105
  maximum=10,
@@ -128,10 +98,9 @@ with gr.Blocks() as demo:
128
  )
129
 
130
  with gr.Column():
131
- chat_state = gr.State([])
132
 
133
- chatbot = gr.Chatbot(label='SALMONN-7B')
134
- text_input = gr.Textbox(label='User', placeholder='Please upload your audio first', interactive=False)
135
 
136
  with gr.Row():
137
  examples = gr.Examples(
@@ -157,12 +126,11 @@ with gr.Blocks() as demo:
157
  inputs=[speech, text_input]
158
  )
159
 
160
- upload_button.click(upload_speech, [speech, text_input, chat_state], [speech, text_input, upload_button, chat_state])
161
-
162
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
163
- gradio_answer, [chatbot, chat_state, num_beams, temperature, top_p], [chatbot, chat_state]
164
  )
165
- clear.click(gradio_reset, [chat_state], [chatbot, speech, text_input, upload_button, chat_state], queue=False)
166
 
167
  # demo.launch(share=True, enable_queue=True, server_port=int(args.port))
168
- demo.launch(share=True)
 
44
  model.to(args.device)
45
  model.eval()
46
 
47
+ def gradio_answer(speech, text_input, num_beams, temperature, top_p):
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  llm_message = model.generate(
50
+ wav_path=speech,
51
+ prompt=text_input,
52
  num_beams=num_beams,
53
  temperature=temperature,
54
  top_p=top_p,
55
  )
56
+
57
+ return llm_message
58
 
59
  title = """<h1 align="center">SALMONN: Speech Audio Language Music Open Neural Network</h1>"""
60
  image_src = """<h1 align="center"><a href="https://github.com/bytedance/SALMONN"><img src="https://raw.githubusercontent.com/bytedance/SALMONN/main/resource/salmon.png", alt="SALMONN" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>"""
 
69
  with gr.Row():
70
  with gr.Column():
71
  speech = gr.Audio(label="Audio", type='filepath')
72
+
 
 
73
  num_beams = gr.Slider(
74
  minimum=1,
75
  maximum=10,
 
98
  )
99
 
100
  with gr.Column():
 
101
 
102
+ text_input = gr.Textbox(label='User', placeholder='Please upload your audio first', interactive=True)
103
+ answer = gr.Textbox(label="Salmonn answer")
104
 
105
  with gr.Row():
106
  examples = gr.Examples(
 
126
  inputs=[speech, text_input]
127
  )
128
 
129
+
130
+ text_input.submit(
131
+ gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer]
 
132
  )
133
+
134
 
135
  # demo.launch(share=True, enable_queue=True, server_port=int(args.port))
136
+ demo.launch(share=False)