daihui.zhang commited on
Commit
ea9e44a
·
1 Parent(s): 55cf28e

remove unused codes

Browse files
config.py CHANGED
@@ -3,10 +3,11 @@ import re
3
  import logging
4
 
5
  DEBUG = True
 
6
 
7
  logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
8
  logging.basicConfig(
9
- level=logging.DEBUG if DEBUG else logging.INFO,
10
  format="%(asctime)s - %(levelname)s - %(message)s",
11
  filename='translator.log',
12
  datefmt="%H:%M:%S"
@@ -15,7 +16,7 @@ logging.basicConfig(
15
  SAVE_DATA_SAVE = False
16
  # Add terminal log
17
  console_handler = logging.StreamHandler()
18
- console_handler.setLevel(logging.DEBUG if DEBUG else logging.INFO)
19
  console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
20
  console_handler.setFormatter(console_formatter)
21
  logging.getLogger().addHandler(console_handler)
 
3
  import logging
4
 
5
  DEBUG = True
6
+ LOG_LEVEL = logging.WARNING if DEBUG else logging.INFO
7
 
8
  logging.getLogger("pywhispercpp").setLevel(logging.WARNING)
9
  logging.basicConfig(
10
+ level=LOG_LEVEL,
11
  format="%(asctime)s - %(levelname)s - %(message)s",
12
  filename='translator.log',
13
  datefmt="%H:%M:%S"
 
16
  SAVE_DATA_SAVE = False
17
  # Add terminal log
18
  console_handler = logging.StreamHandler()
19
+ console_handler.setLevel(LOG_LEVEL)
20
  console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
21
  console_handler.setFormatter(console_formatter)
22
  logging.getLogger().addHandler(console_handler)
main.py CHANGED
@@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles
11
  from fastapi.responses import RedirectResponse
12
  import os
13
  from transcribe.utils import pcm_bytes_to_np_array
 
14
  logger = getLogger(__name__)
15
 
16
 
@@ -39,9 +40,6 @@ async def lifespan(app:FastAPI):
39
  yield
40
 
41
 
42
- # 获取当前文件所在目录的绝对路径
43
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
44
- # 构建frontend目录的绝对路径
45
  FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
46
 
47
 
@@ -66,9 +64,7 @@ async def translate(websocket: WebSocket):
66
  client_uid=f"{uuid1()}",
67
  )
68
 
69
-
70
  if from_lang and to_lang and client:
71
- client.set_language(from_lang, to_lang)
72
  logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
73
  await websocket.accept()
74
  try:
 
11
  from fastapi.responses import RedirectResponse
12
  import os
13
  from transcribe.utils import pcm_bytes_to_np_array
14
+ from config import BASE_DIR
15
  logger = getLogger(__name__)
16
 
17
 
 
40
  yield
41
 
42
 
 
 
 
43
  FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
44
 
45
 
 
64
  client_uid=f"{uuid1()}",
65
  )
66
 
 
67
  if from_lang and to_lang and client:
 
68
  logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
69
  await websocket.accept()
70
  try:
transcribe/client.py DELETED
@@ -1,677 +0,0 @@
1
- import json
2
- import os
3
- import shutil
4
- import threading
5
- import time
6
- import uuid
7
- import wave
8
-
9
- import av
10
- import numpy as np
11
- import pyaudio
12
- import websocket
13
-
14
- import transcribe.utils as utils
15
-
16
-
17
- class Client:
18
- """
19
- Handles communication with a server using WebSocket.
20
- """
21
- INSTANCES = {}
22
- END_OF_AUDIO = "END_OF_AUDIO"
23
-
24
- def __init__(
25
- self,
26
- host=None,
27
- port=None,
28
- lang=None,
29
- log_transcription=True,
30
- max_clients=4,
31
- max_connection_time=600,
32
- dst_lang='zh',
33
- ):
34
- """
35
- Initializes a Client instance for audio recording and streaming to a server.
36
-
37
- If host and port are not provided, the WebSocket connection will not be established.
38
- the audio recording starts immediately upon initialization.
39
-
40
- Args:
41
- host (str): The hostname or IP address of the server.
42
- port (int): The port number for the WebSocket server.
43
- lang (str, optional): The selected language for transcription. Default is None.
44
- log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
45
- max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
46
- max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
47
- """
48
- self.recording = False
49
- self.uid = str(uuid.uuid4())
50
- self.waiting = False
51
- self.last_response_received = None
52
- self.disconnect_if_no_response_for = 15
53
- self.language = lang
54
- self.server_error = False
55
- self.last_segment = None
56
- self.last_received_segment = None
57
- self.log_transcription = log_transcription
58
- self.max_clients = max_clients
59
- self.max_connection_time = max_connection_time
60
- self.dst_lang = dst_lang
61
-
62
- self.audio_bytes = None
63
-
64
- if host is not None and port is not None:
65
- socket_url = f"ws://{host}:{port}?from={self.language}&to={self.dst_lang}"
66
- self.client_socket = websocket.WebSocketApp(
67
- socket_url,
68
- on_open=lambda ws: self.on_open(ws),
69
- on_message=lambda ws, message: self.on_message(ws, message),
70
- on_error=lambda ws, error: self.on_error(ws, error),
71
- on_close=lambda ws, close_status_code, close_msg: self.on_close(
72
- ws, close_status_code, close_msg
73
- ),
74
- )
75
- else:
76
- print("[ERROR]: No host or port specified.")
77
- return
78
-
79
- Client.INSTANCES[self.uid] = self
80
-
81
- # start websocket client in a thread
82
- self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
83
- self.ws_thread.daemon = True
84
- self.ws_thread.start()
85
-
86
- self.transcript = []
87
- print("[INFO]: * recording")
88
-
89
- def handle_status_messages(self, message_data):
90
- """Handles server status messages."""
91
- status = message_data["status"]
92
- if status == "WAIT":
93
- self.waiting = True
94
- print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
95
- elif status == "ERROR":
96
- print(f"Message from Server: {message_data['message']}")
97
- self.server_error = True
98
- elif status == "WARNING":
99
- print(f"Message from Server: {message_data['message']}")
100
-
101
- def process_segments(self, segments):
102
- """Processes transcript segments."""
103
- text = []
104
- for i, seg in enumerate(segments):
105
- if not text or text[-1] != seg["text"]:
106
- text.append(seg["text"])
107
- if i == len(segments) - 1 and not seg.get("completed", False):
108
- self.last_segment = seg
109
-
110
- # update last received segment and last valid response time
111
- if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
112
- self.last_response_received = time.time()
113
- self.last_received_segment = segments[-1]["text"]
114
-
115
- if self.log_transcription:
116
- # Truncate to last 3 entries for brevity.
117
- text = text[-3:]
118
- utils.clear_screen()
119
- utils.print_transcript(text)
120
-
121
- def on_message(self, ws, message):
122
- """
123
- Callback function called when a message is received from the server.
124
-
125
- It updates various attributes of the client based on the received message, including
126
- recording status, language detection, and server messages. If a disconnect message
127
- is received, it sets the recording status to False.
128
-
129
- Args:
130
- ws (websocket.WebSocketApp): The WebSocket client instance.
131
- message (str): The received message from the server.
132
-
133
- """
134
- message = json.loads(message)
135
-
136
- # if self.uid != message.get("uid"):
137
- # print("[ERROR]: invalid client uid")
138
- # return
139
-
140
- if "status" in message.keys():
141
- self.handle_status_messages(message)
142
- return
143
-
144
- if "message" in message.keys() and message["message"] == "DISCONNECT":
145
- print("[INFO]: Server disconnected due to overtime.")
146
- self.recording = False
147
-
148
- if "message" in message.keys() and message["message"] == "SERVER_READY":
149
- self.last_response_received = time.time()
150
- self.recording = True
151
- self.server_backend = message["backend"]
152
- print(f"[INFO]: Server Running with backend {self.server_backend}")
153
- return
154
-
155
- if "language" in message.keys():
156
- self.language = message.get("language")
157
- lang_prob = message.get("language_prob")
158
- print(
159
- f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
160
- )
161
- return
162
-
163
- if "segments" in message.keys():
164
- self.process_segments(message["segments"])
165
-
166
- def on_error(self, ws, error):
167
- print(f"[ERROR] WebSocket Error: {error}")
168
- self.server_error = True
169
- self.error_message = error
170
-
171
- def on_close(self, ws, close_status_code, close_msg):
172
- print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
173
- self.recording = False
174
- self.waiting = False
175
-
176
- def on_open(self, ws):
177
- """
178
- Callback function called when the WebSocket connection is successfully opened.
179
-
180
- Sends an initial configuration message to the server, including client UID,
181
- language selection, and task type.
182
-
183
- Args:
184
- ws (websocket.WebSocketApp): The WebSocket client instance.
185
-
186
- """
187
- print("[INFO]: Opened connection")
188
- ws.send(
189
- json.dumps(
190
- {
191
- "uid": self.uid,
192
- "language": self.language,
193
- "max_clients": self.max_clients,
194
- "max_connection_time": self.max_connection_time,
195
- }
196
- )
197
- )
198
-
199
- def send_packet_to_server(self, message):
200
- """
201
- Send an audio packet to the server using WebSocket.
202
-
203
- Args:
204
- message (bytes): The audio data packet in bytes to be sent to the server.
205
-
206
- """
207
- try:
208
- self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
209
- except Exception as e:
210
- print(e)
211
-
212
- def close_websocket(self):
213
- """
214
- Close the WebSocket connection and join the WebSocket thread.
215
-
216
- First attempts to close the WebSocket connection using `self.client_socket.close()`. After
217
- closing the connection, it joins the WebSocket thread to ensure proper termination.
218
-
219
- """
220
- try:
221
- self.client_socket.close()
222
- except Exception as e:
223
- print("[ERROR]: Error closing WebSocket:", e)
224
-
225
- try:
226
- self.ws_thread.join()
227
- except Exception as e:
228
- print("[ERROR:] Error joining WebSocket thread:", e)
229
-
230
- def get_client_socket(self):
231
- """
232
- Get the WebSocket client socket instance.
233
-
234
- Returns:
235
- WebSocketApp: The WebSocket client socket instance currently in use by the client.
236
- """
237
- return self.client_socket
238
-
239
- def wait_before_disconnect(self):
240
- """Waits a bit before disconnecting in order to process pending responses."""
241
- assert self.last_response_received
242
- while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
243
- continue
244
-
245
-
246
- class TranscriptionTeeClient:
247
- """
248
- Client for handling audio recording, streaming, and transcription tasks via one or more
249
- WebSocket connections.
250
-
251
- Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
252
- to send audio data for transcription to one or more servers, and receive transcribed text segments.
253
- Args:
254
- clients (list): one or more previously initialized Client instances
255
-
256
- Attributes:
257
- clients (list): the underlying Client instances responsible for handling WebSocket connections.
258
- """
259
-
260
- def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav",
261
- mute_audio_playback=False):
262
- self.clients = clients
263
- if not self.clients:
264
- raise Exception("At least one client is required.")
265
- self.chunk = 4096
266
- self.format = pyaudio.paInt16
267
- self.channels = 1
268
- self.rate = 16000
269
- self.record_seconds = 60000
270
- self.save_output_recording = save_output_recording
271
- self.output_recording_filename = output_recording_filename
272
- self.mute_audio_playback = mute_audio_playback
273
- self.frames = b""
274
- self.p = pyaudio.PyAudio()
275
- try:
276
- self.stream = self.p.open(
277
- format=self.format,
278
- channels=self.channels,
279
- rate=self.rate,
280
- input=True,
281
- frames_per_buffer=self.chunk,
282
- )
283
- except OSError as error:
284
- print(f"[WARN]: Unable to access microphone. {error}")
285
- self.stream = None
286
-
287
- def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
288
- """
289
- Start the transcription process.
290
-
291
- Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
292
- to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
293
- will be played and streamed to the server; otherwise, it will perform live recording.
294
-
295
- Args:
296
- audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
297
-
298
- """
299
- assert sum(
300
- source is not None for source in [audio, rtsp_url, hls_url]
301
- ) <= 1, 'You must provide only one selected source'
302
-
303
- print("[INFO]: Waiting for server ready ...")
304
- for client in self.clients:
305
- while not client.recording:
306
- if client.waiting or client.server_error:
307
- self.close_all_clients()
308
- return
309
-
310
- print("[INFO]: Server Ready!")
311
- if hls_url is not None:
312
- self.process_hls_stream(hls_url, save_file)
313
- elif audio is not None:
314
- resampled_file = utils.resample(audio)
315
- self.play_file(resampled_file)
316
- elif rtsp_url is not None:
317
- self.process_rtsp_stream(rtsp_url)
318
- else:
319
- self.record()
320
-
321
- def close_all_clients(self):
322
- """Closes all client websockets."""
323
- for client in self.clients:
324
- client.close_websocket()
325
-
326
- def multicast_packet(self, packet, unconditional=False):
327
- """
328
- Sends an identical packet via all clients.
329
-
330
- Args:
331
- packet (bytes): The audio data packet in bytes to be sent.
332
- unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
333
- """
334
- for client in self.clients:
335
- if (unconditional or client.recording):
336
- client.send_packet_to_server(packet)
337
-
338
- def play_file(self, filename):
339
- """
340
- Play an audio file and send it to the server for processing.
341
-
342
- Reads an audio file, plays it through the audio output, and simultaneously sends
343
- the audio data to the server for processing. It uses PyAudio to create an audio
344
- stream for playback. The audio data is read from the file in chunks, converted to
345
- floating-point format, and sent to the server using WebSocket communication.
346
- This method is typically used when you want to process pre-recorded audio and send it
347
- to the server in real-time.
348
-
349
- Args:
350
- filename (str): The path to the audio file to be played and sent to the server.
351
- """
352
-
353
- # read audio and create pyaudio stream
354
- with wave.open(filename, "rb") as wavfile:
355
- self.stream = self.p.open(
356
- format=self.p.get_format_from_width(wavfile.getsampwidth()),
357
- channels=wavfile.getnchannels(),
358
- rate=wavfile.getframerate(),
359
- input=True,
360
- output=True,
361
- frames_per_buffer=self.chunk,
362
- )
363
- chunk_duration = self.chunk / float(wavfile.getframerate())
364
- try:
365
- while any(client.recording for client in self.clients):
366
- data = wavfile.readframes(self.chunk)
367
- if data == b"":
368
- break
369
-
370
- audio_array = self.bytes_to_float_array(data)
371
- self.multicast_packet(audio_array.tobytes())
372
- if self.mute_audio_playback:
373
- time.sleep(chunk_duration)
374
- else:
375
- self.stream.write(data)
376
-
377
- wavfile.close()
378
-
379
- for client in self.clients:
380
- client.wait_before_disconnect()
381
- self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
382
- self.stream.close()
383
- self.close_all_clients()
384
-
385
- except KeyboardInterrupt:
386
- wavfile.close()
387
- self.stream.stop_stream()
388
- self.stream.close()
389
- self.p.terminate()
390
- self.close_all_clients()
391
- print("[INFO]: Keyboard interrupt.")
392
-
393
- def process_rtsp_stream(self, rtsp_url):
394
- """
395
- Connect to an RTSP source, process the audio stream, and send it for transcription.
396
-
397
- Args:
398
- rtsp_url (str): The URL of the RTSP stream source.
399
- """
400
- print("[INFO]: Connecting to RTSP stream...")
401
- try:
402
- container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
403
- self.process_av_stream(container, stream_type="RTSP")
404
- except Exception as e:
405
- print(f"[ERROR]: Failed to process RTSP stream: {e}")
406
- finally:
407
- for client in self.clients:
408
- client.wait_before_disconnect()
409
- self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
410
- self.close_all_clients()
411
- print("[INFO]: RTSP stream processing finished.")
412
-
413
- def process_hls_stream(self, hls_url, save_file=None):
414
- """
415
- Connect to an HLS source, process the audio stream, and send it for transcription.
416
-
417
- Args:
418
- hls_url (str): The URL of the HLS stream source.
419
- save_file (str, optional): Local path to save the network stream.
420
- """
421
- print("[INFO]: Connecting to HLS stream...")
422
- try:
423
- container = av.open(hls_url, format="hls")
424
- self.process_av_stream(container, stream_type="HLS", save_file=save_file)
425
- except Exception as e:
426
- print(f"[ERROR]: Failed to process HLS stream: {e}")
427
- finally:
428
- for client in self.clients:
429
- client.wait_before_disconnect()
430
- self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
431
- self.close_all_clients()
432
- print("[INFO]: HLS stream processing finished.")
433
-
434
- def process_av_stream(self, container, stream_type, save_file=None):
435
- """
436
- Process an AV container stream and send audio packets to the server.
437
-
438
- Args:
439
- container (av.container.InputContainer): The input container to process.
440
- stream_type (str): The type of stream being processed ("RTSP" or "HLS").
441
- save_file (str, optional): Local path to save the stream. Default is None.
442
- """
443
- audio_stream = next((s for s in container.streams if s.type == "audio"), None)
444
- if not audio_stream:
445
- print(f"[ERROR]: No audio stream found in {stream_type} source.")
446
- return
447
-
448
- output_container = None
449
- if save_file:
450
- output_container = av.open(save_file, mode="w")
451
- output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
452
-
453
- try:
454
- for packet in container.demux(audio_stream):
455
- for frame in packet.decode():
456
- audio_data = frame.to_ndarray().tobytes()
457
- self.multicast_packet(audio_data)
458
-
459
- if save_file:
460
- output_container.mux(frame)
461
- except Exception as e:
462
- print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
463
- finally:
464
- # Wait for server to send any leftover transcription.
465
- time.sleep(5)
466
- self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
467
- if output_container:
468
- output_container.close()
469
- container.close()
470
-
471
- def save_chunk(self, n_audio_file):
472
- """
473
- Saves the current audio frames to a WAV file in a separate thread.
474
-
475
- Args:
476
- n_audio_file (int): The index of the audio file which determines the filename.
477
- This helps in maintaining the order and uniqueness of each chunk.
478
- """
479
- t = threading.Thread(
480
- target=self.write_audio_frames_to_file,
481
- args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
482
- )
483
- t.start()
484
-
485
- def finalize_recording(self, n_audio_file):
486
- """
487
- Finalizes the recording process by saving any remaining audio frames,
488
- closing the audio stream, and terminating the process.
489
-
490
- Args:
491
- n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
492
- This index is incremented before use if the last chunk is saved.
493
- """
494
- if self.save_output_recording and len(self.frames):
495
- self.write_audio_frames_to_file(
496
- self.frames[:], f"chunks/{n_audio_file}.wav"
497
- )
498
- n_audio_file += 1
499
- self.stream.stop_stream()
500
- self.stream.close()
501
- self.p.terminate()
502
- self.close_all_clients()
503
- if self.save_output_recording:
504
- self.write_output_recording(n_audio_file)
505
-
506
- def record(self):
507
- """
508
- Record audio data from the input stream and save it to a WAV file.
509
-
510
- Continuously records audio data from the input stream, sends it to the server via a WebSocket
511
- connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
512
- the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
513
-
514
- Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
515
- The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
516
- The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
517
- the method combines all the saved audio chunks into the specified `out_file`.
518
- """
519
- n_audio_file = 0
520
- if self.save_output_recording:
521
- if os.path.exists("chunks"):
522
- shutil.rmtree("chunks")
523
- os.makedirs("chunks")
524
- try:
525
- for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
526
- if not any(client.recording for client in self.clients):
527
- break
528
- data = self.stream.read(self.chunk, exception_on_overflow=False)
529
- self.frames += data
530
-
531
- audio_array = self.bytes_to_float_array(data)
532
-
533
- self.multicast_packet(audio_array.tobytes())
534
-
535
- # save frames if more than a minute
536
- if len(self.frames) > 60 * self.rate:
537
- if self.save_output_recording:
538
- self.save_chunk(n_audio_file)
539
- n_audio_file += 1
540
- self.frames = b""
541
-
542
- except KeyboardInterrupt:
543
- self.finalize_recording(n_audio_file)
544
-
545
- def write_audio_frames_to_file(self, frames, file_name):
546
- """
547
- Write audio frames to a WAV file.
548
-
549
- The WAV file is created or overwritten with the specified name. The audio frames should be
550
- in the correct format and match the specified channel, sample width, and sample rate.
551
-
552
- Args:
553
- frames (bytes): The audio frames to be written to the file.
554
- file_name (str): The name of the WAV file to which the frames will be written.
555
-
556
- """
557
- with wave.open(file_name, "wb") as wavfile:
558
- wavfile: wave.Wave_write
559
- wavfile.setnchannels(self.channels)
560
- wavfile.setsampwidth(2)
561
- wavfile.setframerate(self.rate)
562
- wavfile.writeframes(frames)
563
-
564
- def write_output_recording(self, n_audio_file):
565
- """
566
- Combine and save recorded audio chunks into a single WAV file.
567
-
568
- The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
569
- file, appends its audio data to the final recording, and then deletes the chunk file. After combining
570
- and saving, the final recording is stored in the specified `out_file`.
571
-
572
-
573
- Args:
574
- n_audio_file (int): The number of audio chunk files to combine.
575
- out_file (str): The name of the output WAV file to save the final recording.
576
-
577
- """
578
- input_files = [
579
- f"chunks/{i}.wav"
580
- for i in range(n_audio_file)
581
- if os.path.exists(f"chunks/{i}.wav")
582
- ]
583
- with wave.open(self.output_recording_filename, "wb") as wavfile:
584
- wavfile: wave.Wave_write
585
- wavfile.setnchannels(self.channels)
586
- wavfile.setsampwidth(2)
587
- wavfile.setframerate(self.rate)
588
- for in_file in input_files:
589
- with wave.open(in_file, "rb") as wav_in:
590
- while True:
591
- data = wav_in.readframes(self.chunk)
592
- if data == b"":
593
- break
594
- wavfile.writeframes(data)
595
- # remove this file
596
- os.remove(in_file)
597
- wavfile.close()
598
- # clean up temporary directory to store chunks
599
- if os.path.exists("chunks"):
600
- shutil.rmtree("chunks")
601
-
602
- @staticmethod
603
- def bytes_to_float_array(audio_bytes):
604
- """
605
- Convert audio data from bytes to a NumPy float array.
606
-
607
- It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
608
- have values between -1 and 1.
609
-
610
- Args:
611
- audio_bytes (bytes): Audio data in bytes.
612
-
613
- Returns:
614
- np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
615
- """
616
- raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
617
- return raw_data.astype(np.float32) / 32768.0
618
-
619
-
620
- class TranscriptionClient(TranscriptionTeeClient):
621
- """
622
- Client for handling audio transcription tasks via a single WebSocket connection.
623
-
624
- Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
625
- to send audio data for transcription to a server and receive transcribed text segments.
626
-
627
- Args:
628
- host (str): The hostname or IP address of the server.
629
- port (int): The port number to connect to on the server.
630
- lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
631
- save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
632
- output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
633
- output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
634
- log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
635
- max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
636
- max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
637
- mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
638
-
639
- Attributes:
640
- client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
641
-
642
- Example:
643
- To create a TranscriptionClient and start transcription on microphone audio:
644
- ```python
645
- transcription_client = TranscriptionClient(host="localhost", port=9090)
646
- transcription_client()
647
- ```
648
- """
649
-
650
- def __init__(
651
- self,
652
- host,
653
- port,
654
- lang=None,
655
- save_output_recording=False,
656
- output_recording_filename="./output_recording.wav",
657
- log_transcription=True,
658
- max_clients=4,
659
- max_connection_time=600,
660
- mute_audio_playback=False,
661
- dst_lang='en',
662
- ):
663
- self.client = Client(
664
- host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
665
- max_connection_time=max_connection_time, dst_lang=dst_lang
666
- )
667
-
668
- if save_output_recording and not output_recording_filename.endswith(".wav"):
669
- raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
670
-
671
- TranscriptionTeeClient.__init__(
672
- self,
673
- [self.client],
674
- save_output_recording=save_output_recording,
675
- output_recording_filename=output_recording_filename,
676
- mute_audio_playback=mute_audio_playback,
677
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe/helpers/vadprocessor.py CHANGED
@@ -36,7 +36,7 @@ class AdaptiveSilenceController:
36
  speed_factor = 0.5
37
  elif avg_speech < 600:
38
  speed_factor = 0.8
39
-
40
  # 3. silence 的变化趋势也考虑进去
41
  adaptive = self.base * speed_factor + 0.3 * avg_silence
42
 
 
36
  speed_factor = 0.5
37
  elif avg_speech < 600:
38
  speed_factor = 0.8
39
+ logging.warning(f"Avg speech :{avg_speech}, Avg silence: {avg_silence}")
40
  # 3. silence 的变化趋势也考虑进去
41
  adaptive = self.base * speed_factor + 0.3 * avg_silence
42
 
transcribe/pipelines/pipe_vad.py CHANGED
@@ -3,10 +3,8 @@ from .base import MetaItem, BasePipe
3
  from ..helpers.vadprocessor import FixedVADIterator, AdaptiveSilenceController
4
 
5
  import numpy as np
6
- from silero_vad import get_speech_timestamps
7
- from typing import List
8
  import logging
9
- import time
10
  # import noisereduce as nr
11
 
12
 
@@ -60,27 +58,22 @@ class VadPipe(BasePipe):
60
 
61
  def update_silence_ms(self):
62
  min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
63
- <<<<<<< HEAD
64
  min_silence_samples = self.sample_rate * min_silence / 1000
65
- self.vac.min_silence_samples = min_silence_samples
66
- logging.warning(f"🫠 update_silence_ms :{min_silence} => current: {self.vac.min_silence_samples} ")
67
-
68
- =======
69
- logging.warning(f"🫠 update_silence_ms :{min_silence} => current: {self.vac.min_silence_duration_ms} ")
70
- self.vac.min_silence_duration_ms = min_silence
71
-
72
- >>>>>>> efad27a (add log to debug silence ms)
73
  def process(self, in_data: MetaItem) -> MetaItem:
74
  if self._offset == 0:
75
  self.vac.reset_states()
76
-
77
  # silence_audio_100ms = np.zeros(int(0.1*self.sample_rate))
78
  source_audio = np.frombuffer(in_data.source_audio, dtype=np.float32)
79
  speech_data = self._process_speech_chunk(source_audio)
80
 
81
  if speech_data: # 表示有音频的变化点出现
82
- # self.update_silence_ms()
83
- rel_start_frame, rel_end_frame = speech_data
84
  if rel_start_frame is not None and rel_end_frame is None:
85
  self._status = "START" # 语音开始
86
  target_audio = source_audio[rel_start_frame:]
 
3
  from ..helpers.vadprocessor import FixedVADIterator, AdaptiveSilenceController
4
 
5
  import numpy as np
 
 
6
  import logging
7
+
8
  # import noisereduce as nr
9
 
10
 
 
58
 
59
  def update_silence_ms(self):
60
  min_silence = self.adaptive_ctrl.get_adaptive_silence_ms()
 
61
  min_silence_samples = self.sample_rate * min_silence / 1000
62
+ old_silence_samples = self.vac.min_silence_samples
63
+ logging.warning(f"🫠 update_silence_ms :{old_silence_samples * 1000 / self.sample_rate :.2f}ms => current: {min_silence}ms ")
64
+ # self.vac.min_silence_samples = min_silence_samples
65
+
 
 
 
 
66
  def process(self, in_data: MetaItem) -> MetaItem:
67
  if self._offset == 0:
68
  self.vac.reset_states()
69
+
70
  # silence_audio_100ms = np.zeros(int(0.1*self.sample_rate))
71
  source_audio = np.frombuffer(in_data.source_audio, dtype=np.float32)
72
  speech_data = self._process_speech_chunk(source_audio)
73
 
74
  if speech_data: # 表示有音频的变化点出现
75
+
76
+ rel_start_frame, rel_end_frame = speech_data
77
  if rel_start_frame is not None and rel_end_frame is None:
78
  self._status = "START" # 语音开始
79
  target_audio = source_audio[rel_start_frame:]
transcribe/server.py DELETED
@@ -1,382 +0,0 @@
1
-
2
- import json
3
- import logging
4
- import threading
5
- import time
6
- import config
7
- import librosa
8
- import numpy as np
9
- import soundfile
10
- from pywhispercpp.model import Model
11
-
12
- logging.basicConfig(level=logging.INFO)
13
-
14
- class ServeClientBase(object):
15
- RATE = 16000
16
- SERVER_READY = "SERVER_READY"
17
- DISCONNECT = "DISCONNECT"
18
-
19
- def __init__(self, client_uid, websocket):
20
- self.client_uid = client_uid
21
- self.websocket = websocket
22
- self.frames = b""
23
- self.timestamp_offset = 0.0
24
- self.frames_np = None
25
- self.frames_offset = 0.0
26
- self.text = []
27
- self.current_out = ''
28
- self.prev_out = ''
29
- self.t_start = None
30
- self.exit = False
31
- self.same_output_count = 0
32
- self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
33
- self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
34
- self.transcript = []
35
- self.send_last_n_segments = 10
36
-
37
- # text formatting
38
- self.pick_previous_segments = 2
39
-
40
- # threading
41
- self.lock = threading.Lock()
42
-
43
- def speech_to_text(self):
44
- raise NotImplementedError
45
-
46
- def transcribe_audio(self):
47
- raise NotImplementedError
48
-
49
- def handle_transcription_output(self):
50
- raise NotImplementedError
51
-
52
- def add_frames(self, frame_np):
53
- """
54
- Add audio frames to the ongoing audio stream buffer.
55
-
56
- This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
57
- of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
58
- to prevent excessive memory usage.
59
-
60
- If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
61
- of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
62
- audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
63
-
64
- Args:
65
- frame_np (numpy.ndarray): The audio frame data as a NumPy array.
66
-
67
- """
68
- self.lock.acquire()
69
- if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
70
- self.frames_offset += 30.0
71
- self.frames_np = self.frames_np[int(30 * self.RATE):]
72
- # check timestamp offset(should be >= self.frame_offset)
73
- # this basically means that there is no speech as timestamp offset hasnt updated
74
- # and is less than frame_offset
75
- if self.timestamp_offset < self.frames_offset:
76
- self.timestamp_offset = self.frames_offset
77
- if self.frames_np is None:
78
- self.frames_np = frame_np.copy()
79
- else:
80
- self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
81
- self.lock.release()
82
-
83
- def clip_audio_if_no_valid_segment(self):
84
- """
85
- Update the timestamp offset based on audio buffer status.
86
- Clip audio if the current chunk exceeds 30 seconds, this basically implies that
87
- no valid segment for the last 30 seconds from whisper
88
- """
89
- with self.lock:
90
- if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
91
- duration = self.frames_np.shape[0] / self.RATE
92
- self.timestamp_offset = self.frames_offset + duration - 5
93
-
94
- def get_audio_chunk_for_processing(self):
95
- """
96
- Retrieves the next chunk of audio data for processing based on the current offsets.
97
-
98
- Calculates which part of the audio data should be processed next, based on
99
- the difference between the current timestamp offset and the frame's offset, scaled by
100
- the audio sample rate (RATE). It then returns this chunk of audio data along with its
101
- duration in seconds.
102
-
103
- Returns:
104
- tuple: A tuple containing:
105
- - input_bytes (np.ndarray): The next chunk of audio data to be processed.
106
- - duration (float): The duration of the audio chunk in seconds.
107
- """
108
- with self.lock:
109
- samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
110
- input_bytes = self.frames_np[int(samples_take):].copy()
111
- duration = input_bytes.shape[0] / self.RATE
112
- return input_bytes, duration
113
-
114
- def prepare_segments(self, last_segment=None):
115
- """
116
- Prepares the segments of transcribed text to be sent to the client.
117
-
118
- This method compiles the recent segments of transcribed text, ensuring that only the
119
- specified number of the most recent segments are included. It also appends the most
120
- recent segment of text if provided (which is considered incomplete because of the possibility
121
- of the last word being truncated in the audio chunk).
122
-
123
- Args:
124
- last_segment (str, optional): The most recent segment of transcribed text to be added
125
- to the list of segments. Defaults to None.
126
-
127
- Returns:
128
- list: A list of transcribed text segments to be sent to the client.
129
- """
130
- segments = []
131
- if len(self.transcript) >= self.send_last_n_segments:
132
- segments = self.transcript[-self.send_last_n_segments:].copy()
133
- else:
134
- segments = self.transcript.copy()
135
- if last_segment is not None:
136
- segments = segments + [last_segment]
137
- logging.info(f"{segments}")
138
- return segments
139
-
140
- def get_audio_chunk_duration(self, input_bytes):
141
- """
142
- Calculates the duration of the provided audio chunk.
143
-
144
- Args:
145
- input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
146
-
147
- Returns:
148
- float: The duration of the audio chunk in seconds.
149
- """
150
- return input_bytes.shape[0] / self.RATE
151
-
152
- def send_transcription_to_client(self, segments):
153
- """
154
- Sends the specified transcription segments to the client over the websocket connection.
155
-
156
- This method formats the transcription segments into a JSON object and attempts to send
157
- this object to the client. If an error occurs during the send operation, it logs the error.
158
-
159
- Returns:
160
- segments (list): A list of transcription segments to be sent to the client.
161
- """
162
- try:
163
- self.websocket.send(
164
- json.dumps({
165
- "uid": self.client_uid,
166
- "segments": segments,
167
- })
168
- )
169
- except Exception as e:
170
- logging.error(f"[ERROR]: Sending data to client: {e}")
171
-
172
- def disconnect(self):
173
- """
174
- Notify the client of disconnection and send a disconnect message.
175
-
176
- This method sends a disconnect message to the client via the WebSocket connection to notify them
177
- that the transcription service is disconnecting gracefully.
178
-
179
- """
180
- self.websocket.send(json.dumps({
181
- "uid": self.client_uid,
182
- "message": self.DISCONNECT
183
- }))
184
-
185
- def cleanup(self):
186
- """
187
- Perform cleanup tasks before exiting the transcription service.
188
-
189
- This method performs necessary cleanup tasks, including stopping the transcription thread, marking
190
- the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
191
- associated with the transcription process.
192
-
193
- """
194
- logging.info("Cleaning up.")
195
- self.exit = True
196
-
197
-
198
- class ServeClientWhisperCPP(ServeClientBase):
199
- SINGLE_MODEL = None
200
- SINGLE_MODEL_LOCK = threading.Lock()
201
-
202
- def __init__(self, websocket, language=None, client_uid=None,
203
- single_model=False):
204
- """
205
- Initialize a ServeClient instance.
206
- The Whisper model is initialized based on the client's language and device availability.
207
- The transcription thread is started upon initialization. A "SERVER_READY" message is sent
208
- to the client to indicate that the server is ready.
209
-
210
- Args:
211
- websocket (WebSocket): The WebSocket connection for the client.
212
- language (str, optional): The language for transcription. Defaults to None.
213
- client_uid (str, optional): A unique identifier for the client. Defaults to None.
214
- single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
215
-
216
- """
217
- super().__init__(client_uid, websocket)
218
- self.language = language
219
- self.eos = False
220
-
221
- if single_model:
222
- if ServeClientWhisperCPP.SINGLE_MODEL is None:
223
- self.create_model()
224
- ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
225
- else:
226
- self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
227
- else:
228
- self.create_model()
229
-
230
- # threading
231
- logging.info('Create a thread to process audio.')
232
- self.trans_thread = threading.Thread(target=self.speech_to_text)
233
- self.trans_thread.start()
234
-
235
- self.websocket.send(json.dumps({
236
- "uid": self.client_uid,
237
- "message": self.SERVER_READY,
238
- "backend": "pywhispercpp"
239
- }))
240
-
241
- def create_model(self, warmup=True):
242
- """
243
- Instantiates a new model, sets it as the transcriber and does warmup if desired.
244
- """
245
-
246
- self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR)
247
- if warmup:
248
- self.warmup()
249
-
250
- def warmup(self, warmup_steps=1):
251
- """
252
- Warmup TensorRT since first few inferences are slow.
253
-
254
- Args:
255
- warmup_steps (int): Number of steps to warm up the model for.
256
- """
257
- logging.info("[INFO:] Warming up whisper.cpp engine..")
258
- mel, _, = soundfile.read("assets/jfk.flac")
259
- for i in range(warmup_steps):
260
- self.transcriber.transcribe(mel, print_progress=False)
261
-
262
- def set_eos(self, eos):
263
- """
264
- Sets the End of Speech (EOS) flag.
265
-
266
- Args:
267
- eos (bool): The value to set for the EOS flag.
268
- """
269
- self.lock.acquire()
270
- self.eos = eos
271
- self.lock.release()
272
-
273
- def handle_transcription_output(self, last_segment, duration):
274
- """
275
- Handle the transcription output, updating the transcript and sending data to the client.
276
-
277
- Args:
278
- last_segment (str): The last segment from the whisper output which is considered to be incomplete because
279
- of the possibility of word being truncated.
280
- duration (float): Duration of the transcribed audio chunk.
281
- """
282
- segments = self.prepare_segments({"text": last_segment})
283
- self.send_transcription_to_client(segments)
284
- if self.eos:
285
- self.update_timestamp_offset(last_segment, duration)
286
-
287
- def transcribe_audio(self, input_bytes):
288
- """
289
- Transcribe the audio chunk and send the results to the client.
290
-
291
- Args:
292
- input_bytes (np.array): The audio chunk to transcribe.
293
- """
294
- if ServeClientWhisperCPP.SINGLE_MODEL:
295
- ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
296
- logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
297
- mel = input_bytes
298
- duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
299
-
300
- if self.language == "zh":
301
- prompt = '以下是简体中文普通话的句子。'
302
- else:
303
- prompt = 'The following is an English sentence.'
304
-
305
- segments = self.transcriber.transcribe(
306
- mel,
307
- language=self.language,
308
- initial_prompt=prompt,
309
- token_timestamps=True,
310
- # max_len=max_len,
311
- print_progress=False
312
- )
313
- text = []
314
- for segment in segments:
315
- content = segment.text
316
- text.append(content)
317
- last_segment = ' '.join(text)
318
-
319
- logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
320
-
321
- if ServeClientWhisperCPP.SINGLE_MODEL:
322
- ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
323
- if last_segment:
324
- self.handle_transcription_output(last_segment, duration)
325
-
326
- def update_timestamp_offset(self, last_segment, duration):
327
- """
328
- Update timestamp offset and transcript.
329
-
330
- Args:
331
- last_segment (str): Last transcribed audio from the whisper model.
332
- duration (float): Duration of the last audio chunk.
333
- """
334
- if not len(self.transcript):
335
- self.transcript.append({"text": last_segment + " "})
336
- elif self.transcript[-1]["text"].strip() != last_segment:
337
- self.transcript.append({"text": last_segment + " "})
338
-
339
- logging.info(f'Transcript list context: {self.transcript}')
340
-
341
- with self.lock:
342
- self.timestamp_offset += duration
343
-
344
- def speech_to_text(self):
345
- """
346
- Process an audio stream in an infinite loop, continuously transcribing the speech.
347
-
348
- This method continuously receives audio frames, performs real-time transcription, and sends
349
- transcribed segments to the client via a WebSocket connection.
350
-
351
- If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
352
- It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
353
- are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
354
- (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
355
- there is no speech for a specified duration to indicate a pause.
356
-
357
- Raises:
358
- Exception: If there is an issue with audio processing or WebSocket communication.
359
-
360
- """
361
- while True:
362
- if self.exit:
363
- logging.info("Exiting speech to text thread")
364
- break
365
-
366
- if self.frames_np is None:
367
- time.sleep(0.02) # wait for any audio to arrive
368
- continue
369
-
370
- self.clip_audio_if_no_valid_segment()
371
-
372
- input_bytes, duration = self.get_audio_chunk_for_processing()
373
- if duration < 1:
374
- continue
375
-
376
- try:
377
- input_sample = input_bytes.copy()
378
- logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
379
- self.transcribe_audio(input_sample)
380
-
381
- except Exception as e:
382
- logging.error(f"[ERROR]: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe/strategy.py DELETED
@@ -1,405 +0,0 @@
1
-
2
- import collections
3
- import logging
4
- from difflib import SequenceMatcher
5
- from itertools import chain
6
- from dataclasses import dataclass, field
7
- from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
8
- from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
9
- from enum import Enum
10
- import wordninja
11
- import config
12
- import re
13
- logger = logging.getLogger("TranscriptionStrategy")
14
-
15
-
16
- class SplitMode(Enum):
17
- PUNCTUATION = "punctuation"
18
- PAUSE = "pause"
19
- END = "end"
20
-
21
-
22
-
23
- @dataclass
24
- class TranscriptResult:
25
- seg_id: int = 0
26
- cut_index: int = 0
27
- is_end_sentence: bool = False
28
- context: str = ""
29
-
30
- def partial(self):
31
- return not self.is_end_sentence
32
-
33
- @dataclass
34
- class TranscriptToken:
35
- """表示一个转录片段,包含文本和时间信息"""
36
- text: str # 转录的文本内容
37
- t0: int # 开始时间(百分之一秒)
38
- t1: int # 结束时间(百分之一秒)
39
-
40
- def is_punctuation(self):
41
- """检查文本是否包含标点符号"""
42
- return REGEX_MARKERS.search(self.text.strip()) is not None
43
-
44
- def is_end(self):
45
- """检查文本是否为句子结束标记"""
46
- return SENTENCE_END_PATTERN.search(self.text.strip()) is not None
47
-
48
- def is_pause(self):
49
- """检查文本是否为暂停标记"""
50
- return PAUSEE_END_PATTERN.search(self.text.strip()) is not None
51
-
52
- def buffer_index(self) -> int:
53
- return int(self.t1 / 100 * SAMPLE_RATE)
54
-
55
- @dataclass
56
- class TranscriptChunk:
57
- """表示一组转录片段,支持分割和比较操作"""
58
- separator: str = "" # 用于连接片段的分隔符
59
- items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
60
-
61
- @staticmethod
62
- def _calculate_similarity(text1: str, text2: str) -> float:
63
- """计算两段文本的相似度"""
64
- return SequenceMatcher(None, text1, text2).ratio()
65
-
66
- def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
67
- """根据文本中的标点符号分割片段列表"""
68
- if mode == SplitMode.PUNCTUATION:
69
- indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
70
- elif mode == SplitMode.PAUSE:
71
- indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
72
- elif mode == SplitMode.END:
73
- indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
74
- else:
75
- raise ValueError(f"Unsupported mode: {mode}")
76
-
77
- # 每个切分点向后移一个索引,表示“分隔符归前段”
78
- cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
79
- chunks = [
80
- TranscriptChunk(items=self.items[start:end], separator=self.separator)
81
- for start, end in zip(cut_points, cut_points[1:])
82
- ]
83
- return [
84
- ck
85
- for ck in chunks
86
- if not ck.only_punctuation()
87
- ]
88
-
89
-
90
- def get_split_first_rest(self, mode: SplitMode):
91
- chunks = self.split_by(mode)
92
- fisrt_chunk = chunks[0] if chunks else self
93
- rest_chunks = chunks[1:] if chunks else None
94
- return fisrt_chunk, rest_chunks
95
-
96
- def puncation_numbers(self) -> int:
97
- """计算片段中标点符号的数量"""
98
- return sum(1 for seg in self.items if seg.is_punctuation())
99
-
100
- def length(self) -> int:
101
- """返回片段列表的长度"""
102
- return len(self.items)
103
-
104
- def join(self) -> str:
105
- """将片段连接为一个字符串"""
106
- return self.separator.join(seg.text for seg in self.items)
107
-
108
- def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
109
- """比较当前片段与另一个片段的相似度"""
110
- if not chunk:
111
- return 0
112
-
113
- score = self._calculate_similarity(self.join(), chunk.join())
114
- # logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
115
- return score
116
-
117
- def only_punctuation(self)->bool:
118
- return all(seg.is_punctuation() for seg in self.items)
119
-
120
- def has_punctuation(self) -> bool:
121
- return any(seg.is_punctuation() for seg in self.items)
122
-
123
- def get_buffer_index(self) -> int:
124
- return self.items[-1].buffer_index()
125
-
126
- def is_end_sentence(self) ->bool:
127
- return self.items[-1].is_end()
128
-
129
-
130
- class TranscriptHistory:
131
- """管理转录片段的历史记录"""
132
-
133
- def __init__(self) -> None:
134
- self.history = collections.deque(maxlen=2) # 存储最近的两个片段
135
-
136
- def add(self, chunk: TranscriptChunk):
137
- """添加新的片段到历史记录"""
138
- self.history.appendleft(chunk)
139
-
140
- def previous_chunk(self) -> Optional[TranscriptChunk]:
141
- """获取上一个片段(如果存在)"""
142
- return self.history[1] if len(self.history) == 2 else None
143
-
144
- def lastest_chunk(self):
145
- """获取最后一个片段"""
146
- return self.history[-1]
147
-
148
- def clear(self):
149
- self.history.clear()
150
-
151
- class TranscriptBuffer:
152
- """
153
- 管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
154
-
155
- |-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
156
-
157
- 管理 pending -> line -> paragraph 的缓冲逻辑
158
-
159
- """
160
-
161
- def __init__(self, source_lang:str, separator:str):
162
- self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
163
- self._sentences: List[str] = collections.deque() # 当前段落中的短句
164
- self._buffer: str = "" # 当前缓冲中的文本
165
- self._current_seg_id: int = 0
166
- self.source_language = source_lang
167
- self._separator = separator
168
-
169
- def get_seg_id(self) -> int:
170
- return self._current_seg_id
171
-
172
- @property
173
- def current_sentences_length(self) -> int:
174
- count = 0
175
- for item in self._sentences:
176
- if self._separator:
177
- count += len(item.split(self._separator))
178
- else:
179
- count += len(item)
180
- return count
181
-
182
- def update_pending_text(self, text: str) -> None:
183
- """更新临时缓冲字符串"""
184
- self._buffer = text
185
-
186
- def commit_line(self,) -> None:
187
- """将缓冲字符串提交为短句"""
188
- if self._buffer:
189
- self._sentences.append(self._buffer)
190
- self._buffer = ""
191
-
192
- def commit_paragraph(self) -> None:
193
- """
194
- 提交当前短句为完整段落(如句子结束)
195
-
196
- Args:
197
- end_of_sentence: 是否为句子结尾(如检测到句号)
198
- """
199
-
200
- count = 0
201
- current_sentences = []
202
- while len(self._sentences): # and count < 20:
203
- item = self._sentences.popleft()
204
- current_sentences.append(item)
205
- if self._separator:
206
- count += len(item.split(self._separator))
207
- else:
208
- count += len(item)
209
- if current_sentences:
210
- self._segments.append("".join(current_sentences))
211
- logger.debug(f"=== count to paragraph ===")
212
- logger.debug(f"push: {current_sentences}")
213
- logger.debug(f"rest: {self._sentences}")
214
- # if self._sentences:
215
- # self._segments.append("".join(self._sentences))
216
- # self._sentences.clear()
217
-
218
- def rebuild(self, text):
219
- output = self.split_and_join(
220
- text.replace(
221
- self._separator, ""))
222
-
223
- logger.debug("==== rebuild string ====")
224
- logger.debug(text)
225
- logger.debug(output)
226
-
227
- return output
228
-
229
- @staticmethod
230
- def split_and_join(text):
231
- tokens = []
232
- word_buf = ''
233
-
234
- for char in text:
235
- if char in ALL_MARKERS:
236
- if word_buf:
237
- tokens.extend(wordninja.split(word_buf))
238
- word_buf = ''
239
- tokens.append(char)
240
- else:
241
- word_buf += char
242
- if word_buf:
243
- tokens.extend(wordninja.split(word_buf))
244
-
245
- output = ''
246
- for i, token in enumerate(tokens):
247
- if i == 0:
248
- output += token
249
- elif token in ALL_MARKERS:
250
- output += (token + " ")
251
- else:
252
- output += ' ' + token
253
- return output
254
-
255
-
256
- def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
257
- if self.source_language == "en":
258
- stable_strings = [self.rebuild(i) for i in stable_strings]
259
- remaining_strings =[self.rebuild(i) for i in remaining_strings]
260
- remaining_string = "".join(remaining_strings)
261
-
262
- logger.debug(f"{self.__dict__}")
263
- if is_end_sentence:
264
- for stable_str in stable_strings:
265
- self.update_pending_text(stable_str)
266
- self.commit_line()
267
-
268
- current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
269
- # current_text_len = len(self.current_not_commit_text.split(self._separator))
270
- self.update_pending_text(remaining_string)
271
- if current_text_len >= config.TEXT_THREHOLD:
272
- self.commit_paragraph()
273
- self._current_seg_id += 1
274
- return True
275
- else:
276
- for stable_str in stable_strings:
277
- self.update_pending_text(stable_str)
278
- self.commit_line()
279
- self.update_pending_text(remaining_string)
280
- return False
281
-
282
-
283
- @property
284
- def un_commit_paragraph(self) -> str:
285
- """当前短句组合"""
286
- return "".join([i for i in self._sentences])
287
-
288
- @property
289
- def pending_text(self) -> str:
290
- """当前缓冲内容"""
291
- return self._buffer
292
-
293
- @property
294
- def latest_paragraph(self) -> str:
295
- """最新确认的段落"""
296
- return self._segments[-1] if self._segments else ""
297
-
298
- @property
299
- def current_not_commit_text(self) -> str:
300
- return self.un_commit_paragraph + self.pending_text
301
-
302
-
303
-
304
- class TranscriptStabilityAnalyzer:
305
- def __init__(self, source_lang, separator) -> None:
306
- self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
307
- self._transcript_history = TranscriptHistory()
308
- self._separator = separator
309
- logger.debug(f"Current separator: {self._separator}")
310
-
311
- def merge_chunks(self, chunks: List[TranscriptChunk])->str:
312
- if not chunks:
313
- return [""]
314
- output = list(r.join() for r in chunks if r)
315
- return output
316
-
317
-
318
- def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
319
- current = TranscriptChunk(items=current, separator=self._separator)
320
- self._transcript_history.add(current)
321
-
322
- prev = self._transcript_history.previous_chunk()
323
- self._transcript_buffer.update_pending_text(current.join())
324
- if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
325
- yield TranscriptResult(
326
- context=self._transcript_buffer.current_not_commit_text,
327
- seg_id=self._transcript_buffer.get_seg_id()
328
- )
329
- return
330
-
331
- # yield from self._handle_short_buffer(current, prev)
332
- if buffer_duration <= 4:
333
- yield from self._handle_short_buffer(current, prev)
334
- else:
335
- yield from self._handle_long_buffer(current)
336
-
337
-
338
- def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
339
- curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
340
- prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
341
-
342
- # logger.debug("==== Current cut item ====")
343
- # logger.debug(f"{curr.join()} ")
344
- # logger.debug(f"{prev.join()}")
345
- # logger.debug("==========================")
346
-
347
- if curr_first and prev_first:
348
-
349
- core = curr_first.compare(prev_first)
350
- has_punctuation = curr_first.has_punctuation()
351
- if core >= 0.8 and has_punctuation:
352
- yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
353
- return
354
-
355
- yield TranscriptResult(
356
- seg_id=self._transcript_buffer.get_seg_id(),
357
- context=self._transcript_buffer.current_not_commit_text
358
- )
359
-
360
-
361
- def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
362
- chunks = curr.split_by(SplitMode.PUNCTUATION)
363
- if len(chunks) > 1:
364
- stable, remaining = chunks[:-1], chunks[-1:]
365
- # stable_str = self.merge_chunks(stable)
366
- # remaining_str = self.merge_chunks(remaining)
367
- yield from self._yield_commit_results(
368
- stable, remaining, is_end_sentence=True # 暂时硬编码为True
369
- )
370
- else:
371
- yield TranscriptResult(
372
- seg_id=self._transcript_buffer.get_seg_id(),
373
- context=self._transcript_buffer.current_not_commit_text
374
- )
375
-
376
-
377
- def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
378
- stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
379
- remaining_str_list = self.merge_chunks(remaining_chunks)
380
- frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
381
-
382
- prev_seg_id = self._transcript_buffer.get_seg_id()
383
- commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
384
- logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
385
-
386
- if commit_paragraph:
387
- # 表示生成了一个新段落 换行
388
- yield TranscriptResult(
389
- seg_id=prev_seg_id,
390
- cut_index=frame_cut_index,
391
- context=self._transcript_buffer.latest_paragraph,
392
- is_end_sentence=True
393
- )
394
- if (context := self._transcript_buffer.current_not_commit_text.strip()):
395
- yield TranscriptResult(
396
- seg_id=self._transcript_buffer.get_seg_id(),
397
- context=context,
398
- )
399
- else:
400
- yield TranscriptResult(
401
- seg_id=self._transcript_buffer.get_seg_id(),
402
- cut_index=frame_cut_index,
403
- context=self._transcript_buffer.current_not_commit_text,
404
- )
405
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe/transcription.py DELETED
@@ -1,334 +0,0 @@
1
- import logging
2
- import time
3
- import functools
4
- import json
5
- import logging
6
- import time
7
- from enum import Enum
8
- from typing import List, Optional
9
- import numpy as np
10
- from .server import ServeClientBase
11
- from .whisper_llm_serve import PyWhiperCppServe
12
- from .vad import VoiceActivityDetector
13
- from urllib.parse import urlparse, parse_qsl
14
- from websockets.exceptions import ConnectionClosed
15
- from websockets.sync.server import serve
16
- from uuid import uuid1
17
-
18
-
19
- logging.basicConfig(level=logging.INFO)
20
-
21
-
22
- class ClientManager:
23
- def __init__(self, max_clients=4, max_connection_time=600):
24
- """
25
- Initializes the ClientManager with specified limits on client connections and connection durations.
26
-
27
- Args:
28
- max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
29
- max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
30
- to 600 seconds (10 minutes).
31
- """
32
- self.clients = {}
33
- self.start_times = {}
34
- self.max_clients = max_clients
35
- self.max_connection_time = max_connection_time
36
-
37
- def add_client(self, websocket, client):
38
- """
39
- Adds a client and their connection start time to the tracking dictionaries.
40
-
41
- Args:
42
- websocket: The websocket associated with the client to add.
43
- client: The client object to be added and tracked.
44
- """
45
- self.clients[websocket] = client
46
- self.start_times[websocket] = time.time()
47
-
48
- def get_client(self, websocket):
49
- """
50
- Retrieves a client associated with the given websocket.
51
-
52
- Args:
53
- websocket: The websocket associated with the client to retrieve.
54
-
55
- Returns:
56
- The client object if found, False otherwise.
57
- """
58
- if websocket in self.clients:
59
- return self.clients[websocket]
60
- return False
61
-
62
- def remove_client(self, websocket):
63
- """
64
- Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
65
- client if necessary.
66
-
67
- Args:
68
- websocket: The websocket associated with the client to be removed.
69
- """
70
- client = self.clients.pop(websocket, None)
71
- if client:
72
- client.cleanup()
73
- self.start_times.pop(websocket, None)
74
-
75
- def get_wait_time(self):
76
- """
77
- Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
78
-
79
- Returns:
80
- The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
81
- """
82
- wait_time = None
83
- for start_time in self.start_times.values():
84
- current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
85
- if wait_time is None or current_client_time_remaining < wait_time:
86
- wait_time = current_client_time_remaining
87
- return wait_time / 60 if wait_time is not None else 0
88
-
89
- def is_server_full(self, websocket, options):
90
- """
91
- Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
92
-
93
- Args:
94
- websocket: The websocket of the client attempting to connect.
95
- options: A dictionary of options that may include the client's unique identifier.
96
-
97
- Returns:
98
- True if the server is full, False otherwise.
99
- """
100
- if len(self.clients) >= self.max_clients:
101
- wait_time = self.get_wait_time()
102
- response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
103
- websocket.send(json.dumps(response))
104
- return True
105
- return False
106
-
107
- def is_client_timeout(self, websocket):
108
- """
109
- Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
110
-
111
- Args:
112
- websocket: The websocket associated with the client to check.
113
-
114
- Returns:
115
- True if the client's connection time has exceeded the maximum limit, False otherwise.
116
- """
117
- elapsed_time = time.time() - self.start_times[websocket]
118
- if elapsed_time >= self.max_connection_time:
119
- self.clients[websocket].disconnect()
120
- logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
121
- return True
122
- return False
123
-
124
-
125
- class BackendType(Enum):
126
- PYWHISPERCPP = "pywhispercpp"
127
-
128
- @staticmethod
129
- def valid_types() -> List[str]:
130
- return [backend_type.value for backend_type in BackendType]
131
-
132
- @staticmethod
133
- def is_valid(backend: str) -> bool:
134
- return backend in BackendType.valid_types()
135
-
136
- def is_pywhispercpp(self) -> bool:
137
- return self == BackendType.PYWHISPERCPP
138
-
139
-
140
- class TranscriptionServer:
141
- RATE = 16000
142
-
143
- def __init__(self):
144
- self.client_manager = None
145
- self.no_voice_activity_chunks = 0
146
- self.single_model = False
147
-
148
- def initialize_client(
149
- self, websocket, options
150
- ):
151
- client: Optional[ServeClientBase] = None
152
-
153
- if self.backend.is_pywhispercpp():
154
- client = PyWhiperCppServe(
155
- websocket,
156
- language=options["language"],
157
- client_uid=options["uid"],
158
- )
159
- logging.info("Running pywhispercpp backend.")
160
-
161
- if client is None:
162
- raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
163
-
164
- self.client_manager.add_client(websocket, client)
165
-
166
- def get_audio_from_websocket(self, websocket):
167
- """
168
- Receives audio buffer from websocket and creates a numpy array out of it.
169
-
170
- Args:
171
- websocket: The websocket to receive audio from.
172
-
173
- Returns:
174
- A numpy array containing the audio.
175
- """
176
- frame_data = websocket.recv()
177
- if frame_data == b"END_OF_AUDIO":
178
- return False
179
- return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0
180
- # return np.frombuffer(frame_data, dtype=np.float32)
181
-
182
-
183
- def handle_new_connection(self, websocket):
184
- query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query))
185
- from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
186
-
187
- try:
188
- logging.info("New client connected")
189
- options = websocket.recv()
190
- try:
191
- options = json.loads(options)
192
- except Exception as e:
193
- options = {"language": from_lang, "uid": str(uuid1())}
194
- if self.client_manager is None:
195
- max_clients = options.get('max_clients', 4)
196
- max_connection_time = options.get('max_connection_time', 600)
197
- self.client_manager = ClientManager(max_clients, max_connection_time)
198
-
199
- if self.client_manager.is_server_full(websocket, options):
200
- websocket.close()
201
- return False # Indicates that the connection should not continue
202
-
203
- if self.backend.is_pywhispercpp():
204
- self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
205
-
206
- self.initialize_client(websocket, options)
207
- if from_lang and to_lang:
208
- self.set_lang(websocket, from_lang, to_lang)
209
- logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
210
- return True
211
- except json.JSONDecodeError:
212
- logging.error("Failed to decode JSON from client")
213
- return False
214
- except ConnectionClosed:
215
- logging.info("Connection closed by client")
216
- return False
217
- except Exception as e:
218
- logging.error(f"Error during new connection initialization: {str(e)}")
219
- return False
220
-
221
- def process_audio_frames(self, websocket):
222
- frame_np = self.get_audio_from_websocket(websocket)
223
- client = self.client_manager.get_client(websocket)
224
-
225
- # TODO Vad has some problem, it will be blocking process loop
226
- # if frame_np is False:
227
- # if self.backend.is_pywhispercpp():
228
- # client.set_eos(True)
229
- # return False
230
-
231
- # if self.backend.is_pywhispercpp():
232
- # voice_active = self.voice_activity(websocket, frame_np)
233
- # if voice_active:
234
- # self.no_voice_activity_chunks = 0
235
- # client.set_eos(False)
236
- # if self.use_vad and not voice_active:
237
- # return True
238
-
239
- client.add_frames(frame_np)
240
- return True
241
-
242
- def set_lang(self, websocket, src_lang, dst_lang):
243
- client = self.client_manager.get_client(websocket)
244
- if isinstance(client, PyWhiperCppServe):
245
- client.set_lang(src_lang, dst_lang)
246
-
247
- def recv_audio(self,
248
- websocket,
249
- backend: BackendType = BackendType.PYWHISPERCPP):
250
-
251
- self.backend = backend
252
- if not self.handle_new_connection(websocket):
253
- return
254
-
255
-
256
- try:
257
- while not self.client_manager.is_client_timeout(websocket):
258
- if not self.process_audio_frames(websocket):
259
- break
260
- except ConnectionClosed:
261
- logging.info("Connection closed by client")
262
- except Exception as e:
263
- logging.error(f"Unexpected error: {str(e)}")
264
- finally:
265
- if self.client_manager.get_client(websocket):
266
- self.cleanup(websocket)
267
- websocket.close()
268
- del websocket
269
-
270
- def run(self,
271
- host,
272
- port=9090,
273
- backend="pywhispercpp"):
274
- """
275
- Run the transcription server.
276
-
277
- Args:
278
- host (str): The host address to bind the server.
279
- port (int): The port number to bind the server.
280
- """
281
-
282
- if not BackendType.is_valid(backend):
283
- raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
284
-
285
- with serve(
286
- functools.partial(
287
- self.recv_audio,
288
- backend=BackendType(backend),
289
- ),
290
- host,
291
- port
292
- ) as server:
293
- server.serve_forever()
294
-
295
- def voice_activity(self, websocket, frame_np):
296
- """
297
- Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
298
-
299
- This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
300
- contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
301
- it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
302
- speech detection to improve subsequent processing steps.
303
-
304
- Args:
305
- websocket: The websocket associated with the current client. Used to retrieve the client object
306
- from the client manager for state management.
307
- frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
308
- the audio data for the current frame.
309
-
310
- Returns:
311
- bool: True if voice activity is detected in the current frame, False otherwise. When returning False
312
- after detecting no voice activity for more than three consecutive frames, it also triggers the
313
- end-of-speech (EOS) flag for the client.
314
- """
315
- if not self.vad_detector(frame_np):
316
- self.no_voice_activity_chunks += 1
317
- if self.no_voice_activity_chunks > 3:
318
- client = self.client_manager.get_client(websocket)
319
- if not client.eos:
320
- client.set_eos(True)
321
- time.sleep(0.1) # Sleep 100m; wait some voice activity.
322
- return False
323
- return True
324
-
325
- def cleanup(self, websocket):
326
- """
327
- Cleans up resources associated with a given client's websocket.
328
-
329
- Args:
330
- websocket: The websocket associated with the client to be cleaned up.
331
- """
332
- if self.client_manager.get_client(websocket):
333
- self.client_manager.remove_client(websocket)
334
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe/translatepipes.py CHANGED
@@ -14,7 +14,6 @@ class TranslatePipes:
14
 
