Spaces:
Running
Running
import streamlit as st | |
import tempfile | |
import os | |
import torch | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM | |
import librosa | |
import numpy as np | |
import ffmpeg | |
import time | |
import json | |
import psutil | |
def format_time(seconds): | |
minutes = int(seconds // 60) | |
secs = int(seconds % 60) | |
return f"{minutes}:{secs:02d}" | |
def seconds_to_srt_time(seconds): | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = int(seconds % 60) | |
millis = int((seconds - int(seconds)) * 1000) | |
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" | |
def load_model(language='en', summarizer_type='bart'): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if language == 'ur': | |
processor = AutoProcessor.from_pretrained("GogetaBlueMUI/whisper-medium-ur-fleurs") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained("GogetaBlueMUI/whisper-medium-ur-fleurs").to(device) | |
else: | |
processor = AutoProcessor.from_pretrained("openai/whisper-small") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to(device) | |
if device.type == "cuda": | |
model = model.half() | |
if summarizer_type == 'bart': | |
sum_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
sum_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device) | |
else: | |
sum_tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-large-book-summary") | |
sum_model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-large-book-summary").to(device) | |
return processor, model, sum_tokenizer, sum_model, device | |
def split_audio_into_chunks(audio, sr, chunk_duration): | |
chunk_samples = int(chunk_duration * sr) | |
chunks = [audio[start:start + chunk_samples] for start in range(0, len(audio), chunk_samples)] | |
return chunks | |
def transcribe_audio(audio, sr, processor, model, device, start_time, language, task="transcribe"): | |
inputs = processor(audio, sampling_rate=sr, return_tensors="pt") | |
input_features = inputs.input_features.to(device) | |
if model.dtype == torch.float16: | |
input_features = input_features.half() | |
generate_kwargs = { | |
"task": task, | |
"language": "urdu" if language == "ur" else language, | |
"max_new_tokens": 128, | |
"return_timestamps": True | |
} | |
try: | |
with torch.no_grad(): | |
outputs = model.generate(input_features, **generate_kwargs) | |
text = processor.decode(outputs[0], skip_special_tokens=True) | |
return [(text, start_time, start_time + len(audio) / sr)] | |
except Exception as e: | |
st.error(f"Transcription error: {str(e)}") | |
return [(f"Error: {str(e)}", start_time, start_time + len(audio) / sr)] | |
def process_chunks(chunks, sr, processor, model, device, language, chunk_duration, task="transcribe", transcript_file="temp_transcript.json"): | |
transcript = [] | |
chunk_start = 0 | |
total_chunks = len(chunks) | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
if os.path.exists(transcript_file): | |
os.remove(transcript_file) | |
for i, chunk in enumerate(chunks): | |
status_text.text(f"Processing chunk {i+1}/{total_chunks}...") | |
try: | |
memory = psutil.virtual_memory() | |
st.write(f"Memory usage: {memory.percent}% (Chunk {i+1}/{total_chunks})") | |
chunk_transcript = transcribe_audio(chunk, sr, processor, model, device, chunk_start, language, task) | |
transcript.extend(chunk_transcript) | |
with open(transcript_file, "w", encoding="utf-8") as f: | |
json.dump(transcript, f, ensure_ascii=False) | |
chunk_start += chunk_duration | |
progress_bar.progress((i + 1) / total_chunks) | |
except Exception as e: | |
st.error(f"Error processing chunk {i+1}: {str(e)}") | |
break | |
status_text.text("Processing complete!") | |
progress_bar.empty() | |
return transcript | |
def summarize_text(text, tokenizer, model, device, summarizer_type='bart'): | |
if summarizer_type == 'bart': | |
max_input_length = 1024 | |
max_summary_length = 150 | |
chunk_size = 512 | |
else: | |
max_input_length = 16384 | |
max_summary_length = 512 | |
chunk_size = 8192 | |
inputs = tokenizer(text, return_tensors="pt", truncation=False) | |
input_ids = inputs["input_ids"].to(device) | |
num_tokens = input_ids.shape[1] | |
st.write(f"Number of tokens in input: {num_tokens}") | |
if num_tokens < 50: | |
return "Transcript too short to summarize effectively." | |
try: | |
summaries = [] | |
if num_tokens <= max_input_length: | |
truncated_inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length).to(device) | |
with torch.no_grad(): | |
summary_ids = model.generate(truncated_inputs["input_ids"], num_beams=4, max_length=max_summary_length, min_length=50, early_stopping=True, temperature=0.7) | |
summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True)) | |
else: | |
st.write(f"Transcript exceeds {max_input_length} tokens. Processing in chunks...") | |
tokens = input_ids[0].tolist() | |
for i in range(0, num_tokens, chunk_size): | |
chunk_tokens = tokens[i:i + chunk_size] | |
chunk_input_ids = torch.tensor([chunk_tokens]).to(device) | |
with torch.no_grad(): | |
summary_ids = model.generate(chunk_input_ids, num_beams=4, max_length=max_summary_length // 2, min_length=25, early_stopping=True, temperature=0.7) | |
summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True)) | |
combined_summary = " ".join(summaries) | |
combined_inputs = tokenizer(combined_summary, return_tensors="pt", truncation=True, max_length=max_input_length).to(device) | |
with torch.no_grad(): | |
final_summary_ids = model.generate(combined_inputs["input_ids"], num_beams=4, max_length=max_summary_length, min_length=50, early_stopping=True, temperature=0.7) | |
summaries = [tokenizer.decode(final_summary_ids[0], skip_special_tokens=True)] | |
return " ".join(summaries) | |
except Exception as e: | |
st.error(f"Summarization error: {str(e)}") | |
return f"Error: {str(e)}" | |
def save_uploaded_file(uploaded_file): | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
return tmp_file.name | |
except Exception as e: | |
st.error(f"Error saving uploaded file: {str(e)}") | |
return None | |
def merge_intervals(intervals): | |
if not intervals: | |
return [] | |
intervals.sort(key=lambda x: x[0]) | |
merged = [intervals[0]] | |
for current in intervals[1:]: | |
previous = merged[-1] | |
if previous[1] >= current[0]: | |
merged[-1] = (previous[0], max(previous[1], current[1])) | |
else: | |
merged.append(current) | |
return merged | |
def create_edited_video(video_path, transcript, keep_indices): | |
try: | |
intervals_to_keep = [(transcript[i][1], transcript[i][2]) for i in keep_indices] | |
merged_intervals = merge_intervals(intervals_to_keep) | |
temp_files = [] | |
for j, (start, end) in enumerate(merged_intervals): | |
temp_file = f"temp_{j}.mp4" | |
ffmpeg.input(video_path, ss=start, to=end).output(temp_file, c='copy').run(overwrite_output=True, quiet=True) | |
temp_files.append(temp_file) | |
with open("list.txt", "w") as f: | |
for temp_file in temp_files: | |
f.write(f"file '{temp_file}'\n") | |
edited_video_path = "edited_video.mp4" | |
ffmpeg.input('list.txt', format='concat', safe=0).output(edited_video_path, c='copy').run(overwrite_output=True, quiet=True) | |
for temp_file in temp_files: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
if os.path.exists("list.txt"): | |
os.remove("list.txt") | |
return edited_video_path | |
except Exception as e: | |
st.error(f"Error creating edited video: {str(e)}") | |
return None | |
def generate_srt(transcript, include_timeframe=True): | |
srt_content = "" | |
for idx, (text, start, end) in enumerate(transcript, 1): | |
if include_timeframe: | |
start_time = seconds_to_srt_time(start) | |
end_time = seconds_to_srt_time(end) | |
srt_content += f"{idx}\n{start_time} --> {end_time}\n{text}\n\n" | |
else: | |
srt_content += f"{text}\n\n" | |
return srt_content | |
def main(): | |
st.title("Video Transcription and Summarization App") | |
st.markdown("Upload a video to transcribe its audio, generate a summary, and edit subtitles.") | |
# Inject CSS to set video height | |
st.markdown(""" | |
<style> | |
video { | |
width: 350px !important; | |
height: 500px !important; | |
object-fit: contain; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
if 'app_state' not in st.session_state: | |
st.session_state['app_state'] = 'upload' | |
if 'video_path' not in st.session_state: | |
st.session_state['video_path'] = None | |
if 'primary_transcript' not in st.session_state: | |
st.session_state['primary_transcript'] = None | |
if 'english_transcript' not in st.session_state: | |
st.session_state['english_transcript'] = None | |
if 'english_summary' not in st.session_state: | |
st.session_state['english_summary'] = None | |
if 'language' not in st.session_state: | |
st.session_state['language'] = None | |
if 'language_code' not in st.session_state: | |
st.session_state['language_code'] = None | |
if 'translate_to_english' not in st.session_state: | |
st.session_state['translate_to_english'] = False | |
if 'summarizer_type' not in st.session_state: | |
st.session_state['summarizer_type'] = None | |
if 'summary_generated' not in st.session_state: | |
st.session_state['summary_generated'] = False | |
if 'current_time' not in st.session_state: | |
st.session_state['current_time'] = 0 | |
if 'edited_video_path' not in st.session_state: | |
st.session_state['edited_video_path'] = None | |
if 'search_query' not in st.session_state: | |
st.session_state['search_query'] = "" | |
if 'show_timeframe' not in st.session_state: | |
st.session_state['show_timeframe'] = True | |
st.write(f"Current app state: {st.session_state['app_state']}") | |
if st.session_state['app_state'] == 'upload': | |
with st.form(key="upload_form"): | |
uploaded_file = st.file_uploader("Upload a video", type=["mp4"]) | |
if st.form_submit_button("Upload") and uploaded_file: | |
video_path = save_uploaded_file(uploaded_file) | |
if video_path: | |
st.session_state['video_path'] = video_path | |
st.session_state['app_state'] = 'processing' | |
st.write(f"Uploaded file: {uploaded_file.name}") | |
st.rerun() | |
if st.session_state['app_state'] == 'processing': | |
with st.form(key="processing_form"): | |
language = st.selectbox("Select language", ["English", "Urdu"], key="language_select") | |
language_code = "en" if language == "English" else "ur" | |
st.session_state['language'] = language | |
st.session_state['language_code'] = language_code | |
chunk_duration = st.number_input("Duration per chunk (seconds):", min_value=1.0, step=0.1, value=10.0) | |
if language_code == "ur": | |
translate_to_english = st.checkbox("Generate English translation", key="translate_checkbox") | |
st.session_state['translate_to_english'] = translate_to_english | |
else: | |
st.session_state['translate_to_english'] = False | |
if st.form_submit_button("Process"): | |
with st.spinner("Processing video..."): | |
start_time = time.time() | |
try: | |
st.write("Extracting audio...") | |
audio_path = "processed_audio.wav" | |
ffmpeg.input(st.session_state['video_path']).output(audio_path, ac=1, ar=16000).run(overwrite_output=True, quiet=True) | |
audio, sr = librosa.load(audio_path, sr=16000) | |
audio = np.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0) | |
audio_duration = len(audio) / sr | |
st.write(f"Audio duration: {audio_duration:.2f} seconds") | |
if audio_duration < 5: | |
st.error("Audio too short (< 5s). Upload a longer video.") | |
return | |
summarizer_type = 'bart' if audio_duration <= 300 else 'led' | |
st.write(f"Using summarizer: {summarizer_type}") | |
st.session_state['summarizer_type'] = summarizer_type | |
st.write("Loading models...") | |
processor, model, sum_tokenizer, sum_model, device = load_model(language_code, summarizer_type) | |
st.write("Splitting audio into chunks...") | |
chunks = split_audio_into_chunks(audio, sr, chunk_duration) | |
st.write(f"Number of chunks: {len(chunks)}") | |
st.write("Transcribing audio...") | |
primary_transcript = process_chunks(chunks, sr, processor, model, device, language_code, chunk_duration, task="transcribe", transcript_file="temp_primary_transcript.json") | |
english_transcript = None | |
if st.session_state['translate_to_english'] and language_code == "ur": | |
st.write("Translating to English...") | |
processor, model, _, _, device = load_model('en', summarizer_type) | |
english_transcript = process_chunks(chunks, sr, processor, model, device, 'ur', chunk_duration, task="translate", transcript_file="temp_english_transcript.json") | |
st.session_state.update({ | |
'primary_transcript': primary_transcript, | |
'english_transcript': english_transcript, | |
'summary_generated': False, | |
'app_state': 'results' | |
}) | |
st.write("Processing completed successfully!") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Processing failed: {str(e)}") | |
finally: | |
if os.path.exists(audio_path): | |
os.remove(audio_path) | |
for temp_file in ["temp_primary_transcript.json", "temp_english_transcript.json"]: | |
if os.path.exists(temp_file): | |
os.remove(temp_file) | |
if st.session_state['app_state'] == 'results': | |
st.video(st.session_state['video_path'], start_time=st.session_state['current_time']) | |
st.session_state['show_timeframe'] = st.checkbox("Show timeframe in transcript", value=st.session_state['show_timeframe']) | |
st.markdown("### Search Subtitles") | |
search_query = st.text_input("Search subtitles...", value=st.session_state['search_query'] or "", key="search_input") | |
st.session_state['search_query'] = search_query.lower() | |
st.markdown(f"### {st.session_state['language']} Transcript") | |
for text, start, end in st.session_state['primary_transcript']: | |
display_text = text.lower() | |
if not search_query or search_query in display_text: | |
label = f"[{format_time(start)} - {format_time(end)}] {text}" if st.session_state['show_timeframe'] else text | |
if st.button(label, key=f"primary_{start}"): | |
st.session_state['current_time'] = start | |
st.rerun() | |
if st.session_state['english_transcript']: | |
st.markdown("### English Translation") | |
for text, start, end in st.session_state['english_transcript']: | |
display_text = text.lower() | |
if not search_query or search_query in display_text: | |
label = f"[{format_time(start)} - {format_time(end)}] {text}" if st.session_state['show_timeframe'] else text | |
if st.button(label, key=f"english_{start}"): | |
st.session_state['current_time'] = start | |
st.rerun() | |
if (st.session_state['language_code'] == 'en' or st.session_state['translate_to_english']) and not st.session_state['summary_generated']: | |
if st.button("Generate Summary"): | |
with st.spinner("Generating summary..."): | |
try: | |
_, _, sum_tokenizer, sum_model, device = load_model(st.session_state['language_code'], st.session_state['summarizer_type']) | |
full_text = " ".join([text for text, _, _ in (st.session_state['english_transcript'] or st.session_state['primary_transcript'])]) | |
english_summary = summarize_text(full_text, sum_tokenizer, sum_model, device, st.session_state['summarizer_type']) | |
st.session_state['english_summary'] = english_summary | |
st.session_state['summary_generated'] = True | |
except Exception as e: | |
st.error(f"Summary generation failed: {str(e)}") | |
if st.session_state['english_summary'] and st.session_state['summary_generated']: | |
st.markdown("### Summary") | |
st.write(st.session_state['english_summary']) | |
st.markdown("### Download Subtitles") | |
include_timeframe = st.checkbox("Include timeframe in subtitles", value=True) | |
transcript_to_download = st.session_state['primary_transcript'] or st.session_state['english_transcript'] | |
if transcript_to_download: | |
srt_content = generate_srt(transcript_to_download, include_timeframe) | |
st.download_button(label="Download Subtitles (SRT)", data=srt_content, file_name="subtitles.srt", mime="text/plain") | |
st.markdown("### Edit Subtitles") | |
transcript_to_edit = st.session_state['primary_transcript'] or st.session_state['english_transcript'] | |
if transcript_to_edit and st.button("Delete Subtitles"): | |
st.session_state['app_state'] = 'editing' | |
st.rerun() | |
if st.session_state['app_state'] == 'editing': | |
st.markdown("### Delete Subtitles") | |
transcript_to_edit = st.session_state['primary_transcript'] or st.session_state['english_transcript'] | |
for i, (text, start, end) in enumerate(transcript_to_edit): | |
st.write(f"{i}: [{format_time(start)} - {format_time(end)}] {text}") | |
indices_input = st.text_input("Enter the indices of subtitles to delete (comma-separated, e.g., 0,1,3):") | |
if st.button("Confirm Deletion"): | |
try: | |
delete_indices = [int(idx.strip()) for idx in indices_input.split(',') if idx.strip()] | |
delete_indices = [idx for idx in delete_indices if 0 <= idx < len(transcript_to_edit)] | |
keep_indices = [i for i in range(len(transcript_to_edit)) if i not in delete_indices] | |
if not keep_indices: | |
st.error("All subtitles are deleted. No video to generate.") | |
else: | |
edited_video_path = create_edited_video(st.session_state['video_path'], transcript_to_edit, keep_indices) | |
if edited_video_path: | |
st.session_state['edited_video_path'] = edited_video_path | |
st.session_state['app_state'] = 'results' | |
st.rerun() | |
except ValueError: | |
st.error("Invalid input. Please enter comma-separated integers.") | |
except Exception as e: | |
st.error(f"Error during video editing: {str(e)}") | |
if st.button("Cancel Deletion"): | |
st.session_state['app_state'] = 'results' | |
st.rerun() | |
if st.session_state['app_state'] == 'results' and st.session_state['edited_video_path']: | |
st.markdown("### Edited Video") | |
st.video(st.session_state['edited_video_path']) | |
with open(st.session_state['edited_video_path'], "rb") as file: | |
st.download_button(label="Download Edited Video", data=file, file_name="edited_video.mp4", mime="video/mp4") | |
if st.session_state.get('video_path') and st.button("Reset"): | |
if st.session_state['video_path'] and os.path.exists(st.session_state['video_path']): | |
os.remove(st.session_state['video_path']) | |
if st.session_state['edited_video_path'] and os.path.exists(st.session_state['edited_video_path']): | |
os.remove(st.session_state['edited_video_path']) | |
st.session_state.clear() | |
st.rerun() | |
if __name__ == "__main__": | |
main() |