rayan-saleh commited on
Commit
04878d0
·
0 Parent(s):

First model version

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ ==============
3
+
4
+ _Version 2.0, January 2004_
5
+ _&lt;<http://www.apache.org/licenses/>&gt;_
6
+
7
+ ### Terms and Conditions for use, reproduction, and distribution
8
+
9
+ #### 1. Definitions
10
+
11
+ “License” shall mean the terms and conditions for use, reproduction, and
12
+ distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright
15
+ owner that is granting the License.
16
+
17
+ “Legal Entity” shall mean the union of the acting entity and all other entities
18
+ that control, are controlled by, or are under common control with that entity.
19
+ For the purposes of this definition, “control” means **(i)** the power, direct or
20
+ indirect, to cause the direction or management of such entity, whether by
21
+ contract or otherwise, or **(ii)** ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or **(iii)** beneficial ownership of such entity.
23
+
24
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising
25
+ permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including
28
+ but not limited to software source code, documentation source, and configuration
29
+ files.
30
+
31
+ “Object” form shall mean any form resulting from mechanical transformation or
32
+ translation of a Source form, including but not limited to compiled object code,
33
+ generated documentation, and conversions to other media types.
34
+
35
+ “Work” shall mean the work of authorship, whether in Source or Object form, made
36
+ available under the License, as indicated by a copyright notice that is included
37
+ in or attached to the work (an example is provided in the Appendix below).
38
+
39
+ “Derivative Works” shall mean any work, whether in Source or Object form, that
40
+ is based on (or derived from) the Work and for which the editorial revisions,
41
+ annotations, elaborations, or other modifications represent, as a whole, an
42
+ original work of authorship. For the purposes of this License, Derivative Works
43
+ shall not include works that remain separable from, or merely link (or bind by
44
+ name) to the interfaces of, the Work and Derivative Works thereof.
45
+
46
+ “Contribution” shall mean any work of authorship, including the original version
47
+ of the Work and any modifications or additions to that Work or Derivative Works
48
+ thereof, that is intentionally submitted to Licensor for inclusion in the Work
49
+ by the copyright owner or by an individual or Legal Entity authorized to submit
50
+ on behalf of the copyright owner. For the purposes of this definition,
51
+ “submitted” means any form of electronic, verbal, or written communication sent
52
+ to the Licensor or its representatives, including but not limited to
53
+ communication on electronic mailing lists, source code control systems, and
54
+ issue tracking systems that are managed by, or on behalf of, the Licensor for
55
+ the purpose of discussing and improving the Work, but excluding communication
56
+ that is conspicuously marked or otherwise designated in writing by the copyright
57
+ owner as “Not a Contribution.”
58
+
59
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf
60
+ of whom a Contribution has been received by Licensor and subsequently
61
+ incorporated within the Work.
62
+
63
+ #### 2. Grant of Copyright License
64
+
65
+ Subject to the terms and conditions of this License, each Contributor hereby
66
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
67
+ irrevocable copyright license to reproduce, prepare Derivative Works of,
68
+ publicly display, publicly perform, sublicense, and distribute the Work and such
69
+ Derivative Works in Source or Object form.
70
+
71
+ #### 3. Grant of Patent License
72
+
73
+ Subject to the terms and conditions of this License, each Contributor hereby
74
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
75
+ irrevocable (except as stated in this section) patent license to make, have
76
+ made, use, offer to sell, sell, import, and otherwise transfer the Work, where
77
+ such license applies only to those patent claims licensable by such Contributor
78
+ that are necessarily infringed by their Contribution(s) alone or by combination
79
+ of their Contribution(s) with the Work to which such Contribution(s) was
80
+ submitted. If You institute patent litigation against any entity (including a
81
+ cross-claim or counterclaim in a lawsuit) alleging that the Work or a
82
+ Contribution incorporated within the Work constitutes direct or contributory
83
+ patent infringement, then any patent licenses granted to You under this License
84
+ for that Work shall terminate as of the date such litigation is filed.
85
+
86
+ #### 4. Redistribution
87
+
88
+ You may reproduce and distribute copies of the Work or Derivative Works thereof
89
+ in any medium, with or without modifications, and in Source or Object form,
90
+ provided that You meet the following conditions:
91
+
92
+ * **(a)** You must give any other recipients of the Work or Derivative Works a copy of
93
+ this License; and
94
+ * **(b)** You must cause any modified files to carry prominent notices stating that You
95
+ changed the files; and
96
+ * **(c)** You must retain, in the Source form of any Derivative Works that You distribute,
97
+ all copyright, patent, trademark, and attribution notices from the Source form
98
+ of the Work, excluding those notices that do not pertain to any part of the
99
+ Derivative Works; and
100
+ * **(d)** If the Work includes a “NOTICE” text file as part of its distribution, then any
101
+ Derivative Works that You distribute must include a readable copy of the
102
+ attribution notices contained within such NOTICE file, excluding those notices
103
+ that do not pertain to any part of the Derivative Works, in at least one of the
104
+ following places: within a NOTICE text file distributed as part of the
105
+ Derivative Works; within the Source form or documentation, if provided along
106
+ with the Derivative Works; or, within a display generated by the Derivative
107
+ Works, if and wherever such third-party notices normally appear. The contents of
108
+ the NOTICE file are for informational purposes only and do not modify the
109
+ License. You may add Your own attribution notices within Derivative Works that
110
+ You distribute, alongside or as an addendum to the NOTICE text from the Work,
111
+ provided that such additional attribution notices cannot be construed as
112
+ modifying the License.
113
+
114
+ You may add Your own copyright statement to Your modifications and may provide
115
+ additional or different license terms and conditions for use, reproduction, or
116
+ distribution of Your modifications, or for any such Derivative Works as a whole,
117
+ provided Your use, reproduction, and distribution of the Work otherwise complies
118
+ with the conditions stated in this License.
119
+
120
+ #### 5. Submission of Contributions
121
+
122
+ Unless You explicitly state otherwise, any Contribution intentionally submitted
123
+ for inclusion in the Work by You to the Licensor shall be under the terms and
124
+ conditions of this License, without any additional terms or conditions.
125
+ Notwithstanding the above, nothing herein shall supersede or modify the terms of
126
+ any separate license agreement you may have executed with Licensor regarding
127
+ such Contributions.
128
+
129
+ #### 6. Trademarks
130
+
131
+ This License does not grant permission to use the trade names, trademarks,
132
+ service marks, or product names of the Licensor, except as required for
133
+ reasonable and customary use in describing the origin of the Work and
134
+ reproducing the content of the NOTICE file.
135
+
136
+ #### 7. Disclaimer of Warranty
137
+
138
+ Unless required by applicable law or agreed to in writing, Licensor provides the
139
+ Work (and each Contributor provides its Contributions) on an “AS IS” BASIS,
140
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
141
+ including, without limitation, any warranties or conditions of TITLE,
142
+ NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
143
+ solely responsible for determining the appropriateness of using or
144
+ redistributing the Work and assume any risks associated with Your exercise of
145
+ permissions under this License.
146
+
147
+ #### 8. Limitation of Liability
148
+
149
+ In no event and under no legal theory, whether in tort (including negligence),
150
+ contract, or otherwise, unless required by applicable law (such as deliberate
151
+ and grossly negligent acts) or agreed to in writing, shall any Contributor be
152
+ liable to You for damages, including any direct, indirect, special, incidental,
153
+ or consequential damages of any character arising as a result of this License or
154
+ out of the use or inability to use the Work (including but not limited to
155
+ damages for loss of goodwill, work stoppage, computer failure or malfunction, or
156
+ any and all other commercial damages or losses), even if such Contributor has
157
+ been advised of the possibility of such damages.
158
+
159
+ #### 9. Accepting Warranty or Additional Liability
160
+
161
+ While redistributing the Work or Derivative Works thereof, You may choose to
162
+ offer, and charge a fee for, acceptance of support, warranty, indemnity, or
163
+ other liability obligations and/or rights consistent with this License. However,
164
+ in accepting such obligations, You may act only on Your own behalf and on Your
165
+ sole responsibility, not on behalf of any other Contributor, and only if You
166
+ agree to indemnify, defend, and hold each Contributor harmless for any liability
167
+ incurred by, or claims asserted against, such Contributor by reason of your
168
+ accepting any such warranty or additional liability.
169
+
170
+ _END OF TERMS AND CONDITIONS_
171
+
172
+ ### APPENDIX: How to apply the Apache License to your work
173
+
174
+ To apply the Apache License to your work, attach the following boilerplate
175
+ notice, with the fields enclosed by brackets `[]` replaced with your own
176
+ identifying information. (Don't include the brackets!) The text should be
177
+ enclosed in the appropriate comment syntax for the file format. We also
178
+ recommend that a file or class name and description of purpose be included on
179
+ the same “printed page” as the copyright notice for easier identification within
180
+ third-party archives.
181
+
182
+ Copyright [yyyy] [name of copyright owner]
183
+
184
+ Licensed under the Apache License, Version 2.0 (the "License");
185
+ you may not use this file except in compliance with the License.
186
+ You may obtain a copy of the License at
187
+
188
+ http://www.apache.org/licenses/LICENSE-2.0
189
+
190
+ Unless required by applicable law or agreed to in writing, software
191
+ distributed under the License is distributed on an "AS IS" BASIS,
192
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
193
+ See the License for the specific language governing permissions and
194
+ limitations under the License.
195
+
README.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Whisper Webui
3
+ emoji: ⚡
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.3.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # Running Locally
16
+
17
+ To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
18
+ ```
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ You can find detailed instructions for how to install this on Windows 10/11 [here (PDF)](docs/windows/install_win10_win11.pdf).
23
+
24
+ Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
25
+ ```
26
+ python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
27
+ ```
28
+
29
+ You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
30
+ ```
31
+ python cli.py \
32
+ [--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
33
+ [--vad_merge_window VAD_MERGE_WINDOW] \
34
+ [--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
35
+ [--vad_padding VAD_PADDING] \
36
+ [--vad_prompt_window VAD_PROMPT_WINDOW]
37
+ [--vad_cpu_cores NUMBER_OF_CORES]
38
+ [--vad_parallel_devices COMMA_DELIMITED_DEVICES]
39
+ [--auto_parallel BOOLEAN]
40
+ ```
41
+ In addition, you may also use URL's in addition to file paths as input.
42
+ ```
43
+ python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
44
+ ```
45
+
46
+ ## Google Colab
47
+
48
+ You can also run this Web UI directly on [Google Colab](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing), if you haven't got a GPU powerful enough to run the larger models.
49
+
50
+ See the [colab documentation](docs/colab.md) for more information.
51
+
52
+ ## Parallel Execution
53
+
54
+ You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
55
+ device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
56
+ ```
57
+ python cli.py --model large --vad silero-vad --language Japanese \
58
+ --vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
59
+ ```
60
+
61
+ Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
62
+ of running Silero-Vad, at a slight cost to accuracy.
63
+
64
+ This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
65
+ set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
66
+ The default value is 30 minutes.
67
+
68
+ ```
69
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
70
+ ```
71
+
72
+ To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
73
+ ```
74
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
75
+ ```
76
+
77
+ You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
78
+
79
+ ### Auto Parallel
80
+
81
+ You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
82
+ cores (up to 8):
83
+ ```
84
+ python app.py --input_audio_max_duration -1 --auto_parallel True
85
+ ```
86
+
87
+ ### Multiple Files
88
+
89
+ You can upload multiple files either through the "Upload files" option, or as a playlist on YouTube.
90
+ Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section.
91
+ When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
92
+
93
+ # Docker
94
+
95
+ To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
96
+ Then either use the GitLab hosted container below, or check out this repository and build an image:
97
+ ```
98
+ sudo docker build -t whisper-webui:1 .
99
+ ```
100
+
101
+ You can then start the WebUI with GPU support like so:
102
+ ```
103
+ sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
104
+ ```
105
+
106
+ Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
107
+ ```
108
+ sudo docker run -d -p 7860:7860 whisper-webui:1
109
+ ```
110
+
111
+ # GitLab Docker Registry
112
+
113
+ This Docker container is also hosted on GitLab:
114
+
115
+ ```
116
+ sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
117
+ ```
118
+
119
+ ## Custom Arguments
120
+
121
+ You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel:
122
+ ```
123
+ sudo docker run -d --gpus all -p 7860:7860 \
124
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
125
+ --restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
126
+ app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --auto_parallel True \
127
+ --default_vad silero-vad --default_model_name large
128
+ ```
129
+
130
+ You can also call `cli.py` the same way:
131
+ ```
132
+ sudo docker run --gpus all \
133
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
134
+ --mount type=bind,source=${PWD},target=/app/data \
135
+ registry.gitlab.com/aadnk/whisper-webui:latest \
136
+ cli.py --model large --auto_parallel True --vad silero-vad \
137
+ --output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
138
+ ```
139
+
140
+ ## Caching
141
+
142
+ Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
143
+ To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
144
+ prepopulate the directory with the different Whisper models.
145
+ ```
146
+ sudo docker run -d --gpus=all -p 7860:7860 \
147
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
148
+ registry.gitlab.com/aadnk/whisper-webui:latest
149
+ ```
app-local.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1)
app-network.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ create_ui(-1, server_name="0.0.0.0")
app-shared.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1, share=True)
app.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import math
3
+ from typing import Iterator
4
+ import argparse
5
+
6
+ from io import StringIO
7
+ import os
8
+ import pathlib
9
+ import tempfile
10
+ import zipfile
11
+ import numpy as np
12
+
13
+ import torch
14
+ from src.modelCache import ModelCache
15
+ from src.source import get_audio_source_collection
16
+ from src.vadParallel import ParallelContext, ParallelTranscription
17
+
18
+ # External programs
19
+ import ffmpeg
20
+
21
+ # UI
22
+ import gradio as gr
23
+
24
+ from src.download import ExceededMaximumDuration, download_url
25
+ from src.utils import slugify, write_srt, write_vtt
26
+ from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
27
+ from src.whisperContainer import WhisperContainer
28
+
29
+ # Limitations (set to -1 to disable)
30
+ DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
31
+
32
+ # Whether or not to automatically delete all uploaded files, to save disk space
33
+ DELETE_UPLOADED_FILES = True
34
+
35
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
36
+ MAX_FILE_PREFIX_LENGTH = 17
37
+
38
+ # Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
39
+ MAX_AUTO_CPU_CORES = 8
40
+
41
+ LANGUAGES = [
42
+ "English", "Chinese", "German", "Spanish", "Russian", "Korean",
43
+ "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
44
+ "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
45
+ "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
46
+ "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
47
+ "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
48
+ "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
49
+ "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
50
+ "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
51
+ "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
52
+ "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
53
+ "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
54
+ "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
55
+ "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
56
+ "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
57
+ "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
58
+ "Hausa", "Bashkir", "Javanese", "Sundanese"
59
+ ]
60
+
61
+ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
62
+
63
+ class WhisperTranscriber:
64
+ def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
65
+ vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None):
66
+ self.model_cache = ModelCache()
67
+ self.parallel_device_list = None
68
+ self.gpu_parallel_context = None
69
+ self.cpu_parallel_context = None
70
+ self.vad_process_timeout = vad_process_timeout
71
+ self.vad_cpu_cores = vad_cpu_cores
72
+
73
+ self.vad_model = None
74
+ self.inputAudioMaxDuration = input_audio_max_duration
75
+ self.deleteUploadedFiles = delete_uploaded_files
76
+ self.output_dir = output_dir
77
+
78
+ def set_parallel_devices(self, vad_parallel_devices: str):
79
+ self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
80
+
81
+ def set_auto_parallel(self, auto_parallel: bool):
82
+ if auto_parallel:
83
+ if torch.cuda.is_available():
84
+ self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
85
+
86
+ self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
87
+ print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
88
+
89
+ # Entry function for the simple tab
90
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
91
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
92
+
93
+ # Entry function for the full tab
94
+ def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
95
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
96
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
97
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
98
+
99
+ # Handle temperature_increment_on_fallback
100
+ if temperature_increment_on_fallback is not None:
101
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
102
+ else:
103
+ temperature = [temperature]
104
+
105
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
106
+ initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
107
+ condition_on_previous_text=condition_on_previous_text, fp16=fp16,
108
+ compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold)
109
+
110
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions: dict):
111
+ try:
112
+ sources = self.__get_source(urlData, multipleFiles, microphoneData)
113
+
114
+ try:
115
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
116
+ selectedModel = modelName if modelName is not None else "base"
117
+
118
+ model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
119
+
120
+ # Result
121
+ download = []
122
+ zip_file_lookup = {}
123
+ text = ""
124
+ vtt = ""
125
+
126
+ # Write result
127
+ downloadDirectory = tempfile.mkdtemp()
128
+ source_index = 0
129
+
130
+ outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
131
+
132
+ # Execute whisper
133
+ for source in sources:
134
+ source_prefix = ""
135
+
136
+ if (len(sources) > 1):
137
+ # Prefix (minimum 2 digits)
138
+ source_index += 1
139
+ source_prefix = str(source_index).zfill(2) + "_"
140
+ print("Transcribing ", source.source_path)
141
+
142
+ # Transcribe
143
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions)
144
+ filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
145
+
146
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
147
+
148
+ if len(sources) > 1:
149
+ # Add new line separators
150
+ if (len(source_text) > 0):
151
+ source_text += os.linesep + os.linesep
152
+ if (len(source_vtt) > 0):
153
+ source_vtt += os.linesep + os.linesep
154
+
155
+ # Append file name to source text too
156
+ source_text = source.get_full_name() + ":" + os.linesep + source_text
157
+ source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
158
+
159
+ # Add to result
160
+ download.extend(source_download)
161
+ text += source_text
162
+ vtt += source_vtt
163
+
164
+ if (len(sources) > 1):
165
+ # Zip files support at least 260 characters, but we'll play it safe and use 200
166
+ zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
167
+
168
+ # File names in ZIP file can be longer
169
+ for source_download_file in source_download:
170
+ # Get file postfix (after last -)
171
+ filePostfix = os.path.basename(source_download_file).split("-")[-1]
172
+ zip_file_name = zipFilePrefix + "-" + filePostfix
173
+ zip_file_lookup[source_download_file] = zip_file_name
174
+
175
+ # Create zip file from all sources
176
+ if len(sources) > 1:
177
+ downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
178
+
179
+ with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
180
+ for download_file in download:
181
+ # Get file name from lookup
182
+ zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
183
+ zip.write(download_file, arcname=zip_file_name)
184
+
185
+ download.insert(0, downloadAllPath)
186
+
187
+ return download, text, vtt
188
+
189
+ finally:
190
+ # Cleanup source
191
+ if self.deleteUploadedFiles:
192
+ for source in sources:
193
+ print("Deleting source file " + source.source_path)
194
+
195
+ try:
196
+ os.remove(source.source_path)
197
+ except Exception as e:
198
+ # Ignore error - it's just a cleanup
199
+ print("Error deleting source file " + source.source_path + ": " + str(e))
200
+
201
+ except ExceededMaximumDuration as e:
202
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
203
+
204
+ def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
205
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
206
+
207
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
208
+
209
+ if ('task' in decodeOptions):
210
+ task = decodeOptions.pop('task')
211
+
212
+ # Callable for processing an audio file
213
+ whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
214
+
215
+ # The results
216
+ if (vad == 'silero-vad'):
217
+ # Silero VAD where non-speech gaps are transcribed
218
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
219
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
220
+ elif (vad == 'silero-vad-skip-gaps'):
221
+ # Silero VAD where non-speech gaps are simply ignored
222
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
223
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
224
+ elif (vad == 'silero-vad-expand-into-gaps'):
225
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
226
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
227
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
228
+ elif (vad == 'periodic-vad'):
229
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
230
+ # it may create a break in the middle of a sentence, causing some artifacts.
231
+ periodic_vad = VadPeriodicTranscription()
232
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
233
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
234
+
235
+ else:
236
+ if (self._has_parallel_devices()):
237
+ # Use a simple period transcription instead, as we need to use the parallel context
238
+ periodic_vad = VadPeriodicTranscription()
239
+ period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
240
+
241
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
242
+ else:
243
+ # Default VAD
244
+ result = whisperCallable.invoke(audio_path, 0, None, None)
245
+
246
+ return result
247
+
248
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
249
+ if (not self._has_parallel_devices()):
250
+ # No parallel devices, so just run the VAD and Whisper in sequence
251
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
252
+
253
+ gpu_devices = self.parallel_device_list
254
+
255
+ if (gpu_devices is None or len(gpu_devices) == 0):
256
+ # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
257
+ gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
258
+
259
+ # Create parallel context if needed
260
+ if (self.gpu_parallel_context is None):
261
+ # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
262
+ self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
263
+ # We also need a CPU context for the VAD
264
+ if (self.cpu_parallel_context is None):
265
+ self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
266
+
267
+ parallel_vad = ParallelTranscription()
268
+ return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
269
+ config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
270
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context)
271
+
272
+ def _has_parallel_devices(self):
273
+ return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
274
+
275
+ def _concat_prompt(self, prompt1, prompt2):
276
+ if (prompt1 is None):
277
+ return prompt2
278
+ elif (prompt2 is None):
279
+ return prompt1
280
+ else:
281
+ return prompt1 + " " + prompt2
282
+
283
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
284
+ # Use Silero VAD
285
+ if (self.vad_model is None):
286
+ self.vad_model = VadSileroTranscription()
287
+
288
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
289
+ max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
290
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
291
+ max_prompt_window=vadPromptWindow)
292
+
293
+ return config
294
+
295
+ def write_result(self, result: dict, source_name: str, output_dir: str):
296
+ if not os.path.exists(output_dir):
297
+ os.makedirs(output_dir)
298
+
299
+ text = result["text"]
300
+ language = result["language"]
301
+ languageMaxLineWidth = self.__get_max_line_width(language)
302
+
303
+ print("Max line width " + str(languageMaxLineWidth))
304
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
305
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
306
+
307
+ output_files = []
308
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
309
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
310
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
311
+
312
+ return output_files, text, vtt
313
+
314
+ def clear_cache(self):
315
+ self.model_cache.clear()
316
+ self.vad_model = None
317
+
318
+ def __get_source(self, urlData, multipleFiles, microphoneData):
319
+ return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
320
+
321
+ def __get_max_line_width(self, language: str) -> int:
322
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
323
+ # Chinese characters and kana are wider, so limit line length to 40 characters
324
+ return 40
325
+ else:
326
+ # TODO: Add more languages
327
+ # 80 latin characters should fit on a 1080p/720p screen
328
+ return 80
329
+
330
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
331
+ segmentStream = StringIO()
332
+
333
+ if format == 'vtt':
334
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
335
+ elif format == 'srt':
336
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
337
+ else:
338
+ raise Exception("Unknown format " + format)
339
+
340
+ segmentStream.seek(0)
341
+ return segmentStream.read()
342
+
343
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
344
+ # Write the text to a file
345
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
346
+ file.write(text)
347
+
348
+ return file.name
349
+
350
+ def close(self):
351
+ print("Closing parallel contexts")
352
+ self.clear_cache()
353
+
354
+ if (self.gpu_parallel_context is not None):
355
+ self.gpu_parallel_context.close()
356
+ if (self.cpu_parallel_context is not None):
357
+ self.cpu_parallel_context.close()
358
+
359
+
360
+ def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
361
+ default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
362
+ vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
363
+ output_dir: str = None):
364
+ ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir)
365
+
366
+ # Specify a list of devices to use for parallel processing
367
+ ui.set_parallel_devices(vad_parallel_devices)
368
+ ui.set_auto_parallel(auto_parallel)
369
+
370
+ ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
371
+ ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
372
+ ui_description += " as well as speech translation and language identification. "
373
+
374
+ ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
375
+
376
+ if input_audio_max_duration > 0:
377
+ ui_description += "\n\n" + "Max audio file length: " + str(input_audio_max_duration) + " s"
378
+
379
+ ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
380
+
381
+ simple_inputs = lambda : [
382
+ gr.Dropdown(choices=WHISPER_MODELS, value=default_model_name, label="Model"),
383
+ gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
384
+ gr.Text(label="URL (YouTube, etc.)"),
385
+ gr.File(label="Upload Files", file_count="multiple"),
386
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
387
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
388
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
389
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
390
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
391
+ gr.Number(label="VAD - Padding (s)", precision=None, value=1),
392
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
393
+ ]
394
+
395
+ simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple, description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
396
+ gr.File(label="Download"),
397
+ gr.Text(label="Transcription"),
398
+ gr.Text(label="Segments")
399
+ ])
400
+
401
+ full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
402
+
403
+ full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
404
+ *simple_inputs(),
405
+ gr.TextArea(label="Initial Prompt"),
406
+ gr.Number(label="Temperature", value=0),
407
+ gr.Number(label="Best Of - Non-zero temperature", value=5, precision=0),
408
+ gr.Number(label="Beam Size - Zero temperature", value=5, precision=0),
409
+ gr.Number(label="Patience - Zero temperature", value=None),
410
+ gr.Number(label="Length Penalty - Any temperature", value=None),
411
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value="-1"),
412
+ gr.Checkbox(label="Condition on previous text", value=True),
413
+ gr.Checkbox(label="FP16", value=True),
414
+ gr.Number(label="Temperature increment on fallback", value=0.2),
415
+ gr.Number(label="Compression ratio threshold", value=2.4),
416
+ gr.Number(label="Logprob threshold", value=-1.0),
417
+ gr.Number(label="No speech threshold", value=0.6)
418
+ ], outputs=[
419
+ gr.File(label="Download"),
420
+ gr.Text(label="Transcription"),
421
+ gr.Text(label="Segments")
422
+ ])
423
+
424
+ demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
425
+
426
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
427
+
428
+ # Clean up
429
+ ui.close()
430
+
431
+ if __name__ == '__main__':
432
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
433
+ parser.add_argument("--input_audio_max_duration", type=int, default=DEFAULT_INPUT_AUDIO_MAX_DURATION, help="Maximum audio file length in seconds, or -1 for no limit.")
434
+ parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
435
+ parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
436
+ parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
437
+ parser.add_argument("--default_model_name", type=str, choices=WHISPER_MODELS, default="medium", help="The default model name.")
438
+ parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
439
+ parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
440
+ parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
441
+ parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
442
+ parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
443
+ parser.add_argument("--output_dir", "-o", type=str, default=None, help="directory to save the outputs")
444
+
445
+ args = parser.parse_args().__dict__
446
+ create_ui(**args)
cli.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ from urllib.parse import urlparse
5
+ import warnings
6
+ import numpy as np
7
+
8
+ import torch
9
+ from app import LANGUAGES, WHISPER_MODELS, WhisperTranscriber
10
+ from src.download import download_url
11
+
12
+ from src.utils import optional_float, optional_int, str2bool
13
+ from src.whisperContainer import WhisperContainer
14
+
15
+ def cli():
16
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
18
+ parser.add_argument("--model", default="small", choices=WHISPER_MODELS, help="name of the Whisper model to use")
19
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
20
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
21
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
22
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
23
+
24
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
25
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
26
+
27
+ parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
28
+ parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
29
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
30
+ parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
31
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
32
+ parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
33
+ parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
34
+ parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
35
+
36
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
37
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
38
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
39
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
40
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
41
+
42
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
43
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
44
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
45
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
46
+
47
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
48
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
49
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
50
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
51
+
52
+ args = parser.parse_args().__dict__
53
+ model_name: str = args.pop("model")
54
+ model_dir: str = args.pop("model_dir")
55
+ output_dir: str = args.pop("output_dir")
56
+ device: str = args.pop("device")
57
+ os.makedirs(output_dir, exist_ok=True)
58
+
59
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
60
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
61
+ args["language"] = "en"
62
+
63
+ temperature = args.pop("temperature")
64
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
65
+ if temperature_increment_on_fallback is not None:
66
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
67
+ else:
68
+ temperature = [temperature]
69
+
70
+ vad = args.pop("vad")
71
+ vad_merge_window = args.pop("vad_merge_window")
72
+ vad_max_merge_size = args.pop("vad_max_merge_size")
73
+ vad_padding = args.pop("vad_padding")
74
+ vad_prompt_window = args.pop("vad_prompt_window")
75
+ vad_cpu_cores = args.pop("vad_cpu_cores")
76
+ auto_parallel = args.pop("auto_parallel")
77
+
78
+ model = WhisperContainer(model_name, device=device, download_root=model_dir)
79
+ transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
80
+ transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
81
+ transcriber.set_auto_parallel(auto_parallel)
82
+
83
+ if (transcriber._has_parallel_devices()):
84
+ print("Using parallel devices:", transcriber.parallel_device_list)
85
+
86
+ for audio_path in args.pop("audio"):
87
+ sources = []
88
+
89
+ # Detect URL and download the audio
90
+ if (uri_validator(audio_path)):
91
+ # Download from YouTube/URL directly
92
+ for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
93
+ source_name = os.path.basename(source_path)
94
+ sources.append({ "path": source_path, "name": source_name })
95
+ else:
96
+ sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
97
+
98
+ for source in sources:
99
+ source_path = source["path"]
100
+ source_name = source["name"]
101
+
102
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature,
103
+ vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
104
+ vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
105
+
106
+ transcriber.write_result(result, source_name, output_dir)
107
+
108
+ transcriber.close()
109
+
110
+ def uri_validator(x):
111
+ try:
112
+ result = urlparse(x)
113
+ return all([result.scheme, result.netloc])
114
+ except:
115
+ return False
116
+
117
+ if __name__ == '__main__':
118
+ cli()
dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM huggingface/transformers-pytorch-gpu
2
+ EXPOSE 7860
3
+
4
+ ADD . /opt/whisper-webui/
5
+
6
+ # Latest version of transformers-pytorch-gpu seems to lack tk.
7
+ # Further, pip install fails, so we must upgrade pip first.
8
+ RUN apt-get -y install python3-tk
9
+ RUN python3 -m pip install --upgrade pip &&\
10
+ python3 -m pip install -r /opt/whisper-webui/requirements.txt
11
+
12
+ # Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
13
+ # You can also bind this directory in the container to somewhere on the host.
14
+
15
+ # To be able to see logs in real time
16
+ ENV PYTHONUNBUFFERED=1
17
+
18
+ WORKDIR /opt/whisper-webui/
19
+ ENTRYPOINT ["python3"]
20
+ CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
docs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
docs/colab.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Whisper on Google Colab
2
+
3
+ If you don't have a decent GPU or any experience in running command-line applications, you might want to try this Google Colab instead:
4
+
5
+ * [Google Colab - Whisper WebUI GPU](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing)
6
+ * [Screenshots](https://imgur.com/a/ZfY6uBO)
7
+
8
+ The runtime (Runtime -> Change runtime type -> Hardware accelerator) should already be set top GPU. But if not, change it to GPU.
9
+
10
+ Then, sign in to Google if you haven't already. Next, click on "Connect" at the top right.
11
+
12
+ Under "Checking out WebUI from Git", click on the [play icon](https://imgur.com/a/81gOLyD) that appears in "[ ]" at the left. If you get a warning, click "Run anyway".
13
+
14
+ After this step has completed, it should be get a green check mark. Then move on to the next section under "Installing dependencies", and click in "[ ]" again. This might take approximately 30 seconds.
15
+
16
+ Once this has completed, scroll down to the "Run WebUI" section, and click on "[ ]". This will launch the WebUI in a shared link (expires in 72 hours). To open the UI, click on the link next to "Running on public URL", which will be something like https://12xxx.gradio.app/
17
+
18
+ The audio length in this version is not restricted, and it will run much faster as it is backed by a GPU. You can also run it using the "Large" model. Also note that it might take some time to start the model the first time, as it may need to download a 2.8 GB file on Google's servers.
19
+
20
+ Once you're done, you can close the WebUI session by clicking the animated close button under "Run WebUI". You can also do this if you encounter any errors and need to restart the UI. You should also go to "Manage Sessions" and terminate the session, otherwise you may end up using all your free compute credits.
docs/options.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Options
2
+ To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
+ supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
+ in the file selector to select any file type, including video files) or use the microphone.
5
+
6
+ For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option, especially if you are using the `large-v1` model. Note that `large-v2` is a lot more forgiving, but you may still want to use a VAD with a slightly higher "VAD - Max Merge Size (s)" (60 seconds or more).
7
+
8
+ ## Model
9
+ Select the model that Whisper will use to transcribe the audio:
10
+
11
+ | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
12
+ |-----------|------------|--------------------|--------------------|---------------|----------------|
13
+ | tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
14
+ | base | 74 M | base.en | base | ~1 GB | ~16x |
15
+ | small | 244 M | small.en | small | ~2 GB | ~6x |
16
+ | medium | 769 M | medium.en | medium | ~5 GB | ~2x |
17
+ | large | 1550 M | N/A | large | ~10 GB | 1x |
18
+ | large-v2 | 1550 M | N/A | large | ~10 GB | 1x |
19
+
20
+ ## Language
21
+
22
+ Select the language, or leave it empty for Whisper to automatically detect it.
23
+
24
+ Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
25
+ language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
26
+
27
+ ## Inputs
28
+ The options "URL (YouTube, etc.)", "Upload Files" or "Micriphone Input" allows you to send an audio input to the model.
29
+
30
+ ### Multiple Files
31
+ Note that the UI will only process either the given URL or the upload files (including microphone) - not both.
32
+
33
+ But you can upload multiple files either through the "Upload files" option, or as a playlist on YouTube. Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section. When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
34
+
35
+ ## Task
36
+ Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
37
+
38
+ ## Vad
39
+ Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
40
+ loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
41
+ with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
42
+
43
+ Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
44
+ So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
45
+
46
+ * none
47
+ * Run whisper on the entire audio input
48
+ * silero-vad
49
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
50
+ on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
51
+ on the non-speech section.
52
+ * silero-vad-expand-into-gaps
53
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
54
+ such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
55
+ 00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
56
+ * silero-vad-skip-gaps
57
+ * As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
58
+ may cause dialogue to be skipped.
59
+ * periodic-vad
60
+ * Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
61
+ a sentence or word in two.
62
+
63
+ ## VAD - Merge Window
64
+ If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
65
+
66
+ ## VAD - Max Merge Size (s)
67
+ Disables merging of adjacent speech sections if they are this number of seconds long.
68
+
69
+ ## VAD - Padding (s)
70
+ The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
71
+ larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
72
+ a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
73
+ to each transcribed line. The default value is 1 second.
74
+
75
+ ## VAD - Prompt Window (s)
76
+ The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
77
+ number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
78
+ 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
79
+
80
+ Note that detected lines in gaps between speech sections will not be included in the prompt
81
+ (if silero-vad or silero-vad-expand-into-gaps) is used.
82
+
83
+ # Command Line Options
84
+
85
+ Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
86
+ CPU/GPU cores, the default model name/VAD and so on. Consult the README in the root folder for more information.
87
+
88
+ # Additional Options
89
+
90
+ In addition to the above, there's also a "Full" options interface that allows you to set all the options available in the Whisper
91
+ model. The options are as follows:
92
+
93
+ ## Initial Prompt
94
+ Optional text to provide as a prompt for the first 30 seconds window. Whisper will attempt to use this as a starting point for the transcription, but you can
95
+ also get creative and specify a style or format for the output of the transcription.
96
+
97
+ For instance, if you use the prompt "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they", Whisper will
98
+ be biased to output lower capital letters and no punctuation, and may also be biased to output the words in the prompt more often.
99
+
100
+ ## Temperature
101
+ The temperature to use when sampling. Default is 0 (zero). A higher temperature will result in more random output, while a lower temperature will be more deterministic.
102
+
103
+ ## Best Of - Non-zero temperature
104
+ The number of candidates to sample from when sampling with non-zero temperature. Default is 5.
105
+
106
+ ## Beam Size - Zero temperature
107
+ The number of beams to use in beam search when sampling with zero temperature. Default is 5.
108
+
109
+ ## Patience - Zero temperature
110
+ The patience value to use in beam search when sampling with zero temperature. As in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search.
111
+
112
+ ## Length Penalty - Any temperature
113
+ The token length penalty coefficient (alpha) to use when sampling with any temperature. As in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.
114
+
115
+ ## Suppress Tokens - Comma-separated list of token IDs
116
+ A comma-separated list of token IDs to suppress during sampling. The default value of "-1" will suppress most special characters except common punctuations.
117
+
118
+ ## Condition on previous text
119
+ If True, provide the previous output of the model as a prompt for the next window. Disabling this may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop.
120
+
121
+ ## FP16
122
+ Whether to perform inference in fp16. True by default.
123
+
124
+ ## Temperature increment on fallback
125
+ The temperature to increase when falling back when the decoding fails to meet either of the thresholds below. Default is 0.2.
126
+
127
+ ## Compression ratio threshold
128
+ If the gzip compression ratio is higher than this value, treat the decoding as failed. Default is 2.4.
129
+
130
+ ## Logprob threshold
131
+ If the average log probability is lower than this value, treat the decoding as failed. Default is -1.0.
132
+
133
+ ## No speech threshold
134
+ If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/whisper.git
2
+ transformers
3
+ ffmpeg-python==0.2.0
4
+ gradio==3.13.0
5
+ yt-dlp
6
+ torchaudio
7
+ altair
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
src/__pycache__/download.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
src/__pycache__/modelCache.cpython-310.pyc ADDED
Binary file (819 Bytes). View file
 
src/__pycache__/segments.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
src/__pycache__/source.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
src/__pycache__/vad.cpython-310.pyc ADDED
Binary file (14.4 kB). View file
 
src/__pycache__/vadParallel.cpython-310.pyc ADDED
Binary file (7.49 kB). View file
 
src/__pycache__/whisperContainer.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
src/download.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ entries = "entries" in info and info["entries"] or [info]
50
+
51
+ total_duration = 0
52
+
53
+ # Compute total duration
54
+ for entry in entries:
55
+ total_duration += float(entry["duration"])
56
+
57
+ if total_duration >= maxDuration:
58
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
59
+
60
+ ydl.add_post_processor(filename_collector)
61
+ ydl.download([url])
62
+
63
+ if len(filename_collector.filenames) <= 0:
64
+ raise Exception("Cannot download " + url)
65
+
66
+ result = []
67
+
68
+ for filename in filename_collector.filenames:
69
+ result.append(filename)
70
+ print("Downloaded " + filename)
71
+
72
+ return result
73
+
74
+ class ExceededMaximumDuration(Exception):
75
+ def __init__(self, videoDuration, maxDuration, message):
76
+ self.videoDuration = videoDuration
77
+ self.maxDuration = maxDuration
78
+ super().__init__(message)
src/modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/source.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
2
+ import os
3
+ import pathlib
4
+ from typing import List
5
+ import zipfile
6
+
7
+ import ffmpeg
8
+ from more_itertools import unzip
9
+
10
+ from src.download import ExceededMaximumDuration, download_url
11
+
12
+ MAX_FILE_PREFIX_LENGTH = 17
13
+
14
+ class AudioSource:
15
+ def __init__(self, source_path, source_name = None):
16
+ self.source_path = source_path
17
+ self.source_name = source_name
18
+
19
+ # Load source name if not provided
20
+ if (self.source_name is None):
21
+ file_path = pathlib.Path(self.source_path)
22
+ self.source_name = file_path.name
23
+
24
+ def get_full_name(self):
25
+ return self.source_name
26
+
27
+ def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
28
+ file_path = pathlib.Path(self.source_name)
29
+ short_name = file_path.stem[:max_length] + file_path.suffix
30
+
31
+ return short_name
32
+
33
+ def __str__(self) -> str:
34
+ return self.source_path
35
+
36
+ class AudioSourceCollection:
37
+ def __init__(self, sources: List[AudioSource]):
38
+ self.sources = sources
39
+
40
+ def __iter__(self):
41
+ return iter(self.sources)
42
+
43
+ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
44
+ output: List[AudioSource] = []
45
+
46
+ if urlData:
47
+ # Download from YouTube. This could also be a playlist or a channel.
48
+ output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
49
+ else:
50
+ # Add input files
51
+ if (multipleFiles is not None):
52
+ output.extend([ AudioSource(x.name) for x in multipleFiles ])
53
+ if (microphoneData is not None):
54
+ output.append(AudioSource(microphoneData))
55
+
56
+ total_duration = 0
57
+
58
+ # Calculate total audio length. We do this even if input_audio_max_duration
59
+ # is disabled to ensure that all the audio files are valid.
60
+ for source in output:
61
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
62
+ total_duration += float(audioDuration)
63
+
64
+ # Ensure the total duration of the audio is not too long
65
+ if input_audio_max_duration > 0:
66
+ if float(total_duration) > input_audio_max_duration:
67
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
68
+
69
+ # Return a list of audio sources
70
+ return output
src/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+
8
+
9
+ def exact_div(x, y):
10
+ assert x % y == 0
11
+ return x // y
12
+
13
+
14
+ def str2bool(string):
15
+ str2val = {"True": True, "False": False}
16
+ if string in str2val:
17
+ return str2val[string]
18
+ else:
19
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
20
+
21
+
22
+ def optional_int(string):
23
+ return None if string == "None" else int(string)
24
+
25
+
26
+ def optional_float(string):
27
+ return None if string == "None" else float(string)
28
+
29
+
30
+ def compression_ratio(text) -> float:
31
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
32
+
33
+
34
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
35
+ assert seconds >= 0, "non-negative timestamp expected"
36
+ milliseconds = round(seconds * 1000.0)
37
+
38
+ hours = milliseconds // 3_600_000
39
+ milliseconds -= hours * 3_600_000
40
+
41
+ minutes = milliseconds // 60_000
42
+ milliseconds -= minutes * 60_000
43
+
44
+ seconds = milliseconds // 1_000
45
+ milliseconds -= seconds * 1_000
46
+
47
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
48
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
49
+
50
+
51
+ def write_txt(transcript: Iterator[dict], file: TextIO):
52
+ for segment in transcript:
53
+ print(segment['text'].strip(), file=file, flush=True)
54
+
55
+
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
+ print("WEBVTT\n", file=file)
58
+ for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
+
61
+ print(
62
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
63
+ f"{text}\n",
64
+ file=file,
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
+ """
71
+ Write a transcript to a file in SRT format.
72
+ Example usage:
73
+ from pathlib import Path
74
+ from whisper.utils import write_srt
75
+ result = transcribe(model, audio_path, temperature=temperature, **args)
76
+ # save SRT
77
+ audio_basename = Path(audio_path).stem
78
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
79
+ write_srt(result["segments"], file=srt)
80
+ """
81
+ for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
+
84
+ # write srt lines
85
+ print(
86
+ f"{i}\n"
87
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
88
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
89
+ f"{text}\n",
90
+ file=file,
91
+ flush=True,
92
+ )
93
+
94
+ def process_text(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
+ return '\n'.join(lines)
100
+
101
+ def slugify(value, allow_unicode=False):
102
+ """
103
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
104
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
105
+ dashes to single dashes. Remove characters that aren't alphanumerics,
106
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
107
+ trailing whitespace, dashes, and underscores.
108
+ """
109
+ value = str(value)
110
+ if allow_unicode:
111
+ value = unicodedata.normalize('NFKC', value)
112
+ else:
113
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
+ value = re.sub(r'[^\w\s-]', '', value.lower())
115
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
src/vad.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+ import time
4
+
5
+ from typing import Any, Deque, Iterator, List, Dict
6
+
7
+ from pprint import pprint
8
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
9
+
10
+ from src.segments import merge_timestamps
11
+ from src.whisperContainer import WhisperCallback
12
+
13
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
14
+ try:
15
+ import tensorflow as tf
16
+ except ModuleNotFoundError:
17
+ # Error handling
18
+ pass
19
+
20
+ import torch
21
+
22
+ import ffmpeg
23
+ import numpy as np
24
+
25
+ from src.utils import format_timestamp
26
+ from enum import Enum
27
+
28
+ class NonSpeechStrategy(Enum):
29
+ """
30
+ Ignore non-speech frames segments.
31
+ """
32
+ SKIP = 1
33
+ """
34
+ Just treat non-speech segments as speech.
35
+ """
36
+ CREATE_SEGMENT = 2
37
+ """
38
+ Expand speech segments into subsequent non-speech segments.
39
+ """
40
+ EXPAND_SEGMENT = 3
41
+
42
+ # Defaults for Silero
43
+ SPEECH_TRESHOLD = 0.3
44
+
45
+ # Minimum size of segments to process
46
+ MIN_SEGMENT_DURATION = 1
47
+
48
+ # The maximum time for texts from old segments to be used in the next segment
49
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
50
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
51
+
52
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
53
+
54
+ class TranscriptionConfig(ABC):
55
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
56
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
57
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
58
+ self.non_speech_strategy = non_speech_strategy
59
+ self.segment_padding_left = segment_padding_left
60
+ self.segment_padding_right = segment_padding_right
61
+ self.max_silent_period = max_silent_period
62
+ self.max_merge_size = max_merge_size
63
+ self.max_prompt_window = max_prompt_window
64
+ self.initial_segment_index = initial_segment_index
65
+
66
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
67
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
68
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
69
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
70
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
71
+ self.periodic_duration = periodic_duration
72
+
73
+ class AbstractTranscription(ABC):
74
+ def __init__(self, sampling_rate: int = 16000):
75
+ self.sampling_rate = sampling_rate
76
+
77
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
78
+ return load_audio(str, self.sampling_rate, start_time, duration)
79
+
80
+ def is_transcribe_timestamps_fast(self):
81
+ """
82
+ Determine if get_transcribe_timestamps is fast enough to not need parallelization.
83
+ """
84
+ return False
85
+
86
+ @abstractmethod
87
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
88
+ """
89
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
90
+
91
+ Parameters
92
+ ----------
93
+ audio: str
94
+ The audio file.
95
+ config: TranscriptionConfig
96
+ The transcription configuration.
97
+
98
+ Returns
99
+ -------
100
+ A list of start and end timestamps, in fractional seconds.
101
+ """
102
+ return
103
+
104
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
105
+ """
106
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
107
+ after merging the given segments using the specified configuration.
108
+
109
+ Parameters
110
+ ----------
111
+ audio: str
112
+ The audio file.
113
+ config: TranscriptionConfig
114
+ The transcription configuration.
115
+
116
+ Returns
117
+ -------
118
+ A list of start and end timestamps, in fractional seconds.
119
+ """
120
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
121
+ config.segment_padding_left, config.segment_padding_right)
122
+
123
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
124
+ # Expand segments to include the gaps between them
125
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
126
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
127
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
128
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
129
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
130
+ merged = self.expand_gaps(merged, total_duration=total_duration)
131
+ else:
132
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
133
+
134
+ print("Transcribing non-speech:")
135
+ pprint(merged)
136
+ return merged
137
+
138
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
139
+ """
140
+ Transcribe the given audo file.
141
+
142
+ Parameters
143
+ ----------
144
+ audio: str
145
+ The audio file.
146
+ whisperCallable: WhisperCallback
147
+ A callback object to call to transcribe each segment.
148
+
149
+ Returns
150
+ -------
151
+ A list of start and end timestamps, in fractional seconds.
152
+ """
153
+
154
+ max_audio_duration = get_audio_duration(audio)
155
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
156
+
157
+ # Get speech timestamps from full audio file
158
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
159
+
160
+ # A deque of transcribed segments that is passed to the next segment as a prompt
161
+ prompt_window = deque()
162
+
163
+ print("Processing timestamps:")
164
+ pprint(merged)
165
+
166
+ result = {
167
+ 'text': "",
168
+ 'segments': [],
169
+ 'language': ""
170
+ }
171
+ languageCounter = Counter()
172
+ detected_language = None
173
+
174
+ segment_index = config.initial_segment_index
175
+
176
+ # For each time segment, run whisper
177
+ for segment in merged:
178
+ segment_index += 1
179
+ segment_start = segment['start']
180
+ segment_end = segment['end']
181
+ segment_expand_amount = segment.get('expand_amount', 0)
182
+ segment_gap = segment.get('gap', False)
183
+
184
+ segment_duration = segment_end - segment_start
185
+
186
+ if segment_duration < MIN_SEGMENT_DURATION:
187
+ continue;
188
+
189
+ # Audio to run on Whisper
190
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
191
+ # Previous segments to use as a prompt
192
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
193
+
194
+ # Detected language
195
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
196
+
197
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
198
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
199
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
200
+
201
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
202
+
203
+ # Propagate expand amount to the segments
204
+ if (segment_expand_amount > 0):
205
+ segment_without_expansion = segment_duration - segment_expand_amount
206
+
207
+ for adjusted_segment in adjusted_segments:
208
+ adjusted_segment_end = adjusted_segment['end']
209
+
210
+ # Add expand amount if the segment got expanded
211
+ if (adjusted_segment_end > segment_without_expansion):
212
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
213
+
214
+ # Append to output
215
+ result['text'] += segment_result['text']
216
+ result['segments'].extend(adjusted_segments)
217
+
218
+ # Increment detected language
219
+ if not segment_gap:
220
+ languageCounter[segment_result['language']] += 1
221
+
222
+ # Update prompt window
223
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
224
+
225
+ if detected_language is not None:
226
+ result['language'] = detected_language
227
+
228
+ return result
229
+
230
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
231
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
232
+ # Add segments to the current prompt window (unless it is a speech gap)
233
+ if not segment_gap:
234
+ for segment in adjusted_segments:
235
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
236
+ prompt_window.append(segment)
237
+
238
+ while (len(prompt_window) > 0):
239
+ first_end_time = prompt_window[0].get('end', 0)
240
+ # Time expanded in the segments should be discounted from the prompt window
241
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
242
+
243
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
244
+ prompt_window.popleft()
245
+ else:
246
+ break
247
+
248
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
249
+ result = []
250
+ last_end_time = 0
251
+
252
+ for segment in segments:
253
+ segment_start = float(segment['start'])
254
+ segment_end = float(segment['end'])
255
+
256
+ if (last_end_time != segment_start):
257
+ delta = segment_start - last_end_time
258
+
259
+ if (min_gap_length is None or delta >= min_gap_length):
260
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
261
+
262
+ last_end_time = segment_end
263
+ result.append(segment)
264
+
265
+ # Also include total duration if specified
266
+ if (total_duration is not None and last_end_time < total_duration):
267
+ delta = total_duration - segment_start
268
+
269
+ if (min_gap_length is None or delta >= min_gap_length):
270
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
271
+
272
+ return result
273
+
274
+ # Expand the end time of each segment to the start of the next segment
275
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
276
+ result = []
277
+
278
+ if len(segments) == 0:
279
+ return result
280
+
281
+ # Add gap at the beginning if needed
282
+ if (segments[0]['start'] > 0):
283
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
284
+
285
+ for i in range(len(segments) - 1):
286
+ current_segment = segments[i]
287
+ next_segment = segments[i + 1]
288
+
289
+ delta = next_segment['start'] - current_segment['end']
290
+
291
+ # Expand if the gap actually exists
292
+ if (delta >= 0):
293
+ current_segment = current_segment.copy()
294
+ current_segment['expand_amount'] = delta
295
+ current_segment['end'] = next_segment['start']
296
+
297
+ result.append(current_segment)
298
+
299
+ # Add last segment
300
+ last_segment = segments[-1]
301
+ result.append(last_segment)
302
+
303
+ # Also include total duration if specified
304
+ if (total_duration is not None):
305
+ last_segment = result[-1]
306
+
307
+ if (last_segment['end'] < total_duration):
308
+ last_segment = last_segment.copy()
309
+ last_segment['end'] = total_duration
310
+ result[-1] = last_segment
311
+
312
+ return result
313
+
314
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
315
+ result = []
316
+
317
+ if len(segments) == 0:
318
+ return result
319
+
320
+ # Add gap at the beginning if needed
321
+ if (segments[0]['start'] > 0):
322
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
323
+
324
+ for i in range(len(segments) - 1):
325
+ expanded = False
326
+ current_segment = segments[i]
327
+ next_segment = segments[i + 1]
328
+
329
+ delta = next_segment['start'] - current_segment['end']
330
+
331
+ if (max_expand_size is not None and delta <= max_expand_size):
332
+ # Just expand the current segment
333
+ current_segment = current_segment.copy()
334
+ current_segment['expand_amount'] = delta
335
+ current_segment['end'] = next_segment['start']
336
+ expanded = True
337
+
338
+ result.append(current_segment)
339
+
340
+ # Add a gap to the next segment if needed
341
+ if (delta >= 0 and not expanded):
342
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
343
+
344
+ # Add last segment
345
+ last_segment = segments[-1]
346
+ result.append(last_segment)
347
+
348
+ # Also include total duration if specified
349
+ if (total_duration is not None):
350
+ last_segment = result[-1]
351
+
352
+ delta = total_duration - last_segment['end']
353
+
354
+ if (delta > 0):
355
+ if (max_expand_size is not None and delta <= max_expand_size):
356
+ # Expand the last segment
357
+ last_segment = last_segment.copy()
358
+ last_segment['expand_amount'] = delta
359
+ last_segment['end'] = total_duration
360
+ result[-1] = last_segment
361
+ else:
362
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
363
+
364
+ return result
365
+
366
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
367
+ result = []
368
+
369
+ for segment in segments:
370
+ segment_start = float(segment['start'])
371
+ segment_end = float(segment['end'])
372
+
373
+ # Filter segments?
374
+ if (max_source_time is not None):
375
+ if (segment_start > max_source_time):
376
+ continue
377
+ segment_end = min(max_source_time, segment_end)
378
+
379
+ new_segment = segment.copy()
380
+
381
+ # Add to start and end
382
+ new_segment['start'] = segment_start + adjust_seconds
383
+ new_segment['end'] = segment_end + adjust_seconds
384
+ result.append(new_segment)
385
+ return result
386
+
387
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
388
+ result = []
389
+
390
+ for entry in timestamps:
391
+ start = entry['start']
392
+ end = entry['end']
393
+
394
+ result.append({
395
+ 'start': start * factor,
396
+ 'end': end * factor
397
+ })
398
+ return result
399
+
400
+
401
+ class VadSileroTranscription(AbstractTranscription):
402
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
403
+ super().__init__(sampling_rate=sampling_rate)
404
+ self.model = None
405
+ self.cache = cache
406
+ self._initialize_model()
407
+
408
+ def _initialize_model(self):
409
+ if (self.cache is not None):
410
+ model_key = "VadSileroTranscription"
411
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
412
+ print("Loaded Silerio model from cache.")
413
+ else:
414
+ self.model, self.get_speech_timestamps = self._create_model()
415
+ print("Created Silerio model")
416
+
417
+ def _create_model(self):
418
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
419
+
420
+ # Silero does not benefit from multi-threading
421
+ torch.set_num_threads(1) # JIT
422
+ (get_speech_timestamps, _, _, _, _) = utils
423
+
424
+ return model, get_speech_timestamps
425
+
426
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
427
+ result = []
428
+
429
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
430
+ perf_start_time = time.perf_counter()
431
+
432
+ # Divide procesisng of audio into chunks
433
+ chunk_start = start_time
434
+
435
+ while (chunk_start < end_time):
436
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
437
+
438
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
439
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
440
+
441
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
442
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
443
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
444
+
445
+ #pprint(adjusted)
446
+
447
+ result.extend(adjusted)
448
+ chunk_start += chunk_duration
449
+
450
+ perf_end_time = time.perf_counter()
451
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
452
+
453
+ return result
454
+
455
+ def __getstate__(self):
456
+ # We only need the sampling rate
457
+ return { 'sampling_rate': self.sampling_rate }
458
+
459
+ def __setstate__(self, state):
460
+ self.sampling_rate = state['sampling_rate']
461
+ self.model = None
462
+ # Use the global cache
463
+ self.cache = GLOBAL_MODEL_CACHE
464
+ self._initialize_model()
465
+
466
+ # A very simple VAD that just marks every N seconds as speech
467
+ class VadPeriodicTranscription(AbstractTranscription):
468
+ def __init__(self, sampling_rate: int = 16000):
469
+ super().__init__(sampling_rate=sampling_rate)
470
+
471
+ def is_transcribe_timestamps_fast(self):
472
+ # This is a very fast VAD - no need to parallelize it
473
+ return True
474
+
475
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
476
+ result = []
477
+
478
+ # Generate a timestamp every N seconds
479
+ start_timestamp = start_time
480
+
481
+ while (start_timestamp < end_time):
482
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
483
+ segment_duration = end_timestamp - start_timestamp
484
+
485
+ # Minimum duration is 1 second
486
+ if (segment_duration >= 1):
487
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
488
+
489
+ start_timestamp = end_timestamp
490
+
491
+ return result
492
+
493
+ def get_audio_duration(file: str):
494
+ return float(ffmpeg.probe(file)["format"]["duration"])
495
+
496
+ def load_audio(file: str, sample_rate: int = 16000,
497
+ start_time: str = None, duration: str = None):
498
+ """
499
+ Open an audio file and read as mono waveform, resampling as necessary
500
+
501
+ Parameters
502
+ ----------
503
+ file: str
504
+ The audio file to open
505
+
506
+ sr: int
507
+ The sample rate to resample the audio if necessary
508
+
509
+ start_time: str
510
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
511
+
512
+ duration: str
513
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
514
+
515
+ Returns
516
+ -------
517
+ A NumPy array containing the audio waveform, in float32 dtype.
518
+ """
519
+ try:
520
+ inputArgs = {'threads': 0}
521
+
522
+ if (start_time is not None):
523
+ inputArgs['ss'] = start_time
524
+ if (duration is not None):
525
+ inputArgs['t'] = duration
526
+
527
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
528
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
529
+ out, _ = (
530
+ ffmpeg.input(file, **inputArgs)
531
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
532
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
533
+ )
534
+ except ffmpeg.Error as e:
535
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
536
+
537
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
src/vadParallel.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import threading
3
+ import time
4
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
5
+ from src.whisperContainer import WhisperCallback
6
+
7
+ from multiprocessing import Pool
8
+
9
+ from typing import Any, Dict, List
10
+ import os
11
+
12
+
13
+ class ParallelContext:
14
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
15
+ self.num_processes = num_processes
16
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
17
+ self.lock = threading.Lock()
18
+
19
+ self.ref_count = 0
20
+ self.pool = None
21
+ self.cleanup_timer = None
22
+
23
+ def get_pool(self):
24
+ # Initialize pool lazily
25
+ if (self.pool is None):
26
+ context = multiprocessing.get_context('spawn')
27
+ self.pool = context.Pool(self.num_processes)
28
+
29
+ self.ref_count = self.ref_count + 1
30
+
31
+ if (self.auto_cleanup_timeout_seconds is not None):
32
+ self._stop_auto_cleanup()
33
+
34
+ return self.pool
35
+
36
+ def return_pool(self, pool):
37
+ if (self.pool == pool and self.ref_count > 0):
38
+ self.ref_count = self.ref_count - 1
39
+
40
+ if (self.ref_count == 0):
41
+ if (self.auto_cleanup_timeout_seconds is not None):
42
+ self._start_auto_cleanup()
43
+
44
+ def _start_auto_cleanup(self):
45
+ if (self.cleanup_timer is not None):
46
+ self.cleanup_timer.cancel()
47
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
48
+ self.cleanup_timer.start()
49
+
50
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
51
+
52
+ def _stop_auto_cleanup(self):
53
+ if (self.cleanup_timer is not None):
54
+ self.cleanup_timer.cancel()
55
+ self.cleanup_timer = None
56
+
57
+ print("Stopped auto cleanup of pool")
58
+
59
+ def _execute_cleanup(self):
60
+ print("Executing cleanup of pool")
61
+
62
+ if (self.ref_count == 0):
63
+ self.close()
64
+
65
+ def close(self):
66
+ self._stop_auto_cleanup()
67
+
68
+ if (self.pool is not None):
69
+ print("Closing pool of " + str(self.num_processes) + " processes")
70
+ self.pool.close()
71
+ self.pool.join()
72
+ self.pool = None
73
+
74
+ class ParallelTranscriptionConfig(TranscriptionConfig):
75
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
76
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
77
+ self.device_id = device_id
78
+ self.override_timestamps = override_timestamps
79
+
80
+ class ParallelTranscription(AbstractTranscription):
81
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
82
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
83
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
84
+
85
+ def __init__(self, sampling_rate: int = 16000):
86
+ super().__init__(sampling_rate=sampling_rate)
87
+
88
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
89
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
90
+ total_duration = get_audio_duration(audio)
91
+
92
+ # First, get the timestamps for the original audio
93
+ if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
94
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
95
+ else:
96
+ timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
97
+ merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
98
+
99
+ # We must make sure the whisper model is downloaded
100
+ if (len(gpu_devices) > 1):
101
+ whisperCallable.model_container.ensure_downloaded()
102
+
103
+ # Split into a list for each device
104
+ # TODO: Split by time instead of by number of chunks
105
+ merged_split = list(self._split(merged, len(gpu_devices)))
106
+
107
+ # Parameters that will be passed to the transcribe function
108
+ parameters = []
109
+ segment_index = config.initial_segment_index
110
+
111
+ for i in range(len(gpu_devices)):
112
+ # Note that device_segment_list can be empty. But we will still create a process for it,
113
+ # as otherwise we run the risk of assigning the same device to multiple processes.
114
+ device_segment_list = list(merged_split[i]) if i < len(merged_split) else []
115
+ device_id = gpu_devices[i]
116
+
117
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
118
+
119
+ # Create a new config with the given device ID
120
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
121
+ segment_index += len(device_segment_list)
122
+
123
+ parameters.append([audio, whisperCallable, device_config]);
124
+
125
+ merged = {
126
+ 'text': '',
127
+ 'segments': [],
128
+ 'language': None
129
+ }
130
+
131
+ created_context = False
132
+
133
+ perf_start_gpu = time.perf_counter()
134
+
135
+ # Spawn a separate process for each device
136
+ try:
137
+ if (gpu_parallel_context is None):
138
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
139
+ created_context = True
140
+
141
+ # Get a pool of processes
142
+ pool = gpu_parallel_context.get_pool()
143
+
144
+ # Run the transcription in parallel
145
+ results = pool.starmap(self.transcribe, parameters)
146
+
147
+ for result in results:
148
+ # Merge the results
149
+ if (result['text'] is not None):
150
+ merged['text'] += result['text']
151
+ if (result['segments'] is not None):
152
+ merged['segments'].extend(result['segments'])
153
+ if (result['language'] is not None):
154
+ merged['language'] = result['language']
155
+
156
+ finally:
157
+ # Return the pool to the context
158
+ if (gpu_parallel_context is not None):
159
+ gpu_parallel_context.return_pool(pool)
160
+ # Always close the context if we created it
161
+ if (created_context):
162
+ gpu_parallel_context.close()
163
+
164
+ perf_end_gpu = time.perf_counter()
165
+ print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
166
+
167
+ return merged
168
+
169
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
170
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
171
+ parameters = []
172
+
173
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
174
+ chunk_start = 0
175
+ cpu_device_id = 0
176
+
177
+ perf_start_time = time.perf_counter()
178
+
179
+ # Create chunks that will be processed on the CPU
180
+ while (chunk_start < total_duration):
181
+ chunk_end = min(chunk_start + chunk_size, total_duration)
182
+
183
+ if (chunk_end - chunk_start < 1):
184
+ # No need to process chunks that are less than 1 second
185
+ break
186
+
187
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
188
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
189
+ parameters.append([audio, config, chunk_start, chunk_end]);
190
+
191
+ cpu_device_id += 1
192
+ chunk_start = chunk_end
193
+
194
+ created_context = False
195
+
196
+ # Spawn a separate process for each device
197
+ try:
198
+ if (cpu_parallel_context is None):
199
+ cpu_parallel_context = ParallelContext(cpu_device_count)
200
+ created_context = True
201
+
202
+ # Get a pool of processes
203
+ pool = cpu_parallel_context.get_pool()
204
+
205
+ # Run the transcription in parallel. Note that transcription must be picklable.
206
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
207
+
208
+ timestamps = []
209
+
210
+ # Flatten the results
211
+ for result in results:
212
+ timestamps.extend(result)
213
+
214
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
215
+
216
+ perf_end_time = time.perf_counter()
217
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
218
+ return merged
219
+
220
+ finally:
221
+ # Return the pool to the context
222
+ if (cpu_parallel_context is not None):
223
+ cpu_parallel_context.return_pool(pool)
224
+ # Always close the context if we created it
225
+ if (created_context):
226
+ cpu_parallel_context.close()
227
+
228
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
229
+ return []
230
+
231
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
232
+ # Override timestamps that will be processed
233
+ if (config.override_timestamps is not None):
234
+ print("Using override timestamps of size " + str(len(config.override_timestamps)))
235
+ return config.override_timestamps
236
+ return super().get_merged_timestamps(timestamps, config, total_duration)
237
+
238
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
239
+ # Override device ID the first time
240
+ if (os.environ.get("INITIALIZED", None) is None):
241
+ os.environ["INITIALIZED"] = "1"
242
+
243
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
244
+ # just use the default GPU device.
245
+ if (config.device_id is not None):
246
+ print("Using device " + config.device_id)
247
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
248
+
249
+ return super().transcribe(audio, whisperCallable, config)
250
+
251
+ def _split(self, a, n):
252
+ """Split a list into n approximately equal parts."""
253
+ k, m = divmod(len(a), n)
254
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
255
+
src/whisperContainer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import os
3
+ import whisper
4
+
5
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
6
+
7
+ class WhisperContainer:
8
+ def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
9
+ self.model_name = model_name
10
+ self.device = device
11
+ self.download_root = download_root
12
+ self.cache = cache
13
+
14
+ # Will be created on demand
15
+ self.model = None
16
+
17
+ def get_model(self):
18
+ if self.model is None:
19
+
20
+ if (self.cache is None):
21
+ self.model = self._create_model()
22
+ else:
23
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
24
+ self.model = self.cache.get(model_key, self._create_model)
25
+ return self.model
26
+
27
+ def ensure_downloaded(self):
28
+ """
29
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
30
+ passing the container to a subprocess.
31
+ """
32
+ # Warning: Using private API here
33
+ try:
34
+ root_dir = self.download_root
35
+
36
+ if root_dir is None:
37
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
38
+
39
+ if self.model_name in whisper._MODELS:
40
+ whisper._download(whisper._MODELS[self.model_name], root_dir, False)
41
+ return True
42
+ except Exception as e:
43
+ # Given that the API is private, it could change at any time. We don't want to crash the program
44
+ print("Error pre-downloading model: " + str(e))
45
+ return False
46
+
47
+ def _create_model(self):
48
+ print("Loading whisper model " + self.model_name)
49
+ return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
50
+
51
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
52
+ """
53
+ Create a WhisperCallback object that can be used to transcript audio files.
54
+
55
+ Parameters
56
+ ----------
57
+ language: str
58
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
59
+ task: str
60
+ The task - either translate or transcribe.
61
+ initial_prompt: str
62
+ The initial prompt to use for the transcription.
63
+ decodeOptions: dict
64
+ Additional options to pass to the decoder. Must be pickleable.
65
+
66
+ Returns
67
+ -------
68
+ A WhisperCallback object.
69
+ """
70
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
71
+
72
+ # This is required for multiprocessing
73
+ def __getstate__(self):
74
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
75
+
76
+ def __setstate__(self, state):
77
+ self.model_name = state["model_name"]
78
+ self.device = state["device"]
79
+ self.download_root = state["download_root"]
80
+ self.model = None
81
+ # Depickled objects must use the global cache
82
+ self.cache = GLOBAL_MODEL_CACHE
83
+
84
+
85
+ class WhisperCallback:
86
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
87
+ self.model_container = model_container
88
+ self.language = language
89
+ self.task = task
90
+ self.initial_prompt = initial_prompt
91
+ self.decodeOptions = decodeOptions
92
+
93
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
94
+ """
95
+ Peform the transcription of the given audio file or data.
96
+
97
+ Parameters
98
+ ----------
99
+ audio: Union[str, np.ndarray, torch.Tensor]
100
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
101
+ segment_index: int
102
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
103
+ task: str
104
+ The task - either translate or transcribe.
105
+ prompt: str
106
+ The prompt to use for the transcription.
107
+ detected_language: str
108
+ The detected language of the audio file.
109
+
110
+ Returns
111
+ -------
112
+ The result of the Whisper call.
113
+ """
114
+ model = self.model_container.get_model()
115
+
116
+ return model.transcribe(audio, \
117
+ language=self.language if self.language else detected_language, task=self.task, \
118
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
119
+ **self.decodeOptions)
120
+
121
+ def _concat_prompt(self, prompt1, prompt2):
122
+ if (prompt1 is None):
123
+ return prompt2
124
+ elif (prompt2 is None):
125
+ return prompt1
126
+ else:
127
+ return prompt1 + " " + prompt2
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()