15
  # llm 翻译
16
  # self._translate_pipe = self._launch_process(TranslatePipe())
17
-
18
  self._translate_7b_pipe = self._launch_process(Translate7BPipe())
19
  # vad
20
  self._vad_pipe = self._launch_process(VadPipe())
 
14
 
15
  # llm 翻译
16
  # self._translate_pipe = self._launch_process(TranslatePipe())
 
17
  self._translate_7b_pipe = self._launch_process(Translate7BPipe())
18
  # vad
19
  self._vad_pipe = self._launch_process(VadPipe())
transcribe/whisper_llm_serve.py CHANGED
@@ -4,7 +4,6 @@ import queue
4
  import threading
5
  import time
6
  from logging import getLogger
7
- from typing import List, Optional, Iterator, Tuple, Any
8
  import asyncio
9
  import numpy as np
10
  import config
@@ -45,8 +44,6 @@ class WhisperTranscriptionService:
45
  self.sample_rate = 16000
46
 
47
  self.lock = threading.Lock()
48
-
49
-
50
  # 文本分隔符,根据语言设置
51
  self.text_separator = self._get_text_separator(language)
52
  self.loop = asyncio.get_event_loop()
@@ -54,7 +51,7 @@ class WhisperTranscriptionService:
54
  # 原始音频队列
