vitaliy-sharandin commited on
Commit
e0f5494
1 Parent(s): b6f9245

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -36
app.py CHANGED
@@ -1,37 +1,288 @@
 
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
-
4
- model = AutoModelForCausalLM.from_pretrained(
5
- 'vitaliy-sharandin/wiseai',
6
- load_in_8bit=True,
7
- device_map = {"": 0}
8
- )
9
- tokenizer = AutoTokenizer.from_pretrained('vitaliy-sharandin/wiseai')
10
-
11
- pipe = pipeline('text-generation', model=model,tokenizer=tokenizer)
12
-
13
- def generate_text(instruction, input):
14
- if not instruction.strip():
15
- return str('The instruction field is required.')
16
-
17
- if instruction.strip() and input.strip():
18
- input_prompt = (f"Below is an instruction that describes a task. "
19
- "Write a response that appropriately completes the request.\n\n"
20
- "### Instruction:\n"
21
- f"{instruction}\n\n"
22
- "### Input:\n"
23
- f"{input}\n\n"
24
- f"### Response: \n")
25
- else :
26
- input_prompt = (f"Below is an instruction that describes a task. "
27
- "Write a response that appropriately completes the request.\n\n"
28
- "### Instruction:\n"
29
- f"{instruction}\n\n"
30
- f"### Response: \n")
31
- result = pipe(input_prompt, max_length=200, top_p=0.9, temperature=0.9, num_return_sequences=1, return_full_text=False)[0]['generated_text']
32
- return result[:str(result).find("###")]
33
-
34
- iface = gr.Interface(fn=generate_text, inputs=[gr.Textbox(label="Instruction"),
35
- gr.Textbox(label="Additional Input")],
36
- outputs=gr.Textbox(label="Response"))
37
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import whisperx
4
+ import numpy as np
5
+ import moviepy.editor as mp
6
+ from moviepy.audio.AudioClip import AudioArrayClip
7
+ from pytube import YouTube
8
+ import deepl
9
+ import torch
10
+ import pyrubberband as pyrb
11
+ import soundfile as sf
12
+ import librosa
13
+ from TTS.api import TTS
14
+
15
+ os.environ["COQUI_TOS_AGREED"] = "1"
16
+ HF_TOKEN = os.environ["HF_TOKEN"]
17
+ DEEPL_TOKEN = os.environ["DEEPL_TOKEN"]
18
+
19
+ # Download video from Youtube
20
+ def download_youtube_video(url):
21
+ yt = YouTube(url)
22
+ stream = yt.streams.filter(file_extension='mp4').first()
23
+ output_path = stream.download()
24
+ return output_path
25
+
26
+
27
+ # Extract audio from video
28
+ def extract_audio(video_path):
29
+ clip = mp.VideoFileClip(video_path)
30
+ audio_path = os.path.splitext(video_path)[0] + ".wav"
31
+ clip.audio.write_audiofile(audio_path)
32
+ return audio_path
33
+
34
+
35
+ # Perform speech diarization
36
+ def speech_diarization(audio_path, hf_token):
37
+ device = "cuda"
38
+ batch_size = 16
39
+ compute_type = "float16"
40
+ model = whisperx.load_model("large-v2", device, compute_type=compute_type)
41
+
42
+ # 1. Transcribe audio
43
+ audio = whisperx.load_audio(audio_path)
44
+ result = model.transcribe(audio, batch_size=batch_size)
45
+
46
+ # delete model if low on GPU resources
47
+ import gc; gc.collect(); torch.cuda.empty_cache(); del model
48
+
49
+ # 2. Align whisper output
50
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
51
+ result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
52
+
53
+ # delete model if low on GPU resources
54
+ import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
55
+
56
+ # 3. Assign speaker labels
57
+ diarize_model = whisperx.DiarizationPipeline(model_name='pyannote/[email protected]', use_auth_token=hf_token, device=device)
58
+
59
+ # add min/max number of speakers if known
60
+ diarize_segments = diarize_model(audio)
61
+ # diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
62
+
63
+ result = whisperx.assign_word_speakers(diarize_segments, result)
64
+ print(f'\n[Original transcript]:\n{result["segments"]}\n')
65
+
66
+ return result["segments"]
67
+
68
+
69
+ # Create per speaker voice clips for tts voice cloning
70
+ def speaker_voice_clips(transcription, audio_path):
71
+ # Create 3 uninterrupted per speaker timecodes
72
+ snippets_timecodes = {}
73
+ for segment in transcription:
74
+ speaker = segment['speaker']
75
+
76
+ if speaker not in snippets_timecodes:
77
+ snippets_timecodes[speaker] = []
78
+
79
+ if len(snippets_timecodes[speaker]) < 3:
80
+ snippet = {
81
+ 'start': segment['start'],
82
+ 'end': segment['end']
83
+ }
84
+ snippets_timecodes[speaker].append(snippet)
85
+
86
+ # Cut voice clips and stitch them together
87
+ original_audio = mp.AudioFileClip(audio_path)
88
+ audio_file_directory = os.path.dirname(audio_path)
89
+
90
+ voice_clips = {}
91
+ for speaker, speaker_snippets in snippets_timecodes.items():
92
+ subclips = []
93
+ for snippet in speaker_snippets:
94
+ start, end = snippet['start'], snippet['end']
95
+ subclip = original_audio.subclip(start, end)
96
+ subclips.append(subclip)
97
+
98
+ concatenated_clip = mp.concatenate_audioclips(subclips)
99
+
100
+ output_filename = os.path.join(audio_file_directory, f"{speaker}_voice_clips.wav")
101
+ concatenated_clip.write_audiofile(output_filename)
102
+ voice_clips[speaker] = output_filename
103
+
104
+ return voice_clips
105
+
106
+
107
+ # Perform text translation
108
+ def translate_transcript(transcript, target_language, deepl_token):
109
+ translator = deepl.Translator(deepl_token)
110
+
111
+ translated_transcript = []
112
+ for segment in transcript:
113
+ text_to_translate = segment['text']
114
+ translated_text = translator.translate_text(text_to_translate, target_lang=target_language)
115
+
116
+ translated_segment = {
117
+ 'start': segment['start'],
118
+ 'end': segment['end'],
119
+ 'text': translated_text.text,
120
+ 'speaker': segment['speaker']
121
+ }
122
+
123
+ translated_transcript.append(translated_segment)
124
+
125
+ print(f'\n[Translated transcript]:\n{translated_transcript}\n')
126
+
127
+ return translated_transcript
128
+
129
+
130
+ # Adjust voice pace
131
+ def adjust_voice_pace(sound_array, sample_rate, target_duration):
132
+ duration = len(sound_array) / sample_rate
133
+ tempo_change = duration / target_duration
134
+ sound_array_stretched = pyrb.time_stretch(sound_array, sample_rate, tempo_change)
135
+ return sound_array_stretched
136
+
137
+
138
+ # Perform voice cloning
139
+ def voice_cloning_translation(translated_transcription, speakers_voice_clips, target_language, speaker_model, audio_path):
140
+ device = "cuda"
141
+
142
+ vits_language_map = {
143
+ 'en':'eng',
144
+ 'ru':'rus',
145
+ 'uk':'ukr',
146
+ 'pl':'pol'
147
+ }
148
+
149
+ # Select model
150
+ selected_model = None
151
+
152
+ if 'vits' in speaker_model.lower() or target_language is 'uk':
153
+ selected_model = f'tts_models/{vits_language_map[target_language]}/fairseq/vits'
154
+ else:
155
+ selected_model = 'tts_models/multilingual/multi-dataset/xtts_v2'
156
+
157
+ print(selected_model)
158
+
159
+
160
+ tts = None
161
+ final_audio_track = None
162
+
163
+ try:
164
+ # TODO uncomment when https://github.com/coqui-ai/TTS/issues/3224 is resolved
165
+ # tts = TTS(selected_model).to(device)
166
+
167
+ # Generate and concatenate voice clips per speaker
168
+
169
+ last_end_time = 0
170
+ clips = []
171
+
172
+ # Generate sentences
173
+ for speech_item in translated_transcription:
174
+
175
+ speech_item_duration = speech_item['end'] - speech_item['start']
176
+
177
+ # Silence
178
+ gap_duration = speech_item['start'] - last_end_time
179
+ if gap_duration > 0:
180
+ silent_audio = np.zeros((int(44100 * gap_duration), 2))
181
+ silent_clip = AudioArrayClip(silent_audio, fps=44100)
182
+ clips.append(silent_clip)
183
+ print(f"\nAdded silence: Start={last_end_time}, Duration={gap_duration}")
184
+
185
+ # Generate speech
186
+ print(f"[{speech_item['speaker']}]")
187
+ tts = TTS(selected_model).to(device)
188
+ audio = tts.tts_with_vc(text=speech_item['text'], speaker_wav=speakers_voice_clips[speech_item['speaker']], language=target_language)
189
+ sample_rate = tts.voice_converter.vc_config.audio.output_sample_rate
190
+
191
+ # Adjust pace to fit the speech timeframe if translated audio is longer than phrase
192
+ audio_duration = len(audio) / sample_rate
193
+ if speech_item_duration < audio_duration:
194
+ audio = adjust_voice_pace(audio, sample_rate, speech_item_duration)
195
+
196
+ # Resample to higher rate
197
+ new_sample_rate = 44100
198
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=new_sample_rate)
199
+
200
+ # Transform to AudioArrayClip object
201
+ audio = np.expand_dims(audio, axis=1)
202
+ audio_stereo = np.repeat(audio, 2, axis=1)
203
+ audio_clip = AudioArrayClip(audio_stereo, fps=44100)
204
+
205
+ # Cut out possible glitch from AudioArrayClip end
206
+ audio_clip = audio_clip.subclip(0, audio_clip.duration - 0.2)
207
+ clips.append(audio_clip)
208
+ print(f"Added speech: Start={speech_item['start']}, Final duration={audio_clip.duration}, Original duration={speech_item_duration}")
209
+
210
+ last_end_time = speech_item['start'] + audio_clip.duration
211
+
212
+ del tts; import gc; gc.collect(); torch.cuda.empty_cache()
213
+
214
+ # Merge sentences
215
+ final_audio_track = mp.concatenate_audioclips(clips)
216
+
217
+ audio_files_directory = os.path.dirname(audio_path)
218
+ final_audio_track.write_audiofile(os.path.join(audio_files_directory, "translated_voice_track.wav"), fps=44100)
219
+
220
+ except Exception as e:
221
+ if tts is not None:
222
+ import gc; gc.collect(); torch.cuda.empty_cache(); del tts
223
+ raise e
224
+
225
+ return final_audio_track
226
+
227
+
228
+ def dub_video(video_path, translated_audio_track, target_language):
229
+ video = mp.VideoFileClip(video_path)
230
+ video = video.subclip(0, translated_audio_track.duration)
231
+ original_audio = video.audio.volumex(0.2)
232
+ dubbed_audio = mp.CompositeAudioClip([original_audio, translated_audio_track.set_start(0)])
233
+ video_with_dubbing = video.set_audio(dubbed_audio)
234
+
235
+ video_with_dubbing_path = os.path.splitext(video_path)[0] + "_" + target_language + ".mp4"
236
+ video_with_dubbing.write_videofile(video_with_dubbing_path)
237
+
238
+ return video_with_dubbing_path
239
+
240
+
241
+ # Perform video translation
242
+ def video_translation(video_path, target_language, speaker_model, hf_token, deepl_token):
243
+
244
+ original_audio_path = extract_audio(video_path)
245
+
246
+ transcription = speech_diarization(original_audio_path, hf_token)
247
+
248
+ translated_transcription = translate_transcript(transcription, target_language, deepl_token)
249
+
250
+ speakers_voice_clips = speaker_voice_clips(transcription, original_audio_path)
251
+
252
+ translated_audio_track = voice_cloning_translation(translated_transcription, speakers_voice_clips, target_language, speaker_model, original_audio_path)
253
+
254
+ video_with_dubbing = dub_video(video_path, translated_audio_track, target_language)
255
+
256
+ return video_with_dubbing
257
+
258
+ def translate_video(_, video_path, __, youtube_link, ___, target_language, speaker_model):
259
+ try:
260
+ if not video_path and not youtube_link:
261
+ gr.Warning("You should either upload video or input a YouTube link")
262
+ return None
263
+ if youtube_link:
264
+ video_path = download_youtube_video(youtube_link)
265
+ dubbed_video = video_translation(video_path, target_language, speaker_model, HF_TOKEN, DEEPL_TOKEN)
266
+ except Exception as e:
267
+ print(f"An error occurred: {e}")
268
+ return gr.components.Video(dubbed_video)
269
+
270
+
271
+ inputs = [
272
+ gr.Markdown("## Currently supported languages are: English, Polish, Ukrainian and Russian"),
273
+ gr.Video(label="Upload a video file"),
274
+ gr.Markdown("**OR**"),
275
+ gr.Textbox(label="Paste YouTube link"),
276
+ gr.Markdown("---"),
277
+ gr.Dropdown(["en", "pl", "uk", "ru"], value="pl", label="Select translation target language"),
278
+ gr.Dropdown(["(Recommended) XTTS_V2", "VITs (will be default for Ukrainian)"], value="(Recommended) XTTS_V2", label="Select text-to-speech generation model")
279
+ ]
280
+
281
+ outputs = gr.Video(label="Translated video")
282
+
283
+ gr.Interface(fn=translate_video,
284
+ inputs=inputs,
285
+ outputs=outputs,
286
+ title="🌐AI Video Translation",
287
+ theme=gr.themes.Base()
288
+ ).launch(show_error=True, debug=True)