NikolaSelic commited on
Commit
f7e8228
1 Parent(s): 0a16df6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import traceback
5
+ from contextlib import contextmanager
6
+
7
+ import diart.operators as dops
8
+ import numpy as np
9
+ import rich
10
+ import rx.operators as ops
11
+ import whisper_timestamped as whisper
12
+ from diart import OnlineSpeakerDiarization, PipelineConfig
13
+ from diart.sources import MicrophoneAudioSource
14
+ from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
15
+
16
+
17
+ def concat(chunks, collar=0.05):
18
+ """
19
+ Concatenate predictions and audio
20
+ given a list of `(diarization, waveform)` pairs
21
+ and merge contiguous single-speaker regions
22
+ with pauses shorter than `collar` seconds.
23
+ """
24
+ first_annotation = chunks[0][0]
25
+ first_waveform = chunks[0][1]
26
+ annotation = Annotation(uri=first_annotation.uri)
27
+ data = []
28
+ for ann, wav in chunks:
29
+ annotation.update(ann)
30
+ data.append(wav.data)
31
+ annotation = annotation.support(collar)
32
+ window = SlidingWindow(
33
+ first_waveform.sliding_window.duration,
34
+ first_waveform.sliding_window.step,
35
+ first_waveform.sliding_window.start,
36
+ )
37
+ data = np.concatenate(data, axis=0)
38
+ return annotation, SlidingWindowFeature(data, window)
39
+
40
+
41
+ def colorize_transcription(transcription):
42
+ colors = 2 * [
43
+ "bright_red",
44
+ "bright_blue",
45
+ "bright_green",
46
+ "orange3",
47
+ "deep_pink1",
48
+ "yellow2",
49
+ "magenta",
50
+ "cyan",
51
+ "bright_magenta",
52
+ "dodger_blue2",
53
+ ]
54
+ result = []
55
+ for speaker, text in transcription:
56
+ if speaker == -1:
57
+ # No speakerfound for this text, use default terminal color
58
+ result.append(text)
59
+ else:
60
+ result.append(f"[{colors[speaker]}]{text}")
61
+ return "\n".join(result)
62
+
63
+
64
+ # @contextmanager
65
+ # def suppress_stdout():
66
+ # with open(os.devnull, "w") as devnull:
67
+ # old_stdout = sys.stdout
68
+ # sys.stdout = devnull
69
+ # try:
70
+ # yield
71
+ # finally:
72
+ # sys.stdout = old_stdout
73
+
74
+
75
+ class WhisperTranscriber:
76
+ def __init__(self, model="small", device=None):
77
+ self.model = whisper.load_model(model, device=device)
78
+ self._buffer = ""
79
+
80
+ def transcribe(self, waveform):
81
+ """Transcribe audio using Whisper"""
82
+ # Pad/trim audio to fit 30 seconds as required by Whisper
83
+ audio = waveform.data.astype("float32").reshape(-1)
84
+ audio = whisper.pad_or_trim(audio)
85
+
86
+ # Transcribe the given audio while suppressing logs
87
+ transcription = whisper.transcribe(
88
+ self.model,
89
+ audio,
90
+ # We use past transcriptions to condition the model
91
+ initial_prompt=self._buffer,
92
+ verbose=True, # to avoid progress bar
93
+ )
94
+
95
+ return transcription
96
+
97
+ def identify_speakers(self, transcription, diarization, time_shift):
98
+ """Iterate over transcription segments to assign speakers"""
99
+ speaker_captions = []
100
+ for segment in transcription["segments"]:
101
+ # Crop diarization to the segment timestamps
102
+ start = time_shift + segment["words"][0]["start"]
103
+ end = time_shift + segment["words"][-1]["end"]
104
+ dia = diarization.crop(Segment(start, end))
105
+
106
+ # Assign a speaker to the segment based on diarization
107
+ speakers = dia.labels()
108
+ num_speakers = len(speakers)
109
+ if num_speakers == 0:
110
+ # No speakers were detected
111
+ caption = (-1, segment["text"])
112
+ elif num_speakers == 1:
113
+ # Only one speaker is active in this segment
114
+ spk_id = int(speakers[0].split("speaker")[1])
115
+ caption = (spk_id, segment["text"])
116
+ else:
117
+ # Multiple speakers, select the one that speaks the most
118
+ max_speaker = int(
119
+ np.argmax([dia.label_duration(spk) for spk in speakers])
120
+ )
121
+ caption = (max_speaker, segment["text"])
122
+ speaker_captions.append(caption)
123
+
124
+ return speaker_captions
125
+
126
+ def __call__(self, diarization, waveform):
127
+ # Step 1: Transcribe
128
+ transcription = self.transcribe(waveform)
129
+ # Update transcription buffer
130
+ self._buffer += transcription["text"]
131
+ # The audio may not be the beginning of the conversation
132
+ time_shift = waveform.sliding_window.start
133
+ # Step 2: Assign speakers
134
+ speaker_transcriptions = self.identify_speakers(
135
+ transcription, diarization, time_shift
136
+ )
137
+ return speaker_transcriptions
138
+
139
+
140
+ logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
141
+
142
+ config = PipelineConfig(
143
+ duration=5, step=0.5, latency="min", tau_active=0.5, rho_update=0.1, delta_new=0.57
144
+ )
145
+ dia = OnlineSpeakerDiarization(config)
146
+ source = MicrophoneAudioSource(config.sample_rate)
147
+
148
+ asr = WhisperTranscriber(model="small")
149
+
150
+ transcription_duration = 2
151
+ batch_size = int(transcription_duration // config.step)
152
+
153
+ source.stream.pipe(
154
+ # Format audio stream to sliding windows of 5s with a step of 500ms
155
+ dops.rearrange_audio_stream(config.duration, config.step, config.sample_rate),
156
+ # Wait until a batch is full
157
+ # The output is a list of audio chunks
158
+ ops.buffer_with_count(count=batch_size),
159
+ # Obtain diarization prediction
160
+ # The output is a list of pairs `(diarization, audio chunk)`
161
+ ops.map(dia),
162
+ # Concatenate 500ms predictions/chunks to form a single 2s chunk
163
+ ops.map(concat),
164
+ # Ignore this chunk if it does not contain speech
165
+ ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
166
+ # Obtain speaker-aware transcriptions
167
+ # The output is a list of pairs `(speaker: int, caption: str)`
168
+ ops.starmap(asr),
169
+ ops.map(colorize_transcription),
170
+ ).subscribe(
171
+ on_next=rich.print, # print colored text
172
+ on_error=lambda _: traceback.print_exc(), # print stacktrace if error
173
+ )
174
+
175
+ print("Listening...")
176
+ source.read()