55
  self._frame_queue = queue.Queue()
56
  # 音频队列缓冲区
57
- self.frames_np = None
58
  # 完整音频队列
59
  self.segments_queue = collections.deque()
60
  self._temp_string = ""
@@ -100,21 +97,6 @@ class WhisperTranscriptionService:
100
  """根据语言返回适当的文本分隔符"""
101
  return "" if language == "zh" else " "
102
 
103
- async def send_ready_state(self) -> None:
104
- """发送服务就绪状态消息"""
105
- await self.websocket.send(json.dumps({
106
- "uid": self.client_uid,
107
- "message": self.SERVER_READY,
108
- "backend": "whisper_transcription"
109
- }))
110
-
111
- def set_language(self, source_lang: str, target_lang: str) -> None:
112
- """设置源语言和目标语言"""
113
- self.source_language = source_lang
114
- self.target_language = target_lang
115
- self.text_separator = self._get_text_separator(source_lang)
116
- # self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator)
117
-
118
  def add_frames(self, frame_np: np.ndarray) -> None:
119
  """添加音频帧到处理队列"""
120
  self._frame_queue.put(frame_np)
@@ -135,60 +117,35 @@ class WhisperTranscriptionService:
135
  if frame_np is None or len(frame_np) == 0:
136
  continue
