File size: 6,435 Bytes
b99bb69
 
 
21eb51f
 
 
 
b99bb69
7e7acc6
 
 
 
 
b99bb69
 
94f0c4f
 
 
7e7acc6
 
21eb51f
7e7acc6
 
 
b99bb69
 
 
 
 
94f0c4f
b99bb69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f0c4f
 
 
 
 
 
 
 
 
 
 
 
 
b99bb69
94f0c4f
 
 
b99bb69
 
 
94f0c4f
 
 
 
 
 
 
 
 
 
23ca0a6
94f0c4f
 
 
 
 
b99bb69
94f0c4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ca0a6
94f0c4f
 
 
 
 
b99bb69
7e7acc6
b99bb69
21eb51f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b99bb69
 
 
 
94f0c4f
b99bb69
21eb51f
b99bb69
 
21eb51f
 
 
 
94f0c4f
21eb51f
 
b99bb69
 
7e7acc6
 
 
 
 
 
 
 
a85fcba
 
 
 
 
 
 
 
 
 
 
 
 
94f0c4f
 
 
 
 
7e7acc6
 
 
21eb51f
94f0c4f
21eb51f
 
 
 
 
 
 
 
 
7e7acc6
 
 
 
94f0c4f
 
21eb51f
7e7acc6
 
a85fcba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import gradio as gr
from transformers import pipeline
import numpy as np
import pandas as pd
import re
from collections import Counter
from functools import reduce

transcriber = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-base.en",
    return_timestamps=True,
)


MAX_AUDIO_DURATION = 5


def transcribe_live(state, words_list, new_chunk):
    try:
        words_to_check_for = [word.strip().lower() for word in words_list.split(",")]
    except:
        gr.Warning("Please enter a valid list of words to check for")
        words_to_check_for = []

    stream = state.get("stream", None)

    if new_chunk is None:
        gr.Info("You can start transcribing by clicking on the Record button")
        return state, {}, ""

    sr, y = new_chunk

    # Convert to mono if stereo
    if y.ndim > 1:
        y = y.mean(axis=1)

    y = y.astype(np.float32)
    y /= np.max(np.abs(y))

    if stream is not None:
        stream = np.concatenate([stream, y])
    else:
        stream = y

    duration_of_the_stream = len(stream) / sr
    print(f"Duration of the stream: {duration_of_the_stream}")

    # Only consider the last 30 seconds of the stream
    if duration_of_the_stream > MAX_AUDIO_DURATION:
        potentially_shorted_stream = stream[-sr * MAX_AUDIO_DURATION :]
    else:
        potentially_shorted_stream = stream

    start_of_the_stream = duration_of_the_stream - (
        len(potentially_shorted_stream) / sr
    )

    try:
        new_transcription = transcriber(
            {"sampling_rate": sr, "raw": potentially_shorted_stream}
        )
    except Exception as e:
        gr.Error(f"Transcription failed. Error: {e}")
        print(f"Transcription failed. Error: {e}")
        return state, {}, ""

    # We get something like: 'chunks': [{'timestamp': (0.0, 10.0), 'text': " I'm going to go."}]}
    new_chunks = new_transcription["chunks"]

    # Sum the start time of the new transcription to every chunk so that we get the real time
    new_chunks_remapped = [
        {
            "timestamp": (
                chunk["timestamp"][0] + start_of_the_stream,
                chunk["timestamp"][1] + start_of_the_stream if chunk["timestamp"][1] else end_time_cutoff,
            ),
            "text": chunk["text"],
        }
        for chunk in new_chunks
    ]

    print(new_chunks_remapped)

    # Remove the first 25% and the last 25% of the chunks, as they are usually not accurate (cut off)
    # Don't remove the first 25% if the stream is less than 20 seconds
    if duration_of_the_stream < MAX_AUDIO_DURATION:
        start_time_cutoff = start_of_the_stream
    else:
        start_time_cutoff = start_of_the_stream + 0.0 * MAX_AUDIO_DURATION
    end_time_cutoff = start_of_the_stream + 1.0 * MAX_AUDIO_DURATION

    print(f"Start time cutoff: {start_time_cutoff}")
    print(f"End time cutoff: {end_time_cutoff}")
    print(f"Start of the stream: {start_of_the_stream}")
    print(f"Before filtering: {new_chunks_remapped}")

    new_chunks_remapped = [
        chunk
        for chunk in new_chunks_remapped
        if chunk["timestamp"][0] >= start_time_cutoff
        and chunk["timestamp"][1] <= end_time_cutoff
    ]

    print(f"After filtering: {new_chunks_remapped}")

    # Merge the new transcription with the previous transcription.
    # Take the texts from the previous transcription up to the time when the new transcription starts

    previous_chunks = state.get("transcription_chunks", [])

    merged_chunks = [
        chunk for chunk in previous_chunks if chunk["timestamp"][1] <= start_time_cutoff
    ] + new_chunks_remapped

    full_transcription_text = reduce(
        lambda x, y: x + " " + y["text"], merged_chunks, ""
    )

    full_transcription_text_lower = full_transcription_text.lower()

    # Use re to find all the words in the transcription, and their start and end indices
    matches: list[re.Match] = list(
        re.finditer(
            r"\b(" + "|".join(words_to_check_for) + r")\b",
            full_transcription_text_lower,
        )
    )

    counter = Counter(
        match.group(0) for match in matches if match.group(0) in words_to_check_for
    )

    new_counts_of_words = {word: counter.get(word, 0) for word in words_to_check_for}

    new_highlighted_transcription = {
        "text": full_transcription_text,
        "entities": [
            {
                "entity": "FILLER",
                "start": match.start(),
                "end": match.end(),
            }
            for match in matches
        ],
    }

    new_state = {
        "stream": stream,
        "transcription_chunks": merged_chunks,
        "counts_of_words": new_counts_of_words,
        "highlighted_transcription": new_highlighted_transcription,
    }

    return (
        new_state,
        new_counts_of_words,
        full_transcription_text,
        merged_chunks,
        new_highlighted_transcription,
    )


with gr.Blocks() as demo:
    state = gr.State(
        value={
            "stream": None,
            "full_transcription": "",
            "counts_of_words": {},
        }
    )

    gr.Markdown(
        """
        # GrammASRian

        This app transcribes your speech in real-time and counts the number of filler words you use.

        The intended use case is to help you become more aware of the filler words you use, so you can reduce them and improve your speech.

        It uses the OpenAI Whisper model for transcription on a streaming configuration.
        """
    )

    filler_words = gr.Textbox(
        label="List of filer words",
        value="like, so, you know",
        info="Enter a comma-separated list of words to check for",
    )
    recording = gr.Audio(streaming=True, label="Recording")

    word_counts = gr.JSON(label="Filler words count", value={})
    transcription = gr.Textbox(label="Transcription", value="", visible=False)
    chunks = gr.JSON(label="Chunks", value=[], visible=False)

    highlighted_transcription = gr.HighlightedText(
        label="Transcription",
        value={
            "text": "",
            "entities": [],
        },
        color_map={"FILLER": "red"},
    )

    recording.stream(
        transcribe_live,
        inputs=[state, filler_words, recording],
        outputs=[state, word_counts, transcription, chunks, highlighted_transcription],
        stream_every=MAX_AUDIO_DURATION,
        time_limit=-1,
    )

demo.launch()