rafaaa2105's picture
Update app.py
63a13ce verified
raw
history blame
8.23 kB
import gradio as gr
import moviepy.editor as mp
from moviepy.video.tools.subtitles import SubtitlesClip
from datetime import timedelta
import os
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
MarianMTModel,
MarianTokenizer,
pipeline
)
import torch
import numpy as np
from pydub import AudioSegment
import spaces
# Dictionary of supported languages and their codes for MarianMT
LANGUAGE_CODES = {
"English": "en",
"Spanish": "es",
"French": "fr",
"German": "de",
"Italian": "it",
"Portuguese": "pt",
"Russian": "ru",
"Chinese": "zh",
"Japanese": "ja",
"Korean": "ko"
}
def get_model_name(source_lang, target_lang):
"""Get MarianMT model name for language pair"""
return f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
def format_timestamp(seconds):
"""Convert seconds to SRT timestamp format"""
td = timedelta(seconds=seconds)
hours = td.seconds//3600
minutes = (td.seconds//60)%60
seconds = td.seconds%60
milliseconds = td.microseconds//1000
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def translate_text(text, source_lang, target_lang):
"""Translate text using MarianMT"""
if source_lang == target_lang:
return text
try:
model_name = get_model_name(source_lang, target_lang)
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
translated = model.generate(**inputs)
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
return translated_text
except Exception as e:
print(f"Translation error: {e}")
return text
def load_audio(video_path):
"""Extract and load audio from video file"""
video = mp.VideoFileClip(video_path)
temp_audio_path = "temp_audio.wav"
video.audio.write_audiofile(temp_audio_path)
# Load audio using pydub
audio = AudioSegment.from_wav(temp_audio_path)
audio_array = np.array(audio.get_array_of_samples())
# Convert to float32 and normalize
audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
# If stereo, convert to mono
if len(audio_array.shape) > 1:
audio_array = audio_array.mean(axis=1)
return audio_array, audio.frame_rate, video, temp_audio_path
def create_srt(segments, target_lang="en"):
"""Convert transcribed segments to SRT format with optional translation"""
srt_content = ""
for i, segment in enumerate(segments, start=1):
start_time = format_timestamp(segment['start'])
end_time = format_timestamp(segment['end'])
text = segment['text'].strip()
# Translate if target language is different
if segment.get('language') and segment['language'] != target_lang:
text = translate_text(text, segment['language'], target_lang)
srt_content += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt_content
def create_subtitle_clips(segments, videosize, target_lang="en"):
"""Create subtitle clips for moviepy with translation support"""
subtitle_clips = []
for segment in segments:
start_time = segment['start']
end_time = segment['end']
duration = end_time - start_time
text = segment['text'].strip()
# Translate if target language is different
if segment.get('language') and segment['language'] != target_lang:
text = translate_text(text, segment['language'], target_lang)
text_clip = mp.TextClip(
text,
font='Arial',
fontsize=24,
color='white',
stroke_color='black',
stroke_width=1,
size=videosize,
method='caption'
).set_position(('center', 'bottom'))
text_clip = text_clip.set_start(start_time).set_duration(duration)
subtitle_clips.append(text_clip)
return subtitle_clips
@spaces.GPU
def process_video(video_path, target_lang="en"):
"""Main function to process video and add subtitles with translation"""
# Load CrisperWhisper model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "nyrahealth/CrisperWhisper"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True
).to(device)
processor = AutoProcessor.from_pretrained(model_id)
# Load audio and video
audio_array, sampling_rate, video, temp_audio_path = load_audio(video_path)
# Create pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device=device,
)
# Transcribe audio
result = pipe(audio_array, return_timestamps="word")
# Convert word-level timestamps to segments
segments = []
current_segment = {"text": "", "start": result["chunks"][0]["timestamp"][0]}
for chunk in result["chunks"]:
current_segment["text"] += " " + chunk["text"]
current_segment["end"] = chunk["timestamp"][1]
# Start new segment if text is long enough or enough time has passed
if len(current_segment["text"].split()) > 10 or \
(current_segment["end"] - current_segment["start"]) > 5.0:
segments.append(current_segment)
if chunk != result["chunks"][-1]: # If not the last chunk
current_segment = {"text": "", "start": chunk["timestamp"][1]}
# Add last segment if not empty
if current_segment["text"]:
segments.append(current_segment)
# Add detected language to segments
detected_language = "en" # CrisperWhisper is English-focused
for segment in segments:
segment['language'] = detected_language
# Create SRT content
srt_content = create_srt(segments, target_lang)
# Save SRT file
video_name = os.path.splitext(os.path.basename(video_path))[0]
srt_path = f"{video_name}_subtitles_{target_lang}.srt"
with open(srt_path, "w", encoding="utf-8") as f:
f.write(srt_content)
# Create subtitle clips
subtitle_clips = create_subtitle_clips(segments, video.size, target_lang)
# Combine video with subtitles
final_video = mp.CompositeVideoClip([video] + subtitle_clips)
# Save final video
output_video_path = f"{video_name}_with_subtitles_{target_lang}.mp4"
final_video.write_videofile(output_video_path)
# Clean up
os.remove(temp_audio_path)
video.close()
final_video.close()
return output_video_path, srt_path
def gradio_interface(video_file, target_language):
"""Gradio interface function with language selection"""
try:
video_path = video_file.name
target_lang = LANGUAGE_CODES[target_language]
output_video, srt_file = process_video(video_path, target_lang)
return output_video, srt_file
except Exception as e:
return str(e), None
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Video(label="Upload Video"),
gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="English",
label="Target Language"
)
],
outputs=[
gr.Video(label="Video with Subtitles"),
gr.File(label="SRT Subtitle File")
],
title="Video Subtitler with CrisperWhisper",
description="Upload a video to generate subtitles using CrisperWhisper, translate them to your chosen language, and embed them directly in the video."
)
if __name__ == "__main__":
iface.launch()