137
  with self.lock:
138
- if self.frames_np is None:
139
- self.frames_np = frame_np.copy()
140
- else:
141
- self.frames_np = np.append(self.frames_np, frame_np)
142
  if speech_status == "END" and len(self.frames_np) > 0:
143
  self.segments_queue.appendleft(self.frames_np.copy())
144
  self.frames_np = np.array([], dtype=np.float32)
145
  except queue.Empty:
146
  pass
147
 
148
- def _process_transcription_results_2(self, seg_text:str,partial):
149
-
150
- item = TransResult(
151
- seg_id=self.row_number,
152
- context=seg_text,
153
- from_=self.source_language,
154
- to=self.target_language,
155
- tran_content=self._translate_text_large(seg_text),
156
- partial=partial
157
- )
158
- if partial == False:
159
- self.row_number += 1
160
- return item
161
-
162
  def _transcription_processing_loop(self) -> None:
163
  """主转录处理循环"""
164
  frame_epoch = 1
165
  while not self._translate_thread_stop.is_set():
166
-
167
- if self.frames_np is None:
168
- time.sleep(0.01)
169
- continue
170
-
171
 
172
- if len(self.segments_queue) >0:
173
- audio_buffer = self.segments_queue.pop()
174
- partial = False
175
- else:
176
- with self.lock:
177
- audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
178
- partial = True
179
-
180
- if len(audio_buffer) ==0:
181
  time.sleep(0.01)
