daihui.zhang
commited on
Commit
·
ea9e44a
1
Parent(s):
55cf28e
remove unused codes
Browse files- config.py +3 -2
- main.py +1 -5
- transcribe/client.py +0 -677
- transcribe/helpers/vadprocessor.py +1 -1
- transcribe/pipelines/pipe_vad.py +8 -15
- transcribe/server.py +0 -382
- transcribe/strategy.py +0 -405
- transcribe/transcription.py +0 -334
- transcribe/translatepipes.py +0 -1
- transcribe/whisper_llm_serve.py +21 -107
config.py
CHANGED
@@ -3,10 +3,11 @@ import re
|
|
3 |
import logging
|
4 |
|
5 |
DEBUG = True
|
|
|
6 |
|
7 |
logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
|
8 |
logging.basicConfig(
|
9 |
-
level=
|
10 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
11 |
filename='translator.log',
|
12 |
datefmt="%H:%M:%S"
|
@@ -15,7 +16,7 @@ logging.basicConfig(
|
|
15 |
SAVE_DATA_SAVE = False
|
16 |
# Add terminal log
|
17 |
console_handler = logging.StreamHandler()
|
18 |
-
console_handler.setLevel(
|
19 |
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
20 |
console_handler.setFormatter(console_formatter)
|
21 |
logging.getLogger().addHandler(console_handler)
|
|
|
3 |
import logging
|
4 |
|
5 |
DEBUG = True
|
6 |
+
LOG_LEVEL = logging.WARNING if DEBUG else logging.INFO
|
7 |
|
8 |
logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
|
9 |
logging.basicConfig(
|
10 |
+
level=LOG_LEVEL,
|
11 |
format="%(asctime)s - %(levelname)s - %(message)s",
|
12 |
filename='translator.log',
|
13 |
datefmt="%H:%M:%S"
|
|
|
16 |
SAVE_DATA_SAVE = False
|
17 |
# Add terminal log
|
18 |
console_handler = logging.StreamHandler()
|
19 |
+
console_handler.setLevel(LOG_LEVEL)
|
20 |
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
21 |
console_handler.setFormatter(console_formatter)
|
22 |
logging.getLogger().addHandler(console_handler)
|
main.py
CHANGED
@@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles
|
|
11 |
from fastapi.responses import RedirectResponse
|
12 |
import os
|
13 |
from transcribe.utils import pcm_bytes_to_np_array
|
|
|
14 |
logger = getLogger(__name__)
|
15 |
|
16 |
|
@@ -39,9 +40,6 @@ async def lifespan(app:FastAPI):
|
|
39 |
yield
|
40 |
|
41 |
|
42 |
-
# 获取当前文件所在目录的绝对路径
|
43 |
-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
44 |
-
# 构建frontend目录的绝对路径
|
45 |
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
46 |
|
47 |
|
@@ -66,9 +64,7 @@ async def translate(websocket: WebSocket):
|
|
66 |
client_uid=f"{uuid1()}",
|
67 |
)
|
68 |
|
69 |
-
|
70 |
if from_lang and to_lang and client:
|
71 |
-
client.set_language(from_lang, to_lang)
|
72 |
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
73 |
await websocket.accept()
|
74 |
try:
|
|
|
11 |
from fastapi.responses import RedirectResponse
|
12 |
import os
|
13 |
from transcribe.utils import pcm_bytes_to_np_array
|
14 |
+
from config import BASE_DIR
|
15 |
logger = getLogger(__name__)
|
16 |
|
17 |
|
|
|
40 |
yield
|
41 |
|
42 |
|
|
|
|
|
|
|
43 |
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
44 |
|
45 |
|
|
|
64 |
client_uid=f"{uuid1()}",
|
65 |
)
|
66 |
|
|
|
67 |
if from_lang and to_lang and client:
|
|
|
68 |
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
69 |
await websocket.accept()
|
70 |
try:
|
transcribe/client.py
DELETED
@@ -1,677 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import shutil
|
4 |
-
import threading
|
5 |
-
import time
|
6 |
-
import uuid
|
7 |
-
import wave
|
8 |
-
|
9 |
-
import av
|
10 |
-
import numpy as np
|
11 |
-
import pyaudio
|
12 |
-
import websocket
|
13 |
-
|
14 |
-
import transcribe.utils as utils
|
15 |
-
|
16 |
-
|
17 |
-
class Client:
|
18 |
-
"""
|
19 |
-
Handles communication with a server using WebSocket.
|
20 |
-
"""
|
21 |
-
INSTANCES = {}
|
22 |
-
END_OF_AUDIO = "END_OF_AUDIO"
|
23 |
-
|
24 |
-
def __init__(
|
25 |
-
self,
|
26 |
-
host=None,
|
27 |
-
port=None,
|
28 |
-
lang=None,
|
29 |
-
log_transcription=True,
|
30 |
-
max_clients=4,
|
31 |
-
max_connection_time=600,
|
32 |
-
dst_lang='zh',
|
33 |
-
):
|
34 |
-
"""
|
35 |
-
Initializes a Client instance for audio recording and streaming to a server.
|
36 |
-
|
37 |
-
If host and port are not provided, the WebSocket connection will not be established.
|
38 |
-
the audio recording starts immediately upon initialization.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
host (str): The hostname or IP address of the server.
|
42 |
-
port (int): The port number for the WebSocket server.
|
43 |
-
lang (str, optional): The selected language for transcription. Default is None.
|
44 |
-
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
45 |
-
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
46 |
-
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
47 |
-
"""
|
48 |
-
self.recording = False
|
49 |
-
self.uid = str(uuid.uuid4())
|
50 |
-
self.waiting = False
|
51 |
-
self.last_response_received = None
|
52 |
-
self.disconnect_if_no_response_for = 15
|
53 |
-
self.language = lang
|
54 |
-
self.server_error = False
|
55 |
-
self.last_segment = None
|
56 |
-
self.last_received_segment = None
|
57 |
-
self.log_transcription = log_transcription
|
58 |
-
self.max_clients = max_clients
|
59 |
-
self.max_connection_time = max_connection_time
|
60 |
-
self.dst_lang = dst_lang
|
61 |
-
|
62 |
-
self.audio_bytes = None
|
63 |
-
|
64 |
-
if host is not None and port is not None:
|
65 |
-
socket_url = f"ws://{host}:{port}?from={self.language}&to={self.dst_lang}"
|
66 |
-
self.client_socket = websocket.WebSocketApp(
|
67 |
-
socket_url,
|
68 |
-
on_open=lambda ws: self.on_open(ws),
|
69 |
-
on_message=lambda ws, message: self.on_message(ws, message),
|
70 |
-
on_error=lambda ws, error: self.on_error(ws, error),
|
71 |
-
on_close=lambda ws, close_status_code, close_msg: self.on_close(
|
72 |
-
ws, close_status_code, close_msg
|
73 |
-
),
|
74 |
-
)
|
75 |
-
else:
|
76 |
-
print("[ERROR]: No host or port specified.")
|
77 |
-
return
|
78 |
-
|
79 |
-
Client.INSTANCES[self.uid] = self
|
80 |
-
|
81 |
-
# start websocket client in a thread
|
82 |
-
self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
|
83 |
-
self.ws_thread.daemon = True
|
84 |
-
self.ws_thread.start()
|
85 |
-
|
86 |
-
self.transcript = []
|
87 |
-
print("[INFO]: * recording")
|
88 |
-
|
89 |
-
def handle_status_messages(self, message_data):
|
90 |
-
"""Handles server status messages."""
|
91 |
-
status = message_data["status"]
|
92 |
-
if status == "WAIT":
|
93 |
-
self.waiting = True
|
94 |
-
print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
|
95 |
-
elif status == "ERROR":
|
96 |
-
print(f"Message from Server: {message_data['message']}")
|
97 |
-
self.server_error = True
|
98 |
-
elif status == "WARNING":
|
99 |
-
print(f"Message from Server: {message_data['message']}")
|
100 |
-
|
101 |
-
def process_segments(self, segments):
|
102 |
-
"""Processes transcript segments."""
|
103 |
-
text = []
|
104 |
-
for i, seg in enumerate(segments):
|
105 |
-
if not text or text[-1] != seg["text"]:
|
106 |
-
text.append(seg["text"])
|
107 |
-
if i == len(segments) - 1 and not seg.get("completed", False):
|
108 |
-
self.last_segment = seg
|
109 |
-
|
110 |
-
# update last received segment and last valid response time
|
111 |
-
if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
|
112 |
-
self.last_response_received = time.time()
|
113 |
-
self.last_received_segment = segments[-1]["text"]
|
114 |
-
|
115 |
-
if self.log_transcription:
|
116 |
-
# Truncate to last 3 entries for brevity.
|
117 |
-
text = text[-3:]
|
118 |
-
utils.clear_screen()
|
119 |
-
utils.print_transcript(text)
|
120 |
-
|
121 |
-
def on_message(self, ws, message):
|
122 |
-
"""
|
123 |
-
Callback function called when a message is received from the server.
|
124 |
-
|
125 |
-
It updates various attributes of the client based on the received message, including
|
126 |
-
recording status, language detection, and server messages. If a disconnect message
|
127 |
-
is received, it sets the recording status to False.
|
128 |
-
|
129 |
-
Args:
|
130 |
-
ws (websocket.WebSocketApp): The WebSocket client instance.
|
131 |
-
message (str): The received message from the server.
|
132 |
-
|
133 |
-
"""
|
134 |
-
message = json.loads(message)
|
135 |
-
|
136 |
-
# if self.uid != message.get("uid"):
|
137 |
-
# print("[ERROR]: invalid client uid")
|
138 |
-
# return
|
139 |
-
|
140 |
-
if "status" in message.keys():
|
141 |
-
self.handle_status_messages(message)
|
142 |
-
return
|
143 |
-
|
144 |
-
if "message" in message.keys() and message["message"] == "DISCONNECT":
|
145 |
-
print("[INFO]: Server disconnected due to overtime.")
|
146 |
-
self.recording = False
|
147 |
-
|
148 |
-
if "message" in message.keys() and message["message"] == "SERVER_READY":
|
149 |
-
self.last_response_received = time.time()
|
150 |
-
self.recording = True
|
151 |
-
self.server_backend = message["backend"]
|
152 |
-
print(f"[INFO]: Server Running with backend {self.server_backend}")
|
153 |
-
return
|
154 |
-
|
155 |
-
if "language" in message.keys():
|
156 |
-
self.language = message.get("language")
|
157 |
-
lang_prob = message.get("language_prob")
|
158 |
-
print(
|
159 |
-
f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
|
160 |
-
)
|
161 |
-
return
|
162 |
-
|
163 |
-
if "segments" in message.keys():
|
164 |
-
self.process_segments(message["segments"])
|
165 |
-
|
166 |
-
def on_error(self, ws, error):
|
167 |
-
print(f"[ERROR] WebSocket Error: {error}")
|
168 |
-
self.server_error = True
|
169 |
-
self.error_message = error
|
170 |
-
|
171 |
-
def on_close(self, ws, close_status_code, close_msg):
|
172 |
-
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
|
173 |
-
self.recording = False
|
174 |
-
self.waiting = False
|
175 |
-
|
176 |
-
def on_open(self, ws):
|
177 |
-
"""
|
178 |
-
Callback function called when the WebSocket connection is successfully opened.
|
179 |
-
|
180 |
-
Sends an initial configuration message to the server, including client UID,
|
181 |
-
language selection, and task type.
|
182 |
-
|
183 |
-
Args:
|
184 |
-
ws (websocket.WebSocketApp): The WebSocket client instance.
|
185 |
-
|
186 |
-
"""
|
187 |
-
print("[INFO]: Opened connection")
|
188 |
-
ws.send(
|
189 |
-
json.dumps(
|
190 |
-
{
|
191 |
-
"uid": self.uid,
|
192 |
-
"language": self.language,
|
193 |
-
"max_clients": self.max_clients,
|
194 |
-
"max_connection_time": self.max_connection_time,
|
195 |
-
}
|
196 |
-
)
|
197 |
-
)
|
198 |
-
|
199 |
-
def send_packet_to_server(self, message):
|
200 |
-
"""
|
201 |
-
Send an audio packet to the server using WebSocket.
|
202 |
-
|
203 |
-
Args:
|
204 |
-
message (bytes): The audio data packet in bytes to be sent to the server.
|
205 |
-
|
206 |
-
"""
|
207 |
-
try:
|
208 |
-
self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
|
209 |
-
except Exception as e:
|
210 |
-
print(e)
|
211 |
-
|
212 |
-
def close_websocket(self):
|
213 |
-
"""
|
214 |
-
Close the WebSocket connection and join the WebSocket thread.
|
215 |
-
|
216 |
-
First attempts to close the WebSocket connection using `self.client_socket.close()`. After
|
217 |
-
closing the connection, it joins the WebSocket thread to ensure proper termination.
|
218 |
-
|
219 |
-
"""
|
220 |
-
try:
|
221 |
-
self.client_socket.close()
|
222 |
-
except Exception as e:
|
223 |
-
print("[ERROR]: Error closing WebSocket:", e)
|
224 |
-
|
225 |
-
try:
|
226 |
-
self.ws_thread.join()
|
227 |
-
except Exception as e:
|
228 |
-
print("[ERROR:] Error joining WebSocket thread:", e)
|
229 |
-
|
230 |
-
def get_client_socket(self):
|
231 |
-
"""
|
232 |
-
Get the WebSocket client socket instance.
|
233 |
-
|
234 |
-
Returns:
|
235 |
-
WebSocketApp: The WebSocket client socket instance currently in use by the client.
|
236 |
-
"""
|
237 |
-
return self.client_socket
|
238 |
-
|
239 |
-
def wait_before_disconnect(self):
|
240 |
-
"""Waits a bit before disconnecting in order to process pending responses."""
|
241 |
-
assert self.last_response_received
|
242 |
-
while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
|
243 |
-
continue
|
244 |
-
|
245 |
-
|
246 |
-
class TranscriptionTeeClient:
|
247 |
-
"""
|
248 |
-
Client for handling audio recording, streaming, and transcription tasks via one or more
|
249 |
-
WebSocket connections.
|
250 |
-
|
251 |
-
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
252 |
-
to send audio data for transcription to one or more servers, and receive transcribed text segments.
|
253 |
-
Args:
|
254 |
-
clients (list): one or more previously initialized Client instances
|
255 |
-
|
256 |
-
Attributes:
|
257 |
-
clients (list): the underlying Client instances responsible for handling WebSocket connections.
|
258 |
-
"""
|
259 |
-
|
260 |
-
def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav",
|
261 |
-
mute_audio_playback=False):
|
262 |
-
self.clients = clients
|
263 |
-
if not self.clients:
|
264 |
-
raise Exception("At least one client is required.")
|
265 |
-
self.chunk = 4096
|
266 |
-
self.format = pyaudio.paInt16
|
267 |
-
self.channels = 1
|
268 |
-
self.rate = 16000
|
269 |
-
self.record_seconds = 60000
|
270 |
-
self.save_output_recording = save_output_recording
|
271 |
-
self.output_recording_filename = output_recording_filename
|
272 |
-
self.mute_audio_playback = mute_audio_playback
|
273 |
-
self.frames = b""
|
274 |
-
self.p = pyaudio.PyAudio()
|
275 |
-
try:
|
276 |
-
self.stream = self.p.open(
|
277 |
-
format=self.format,
|
278 |
-
channels=self.channels,
|
279 |
-
rate=self.rate,
|
280 |
-
input=True,
|
281 |
-
frames_per_buffer=self.chunk,
|
282 |
-
)
|
283 |
-
except OSError as error:
|
284 |
-
print(f"[WARN]: Unable to access microphone. {error}")
|
285 |
-
self.stream = None
|
286 |
-
|
287 |
-
def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
|
288 |
-
"""
|
289 |
-
Start the transcription process.
|
290 |
-
|
291 |
-
Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
|
292 |
-
to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
|
293 |
-
will be played and streamed to the server; otherwise, it will perform live recording.
|
294 |
-
|
295 |
-
Args:
|
296 |
-
audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
|
297 |
-
|
298 |
-
"""
|
299 |
-
assert sum(
|
300 |
-
source is not None for source in [audio, rtsp_url, hls_url]
|
301 |
-
) <= 1, 'You must provide only one selected source'
|
302 |
-
|
303 |
-
print("[INFO]: Waiting for server ready ...")
|
304 |
-
for client in self.clients:
|
305 |
-
while not client.recording:
|
306 |
-
if client.waiting or client.server_error:
|
307 |
-
self.close_all_clients()
|
308 |
-
return
|
309 |
-
|
310 |
-
print("[INFO]: Server Ready!")
|
311 |
-
if hls_url is not None:
|
312 |
-
self.process_hls_stream(hls_url, save_file)
|
313 |
-
elif audio is not None:
|
314 |
-
resampled_file = utils.resample(audio)
|
315 |
-
self.play_file(resampled_file)
|
316 |
-
elif rtsp_url is not None:
|
317 |
-
self.process_rtsp_stream(rtsp_url)
|
318 |
-
else:
|
319 |
-
self.record()
|
320 |
-
|
321 |
-
def close_all_clients(self):
|
322 |
-
"""Closes all client websockets."""
|
323 |
-
for client in self.clients:
|
324 |
-
client.close_websocket()
|
325 |
-
|
326 |
-
def multicast_packet(self, packet, unconditional=False):
|
327 |
-
"""
|
328 |
-
Sends an identical packet via all clients.
|
329 |
-
|
330 |
-
Args:
|
331 |
-
packet (bytes): The audio data packet in bytes to be sent.
|
332 |
-
unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
|
333 |
-
"""
|
334 |
-
for client in self.clients:
|
335 |
-
if (unconditional or client.recording):
|
336 |
-
client.send_packet_to_server(packet)
|
337 |
-
|
338 |
-
def play_file(self, filename):
|
339 |
-
"""
|
340 |
-
Play an audio file and send it to the server for processing.
|
341 |
-
|
342 |
-
Reads an audio file, plays it through the audio output, and simultaneously sends
|
343 |
-
the audio data to the server for processing. It uses PyAudio to create an audio
|
344 |
-
stream for playback. The audio data is read from the file in chunks, converted to
|
345 |
-
floating-point format, and sent to the server using WebSocket communication.
|
346 |
-
This method is typically used when you want to process pre-recorded audio and send it
|
347 |
-
to the server in real-time.
|
348 |
-
|
349 |
-
Args:
|
350 |
-
filename (str): The path to the audio file to be played and sent to the server.
|
351 |
-
"""
|
352 |
-
|
353 |
-
# read audio and create pyaudio stream
|
354 |
-
with wave.open(filename, "rb") as wavfile:
|
355 |
-
self.stream = self.p.open(
|
356 |
-
format=self.p.get_format_from_width(wavfile.getsampwidth()),
|
357 |
-
channels=wavfile.getnchannels(),
|
358 |
-
rate=wavfile.getframerate(),
|
359 |
-
input=True,
|
360 |
-
output=True,
|
361 |
-
frames_per_buffer=self.chunk,
|
362 |
-
)
|
363 |
-
chunk_duration = self.chunk / float(wavfile.getframerate())
|
364 |
-
try:
|
365 |
-
while any(client.recording for client in self.clients):
|
366 |
-
data = wavfile.readframes(self.chunk)
|
367 |
-
if data == b"":
|
368 |
-
break
|
369 |
-
|
370 |
-
audio_array = self.bytes_to_float_array(data)
|
371 |
-
self.multicast_packet(audio_array.tobytes())
|
372 |
-
if self.mute_audio_playback:
|
373 |
-
time.sleep(chunk_duration)
|
374 |
-
else:
|
375 |
-
self.stream.write(data)
|
376 |
-
|
377 |
-
wavfile.close()
|
378 |
-
|
379 |
-
for client in self.clients:
|
380 |
-
client.wait_before_disconnect()
|
381 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
382 |
-
self.stream.close()
|
383 |
-
self.close_all_clients()
|
384 |
-
|
385 |
-
except KeyboardInterrupt:
|
386 |
-
wavfile.close()
|
387 |
-
self.stream.stop_stream()
|
388 |
-
self.stream.close()
|
389 |
-
self.p.terminate()
|
390 |
-
self.close_all_clients()
|
391 |
-
print("[INFO]: Keyboard interrupt.")
|
392 |
-
|
393 |
-
def process_rtsp_stream(self, rtsp_url):
|
394 |
-
"""
|
395 |
-
Connect to an RTSP source, process the audio stream, and send it for transcription.
|
396 |
-
|
397 |
-
Args:
|
398 |
-
rtsp_url (str): The URL of the RTSP stream source.
|
399 |
-
"""
|
400 |
-
print("[INFO]: Connecting to RTSP stream...")
|
401 |
-
try:
|
402 |
-
container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
|
403 |
-
self.process_av_stream(container, stream_type="RTSP")
|
404 |
-
except Exception as e:
|
405 |
-
print(f"[ERROR]: Failed to process RTSP stream: {e}")
|
406 |
-
finally:
|
407 |
-
for client in self.clients:
|
408 |
-
client.wait_before_disconnect()
|
409 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
410 |
-
self.close_all_clients()
|
411 |
-
print("[INFO]: RTSP stream processing finished.")
|
412 |
-
|
413 |
-
def process_hls_stream(self, hls_url, save_file=None):
|
414 |
-
"""
|
415 |
-
Connect to an HLS source, process the audio stream, and send it for transcription.
|
416 |
-
|
417 |
-
Args:
|
418 |
-
hls_url (str): The URL of the HLS stream source.
|
419 |
-
save_file (str, optional): Local path to save the network stream.
|
420 |
-
"""
|
421 |
-
print("[INFO]: Connecting to HLS stream...")
|
422 |
-
try:
|
423 |
-
container = av.open(hls_url, format="hls")
|
424 |
-
self.process_av_stream(container, stream_type="HLS", save_file=save_file)
|
425 |
-
except Exception as e:
|
426 |
-
print(f"[ERROR]: Failed to process HLS stream: {e}")
|
427 |
-
finally:
|
428 |
-
for client in self.clients:
|
429 |
-
client.wait_before_disconnect()
|
430 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
431 |
-
self.close_all_clients()
|
432 |
-
print("[INFO]: HLS stream processing finished.")
|
433 |
-
|
434 |
-
def process_av_stream(self, container, stream_type, save_file=None):
|
435 |
-
"""
|
436 |
-
Process an AV container stream and send audio packets to the server.
|
437 |
-
|
438 |
-
Args:
|
439 |
-
container (av.container.InputContainer): The input container to process.
|
440 |
-
stream_type (str): The type of stream being processed ("RTSP" or "HLS").
|
441 |
-
save_file (str, optional): Local path to save the stream. Default is None.
|
442 |
-
"""
|
443 |
-
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
444 |
-
if not audio_stream:
|
445 |
-
print(f"[ERROR]: No audio stream found in {stream_type} source.")
|
446 |
-
return
|
447 |
-
|
448 |
-
output_container = None
|
449 |
-
if save_file:
|
450 |
-
output_container = av.open(save_file, mode="w")
|
451 |
-
output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
|
452 |
-
|
453 |
-
try:
|
454 |
-
for packet in container.demux(audio_stream):
|
455 |
-
for frame in packet.decode():
|
456 |
-
audio_data = frame.to_ndarray().tobytes()
|
457 |
-
self.multicast_packet(audio_data)
|
458 |
-
|
459 |
-
if save_file:
|
460 |
-
output_container.mux(frame)
|
461 |
-
except Exception as e:
|
462 |
-
print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
|
463 |
-
finally:
|
464 |
-
# Wait for server to send any leftover transcription.
|
465 |
-
time.sleep(5)
|
466 |
-
self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
|
467 |
-
if output_container:
|
468 |
-
output_container.close()
|
469 |
-
container.close()
|
470 |
-
|
471 |
-
def save_chunk(self, n_audio_file):
|
472 |
-
"""
|
473 |
-
Saves the current audio frames to a WAV file in a separate thread.
|
474 |
-
|
475 |
-
Args:
|
476 |
-
n_audio_file (int): The index of the audio file which determines the filename.
|
477 |
-
This helps in maintaining the order and uniqueness of each chunk.
|
478 |
-
"""
|
479 |
-
t = threading.Thread(
|
480 |
-
target=self.write_audio_frames_to_file,
|
481 |
-
args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
|
482 |
-
)
|
483 |
-
t.start()
|
484 |
-
|
485 |
-
def finalize_recording(self, n_audio_file):
|
486 |
-
"""
|
487 |
-
Finalizes the recording process by saving any remaining audio frames,
|
488 |
-
closing the audio stream, and terminating the process.
|
489 |
-
|
490 |
-
Args:
|
491 |
-
n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
|
492 |
-
This index is incremented before use if the last chunk is saved.
|
493 |
-
"""
|
494 |
-
if self.save_output_recording and len(self.frames):
|
495 |
-
self.write_audio_frames_to_file(
|
496 |
-
self.frames[:], f"chunks/{n_audio_file}.wav"
|
497 |
-
)
|
498 |
-
n_audio_file += 1
|
499 |
-
self.stream.stop_stream()
|
500 |
-
self.stream.close()
|
501 |
-
self.p.terminate()
|
502 |
-
self.close_all_clients()
|
503 |
-
if self.save_output_recording:
|
504 |
-
self.write_output_recording(n_audio_file)
|
505 |
-
|
506 |
-
def record(self):
|
507 |
-
"""
|
508 |
-
Record audio data from the input stream and save it to a WAV file.
|
509 |
-
|
510 |
-
Continuously records audio data from the input stream, sends it to the server via a WebSocket
|
511 |
-
connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
|
512 |
-
the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
|
513 |
-
|
514 |
-
Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
|
515 |
-
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
|
516 |
-
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
|
517 |
-
the method combines all the saved audio chunks into the specified `out_file`.
|
518 |
-
"""
|
519 |
-
n_audio_file = 0
|
520 |
-
if self.save_output_recording:
|
521 |
-
if os.path.exists("chunks"):
|
522 |
-
shutil.rmtree("chunks")
|
523 |
-
os.makedirs("chunks")
|
524 |
-
try:
|
525 |
-
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
|
526 |
-
if not any(client.recording for client in self.clients):
|
527 |
-
break
|
528 |
-
data = self.stream.read(self.chunk, exception_on_overflow=False)
|
529 |
-
self.frames += data
|
530 |
-
|
531 |
-
audio_array = self.bytes_to_float_array(data)
|
532 |
-
|
533 |
-
self.multicast_packet(audio_array.tobytes())
|
534 |
-
|
535 |
-
# save frames if more than a minute
|
536 |
-
if len(self.frames) > 60 * self.rate:
|
537 |
-
if self.save_output_recording:
|
538 |
-
self.save_chunk(n_audio_file)
|
539 |
-
n_audio_file += 1
|
540 |
-
self.frames = b""
|
541 |
-
|
542 |
-
except KeyboardInterrupt:
|
543 |
-
self.finalize_recording(n_audio_file)
|
544 |
-
|
545 |
-
def write_audio_frames_to_file(self, frames, file_name):
|
546 |
-
"""
|
547 |
-
Write audio frames to a WAV file.
|
548 |
-
|
549 |
-
The WAV file is created or overwritten with the specified name. The audio frames should be
|
550 |
-
in the correct format and match the specified channel, sample width, and sample rate.
|
551 |
-
|
552 |
-
Args:
|
553 |
-
frames (bytes): The audio frames to be written to the file.
|
554 |
-
file_name (str): The name of the WAV file to which the frames will be written.
|
555 |
-
|
556 |
-
"""
|
557 |
-
with wave.open(file_name, "wb") as wavfile:
|
558 |
-
wavfile: wave.Wave_write
|
559 |
-
wavfile.setnchannels(self.channels)
|
560 |
-
wavfile.setsampwidth(2)
|
561 |
-
wavfile.setframerate(self.rate)
|
562 |
-
wavfile.writeframes(frames)
|
563 |
-
|
564 |
-
def write_output_recording(self, n_audio_file):
|
565 |
-
"""
|
566 |
-
Combine and save recorded audio chunks into a single WAV file.
|
567 |
-
|
568 |
-
The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
|
569 |
-
file, appends its audio data to the final recording, and then deletes the chunk file. After combining
|
570 |
-
and saving, the final recording is stored in the specified `out_file`.
|
571 |
-
|
572 |
-
|
573 |
-
Args:
|
574 |
-
n_audio_file (int): The number of audio chunk files to combine.
|
575 |
-
out_file (str): The name of the output WAV file to save the final recording.
|
576 |
-
|
577 |
-
"""
|
578 |
-
input_files = [
|
579 |
-
f"chunks/{i}.wav"
|
580 |
-
for i in range(n_audio_file)
|
581 |
-
if os.path.exists(f"chunks/{i}.wav")
|
582 |
-
]
|
583 |
-
with wave.open(self.output_recording_filename, "wb") as wavfile:
|
584 |
-
wavfile: wave.Wave_write
|
585 |
-
wavfile.setnchannels(self.channels)
|
586 |
-
wavfile.setsampwidth(2)
|
587 |
-
wavfile.setframerate(self.rate)
|
588 |
-
for in_file in input_files:
|
589 |
-
with wave.open(in_file, "rb") as wav_in:
|
590 |
-
while True:
|
591 |
-
data = wav_in.readframes(self.chunk)
|
592 |
-
if data == b"":
|
593 |
-
break
|
594 |
-
wavfile.writeframes(data)
|
595 |
-
# remove this file
|
596 |
-
os.remove(in_file)
|
597 |
-
wavfile.close()
|
598 |
-
# clean up temporary directory to store chunks
|
599 |
-
if os.path.exists("chunks"):
|
600 |
-
shutil.rmtree("chunks")
|
601 |
-
|
602 |
-
@staticmethod
|
603 |
-
def bytes_to_float_array(audio_bytes):
|
604 |
-
"""
|
605 |
-
Convert audio data from bytes to a NumPy float array.
|
606 |
-
|
607 |
-
It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
|
608 |
-
have values between -1 and 1.
|
609 |
-
|
610 |
-
Args:
|
611 |
-
audio_bytes (bytes): Audio data in bytes.
|
612 |
-
|
613 |
-
Returns:
|
614 |
-
np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
|
615 |
-
"""
|
616 |
-
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
|
617 |
-
return raw_data.astype(np.float32) / 32768.0
|
618 |
-
|
619 |
-
|
620 |
-
class TranscriptionClient(TranscriptionTeeClient):
|
621 |
-
"""
|
622 |
-
Client for handling audio transcription tasks via a single WebSocket connection.
|
623 |
-
|
624 |
-
Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
|
625 |
-
to send audio data for transcription to a server and receive transcribed text segments.
|
626 |
-
|
627 |
-
Args:
|
628 |
-
host (str): The hostname or IP address of the server.
|
629 |
-
port (int): The port number to connect to on the server.
|
630 |
-
lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
|
631 |
-
save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
|
632 |
-
output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
|
633 |
-
output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
|
634 |
-
log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
|
635 |
-
max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
|
636 |
-
max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
|
637 |
-
mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
|
638 |
-
|
639 |
-
Attributes:
|
640 |
-
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
|
641 |
-
|
642 |
-
Example:
|
643 |
-
To create a TranscriptionClient and start transcription on microphone audio:
|
644 |
-
```python
|
645 |
-
transcription_client = TranscriptionClient(host="localhost", port=9090)
|
646 |
-
transcription_client()
|
647 |
-
```
|
648 |
-
"""
|
649 |
-
|
650 |
-
def __init__(
|
651 |
-
self,
|
652 |
-
host,
|
653 |
-
port,
|
654 |
-
lang=None,
|
655 |
-
save_output_recording=False,
|
656 |
-
output_recording_filename="./output_recording.wav",
|
657 |
-
log_transcription=True,
|
658 |
-
max_clients=4,
|
659 |
-
max_connection_time=600,
|
660 |
-
mute_audio_playback=False,
|
661 |
-
dst_lang='en',
|
662 |
-
):
|
663 |
-
self.client = Client(
|
664 |
-
host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
|
665 |
-
max_connection_time=max_connection_time, dst_lang=dst_lang
|
666 |
-
)
|
667 |
-
|
668 |
-
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
669 |
-
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
|
670 |
-
|
671 |
-
TranscriptionTeeClient.__init__(
|
672 |
-
self,
|
673 |
-
[self.client],
|
674 |
-
save_output_recording=save_output_recording,
|
675 |
-
output_recording_filename=output_recording_filename,
|
676 |
-
mute_audio_playback=mute_audio_playback,
|
677 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/helpers/vadprocessor.py
CHANGED
@@ -36,7 +36,7 @@ class AdaptiveSilenceController:
|
|
36 |
speed_factor = 0.5
|
37 |
elif avg_speech < 600:
|
38 |
speed_factor = 0.8
|
39 |
-
|
40 |
# 3. silence 的变化趋势也考虑进去
|
41 |
adaptive = self.base * speed_factor + 0.3 * avg_silence
|
42 |
|
|
|
36 |
speed_factor = 0.5
|
37 |
elif avg_speech < 600:
|
38 |
speed_factor = 0.8
|
39 |
+
logging.warning(f"Avg speech :{avg_speech}, Avg silence: {avg_silence}")
|
40 |
# 3. silence 的变化趋势也考虑进去
|
41 |
adaptive = self.base * speed_factor + 0.3 * avg_silence
|
42 |
|
transcribe/pipelines/pipe_vad.py
CHANGED
@@ -3,10 +3,8 @@ from .base import MetaItem, BasePipe
|
|
3 |
from ..helpers.vadprocessor import FixedVADIterator, AdaptiveSilenceController
|
4 |
|
5 |
import numpy as np
|
6 |
-
from silero_vad import get_speech_timestamps
|
7 |
-
from typing import List
|
8 |
import logging
|
9 |
-
|
10 |
# import noisereduce as nr
|
11 |
|
12 |
|
@@ -60,27 +58,22 @@ class VadPipe(BasePipe):
|
|
60 |
|
61 |
def update_silence_ms(self):
|
62 |
min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
|
63 |
-
<<<<<<< HEAD
|
64 |
min_silence_samples = self.sample_rate * min_silence / 1000
|
65 |
-
self.vac.min_silence_samples
|
66 |
-
logging.warning(f"🫠 update_silence_ms :{
|
67 |
-
|
68 |
-
|
69 |
-
logging.warning(f"🫠 update_silence_ms :{min_silence} => current: {self.vac.min_silence_duration_ms} ")
|
70 |
-
self.vac.min_silence_duration_ms = min_silence
|
71 |
-
|
72 |
-
>>>>>>> efad27a (add log to debug silence ms)
|
73 |
def process(self, in_data: MetaItem) -> MetaItem:
|
74 |
if self._offset == 0:
|
75 |
self.vac.reset_states()
|
76 |
-
|
77 |
# silence_audio_100ms = np.zeros(int(0.1*self.sample_rate))
|
78 |
source_audio = np.frombuffer(in_data.source_audio, dtype=np.float32)
|
79 |
speech_data = self._process_speech_chunk(source_audio)
|
80 |
|
81 |
if speech_data: # 表示有音频的变化点出现
|
82 |
-
|
83 |
-
rel_start_frame, rel_end_frame = speech_data
|
84 |
if rel_start_frame is not None and rel_end_frame is None:
|
85 |
self._status = "START" # 语音开始
|
86 |
target_audio = source_audio[rel_start_frame:]
|
|
|
3 |
from ..helpers.vadprocessor import FixedVADIterator, AdaptiveSilenceController
|
4 |
|
5 |
import numpy as np
|
|
|
|
|
6 |
import logging
|
7 |
+
|
8 |
# import noisereduce as nr
|
9 |
|
10 |
|
|
|
58 |
|
59 |
def update_silence_ms(self):
|
60 |
min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
|
|
|
61 |
min_silence_samples = self.sample_rate * min_silence / 1000
|
62 |
+
old_silence_samples = self.vac.min_silence_samples
|
63 |
+
logging.warning(f"🫠 update_silence_ms :{old_silence_samples * 1000 / self.sample_rate :.2f}ms => current: {min_silence}ms ")
|
64 |
+
# self.vac.min_silence_samples = min_silence_samples
|
65 |
+
|
|
|
|
|
|
|
|
|
66 |
def process(self, in_data: MetaItem) -> MetaItem:
|
67 |
if self._offset == 0:
|
68 |
self.vac.reset_states()
|
69 |
+
|
70 |
# silence_audio_100ms = np.zeros(int(0.1*self.sample_rate))
|
71 |
source_audio = np.frombuffer(in_data.source_audio, dtype=np.float32)
|
72 |
speech_data = self._process_speech_chunk(source_audio)
|
73 |
|
74 |
if speech_data: # 表示有音频的变化点出现
|
75 |
+
|
76 |
+
rel_start_frame, rel_end_frame = speech_data
|
77 |
if rel_start_frame is not None and rel_end_frame is None:
|
78 |
self._status = "START" # 语音开始
|
79 |
target_audio = source_audio[rel_start_frame:]
|
transcribe/server.py
DELETED
@@ -1,382 +0,0 @@
|
|
1 |
-
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import threading
|
5 |
-
import time
|
6 |
-
import config
|
7 |
-
import librosa
|
8 |
-
import numpy as np
|
9 |
-
import soundfile
|
10 |
-
from pywhispercpp.model import Model
|
11 |
-
|
12 |
-
logging.basicConfig(level=logging.INFO)
|
13 |
-
|
14 |
-
class ServeClientBase(object):
|
15 |
-
RATE = 16000
|
16 |
-
SERVER_READY = "SERVER_READY"
|
17 |
-
DISCONNECT = "DISCONNECT"
|
18 |
-
|
19 |
-
def __init__(self, client_uid, websocket):
|
20 |
-
self.client_uid = client_uid
|
21 |
-
self.websocket = websocket
|
22 |
-
self.frames = b""
|
23 |
-
self.timestamp_offset = 0.0
|
24 |
-
self.frames_np = None
|
25 |
-
self.frames_offset = 0.0
|
26 |
-
self.text = []
|
27 |
-
self.current_out = ''
|
28 |
-
self.prev_out = ''
|
29 |
-
self.t_start = None
|
30 |
-
self.exit = False
|
31 |
-
self.same_output_count = 0
|
32 |
-
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
33 |
-
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
34 |
-
self.transcript = []
|
35 |
-
self.send_last_n_segments = 10
|
36 |
-
|
37 |
-
# text formatting
|
38 |
-
self.pick_previous_segments = 2
|
39 |
-
|
40 |
-
# threading
|
41 |
-
self.lock = threading.Lock()
|
42 |
-
|
43 |
-
def speech_to_text(self):
|
44 |
-
raise NotImplementedError
|
45 |
-
|
46 |
-
def transcribe_audio(self):
|
47 |
-
raise NotImplementedError
|
48 |
-
|
49 |
-
def handle_transcription_output(self):
|
50 |
-
raise NotImplementedError
|
51 |
-
|
52 |
-
def add_frames(self, frame_np):
|
53 |
-
"""
|
54 |
-
Add audio frames to the ongoing audio stream buffer.
|
55 |
-
|
56 |
-
This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
|
57 |
-
of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
|
58 |
-
to prevent excessive memory usage.
|
59 |
-
|
60 |
-
If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
|
61 |
-
of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
|
62 |
-
audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
frame_np (numpy.ndarray): The audio frame data as a NumPy array.
|
66 |
-
|
67 |
-
"""
|
68 |
-
self.lock.acquire()
|
69 |
-
if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
|
70 |
-
self.frames_offset += 30.0
|
71 |
-
self.frames_np = self.frames_np[int(30 * self.RATE):]
|
72 |
-
# check timestamp offset(should be >= self.frame_offset)
|
73 |
-
# this basically means that there is no speech as timestamp offset hasnt updated
|
74 |
-
# and is less than frame_offset
|
75 |
-
if self.timestamp_offset < self.frames_offset:
|
76 |
-
self.timestamp_offset = self.frames_offset
|
77 |
-
if self.frames_np is None:
|
78 |
-
self.frames_np = frame_np.copy()
|
79 |
-
else:
|
80 |
-
self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
|
81 |
-
self.lock.release()
|
82 |
-
|
83 |
-
def clip_audio_if_no_valid_segment(self):
|
84 |
-
"""
|
85 |
-
Update the timestamp offset based on audio buffer status.
|
86 |
-
Clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
87 |
-
no valid segment for the last 30 seconds from whisper
|
88 |
-
"""
|
89 |
-
with self.lock:
|
90 |
-
if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
|
91 |
-
duration = self.frames_np.shape[0] / self.RATE
|
92 |
-
self.timestamp_offset = self.frames_offset + duration - 5
|
93 |
-
|
94 |
-
def get_audio_chunk_for_processing(self):
|
95 |
-
"""
|
96 |
-
Retrieves the next chunk of audio data for processing based on the current offsets.
|
97 |
-
|
98 |
-
Calculates which part of the audio data should be processed next, based on
|
99 |
-
the difference between the current timestamp offset and the frame's offset, scaled by
|
100 |
-
the audio sample rate (RATE). It then returns this chunk of audio data along with its
|
101 |
-
duration in seconds.
|
102 |
-
|
103 |
-
Returns:
|
104 |
-
tuple: A tuple containing:
|
105 |
-
- input_bytes (np.ndarray): The next chunk of audio data to be processed.
|
106 |
-
- duration (float): The duration of the audio chunk in seconds.
|
107 |
-
"""
|
108 |
-
with self.lock:
|
109 |
-
samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
|
110 |
-
input_bytes = self.frames_np[int(samples_take):].copy()
|
111 |
-
duration = input_bytes.shape[0] / self.RATE
|
112 |
-
return input_bytes, duration
|
113 |
-
|
114 |
-
def prepare_segments(self, last_segment=None):
|
115 |
-
"""
|
116 |
-
Prepares the segments of transcribed text to be sent to the client.
|
117 |
-
|
118 |
-
This method compiles the recent segments of transcribed text, ensuring that only the
|
119 |
-
specified number of the most recent segments are included. It also appends the most
|
120 |
-
recent segment of text if provided (which is considered incomplete because of the possibility
|
121 |
-
of the last word being truncated in the audio chunk).
|
122 |
-
|
123 |
-
Args:
|
124 |
-
last_segment (str, optional): The most recent segment of transcribed text to be added
|
125 |
-
to the list of segments. Defaults to None.
|
126 |
-
|
127 |
-
Returns:
|
128 |
-
list: A list of transcribed text segments to be sent to the client.
|
129 |
-
"""
|
130 |
-
segments = []
|
131 |
-
if len(self.transcript) >= self.send_last_n_segments:
|
132 |
-
segments = self.transcript[-self.send_last_n_segments:].copy()
|
133 |
-
else:
|
134 |
-
segments = self.transcript.copy()
|
135 |
-
if last_segment is not None:
|
136 |
-
segments = segments + [last_segment]
|
137 |
-
logging.info(f"{segments}")
|
138 |
-
return segments
|
139 |
-
|
140 |
-
def get_audio_chunk_duration(self, input_bytes):
|
141 |
-
"""
|
142 |
-
Calculates the duration of the provided audio chunk.
|
143 |
-
|
144 |
-
Args:
|
145 |
-
input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
|
146 |
-
|
147 |
-
Returns:
|
148 |
-
float: The duration of the audio chunk in seconds.
|
149 |
-
"""
|
150 |
-
return input_bytes.shape[0] / self.RATE
|
151 |
-
|
152 |
-
def send_transcription_to_client(self, segments):
|
153 |
-
"""
|
154 |
-
Sends the specified transcription segments to the client over the websocket connection.
|
155 |
-
|
156 |
-
This method formats the transcription segments into a JSON object and attempts to send
|
157 |
-
this object to the client. If an error occurs during the send operation, it logs the error.
|
158 |
-
|
159 |
-
Returns:
|
160 |
-
segments (list): A list of transcription segments to be sent to the client.
|
161 |
-
"""
|
162 |
-
try:
|
163 |
-
self.websocket.send(
|
164 |
-
json.dumps({
|
165 |
-
"uid": self.client_uid,
|
166 |
-
"segments": segments,
|
167 |
-
})
|
168 |
-
)
|
169 |
-
except Exception as e:
|
170 |
-
logging.error(f"[ERROR]: Sending data to client: {e}")
|
171 |
-
|
172 |
-
def disconnect(self):
|
173 |
-
"""
|
174 |
-
Notify the client of disconnection and send a disconnect message.
|
175 |
-
|
176 |
-
This method sends a disconnect message to the client via the WebSocket connection to notify them
|
177 |
-
that the transcription service is disconnecting gracefully.
|
178 |
-
|
179 |
-
"""
|
180 |
-
self.websocket.send(json.dumps({
|
181 |
-
"uid": self.client_uid,
|
182 |
-
"message": self.DISCONNECT
|
183 |
-
}))
|
184 |
-
|
185 |
-
def cleanup(self):
|
186 |
-
"""
|
187 |
-
Perform cleanup tasks before exiting the transcription service.
|
188 |
-
|
189 |
-
This method performs necessary cleanup tasks, including stopping the transcription thread, marking
|
190 |
-
the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
|
191 |
-
associated with the transcription process.
|
192 |
-
|
193 |
-
"""
|
194 |
-
logging.info("Cleaning up.")
|
195 |
-
self.exit = True
|
196 |
-
|
197 |
-
|
198 |
-
class ServeClientWhisperCPP(ServeClientBase):
|
199 |
-
SINGLE_MODEL = None
|
200 |
-
SINGLE_MODEL_LOCK = threading.Lock()
|
201 |
-
|
202 |
-
def __init__(self, websocket, language=None, client_uid=None,
|
203 |
-
single_model=False):
|
204 |
-
"""
|
205 |
-
Initialize a ServeClient instance.
|
206 |
-
The Whisper model is initialized based on the client's language and device availability.
|
207 |
-
The transcription thread is started upon initialization. A "SERVER_READY" message is sent
|
208 |
-
to the client to indicate that the server is ready.
|
209 |
-
|
210 |
-
Args:
|
211 |
-
websocket (WebSocket): The WebSocket connection for the client.
|
212 |
-
language (str, optional): The language for transcription. Defaults to None.
|
213 |
-
client_uid (str, optional): A unique identifier for the client. Defaults to None.
|
214 |
-
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
|
215 |
-
|
216 |
-
"""
|
217 |
-
super().__init__(client_uid, websocket)
|
218 |
-
self.language = language
|
219 |
-
self.eos = False
|
220 |
-
|
221 |
-
if single_model:
|
222 |
-
if ServeClientWhisperCPP.SINGLE_MODEL is None:
|
223 |
-
self.create_model()
|
224 |
-
ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
|
225 |
-
else:
|
226 |
-
self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
|
227 |
-
else:
|
228 |
-
self.create_model()
|
229 |
-
|
230 |
-
# threading
|
231 |
-
logging.info('Create a thread to process audio.')
|
232 |
-
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
233 |
-
self.trans_thread.start()
|
234 |
-
|
235 |
-
self.websocket.send(json.dumps({
|
236 |
-
"uid": self.client_uid,
|
237 |
-
"message": self.SERVER_READY,
|
238 |
-
"backend": "pywhispercpp"
|
239 |
-
}))
|
240 |
-
|
241 |
-
def create_model(self, warmup=True):
|
242 |
-
"""
|
243 |
-
Instantiates a new model, sets it as the transcriber and does warmup if desired.
|
244 |
-
"""
|
245 |
-
|
246 |
-
self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR)
|
247 |
-
if warmup:
|
248 |
-
self.warmup()
|
249 |
-
|
250 |
-
def warmup(self, warmup_steps=1):
|
251 |
-
"""
|
252 |
-
Warmup TensorRT since first few inferences are slow.
|
253 |
-
|
254 |
-
Args:
|
255 |
-
warmup_steps (int): Number of steps to warm up the model for.
|
256 |
-
"""
|
257 |
-
logging.info("[INFO:] Warming up whisper.cpp engine..")
|
258 |
-
mel, _, = soundfile.read("assets/jfk.flac")
|
259 |
-
for i in range(warmup_steps):
|
260 |
-
self.transcriber.transcribe(mel, print_progress=False)
|
261 |
-
|
262 |
-
def set_eos(self, eos):
|
263 |
-
"""
|
264 |
-
Sets the End of Speech (EOS) flag.
|
265 |
-
|
266 |
-
Args:
|
267 |
-
eos (bool): The value to set for the EOS flag.
|
268 |
-
"""
|
269 |
-
self.lock.acquire()
|
270 |
-
self.eos = eos
|
271 |
-
self.lock.release()
|
272 |
-
|
273 |
-
def handle_transcription_output(self, last_segment, duration):
|
274 |
-
"""
|
275 |
-
Handle the transcription output, updating the transcript and sending data to the client.
|
276 |
-
|
277 |
-
Args:
|
278 |
-
last_segment (str): The last segment from the whisper output which is considered to be incomplete because
|
279 |
-
of the possibility of word being truncated.
|
280 |
-
duration (float): Duration of the transcribed audio chunk.
|
281 |
-
"""
|
282 |
-
segments = self.prepare_segments({"text": last_segment})
|
283 |
-
self.send_transcription_to_client(segments)
|
284 |
-
if self.eos:
|
285 |
-
self.update_timestamp_offset(last_segment, duration)
|
286 |
-
|
287 |
-
def transcribe_audio(self, input_bytes):
|
288 |
-
"""
|
289 |
-
Transcribe the audio chunk and send the results to the client.
|
290 |
-
|
291 |
-
Args:
|
292 |
-
input_bytes (np.array): The audio chunk to transcribe.
|
293 |
-
"""
|
294 |
-
if ServeClientWhisperCPP.SINGLE_MODEL:
|
295 |
-
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
|
296 |
-
logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
|
297 |
-
mel = input_bytes
|
298 |
-
duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
|
299 |
-
|
300 |
-
if self.language == "zh":
|
301 |
-
prompt = '以下是简体中文普通话的句子。'
|
302 |
-
else:
|
303 |
-
prompt = 'The following is an English sentence.'
|
304 |
-
|
305 |
-
segments = self.transcriber.transcribe(
|
306 |
-
mel,
|
307 |
-
language=self.language,
|
308 |
-
initial_prompt=prompt,
|
309 |
-
token_timestamps=True,
|
310 |
-
# max_len=max_len,
|
311 |
-
print_progress=False
|
312 |
-
)
|
313 |
-
text = []
|
314 |
-
for segment in segments:
|
315 |
-
content = segment.text
|
316 |
-
text.append(content)
|
317 |
-
last_segment = ' '.join(text)
|
318 |
-
|
319 |
-
logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
|
320 |
-
|
321 |
-
if ServeClientWhisperCPP.SINGLE_MODEL:
|
322 |
-
ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
|
323 |
-
if last_segment:
|
324 |
-
self.handle_transcription_output(last_segment, duration)
|
325 |
-
|
326 |
-
def update_timestamp_offset(self, last_segment, duration):
|
327 |
-
"""
|
328 |
-
Update timestamp offset and transcript.
|
329 |
-
|
330 |
-
Args:
|
331 |
-
last_segment (str): Last transcribed audio from the whisper model.
|
332 |
-
duration (float): Duration of the last audio chunk.
|
333 |
-
"""
|
334 |
-
if not len(self.transcript):
|
335 |
-
self.transcript.append({"text": last_segment + " "})
|
336 |
-
elif self.transcript[-1]["text"].strip() != last_segment:
|
337 |
-
self.transcript.append({"text": last_segment + " "})
|
338 |
-
|
339 |
-
logging.info(f'Transcript list context: {self.transcript}')
|
340 |
-
|
341 |
-
with self.lock:
|
342 |
-
self.timestamp_offset += duration
|
343 |
-
|
344 |
-
def speech_to_text(self):
|
345 |
-
"""
|
346 |
-
Process an audio stream in an infinite loop, continuously transcribing the speech.
|
347 |
-
|
348 |
-
This method continuously receives audio frames, performs real-time transcription, and sends
|
349 |
-
transcribed segments to the client via a WebSocket connection.
|
350 |
-
|
351 |
-
If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
|
352 |
-
It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
|
353 |
-
are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
|
354 |
-
(no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
|
355 |
-
there is no speech for a specified duration to indicate a pause.
|
356 |
-
|
357 |
-
Raises:
|
358 |
-
Exception: If there is an issue with audio processing or WebSocket communication.
|
359 |
-
|
360 |
-
"""
|
361 |
-
while True:
|
362 |
-
if self.exit:
|
363 |
-
logging.info("Exiting speech to text thread")
|
364 |
-
break
|
365 |
-
|
366 |
-
if self.frames_np is None:
|
367 |
-
time.sleep(0.02) # wait for any audio to arrive
|
368 |
-
continue
|
369 |
-
|
370 |
-
self.clip_audio_if_no_valid_segment()
|
371 |
-
|
372 |
-
input_bytes, duration = self.get_audio_chunk_for_processing()
|
373 |
-
if duration < 1:
|
374 |
-
continue
|
375 |
-
|
376 |
-
try:
|
377 |
-
input_sample = input_bytes.copy()
|
378 |
-
logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
|
379 |
-
self.transcribe_audio(input_sample)
|
380 |
-
|
381 |
-
except Exception as e:
|
382 |
-
logging.error(f"[ERROR]: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/strategy.py
DELETED
@@ -1,405 +0,0 @@
|
|
1 |
-
|
2 |
-
import collections
|
3 |
-
import logging
|
4 |
-
from difflib import SequenceMatcher
|
5 |
-
from itertools import chain
|
6 |
-
from dataclasses import dataclass, field
|
7 |
-
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
|
8 |
-
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
|
9 |
-
from enum import Enum
|
10 |
-
import wordninja
|
11 |
-
import config
|
12 |
-
import re
|
13 |
-
logger = logging.getLogger("TranscriptionStrategy")
|
14 |
-
|
15 |
-
|
16 |
-
class SplitMode(Enum):
|
17 |
-
PUNCTUATION = "punctuation"
|
18 |
-
PAUSE = "pause"
|
19 |
-
END = "end"
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
@dataclass
|
24 |
-
class TranscriptResult:
|
25 |
-
seg_id: int = 0
|
26 |
-
cut_index: int = 0
|
27 |
-
is_end_sentence: bool = False
|
28 |
-
context: str = ""
|
29 |
-
|
30 |
-
def partial(self):
|
31 |
-
return not self.is_end_sentence
|
32 |
-
|
33 |
-
@dataclass
|
34 |
-
class TranscriptToken:
|
35 |
-
"""表示一个转录片段,包含文本和时间信息"""
|
36 |
-
text: str # 转录的文本内容
|
37 |
-
t0: int # 开始时间(百分之一秒)
|
38 |
-
t1: int # 结束时间(百分之一秒)
|
39 |
-
|
40 |
-
def is_punctuation(self):
|
41 |
-
"""检查文本是否包含标点符号"""
|
42 |
-
return REGEX_MARKERS.search(self.text.strip()) is not None
|
43 |
-
|
44 |
-
def is_end(self):
|
45 |
-
"""检查文本是否为句子结束标记"""
|
46 |
-
return SENTENCE_END_PATTERN.search(self.text.strip()) is not None
|
47 |
-
|
48 |
-
def is_pause(self):
|
49 |
-
"""检查文本是否为暂停标记"""
|
50 |
-
return PAUSEE_END_PATTERN.search(self.text.strip()) is not None
|
51 |
-
|
52 |
-
def buffer_index(self) -> int:
|
53 |
-
return int(self.t1 / 100 * SAMPLE_RATE)
|
54 |
-
|
55 |
-
@dataclass
|
56 |
-
class TranscriptChunk:
|
57 |
-
"""表示一组转录片段,支持分割和比较操作"""
|
58 |
-
separator: str = "" # 用于连接片段的分隔符
|
59 |
-
items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
|
60 |
-
|
61 |
-
@staticmethod
|
62 |
-
def _calculate_similarity(text1: str, text2: str) -> float:
|
63 |
-
"""计算两段文本的相似度"""
|
64 |
-
return SequenceMatcher(None, text1, text2).ratio()
|
65 |
-
|
66 |
-
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
|
67 |
-
"""根据文本中的标点符号分割片段列表"""
|
68 |
-
if mode == SplitMode.PUNCTUATION:
|
69 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
|
70 |
-
elif mode == SplitMode.PAUSE:
|
71 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
|
72 |
-
elif mode == SplitMode.END:
|
73 |
-
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
|
74 |
-
else:
|
75 |
-
raise ValueError(f"Unsupported mode: {mode}")
|
76 |
-
|
77 |
-
# 每个切分点向后移一个索引,表示“分隔符归前段”
|
78 |
-
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
|
79 |
-
chunks = [
|
80 |
-
TranscriptChunk(items=self.items[start:end], separator=self.separator)
|
81 |
-
for start, end in zip(cut_points, cut_points[1:])
|
82 |
-
]
|
83 |
-
return [
|
84 |
-
ck
|
85 |
-
for ck in chunks
|
86 |
-
if not ck.only_punctuation()
|
87 |
-
]
|
88 |
-
|
89 |
-
|
90 |
-
def get_split_first_rest(self, mode: SplitMode):
|
91 |
-
chunks = self.split_by(mode)
|
92 |
-
fisrt_chunk = chunks[0] if chunks else self
|
93 |
-
rest_chunks = chunks[1:] if chunks else None
|
94 |
-
return fisrt_chunk, rest_chunks
|
95 |
-
|
96 |
-
def puncation_numbers(self) -> int:
|
97 |
-
"""计算片段中标点符号的数量"""
|
98 |
-
return sum(1 for seg in self.items if seg.is_punctuation())
|
99 |
-
|
100 |
-
def length(self) -> int:
|
101 |
-
"""返回片段列表的长度"""
|
102 |
-
return len(self.items)
|
103 |
-
|
104 |
-
def join(self) -> str:
|
105 |
-
"""将片段连接为一个字符串"""
|
106 |
-
return self.separator.join(seg.text for seg in self.items)
|
107 |
-
|
108 |
-
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
|
109 |
-
"""比较当前片段与另一个片段的相似度"""
|
110 |
-
if not chunk:
|
111 |
-
return 0
|
112 |
-
|
113 |
-
score = self._calculate_similarity(self.join(), chunk.join())
|
114 |
-
# logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
|
115 |
-
return score
|
116 |
-
|
117 |
-
def only_punctuation(self)->bool:
|
118 |
-
return all(seg.is_punctuation() for seg in self.items)
|
119 |
-
|
120 |
-
def has_punctuation(self) -> bool:
|
121 |
-
return any(seg.is_punctuation() for seg in self.items)
|
122 |
-
|
123 |
-
def get_buffer_index(self) -> int:
|
124 |
-
return self.items[-1].buffer_index()
|
125 |
-
|
126 |
-
def is_end_sentence(self) ->bool:
|
127 |
-
return self.items[-1].is_end()
|
128 |
-
|
129 |
-
|
130 |
-
class TranscriptHistory:
|
131 |
-
"""管理转录片段的历史记录"""
|
132 |
-
|
133 |
-
def __init__(self) -> None:
|
134 |
-
self.history = collections.deque(maxlen=2) # 存储最近的两个片段
|
135 |
-
|
136 |
-
def add(self, chunk: TranscriptChunk):
|
137 |
-
"""添加新的片段到历史记录"""
|
138 |
-
self.history.appendleft(chunk)
|
139 |
-
|
140 |
-
def previous_chunk(self) -> Optional[TranscriptChunk]:
|
141 |
-
"""获取上一个片段(如果存在)"""
|
142 |
-
return self.history[1] if len(self.history) == 2 else None
|
143 |
-
|
144 |
-
def lastest_chunk(self):
|
145 |
-
"""获取最后一个片段"""
|
146 |
-
return self.history[-1]
|
147 |
-
|
148 |
-
def clear(self):
|
149 |
-
self.history.clear()
|
150 |
-
|
151 |
-
class TranscriptBuffer:
|
152 |
-
"""
|
153 |
-
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
|
154 |
-
|
155 |
-
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
|
156 |
-
|
157 |
-
管理 pending -> line -> paragraph 的缓冲逻辑
|
158 |
-
|
159 |
-
"""
|
160 |
-
|
161 |
-
def __init__(self, source_lang:str, separator:str):
|
162 |
-
self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
|
163 |
-
self._sentences: List[str] = collections.deque() # 当前段落中的短句
|
164 |
-
self._buffer: str = "" # 当前缓冲中的文本
|
165 |
-
self._current_seg_id: int = 0
|
166 |
-
self.source_language = source_lang
|
167 |
-
self._separator = separator
|
168 |
-
|
169 |
-
def get_seg_id(self) -> int:
|
170 |
-
return self._current_seg_id
|
171 |
-
|
172 |
-
@property
|
173 |
-
def current_sentences_length(self) -> int:
|
174 |
-
count = 0
|
175 |
-
for item in self._sentences:
|
176 |
-
if self._separator:
|
177 |
-
count += len(item.split(self._separator))
|
178 |
-
else:
|
179 |
-
count += len(item)
|
180 |
-
return count
|
181 |
-
|
182 |
-
def update_pending_text(self, text: str) -> None:
|
183 |
-
"""更新临时缓冲字符串"""
|
184 |
-
self._buffer = text
|
185 |
-
|
186 |
-
def commit_line(self,) -> None:
|
187 |
-
"""将缓冲字符串提交为短句"""
|
188 |
-
if self._buffer:
|
189 |
-
self._sentences.append(self._buffer)
|
190 |
-
self._buffer = ""
|
191 |
-
|
192 |
-
def commit_paragraph(self) -> None:
|
193 |
-
"""
|
194 |
-
提交当前短句为完整段落(如句子结束)
|
195 |
-
|
196 |
-
Args:
|
197 |
-
end_of_sentence: 是否为句子结尾(如检测到句号)
|
198 |
-
"""
|
199 |
-
|
200 |
-
count = 0
|
201 |
-
current_sentences = []
|
202 |
-
while len(self._sentences): # and count < 20:
|
203 |
-
item = self._sentences.popleft()
|
204 |
-
current_sentences.append(item)
|
205 |
-
if self._separator:
|
206 |
-
count += len(item.split(self._separator))
|
207 |
-
else:
|
208 |
-
count += len(item)
|
209 |
-
if current_sentences:
|
210 |
-
self._segments.append("".join(current_sentences))
|
211 |
-
logger.debug(f"=== count to paragraph ===")
|
212 |
-
logger.debug(f"push: {current_sentences}")
|
213 |
-
logger.debug(f"rest: {self._sentences}")
|
214 |
-
# if self._sentences:
|
215 |
-
# self._segments.append("".join(self._sentences))
|
216 |
-
# self._sentences.clear()
|
217 |
-
|
218 |
-
def rebuild(self, text):
|
219 |
-
output = self.split_and_join(
|
220 |
-
text.replace(
|
221 |
-
self._separator, ""))
|
222 |
-
|
223 |
-
logger.debug("==== rebuild string ====")
|
224 |
-
logger.debug(text)
|
225 |
-
logger.debug(output)
|
226 |
-
|
227 |
-
return output
|
228 |
-
|
229 |
-
@staticmethod
|
230 |
-
def split_and_join(text):
|
231 |
-
tokens = []
|
232 |
-
word_buf = ''
|
233 |
-
|
234 |
-
for char in text:
|
235 |
-
if char in ALL_MARKERS:
|
236 |
-
if word_buf:
|
237 |
-
tokens.extend(wordninja.split(word_buf))
|
238 |
-
word_buf = ''
|
239 |
-
tokens.append(char)
|
240 |
-
else:
|
241 |
-
word_buf += char
|
242 |
-
if word_buf:
|
243 |
-
tokens.extend(wordninja.split(word_buf))
|
244 |
-
|
245 |
-
output = ''
|
246 |
-
for i, token in enumerate(tokens):
|
247 |
-
if i == 0:
|
248 |
-
output += token
|
249 |
-
elif token in ALL_MARKERS:
|
250 |
-
output += (token + " ")
|
251 |
-
else:
|
252 |
-
output += ' ' + token
|
253 |
-
return output
|
254 |
-
|
255 |
-
|
256 |
-
def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
|
257 |
-
if self.source_language == "en":
|
258 |
-
stable_strings = [self.rebuild(i) for i in stable_strings]
|
259 |
-
remaining_strings =[self.rebuild(i) for i in remaining_strings]
|
260 |
-
remaining_string = "".join(remaining_strings)
|
261 |
-
|
262 |
-
logger.debug(f"{self.__dict__}")
|
263 |
-
if is_end_sentence:
|
264 |
-
for stable_str in stable_strings:
|
265 |
-
self.update_pending_text(stable_str)
|
266 |
-
self.commit_line()
|
267 |
-
|
268 |
-
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
|
269 |
-
# current_text_len = len(self.current_not_commit_text.split(self._separator))
|
270 |
-
self.update_pending_text(remaining_string)
|
271 |
-
if current_text_len >= config.TEXT_THREHOLD:
|
272 |
-
self.commit_paragraph()
|
273 |
-
self._current_seg_id += 1
|
274 |
-
return True
|
275 |
-
else:
|
276 |
-
for stable_str in stable_strings:
|
277 |
-
self.update_pending_text(stable_str)
|
278 |
-
self.commit_line()
|
279 |
-
self.update_pending_text(remaining_string)
|
280 |
-
return False
|
281 |
-
|
282 |
-
|
283 |
-
@property
|
284 |
-
def un_commit_paragraph(self) -> str:
|
285 |
-
"""当前短句组合"""
|
286 |
-
return "".join([i for i in self._sentences])
|
287 |
-
|
288 |
-
@property
|
289 |
-
def pending_text(self) -> str:
|
290 |
-
"""当前缓冲内容"""
|
291 |
-
return self._buffer
|
292 |
-
|
293 |
-
@property
|
294 |
-
def latest_paragraph(self) -> str:
|
295 |
-
"""最新确认的段落"""
|
296 |
-
return self._segments[-1] if self._segments else ""
|
297 |
-
|
298 |
-
@property
|
299 |
-
def current_not_commit_text(self) -> str:
|
300 |
-
return self.un_commit_paragraph + self.pending_text
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
class TranscriptStabilityAnalyzer:
|
305 |
-
def __init__(self, source_lang, separator) -> None:
|
306 |
-
self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
|
307 |
-
self._transcript_history = TranscriptHistory()
|
308 |
-
self._separator = separator
|
309 |
-
logger.debug(f"Current separator: {self._separator}")
|
310 |
-
|
311 |
-
def merge_chunks(self, chunks: List[TranscriptChunk])->str:
|
312 |
-
if not chunks:
|
313 |
-
return [""]
|
314 |
-
output = list(r.join() for r in chunks if r)
|
315 |
-
return output
|
316 |
-
|
317 |
-
|
318 |
-
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
|
319 |
-
current = TranscriptChunk(items=current, separator=self._separator)
|
320 |
-
self._transcript_history.add(current)
|
321 |
-
|
322 |
-
prev = self._transcript_history.previous_chunk()
|
323 |
-
self._transcript_buffer.update_pending_text(current.join())
|
324 |
-
if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
|
325 |
-
yield TranscriptResult(
|
326 |
-
context=self._transcript_buffer.current_not_commit_text,
|
327 |
-
seg_id=self._transcript_buffer.get_seg_id()
|
328 |
-
)
|
329 |
-
return
|
330 |
-
|
331 |
-
# yield from self._handle_short_buffer(current, prev)
|
332 |
-
if buffer_duration <= 4:
|
333 |
-
yield from self._handle_short_buffer(current, prev)
|
334 |
-
else:
|
335 |
-
yield from self._handle_long_buffer(current)
|
336 |
-
|
337 |
-
|
338 |
-
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
|
339 |
-
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
|
340 |
-
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
|
341 |
-
|
342 |
-
# logger.debug("==== Current cut item ====")
|
343 |
-
# logger.debug(f"{curr.join()} ")
|
344 |
-
# logger.debug(f"{prev.join()}")
|
345 |
-
# logger.debug("==========================")
|
346 |
-
|
347 |
-
if curr_first and prev_first:
|
348 |
-
|
349 |
-
core = curr_first.compare(prev_first)
|
350 |
-
has_punctuation = curr_first.has_punctuation()
|
351 |
-
if core >= 0.8 and has_punctuation:
|
352 |
-
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
|
353 |
-
return
|
354 |
-
|
355 |
-
yield TranscriptResult(
|
356 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
357 |
-
context=self._transcript_buffer.current_not_commit_text
|
358 |
-
)
|
359 |
-
|
360 |
-
|
361 |
-
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
|
362 |
-
chunks = curr.split_by(SplitMode.PUNCTUATION)
|
363 |
-
if len(chunks) > 1:
|
364 |
-
stable, remaining = chunks[:-1], chunks[-1:]
|
365 |
-
# stable_str = self.merge_chunks(stable)
|
366 |
-
# remaining_str = self.merge_chunks(remaining)
|
367 |
-
yield from self._yield_commit_results(
|
368 |
-
stable, remaining, is_end_sentence=True # 暂时硬编码为True
|
369 |
-
)
|
370 |
-
else:
|
371 |
-
yield TranscriptResult(
|
372 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
373 |
-
context=self._transcript_buffer.current_not_commit_text
|
374 |
-
)
|
375 |
-
|
376 |
-
|
377 |
-
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
|
378 |
-
stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
|
379 |
-
remaining_str_list = self.merge_chunks(remaining_chunks)
|
380 |
-
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
|
381 |
-
|
382 |
-
prev_seg_id = self._transcript_buffer.get_seg_id()
|
383 |
-
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
|
384 |
-
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
|
385 |
-
|
386 |
-
if commit_paragraph:
|
387 |
-
# 表示生成了一个新段落 换行
|
388 |
-
yield TranscriptResult(
|
389 |
-
seg_id=prev_seg_id,
|
390 |
-
cut_index=frame_cut_index,
|
391 |
-
context=self._transcript_buffer.latest_paragraph,
|
392 |
-
is_end_sentence=True
|
393 |
-
)
|
394 |
-
if (context := self._transcript_buffer.current_not_commit_text.strip()):
|
395 |
-
yield TranscriptResult(
|
396 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
397 |
-
context=context,
|
398 |
-
)
|
399 |
-
else:
|
400 |
-
yield TranscriptResult(
|
401 |
-
seg_id=self._transcript_buffer.get_seg_id(),
|
402 |
-
cut_index=frame_cut_index,
|
403 |
-
context=self._transcript_buffer.current_not_commit_text,
|
404 |
-
)
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/transcription.py
DELETED
@@ -1,334 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import time
|
3 |
-
import functools
|
4 |
-
import json
|
5 |
-
import logging
|
6 |
-
import time
|
7 |
-
from enum import Enum
|
8 |
-
from typing import List, Optional
|
9 |
-
import numpy as np
|
10 |
-
from .server import ServeClientBase
|
11 |
-
from .whisper_llm_serve import PyWhiperCppServe
|
12 |
-
from .vad import VoiceActivityDetector
|
13 |
-
from urllib.parse import urlparse, parse_qsl
|
14 |
-
from websockets.exceptions import ConnectionClosed
|
15 |
-
from websockets.sync.server import serve
|
16 |
-
from uuid import uuid1
|
17 |
-
|
18 |
-
|
19 |
-
logging.basicConfig(level=logging.INFO)
|
20 |
-
|
21 |
-
|
22 |
-
class ClientManager:
|
23 |
-
def __init__(self, max_clients=4, max_connection_time=600):
|
24 |
-
"""
|
25 |
-
Initializes the ClientManager with specified limits on client connections and connection durations.
|
26 |
-
|
27 |
-
Args:
|
28 |
-
max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
|
29 |
-
max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
|
30 |
-
to 600 seconds (10 minutes).
|
31 |
-
"""
|
32 |
-
self.clients = {}
|
33 |
-
self.start_times = {}
|
34 |
-
self.max_clients = max_clients
|
35 |
-
self.max_connection_time = max_connection_time
|
36 |
-
|
37 |
-
def add_client(self, websocket, client):
|
38 |
-
"""
|
39 |
-
Adds a client and their connection start time to the tracking dictionaries.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
websocket: The websocket associated with the client to add.
|
43 |
-
client: The client object to be added and tracked.
|
44 |
-
"""
|
45 |
-
self.clients[websocket] = client
|
46 |
-
self.start_times[websocket] = time.time()
|
47 |
-
|
48 |
-
def get_client(self, websocket):
|
49 |
-
"""
|
50 |
-
Retrieves a client associated with the given websocket.
|
51 |
-
|
52 |
-
Args:
|
53 |
-
websocket: The websocket associated with the client to retrieve.
|
54 |
-
|
55 |
-
Returns:
|
56 |
-
The client object if found, False otherwise.
|
57 |
-
"""
|
58 |
-
if websocket in self.clients:
|
59 |
-
return self.clients[websocket]
|
60 |
-
return False
|
61 |
-
|
62 |
-
def remove_client(self, websocket):
|
63 |
-
"""
|
64 |
-
Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
|
65 |
-
client if necessary.
|
66 |
-
|
67 |
-
Args:
|
68 |
-
websocket: The websocket associated with the client to be removed.
|
69 |
-
"""
|
70 |
-
client = self.clients.pop(websocket, None)
|
71 |
-
if client:
|
72 |
-
client.cleanup()
|
73 |
-
self.start_times.pop(websocket, None)
|
74 |
-
|
75 |
-
def get_wait_time(self):
|
76 |
-
"""
|
77 |
-
Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
|
78 |
-
|
79 |
-
Returns:
|
80 |
-
The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
|
81 |
-
"""
|
82 |
-
wait_time = None
|
83 |
-
for start_time in self.start_times.values():
|
84 |
-
current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
|
85 |
-
if wait_time is None or current_client_time_remaining < wait_time:
|
86 |
-
wait_time = current_client_time_remaining
|
87 |
-
return wait_time / 60 if wait_time is not None else 0
|
88 |
-
|
89 |
-
def is_server_full(self, websocket, options):
|
90 |
-
"""
|
91 |
-
Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
|
92 |
-
|
93 |
-
Args:
|
94 |
-
websocket: The websocket of the client attempting to connect.
|
95 |
-
options: A dictionary of options that may include the client's unique identifier.
|
96 |
-
|
97 |
-
Returns:
|
98 |
-
True if the server is full, False otherwise.
|
99 |
-
"""
|
100 |
-
if len(self.clients) >= self.max_clients:
|
101 |
-
wait_time = self.get_wait_time()
|
102 |
-
response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
|
103 |
-
websocket.send(json.dumps(response))
|
104 |
-
return True
|
105 |
-
return False
|
106 |
-
|
107 |
-
def is_client_timeout(self, websocket):
|
108 |
-
"""
|
109 |
-
Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
|
110 |
-
|
111 |
-
Args:
|
112 |
-
websocket: The websocket associated with the client to check.
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
True if the client's connection time has exceeded the maximum limit, False otherwise.
|
116 |
-
"""
|
117 |
-
elapsed_time = time.time() - self.start_times[websocket]
|
118 |
-
if elapsed_time >= self.max_connection_time:
|
119 |
-
self.clients[websocket].disconnect()
|
120 |
-
logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
|
121 |
-
return True
|
122 |
-
return False
|
123 |
-
|
124 |
-
|
125 |
-
class BackendType(Enum):
|
126 |
-
PYWHISPERCPP = "pywhispercpp"
|
127 |
-
|
128 |
-
@staticmethod
|
129 |
-
def valid_types() -> List[str]:
|
130 |
-
return [backend_type.value for backend_type in BackendType]
|
131 |
-
|
132 |
-
@staticmethod
|
133 |
-
def is_valid(backend: str) -> bool:
|
134 |
-
return backend in BackendType.valid_types()
|
135 |
-
|
136 |
-
def is_pywhispercpp(self) -> bool:
|
137 |
-
return self == BackendType.PYWHISPERCPP
|
138 |
-
|
139 |
-
|
140 |
-
class TranscriptionServer:
|
141 |
-
RATE = 16000
|
142 |
-
|
143 |
-
def __init__(self):
|
144 |
-
self.client_manager = None
|
145 |
-
self.no_voice_activity_chunks = 0
|
146 |
-
self.single_model = False
|
147 |
-
|
148 |
-
def initialize_client(
|
149 |
-
self, websocket, options
|
150 |
-
):
|
151 |
-
client: Optional[ServeClientBase] = None
|
152 |
-
|
153 |
-
if self.backend.is_pywhispercpp():
|
154 |
-
client = PyWhiperCppServe(
|
155 |
-
websocket,
|
156 |
-
language=options["language"],
|
157 |
-
client_uid=options["uid"],
|
158 |
-
)
|
159 |
-
logging.info("Running pywhispercpp backend.")
|
160 |
-
|
161 |
-
if client is None:
|
162 |
-
raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
|
163 |
-
|
164 |
-
self.client_manager.add_client(websocket, client)
|
165 |
-
|
166 |
-
def get_audio_from_websocket(self, websocket):
|
167 |
-
"""
|
168 |
-
Receives audio buffer from websocket and creates a numpy array out of it.
|
169 |
-
|
170 |
-
Args:
|
171 |
-
websocket: The websocket to receive audio from.
|
172 |
-
|
173 |
-
Returns:
|
174 |
-
A numpy array containing the audio.
|
175 |
-
"""
|
176 |
-
frame_data = websocket.recv()
|
177 |
-
if frame_data == b"END_OF_AUDIO":
|
178 |
-
return False
|
179 |
-
return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0
|
180 |
-
# return np.frombuffer(frame_data, dtype=np.float32)
|
181 |
-
|
182 |
-
|
183 |
-
def handle_new_connection(self, websocket):
|
184 |
-
query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query))
|
185 |
-
from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
|
186 |
-
|
187 |
-
try:
|
188 |
-
logging.info("New client connected")
|
189 |
-
options = websocket.recv()
|
190 |
-
try:
|
191 |
-
options = json.loads(options)
|
192 |
-
except Exception as e:
|
193 |
-
options = {"language": from_lang, "uid": str(uuid1())}
|
194 |
-
if self.client_manager is None:
|
195 |
-
max_clients = options.get('max_clients', 4)
|
196 |
-
max_connection_time = options.get('max_connection_time', 600)
|
197 |
-
self.client_manager = ClientManager(max_clients, max_connection_time)
|
198 |
-
|
199 |
-
if self.client_manager.is_server_full(websocket, options):
|
200 |
-
websocket.close()
|
201 |
-
return False # Indicates that the connection should not continue
|
202 |
-
|
203 |
-
if self.backend.is_pywhispercpp():
|
204 |
-
self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
|
205 |
-
|
206 |
-
self.initialize_client(websocket, options)
|
207 |
-
if from_lang and to_lang:
|
208 |
-
self.set_lang(websocket, from_lang, to_lang)
|
209 |
-
logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
210 |
-
return True
|
211 |
-
except json.JSONDecodeError:
|
212 |
-
logging.error("Failed to decode JSON from client")
|
213 |
-
return False
|
214 |
-
except ConnectionClosed:
|
215 |
-
logging.info("Connection closed by client")
|
216 |
-
return False
|
217 |
-
except Exception as e:
|
218 |
-
logging.error(f"Error during new connection initialization: {str(e)}")
|
219 |
-
return False
|
220 |
-
|
221 |
-
def process_audio_frames(self, websocket):
|
222 |
-
frame_np = self.get_audio_from_websocket(websocket)
|
223 |
-
client = self.client_manager.get_client(websocket)
|
224 |
-
|
225 |
-
# TODO Vad has some problem, it will be blocking process loop
|
226 |
-
# if frame_np is False:
|
227 |
-
# if self.backend.is_pywhispercpp():
|
228 |
-
# client.set_eos(True)
|
229 |
-
# return False
|
230 |
-
|
231 |
-
# if self.backend.is_pywhispercpp():
|
232 |
-
# voice_active = self.voice_activity(websocket, frame_np)
|
233 |
-
# if voice_active:
|
234 |
-
# self.no_voice_activity_chunks = 0
|
235 |
-
# client.set_eos(False)
|
236 |
-
# if self.use_vad and not voice_active:
|
237 |
-
# return True
|
238 |
-
|
239 |
-
client.add_frames(frame_np)
|
240 |
-
return True
|
241 |
-
|
242 |
-
def set_lang(self, websocket, src_lang, dst_lang):
|
243 |
-
client = self.client_manager.get_client(websocket)
|
244 |
-
if isinstance(client, PyWhiperCppServe):
|
245 |
-
client.set_lang(src_lang, dst_lang)
|
246 |
-
|
247 |
-
def recv_audio(self,
|
248 |
-
websocket,
|
249 |
-
backend: BackendType = BackendType.PYWHISPERCPP):
|
250 |
-
|
251 |
-
self.backend = backend
|
252 |
-
if not self.handle_new_connection(websocket):
|
253 |
-
return
|
254 |
-
|
255 |
-
|
256 |
-
try:
|
257 |
-
while not self.client_manager.is_client_timeout(websocket):
|
258 |
-
if not self.process_audio_frames(websocket):
|
259 |
-
break
|
260 |
-
except ConnectionClosed:
|
261 |
-
logging.info("Connection closed by client")
|
262 |
-
except Exception as e:
|
263 |
-
logging.error(f"Unexpected error: {str(e)}")
|
264 |
-
finally:
|
265 |
-
if self.client_manager.get_client(websocket):
|
266 |
-
self.cleanup(websocket)
|
267 |
-
websocket.close()
|
268 |
-
del websocket
|
269 |
-
|
270 |
-
def run(self,
|
271 |
-
host,
|
272 |
-
port=9090,
|
273 |
-
backend="pywhispercpp"):
|
274 |
-
"""
|
275 |
-
Run the transcription server.
|
276 |
-
|
277 |
-
Args:
|
278 |
-
host (str): The host address to bind the server.
|
279 |
-
port (int): The port number to bind the server.
|
280 |
-
"""
|
281 |
-
|
282 |
-
if not BackendType.is_valid(backend):
|
283 |
-
raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
|
284 |
-
|
285 |
-
with serve(
|
286 |
-
functools.partial(
|
287 |
-
self.recv_audio,
|
288 |
-
backend=BackendType(backend),
|
289 |
-
),
|
290 |
-
host,
|
291 |
-
port
|
292 |
-
) as server:
|
293 |
-
server.serve_forever()
|
294 |
-
|
295 |
-
def voice_activity(self, websocket, frame_np):
|
296 |
-
"""
|
297 |
-
Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
|
298 |
-
|
299 |
-
This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
|
300 |
-
contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
|
301 |
-
it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
|
302 |
-
speech detection to improve subsequent processing steps.
|
303 |
-
|
304 |
-
Args:
|
305 |
-
websocket: The websocket associated with the current client. Used to retrieve the client object
|
306 |
-
from the client manager for state management.
|
307 |
-
frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
|
308 |
-
the audio data for the current frame.
|
309 |
-
|
310 |
-
Returns:
|
311 |
-
bool: True if voice activity is detected in the current frame, False otherwise. When returning False
|
312 |
-
after detecting no voice activity for more than three consecutive frames, it also triggers the
|
313 |
-
end-of-speech (EOS) flag for the client.
|
314 |
-
"""
|
315 |
-
if not self.vad_detector(frame_np):
|
316 |
-
self.no_voice_activity_chunks += 1
|
317 |
-
if self.no_voice_activity_chunks > 3:
|
318 |
-
client = self.client_manager.get_client(websocket)
|
319 |
-
if not client.eos:
|
320 |
-
client.set_eos(True)
|
321 |
-
time.sleep(0.1) # Sleep 100m; wait some voice activity.
|
322 |
-
return False
|
323 |
-
return True
|
324 |
-
|
325 |
-
def cleanup(self, websocket):
|
326 |
-
"""
|
327 |
-
Cleans up resources associated with a given client's websocket.
|
328 |
-
|
329 |
-
Args:
|
330 |
-
websocket: The websocket associated with the client to be cleaned up.
|
331 |
-
"""
|
332 |
-
if self.client_manager.get_client(websocket):
|
333 |
-
self.client_manager.remove_client(websocket)
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe/translatepipes.py
CHANGED
@@ -14,7 +14,6 @@ class TranslatePipes:
|
|
14 |
|
15 |
# llm 翻译
|
16 |
# self._translate_pipe = self._launch_process(TranslatePipe())
|
17 |
-
|
18 |
self._translate_7b_pipe = self._launch_process(Translate7BPipe())
|
19 |
# vad
|
20 |
self._vad_pipe = self._launch_process(VadPipe())
|
|
|
14 |
|
15 |
# llm 翻译
|
16 |
# self._translate_pipe = self._launch_process(TranslatePipe())
|
|
|
17 |
self._translate_7b_pipe = self._launch_process(Translate7BPipe())
|
18 |
# vad
|
19 |
self._vad_pipe = self._launch_process(VadPipe())
|
transcribe/whisper_llm_serve.py
CHANGED
@@ -4,7 +4,6 @@ import queue
|
|
4 |
import threading
|
5 |
import time
|
6 |
from logging import getLogger
|
7 |
-
from typing import List, Optional, Iterator, Tuple, Any
|
8 |
import asyncio
|
9 |
import numpy as np
|
10 |
import config
|
@@ -45,8 +44,6 @@ class WhisperTranscriptionService:
|
|
45 |
self.sample_rate = 16000
|
46 |
|
47 |
self.lock = threading.Lock()
|
48 |
-
|
49 |
-
|
50 |
# 文本分隔符,根据语言设置
|
51 |
self.text_separator = self._get_text_separator(language)
|
52 |
self.loop = asyncio.get_event_loop()
|
@@ -54,7 +51,7 @@ class WhisperTranscriptionService:
|
|
54 |
# 原始音频队列
|
55 |
self._frame_queue = queue.Queue()
|
56 |
# 音频队列缓冲区
|
57 |
-
self.frames_np =
|
58 |
# 完整音频队列
|
59 |
self.segments_queue = collections.deque()
|
60 |
self._temp_string = ""
|
@@ -100,21 +97,6 @@ class WhisperTranscriptionService:
|
|
100 |
"""根据语言返回适当的文本分隔符"""
|
101 |
return "" if language == "zh" else " "
|
102 |
|
103 |
-
async def send_ready_state(self) -> None:
|
104 |
-
"""发送服务就绪状态消息"""
|
105 |
-
await self.websocket.send(json.dumps({
|
106 |
-
"uid": self.client_uid,
|
107 |
-
"message": self.SERVER_READY,
|
108 |
-
"backend": "whisper_transcription"
|
109 |
-
}))
|
110 |
-
|
111 |
-
def set_language(self, source_lang: str, target_lang: str) -> None:
|
112 |
-
"""设置源语言和目标语言"""
|
113 |
-
self.source_language = source_lang
|
114 |
-
self.target_language = target_lang
|
115 |
-
self.text_separator = self._get_text_separator(source_lang)
|
116 |
-
# self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator)
|
117 |
-
|
118 |
def add_frames(self, frame_np: np.ndarray) -> None:
|
119 |
"""添加音频帧到处理队列"""
|
120 |
self._frame_queue.put(frame_np)
|
@@ -135,60 +117,35 @@ class WhisperTranscriptionService:
|
|
135 |
if frame_np is None or len(frame_np) == 0:
|
136 |
continue
|
137 |
with self.lock:
|
138 |
-
|
139 |
-
self.frames_np = frame_np.copy()
|
140 |
-
else:
|
141 |
-
self.frames_np = np.append(self.frames_np, frame_np)
|
142 |
if speech_status == "END" and len(self.frames_np) > 0:
|
143 |
self.segments_queue.appendleft(self.frames_np.copy())
|
144 |
self.frames_np = np.array([], dtype=np.float32)
|
145 |
except queue.Empty:
|
146 |
pass
|
147 |
|
148 |
-
def _process_transcription_results_2(self, seg_text:str,partial):
|
149 |
-
|
150 |
-
item = TransResult(
|
151 |
-
seg_id=self.row_number,
|
152 |
-
context=seg_text,
|
153 |
-
from_=self.source_language,
|
154 |
-
to=self.target_language,
|
155 |
-
tran_content=self._translate_text_large(seg_text),
|
156 |
-
partial=partial
|
157 |
-
)
|
158 |
-
if partial == False:
|
159 |
-
self.row_number += 1
|
160 |
-
return item
|
161 |
-
|
162 |
def _transcription_processing_loop(self) -> None:
|
163 |
"""主转录处理循环"""
|
164 |
frame_epoch = 1
|
165 |
while not self._translate_thread_stop.is_set():
|
166 |
-
|
167 |
-
if self.frames_np is None:
|
168 |
-
time.sleep(0.01)
|
169 |
-
continue
|
170 |
-
|
171 |
|
172 |
-
if len(self.
|
173 |
-
audio_buffer = self.segments_queue.pop()
|
174 |
-
partial = False
|
175 |
-
else:
|
176 |
-
with self.lock:
|
177 |
-
audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
|
178 |
-
partial = True
|
179 |
-
|
180 |
-
if len(audio_buffer) ==0:
|
181 |
time.sleep(0.01)
|
182 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
if len(audio_buffer) < int(self.sample_rate):
|
185 |
silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
|
186 |
silence_audio[-len(audio_buffer):] = audio_buffer
|
187 |
audio_buffer = silence_audio
|
188 |
|
189 |
-
|
190 |
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
|
191 |
-
# try:
|
192 |
meta_item = self._transcribe_audio(audio_buffer)
|
193 |
segments = meta_item.segments
|
194 |
logger.debug(f"Segments: {segments}")
|
@@ -205,22 +162,24 @@ class WhisperTranscriptionService:
|
|
205 |
else:
|
206 |
self._temp_string = ""
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
result = self._process_transcription_results_2(seg_text, partial)
|
210 |
self._send_result_to_client(result)
|
211 |
-
time.sleep(0.1)
|
212 |
|
213 |
if partial == False:
|
214 |
frame_epoch = 1
|
215 |
else:
|
216 |
frame_epoch += 1
|
217 |
-
|
218 |
-
# for result in self._process_transcription_results(segments, audio_buffer):
|
219 |
-
# self._send_result_to_client(result)
|
220 |
-
|
221 |
-
# except Exception as e:
|
222 |
-
# logger.error(f"Error processing audio: {e}")
|
223 |
-
|
224 |
|
225 |
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
|
226 |
"""转录音频并返回转录片段"""
|
@@ -270,51 +229,6 @@ class WhisperTranscriptionService:
|
|
270 |
return translated_text
|
271 |
|
272 |
|
273 |
-
|
274 |
-
def _process_transcription_results(self, segments: List[TranscriptToken], audio_buffer: np.ndarray) -> Iterator[TransResult]:
|
275 |
-
"""
|
276 |
-
处理转录结果,生成翻译结果
|
277 |
-
|
278 |
-
Returns:
|
279 |
-
TransResult对象的迭代器
|
280 |
-
"""
|
281 |
-
|
282 |
-
if not segments:
|
283 |
-
return
|
284 |
-
start_time = time.perf_counter()
|
285 |
-
for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate):
|
286 |
-
if (cut_index :=ana_result.cut_index)>0:
|
287 |
-
# 更新音频缓冲区,移除已处理部分
|
288 |
-
self._update_audio_buffer(cut_index)
|
289 |
-
if ana_result.partial():
|
290 |
-
translated_context = self._translate_text(ana_result.context)
|
291 |
-
else:
|
292 |
-
translated_context = self._translate_text_large(ana_result.context)
|
293 |
-
|
294 |
-
yield TransResult(
|
295 |
-
seg_id=ana_result.seg_id,
|
296 |
-
context=ana_result.context,
|
297 |
-
from_=self.source_language,
|
298 |
-
to=self.target_language,
|
299 |
-
tran_content=translated_context,
|
300 |
-
partial=ana_result.partial()
|
301 |
-
)
|
302 |
-
current_time = time.perf_counter()
|
303 |
-
time_diff = current_time - start_time
|
304 |
-
if config.SAVE_DATA_SAVE:
|
305 |
-
self._save_queue.put(DebugResult(
|
306 |
-
seg_id=ana_result.seg_id,
|
307 |
-
transcrible_time=self._transcrible_time_cost,
|
308 |
-
translate_time=self._translate_time_cost,
|
309 |
-
context=ana_result.context,
|
310 |
-
from_=self.source_language,
|
311 |
-
to=self.target_language,
|
312 |
-
tran_content=translated_context,
|
313 |
-
partial=ana_result.partial()
|
314 |
-
))
|
315 |
-
log_block("🚦 Traffic times diff", round(time_diff, 2), 's')
|
316 |
-
|
317 |
-
|
318 |
def _send_result_to_client(self, result: TransResult) -> None:
|
319 |
"""发送翻译结果到客户端"""
|
320 |
try:
|
|
|
4 |
import threading
|
5 |
import time
|
6 |
from logging import getLogger
|
|
|
7 |
import asyncio
|
8 |
import numpy as np
|
9 |
import config
|
|
|
44 |
self.sample_rate = 16000
|
45 |
|
46 |
self.lock = threading.Lock()
|
|
|
|
|
47 |
# 文本分隔符,根据语言设置
|
48 |
self.text_separator = self._get_text_separator(language)
|
49 |
self.loop = asyncio.get_event_loop()
|
|
|
51 |
# 原始音频队列
|
52 |
self._frame_queue = queue.Queue()
|
53 |
# 音频队列缓冲区
|
54 |
+
self.frames_np = np.array([], dtype=np.float32)
|
55 |
# 完整音频队列
|
56 |
self.segments_queue = collections.deque()
|
57 |
self._temp_string = ""
|
|
|
97 |
"""根据语言返回适当的文本分隔符"""
|
98 |
return "" if language == "zh" else " "
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
def add_frames(self, frame_np: np.ndarray) -> None:
|
101 |
"""添加音频帧到处理队列"""
|
102 |
self._frame_queue.put(frame_np)
|
|
|
117 |
if frame_np is None or len(frame_np) == 0:
|
118 |
continue
|
119 |
with self.lock:
|
120 |
+
self.frames_np = np.append(self.frames_np, frame_np)
|
|
|
|
|
|
|
121 |
if speech_status == "END" and len(self.frames_np) > 0:
|
122 |
self.segments_queue.appendleft(self.frames_np.copy())
|
123 |
self.frames_np = np.array([], dtype=np.float32)
|
124 |
except queue.Empty:
|
125 |
pass
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
def _transcription_processing_loop(self) -> None:
|
128 |
"""主转录处理循环"""
|
129 |
frame_epoch = 1
|
130 |
while not self._translate_thread_stop.is_set():
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
if len(self.frames_np) ==0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
time.sleep(0.01)
|
134 |
continue
|
135 |
+
with self.lock:
|
136 |
+
if len(self.segments_queue) >0:
|
137 |
+
audio_buffer = self.segments_queue.pop()
|
138 |
+
partial = False
|
139 |
+
else:
|
140 |
+
audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
|
141 |
+
partial = True
|
142 |
|
143 |
if len(audio_buffer) < int(self.sample_rate):
|
144 |
silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
|
145 |
silence_audio[-len(audio_buffer):] = audio_buffer
|
146 |
audio_buffer = silence_audio
|
147 |
|
|
|
148 |
logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
|
|
|
149 |
meta_item = self._transcribe_audio(audio_buffer)
|
150 |
segments = meta_item.segments
|
151 |
logger.debug(f"Segments: {segments}")
|
|
|
162 |
else:
|
163 |
self._temp_string = ""
|
164 |
|
165 |
+
result = TransResult(
|
166 |
+
seg_id=self.row_number,
|
167 |
+
context=seg_text,
|
168 |
+
from_=self.source_language,
|
169 |
+
to=self.target_language,
|
170 |
+
tran_content=self._translate_text_large(seg_text),
|
171 |
+
partial=partial
|
172 |
+
)
|
173 |
+
if partial == False:
|
174 |
+
self.row_number += 1
|
175 |
|
|
|
176 |
self._send_result_to_client(result)
|
|
|
177 |
|
178 |
if partial == False:
|
179 |
frame_epoch = 1
|
180 |
else:
|
181 |
frame_epoch += 1
|
182 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
|
185 |
"""转录音频并返回转录片段"""
|
|
|
229 |
return translated_text
|
230 |
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
def _send_result_to_client(self, result: TransResult) -> None:
|
233 |
"""发送翻译结果到客户端"""
|
234 |
try:
|