Futuresony commited on
Commit
ca69a74
·
verified ·
1 Parent(s): f3ac400

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -63
app.py CHANGED
@@ -1,64 +1,241 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("Futuresony/future_ai_12_10_2024.gguf")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install utils
2
+ !pip install gradio
3
  import gradio as gr
4
+ from huggingface_hub import snapshot_download
5
+ from threading import Thread
6
+ import time
7
+ import base64
8
+ import numpy as np
9
+ import requests
10
+ import traceback
11
+ from dataclasses import dataclass, field
12
+ import io
13
+ from pydub import AudioSegment
14
+ import librosa
15
+ from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
16
+ import tempfile
17
+
18
+
19
+ from server import serve
20
+
21
+ repo_id = "gpt-omni/mini-omni"
22
+ snapshot_download(repo_id, local_dir="./checkpoint", revision="main")
23
+
24
+ IP = "0.0.0.0"
25
+ PORT = 60808
26
+
27
+ thread = Thread(target=serve, daemon=True)
28
+ thread.start()
29
+
30
+ API_URL = "http://0.0.0.0:60808/chat"
31
+
32
+ # recording parameters
33
+ IN_CHANNELS = 1
34
+ IN_RATE = 24000
35
+ IN_CHUNK = 1024
36
+ IN_SAMPLE_WIDTH = 2
37
+ VAD_STRIDE = 0.5
38
+
39
+ # playing parameters
40
+ OUT_CHANNELS = 1
41
+ OUT_RATE = 24000
42
+ OUT_SAMPLE_WIDTH = 2
43
+ OUT_CHUNK = 5760
44
+
45
+
46
+ OUT_CHUNK = 20 * 4096
47
+ OUT_RATE = 24000
48
+ OUT_CHANNELS = 1
49
+
50
+
51
+ def run_vad(ori_audio, sr):
52
+ _st = time.time()
53
+ try:
54
+ audio = ori_audio
55
+ audio = audio.astype(np.float32) / 32768.0
56
+ sampling_rate = 16000
57
+ if sr != sampling_rate:
58
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
59
+
60
+ vad_parameters = {}
61
+ vad_parameters = VadOptions(**vad_parameters)
62
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
63
+ audio = collect_chunks(audio, speech_chunks)
64
+ duration_after_vad = audio.shape[0] / sampling_rate
65
+
66
+ if sr != sampling_rate:
67
+ # resample to original sampling rate
68
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
69
+ else:
70
+ vad_audio = audio
71
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
72
+ vad_audio_bytes = vad_audio.tobytes()
73
+
74
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
75
+ except Exception as e:
76
+ msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
77
+ print(msg)
78
+ return -1, ori_audio, round(time.time() - _st, 4)
79
+
80
+
81
+ def warm_up():
82
+ frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each
83
+ dur, frames, tcost = run_vad(frames, 16000)
84
+ print(f"warm up done, time_cost: {tcost:.3f} s")
85
+
86
+
87
+ warm_up()
88
+
89
+
90
+ @dataclass
91
+ class AppState:
92
+ stream: np.ndarray | None = None
93
+ sampling_rate: int = 0
94
+ pause_detected: bool = False
95
+ started_talking: bool = False
96
+ stopped: bool = False
97
+ conversation: list = field(default_factory=list)
98
+
99
+
100
+ def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
101
+ """Take in the stream, determine if a pause happened"""
102
+
103
+ temp_audio = audio
104
+
105
+ dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
106
+ duration = len(audio) / sampling_rate
107
+
108
+ if dur_vad > 0.5 and not state.started_talking:
109
+ print("started talking")
110
+ state.started_talking = True
111
+ return False
112
+
113
+ print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
114
+
115
+ return (duration - dur_vad) > 1
116
+
117
+
118
+ def speaking(audio_bytes: str):
119
+
120
+ base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
121
+ files = {"audio": base64_encoded}
122
+ with requests.post(API_URL, json=files, stream=True) as response:
123
+ try:
124
+ for chunk in response.iter_content(chunk_size=OUT_CHUNK):
125
+ if chunk:
126
+ # Create an audio segment from the numpy array
127
+ audio_segment = AudioSegment(
128
+ chunk,
129
+ frame_rate=OUT_RATE,
130
+ sample_width=OUT_SAMPLE_WIDTH,
131
+ channels=OUT_CHANNELS,
132
+ )
133
+
134
+ # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
135
+ mp3_io = io.BytesIO()
136
+ audio_segment.export(mp3_io, format="mp3", bitrate="320k")
137
+
138
+ # Get the MP3 bytes
139
+ mp3_bytes = mp3_io.getvalue()
140
+ mp3_io.close()
141
+ yield mp3_bytes
142
+
143
+ except Exception as e:
144
+ raise gr.Error(f"Error during audio streaming: {e}")
145
+
146
+
147
+
148
+
149
+ def process_audio(audio: tuple, state: AppState):
150
+ if state.stream is None:
151
+ state.stream = audio[1]
152
+ state.sampling_rate = audio[0]
153
+ else:
154
+ state.stream = np.concatenate((state.stream, audio[1]))
155
+
156
+ pause_detected = determine_pause(state.stream, state.sampling_rate, state)
157
+ state.pause_detected = pause_detected
158
+
159
+ if state.pause_detected and state.started_talking:
160
+ return gr.Audio(recording=False), state
161
+ return None, state
162
+
163
+
164
+ def response(state: AppState):
165
+ if not state.pause_detected and not state.started_talking:
166
+ return None, AppState()
167
+
168
+ audio_buffer = io.BytesIO()
169
+
170
+ segment = AudioSegment(
171
+ state.stream.tobytes(),
172
+ frame_rate=state.sampling_rate,
173
+ sample_width=state.stream.dtype.itemsize,
174
+ channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
175
+ )
176
+ segment.export(audio_buffer, format="wav")
177
+
178
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
179
+ f.write(audio_buffer.getvalue())
180
+
181
+ state.conversation.append({"role": "user",
182
+ "content": {"path": f.name,
183
+ "mime_type": "audio/wav"}})
184
+
185
+ output_buffer = b""
186
+
187
+ for mp3_bytes in speaking(audio_buffer.getvalue()):
188
+ output_buffer += mp3_bytes
189
+ yield mp3_bytes, state
190
+
191
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
192
+ f.write(output_buffer)
193
+
194
+ state.conversation.append({"role": "assistant",
195
+ "content": {"path": f.name,
196
+ "mime_type": "audio/mp3"}})
197
+ yield None, AppState(conversation=state.conversation)
198
+
199
+
200
+
201
+
202
+ def start_recording_user(state: AppState):
203
+ if not state.stopped:
204
+ return gr.Audio(recording=True)
205
+
206
+ with gr.Blocks() as demo:
207
+ with gr.Row():
208
+ with gr.Column():
209
+ input_audio = gr.Audio(
210
+ label="Input Audio", sources="microphone", type="numpy"
211
+ )
212
+ with gr.Column():
213
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
214
+ output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
215
+ state = gr.State(value=AppState())
216
+
217
+ stream = input_audio.stream(
218
+ process_audio,
219
+ [input_audio, state],
220
+ [input_audio, state],
221
+ stream_every=0.5,
222
+ time_limit=30,
223
+ )
224
+ respond = input_audio.stop_recording(
225
+ response,
226
+ [state],
227
+ [output_audio, state]
228
+ )
229
+ respond.then(lambda s: s.conversation, [state], [chatbot])
230
+
231
+ restart = output_audio.stop(
232
+ start_recording_user,
233
+ [state],
234
+ [input_audio]
235
+ )
236
+ cancel = gr.Button("Stop Conversation", variant="stop")
237
+ cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
238
+ [state, input_audio], cancels=[respond, restart])
239
+
240
+
241
+ demo.launch()