182
  continue
 
 
 
 
 
 
 
183
 
184
  if len(audio_buffer) < int(self.sample_rate):
185
  silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
186
  silence_audio[-len(audio_buffer):] = audio_buffer
187
  audio_buffer = silence_audio
188
 
189
-
190
  logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
191
- # try:
192
  meta_item = self._transcribe_audio(audio_buffer)
193
  segments = meta_item.segments
194
  logger.debug(f"Segments: {segments}")
@@ -205,22 +162,24 @@ class WhisperTranscriptionService:
205
  else:
206
  self._temp_string = ""
207
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- result = self._process_transcription_results_2(seg_text, partial)
210
  self._send_result_to_client(result)
211
- time.sleep(0.1)
212
 
213
  if partial == False:
214
  frame_epoch = 1
215
  else:
216
  frame_epoch += 1
217
- # 处理转录结果并发送到客户端
218
- # for result in self._process_transcription_results(segments, audio_buffer):
219
- # self._send_result_to_client(result)
220
-
221
- # except Exception as e:
222
- # logger.error(f"Error processing audio: {e}")
223
-
224
 
225
  def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
226
  """转录音频并返回转录片段"""
@@ -270,51 +229,6 @@ class WhisperTranscriptionService:
270
  return translated_text
271
 
