Spaces:
Runtime error
Runtime error
Arnaudding001
commited on
Commit
•
039d7b8
1
Parent(s):
5a2469e
Update app.py
Browse files
app.py
CHANGED
@@ -1,88 +1,254 @@
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
-
import
|
|
|
|
|
|
|
3 |
import whisper
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
)
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterator
|
2 |
+
|
3 |
+
from io import StringIO
|
4 |
import os
|
5 |
+
import pathlib
|
6 |
+
import tempfile
|
7 |
+
|
8 |
+
# External programs
|
9 |
import whisper
|
10 |
+
import ffmpeg
|
11 |
+
|
12 |
+
# UI
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
from src.download import ExceededMaximumDuration, download_url
|
16 |
+
from src.utils import slugify, write_srt, write_vtt
|
17 |
+
from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
18 |
+
|
19 |
+
# Limitations (set to -1 to disable)
|
20 |
+
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
21 |
+
|
22 |
+
# Whether or not to automatically delete all uploaded files, to save disk space
|
23 |
+
DELETE_UPLOADED_FILES = True
|
24 |
+
|
25 |
+
# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
|
26 |
+
MAX_FILE_PREFIX_LENGTH = 17
|
27 |
+
|
28 |
+
LANGUAGES = [
|
29 |
+
"English", "Chinese", "German", "Spanish", "Russian", "Korean",
|
30 |
+
"French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
|
31 |
+
"Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
|
32 |
+
"Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
|
33 |
+
"Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
|
34 |
+
"Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
|
35 |
+
"Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
|
36 |
+
"Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
|
37 |
+
"Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
|
38 |
+
"Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
|
39 |
+
"Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
|
40 |
+
"Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
|
41 |
+
"Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
|
42 |
+
"Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
|
43 |
+
"Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
|
44 |
+
"Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
|
45 |
+
"Hausa", "Bashkir", "Javanese", "Sundanese"
|
46 |
+
]
|
47 |
+
|
48 |
+
class WhisperTranscriber:
|
49 |
+
def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
|
50 |
+
self.model_cache = dict()
|
51 |
+
|
52 |
+
self.vad_model = None
|
53 |
+
self.inputAudioMaxDuration = inputAudioMaxDuration
|
54 |
+
self.deleteUploadedFiles = deleteUploadedFiles
|
55 |
+
|
56 |
+
def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
|
57 |
+
try:
|
58 |
+
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
|
59 |
+
|
60 |
+
try:
|
61 |
+
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
62 |
+
selectedModel = modelName if modelName is not None else "base"
|
63 |
+
|
64 |
+
model = self.model_cache.get(selectedModel, None)
|
65 |
+
|
66 |
+
if not model:
|
67 |
+
model = whisper.load_model(selectedModel)
|
68 |
+
self.model_cache[selectedModel] = model
|
69 |
+
|
70 |
+
# Execute whisper
|
71 |
+
result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
72 |
+
|
73 |
+
# Write result
|
74 |
+
downloadDirectory = tempfile.mkdtemp()
|
75 |
+
|
76 |
+
filePrefix = slugify(sourceName, allow_unicode=True)
|
77 |
+
download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
|
78 |
+
|
79 |
+
return download, text, vtt
|
80 |
+
|
81 |
+
finally:
|
82 |
+
# Cleanup source
|
83 |
+
if self.deleteUploadedFiles:
|
84 |
+
print("Deleting source file " + source)
|
85 |
+
os.remove(source)
|
86 |
+
|
87 |
+
except ExceededMaximumDuration as e:
|
88 |
+
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
89 |
+
|
90 |
+
def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
|
91 |
+
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
|
92 |
+
|
93 |
+
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
94 |
+
|
95 |
+
if ('task' in decodeOptions):
|
96 |
+
task = decodeOptions.pop('task')
|
97 |
+
|
98 |
+
# Callable for processing an audio file
|
99 |
+
whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
|
100 |
+
language=language if language else detected_language, task=task, \
|
101 |
+
initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
|
102 |
+
**decodeOptions)
|
103 |
+
|
104 |
+
# The results
|
105 |
+
if (vad == 'silero-vad'):
|
106 |
+
# Silero VAD where non-speech gaps are transcribed
|
107 |
+
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
108 |
+
result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
|
109 |
+
elif (vad == 'silero-vad-skip-gaps'):
|
110 |
+
# Silero VAD where non-speech gaps are simply ignored
|
111 |
+
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
112 |
+
result = self.vad_model.transcribe(audio_path, whisperCallable, skip_gaps)
|
113 |
+
elif (vad == 'silero-vad-expand-into-gaps'):
|
114 |
+
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
115 |
+
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
116 |
+
result = self.vad_model.transcribe(audio_path, whisperCallable, expand_gaps)
|
117 |
+
elif (vad == 'periodic-vad'):
|
118 |
+
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
119 |
+
# it may create a break in the middle of a sentence, causing some artifacts.
|
120 |
+
periodic_vad = VadPeriodicTranscription()
|
121 |
+
result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
|
122 |
+
else:
|
123 |
+
# Default VAD
|
124 |
+
result = whisperCallable(audio_path, 0, None, None)
|
125 |
+
|
126 |
+
return result
|
127 |
+
|
128 |
+
def _concat_prompt(self, prompt1, prompt2):
|
129 |
+
if (prompt1 is None):
|
130 |
+
return prompt2
|
131 |
+
elif (prompt2 is None):
|
132 |
+
return prompt1
|
133 |
+
else:
|
134 |
+
return prompt1 + " " + prompt2
|
135 |
+
|
136 |
+
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
|
137 |
+
# Use Silero VAD
|
138 |
+
if (self.vad_model is None):
|
139 |
+
self.vad_model = VadSileroTranscription()
|
140 |
+
|
141 |
+
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
142 |
+
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
|
143 |
+
segment_padding_left=vadPadding, segment_padding_right=vadPadding,
|
144 |
+
max_prompt_window=vadPromptWindow)
|
145 |
+
|
146 |
+
return config
|
147 |
+
|
148 |
+
def write_result(self, result: dict, source_name: str, output_dir: str):
|
149 |
+
if not os.path.exists(output_dir):
|
150 |
+
os.makedirs(output_dir)
|
151 |
+
|
152 |
+
text = result["text"]
|
153 |
+
language = result["language"]
|
154 |
+
languageMaxLineWidth = self.__get_max_line_width(language)
|
155 |
+
|
156 |
+
print("Max line width " + str(languageMaxLineWidth))
|
157 |
+
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
|
158 |
+
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
|
159 |
+
|
160 |
+
output_files = []
|
161 |
+
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
162 |
+
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
163 |
+
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
164 |
+
|
165 |
+
return output_files, text, vtt
|
166 |
+
|
167 |
+
def clear_cache(self):
|
168 |
+
self.model_cache = dict()
|
169 |
+
self.vad_model = None
|
170 |
+
|
171 |
+
def __get_source(self, urlData, uploadFile, microphoneData):
|
172 |
+
if urlData:
|
173 |
+
# Download from YouTube
|
174 |
+
source = download_url(urlData, self.inputAudioMaxDuration)[0]
|
175 |
+
else:
|
176 |
+
# File input
|
177 |
+
source = uploadFile if uploadFile is not None else microphoneData
|
178 |
+
|
179 |
+
if self.inputAudioMaxDuration > 0:
|
180 |
+
# Calculate audio length
|
181 |
+
audioDuration = ffmpeg.probe(source)["format"]["duration"]
|
182 |
+
|
183 |
+
if float(audioDuration) > self.inputAudioMaxDuration:
|
184 |
+
raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
|
185 |
+
|
186 |
+
file_path = pathlib.Path(source)
|
187 |
+
sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
|
188 |
+
|
189 |
+
return source, sourceName
|
190 |
+
|
191 |
+
def __get_max_line_width(self, language: str) -> int:
|
192 |
+
if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
|
193 |
+
# Chinese characters and kana are wider, so limit line length to 40 characters
|
194 |
+
return 40
|
195 |
+
else:
|
196 |
+
# TODO: Add more languages
|
197 |
+
# 80 latin characters should fit on a 1080p/720p screen
|
198 |
+
return 80
|
199 |
+
|
200 |
+
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
201 |
+
segmentStream = StringIO()
|
202 |
+
|
203 |
+
if format == 'vtt':
|
204 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
205 |
+
elif format == 'srt':
|
206 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
207 |
+
else:
|
208 |
+
raise Exception("Unknown format " + format)
|
209 |
+
|
210 |
+
segmentStream.seek(0)
|
211 |
+
return segmentStream.read()
|
212 |
+
|
213 |
+
def __create_file(self, text: str, directory: str, fileName: str) -> str:
|
214 |
+
# Write the text to a file
|
215 |
+
with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
|
216 |
+
file.write(text)
|
217 |
+
|
218 |
+
return file.name
|
219 |
+
|
220 |
+
|
221 |
+
def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
|
222 |
+
ui = WhisperTranscriber(inputAudioMaxDuration)
|
223 |
+
|
224 |
+
ui_description = "Whisper是一个语音转文字模型,经过多个语音数据集的训练而成。也可以进行多语言的识别任务和翻译(多种语言翻译成英文)"
|
225 |
+
|
226 |
+
|
227 |
+
ui_description += "\n\n\n\n对于时长大于10分钟的非英语音频文件,建议选择VAD选项中的Silero VAD (语音活动检测器)。"
|
228 |
+
|
229 |
+
if inputAudioMaxDuration > 0:
|
230 |
+
ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
|
231 |
+
|
232 |
+
|
233 |
+
demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, inputs=[
|
234 |
+
gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
|
235 |
+
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
|
236 |
+
gr.Text(label="URL (YouTube, etc.)"),
|
237 |
+
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
|
238 |
+
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
239 |
+
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
|
240 |
+
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], label="VAD"),
|
241 |
+
gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
|
242 |
+
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
|
243 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=1),
|
244 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
|
245 |
+
], outputs=[
|
246 |
+
gr.File(label="Download"),
|
247 |
+
gr.Text(label="Transcription"),
|
248 |
+
gr.Text(label="Segments")
|
249 |
+
])
|
250 |
+
|
251 |
+
demo.launch(share=share, server_name=server_name)
|
252 |
+
|
253 |
+
if __name__ == '__main__':
|
254 |
+
create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)
|