272
 
273
-
274
- def _process_transcription_results(self, segments: List[TranscriptToken], audio_buffer: np.ndarray) -> Iterator[TransResult]:
275
- """
276
- 处理转录结果,生成翻译结果
277
-
278
- Returns:
279
- TransResult对象的迭代器
280
- """
281
-
282
- if not segments:
283
- return
284
- start_time = time.perf_counter()
285
- for ana_result in self._transcrible_analysis.analysis(segments, len(audio_buffer)/self.sample_rate):
286
- if (cut_index :=ana_result.cut_index)>0:
287
- # 更新音频缓冲区,移除已处理部分
288
- self._update_audio_buffer(cut_index)
289
- if ana_result.partial():
290
- translated_context = self._translate_text(ana_result.context)
291
- else:
292
- translated_context = self._translate_text_large(ana_result.context)
293
-
294
- yield TransResult(
295
- seg_id=ana_result.seg_id,
296
- context=ana_result.context,
297
- from_=self.source_language,
298
- to=self.target_language,
299
- tran_content=translated_context,
300
- partial=ana_result.partial()
301
- )
302
- current_time = time.perf_counter()
303
- time_diff = current_time - start_time
304
- if config.SAVE_DATA_SAVE:
305
- self._save_queue.put(DebugResult(
306
- seg_id=ana_result.seg_id,
307
- transcrible_time=self._transcrible_time_cost,
308
- translate_time=self._translate_time_cost,
309
- context=ana_result.context,
310
- from_=self.source_language,
311
- to=self.target_language,
312
- tran_content=translated_context,
313
- partial=ana_result.partial()
314
- ))
315
- log_block("🚦 Traffic times diff", round(time_diff, 2), 's')
316
-
317
-
318
  def _send_result_to_client(self, result: TransResult) -> None:
319
  """发送翻译结果到客户端"""
320
  try:
 
4
  import threading
5
  import time
6
  from logging import getLogger
 
7
  import asyncio
8
  import numpy as np
9
  import config
 
44
  self.sample_rate = 16000
45
 
46
  self.lock = threading.Lock()
 
 
47
  # 文本分隔符,根据语言设置
48
  self.text_separator = self._get_text_separator(language)
49
  self.loop = asyncio.get_event_loop()
 
51
  # 原始音频队列
52
  self._frame_queue = queue.Queue()
53
  # 音频队列缓冲区
54
+ self.frames_np = np.array([], dtype=np.float32)
55
  # 完整音频队列
56
  self.segments_queue = collections.deque()
57
  self._temp_string = ""
 
97
  """根据语言返回适当的文本分隔符"""
98
  return "" if language == "zh" else " "
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def add_frames(self, frame_np: np.ndarray) -> None:
101
  """添加音频帧到处理队列"""
102
  self._frame_queue.put(frame_np)
 
117
  if frame_np is None or len(frame_np) == 0:
118
  continue
119
  with self.lock:
120
+ self.frames_np = np.append(self.frames_np, frame_np)
 
 
 
121
  if speech_status == "END" and len(self.frames_np) > 0:
122
  self.segments_queue.appendleft(self.frames_np.copy())
123
  self.frames_np = np.array([], dtype=np.float32)
124
  except queue.Empty:
125
  pass
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def _transcription_processing_loop(self) -> None:
128
  """主转录处理循环"""
129
  frame_epoch = 1
130
  while not self._translate_thread_stop.is_set():
 
 
 
 
 
131
 
132
+ if len(self.frames_np) ==0:
 
 
 
 
 
 
 
 
133
  time.sleep(0.01)
134
  continue
135
+ with self.lock:
136
+ if len(self.segments_queue) >0:
137
+ audio_buffer = self.segments_queue.pop()
138
+ partial = False
139
+ else:
140
+ audio_buffer = self.frames_np[:int(frame_epoch * 1.5 * self.sample_rate)].copy()# 获取 1.5s * epoch 个音频长度
141
+ partial = True
142
 
143
  if len(audio_buffer) < int(self.sample_rate):
144
  silence_audio = np.zeros(self.sample_rate, dtype=np.float32)
145
  silence_audio[-len(audio_buffer):] = audio_buffer
146
  audio_buffer = silence_audio
147
 
 
148
  logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
 
149
  meta_item = self._transcribe_audio(audio_buffer)
150
  segments = meta_item.segments
151
  logger.debug(f"Segments: {segments}")
 
162
  else:
163
  self._temp_string = ""
164
 
165
+ result = TransResult(
166
+ seg_id=self.row_number,
167
+ context=seg_text,
168
+ from_=self.source_language,
169
+ to=self.target_language,
170
+ tran_content=self._translate_text_large(seg_text),
171
+ partial=partial
172
+ )
173
+ if partial == False:
174
+ self.row_number += 1
175
 
 
176
  self._send_result_to_client(result)
 
177
 
178
  if partial == False:
179
  frame_epoch = 1
180
  else:
181
  frame_epoch += 1
182
+
 
 
 
 
 
 
183
 
184
  def _transcribe_audio(self, audio_buffer: np.ndarray)->MetaItem:
185
  """转录音频并返回转录片段"""
 
229
  return translated_text
230
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def _send_result_to_client(self, result: TransResult) -> None:
233
  """发送翻译结果到客户端"""
234
  try: