GogetaBlueMUI commited on
Commit
00cde4f
·
verified ·
1 Parent(s): 77a8514

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -415
app.py CHANGED
@@ -1,416 +1,399 @@
1
- import streamlit as st
2
- import tempfile
3
- import os
4
- import torch
5
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
6
- import librosa
7
- import numpy as np
8
- import ffmpeg
9
- import time
10
- import json
11
- import psutil
12
-
13
- def format_time(seconds):
14
- minutes = int(seconds // 60)
15
- secs = int(seconds % 60)
16
- return f"{minutes}:{secs:02d}"
17
-
18
- def seconds_to_srt_time(seconds):
19
- hours = int(seconds // 3600)
20
- minutes = int((seconds % 3600) // 60)
21
- secs = int(seconds % 60)
22
- millis = int((seconds - int(seconds)) * 1000)
23
- return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
24
-
25
- @st.cache_resource
26
- def load_model(language='en', summarizer_type='bart'):
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- if language == 'ur':
29
- processor = AutoProcessor.from_pretrained("GogetaBlueMUI/whisper-medium-ur-fleurs")
30
- model = AutoModelForSpeechSeq2Seq.from_pretrained("GogetaBlueMUI/whisper-medium-ur-fleurs").to(device)
31
- else:
32
- processor = AutoProcessor.from_pretrained("openai/whisper-small")
33
- model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to(device)
34
- if device.type == "cuda":
35
- model = model.half()
36
- if summarizer_type == 'bart':
37
- sum_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
38
- sum_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)
39
- else:
40
- sum_tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-large-book-summary")
41
- sum_model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-large-book-summary").to(device)
42
- return processor, model, sum_tokenizer, sum_model, device
43
-
44
- def split_audio_into_chunks(audio, sr, chunk_duration):
45
- chunk_samples = int(chunk_duration * sr)
46
- chunks = [audio[start:start + chunk_samples] for start in range(0, len(audio), chunk_samples)]
47
- return chunks
48
-
49
- def transcribe_audio(audio, sr, processor, model, device, start_time, language, task="transcribe"):
50
- inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
51
- input_features = inputs.input_features.to(device)
52
- if model.dtype == torch.float16:
53
- input_features = input_features.half()
54
- generate_kwargs = {
55
- "task": task,
56
- "language": "urdu" if language == "ur" else language,
57
- "max_new_tokens": 128,
58
- "return_timestamps": True
59
- }
60
- try:
61
- with torch.no_grad():
62
- outputs = model.generate(input_features, **generate_kwargs)
63
- text = processor.decode(outputs[0], skip_special_tokens=True)
64
- return [(text, start_time, start_time + len(audio) / sr)]
65
- except Exception as e:
66
- st.error(f"Transcription error: {str(e)}")
67
- return [(f"Error: {str(e)}", start_time, start_time + len(audio) / sr)]
68
-
69
- def process_chunks(chunks, sr, processor, model, device, language, chunk_duration, task="transcribe", transcript_file="temp_transcript.json"):
70
- transcript = []
71
- chunk_start = 0
72
- total_chunks = len(chunks)
73
- progress_bar = st.progress(0)
74
- status_text = st.empty()
75
- if os.path.exists(transcript_file):
76
- os.remove(transcript_file)
77
- for i, chunk in enumerate(chunks):
78
- status_text.text(f"Processing chunk {i+1}/{total_chunks}...")
79
- try:
80
- memory = psutil.virtual_memory()
81
- st.write(f"Memory usage: {memory.percent}% (Chunk {i+1}/{total_chunks})")
82
- chunk_transcript = transcribe_audio(chunk, sr, processor, model, device, chunk_start, language, task)
83
- transcript.extend(chunk_transcript)
84
- with open(transcript_file, "w", encoding="utf-8") as f:
85
- json.dump(transcript, f, ensure_ascii=False)
86
- chunk_start += chunk_duration
87
- progress_bar.progress((i + 1) / total_chunks)
88
- except Exception as e:
89
- st.error(f"Error processing chunk {i+1}: {str(e)}")
90
- break
91
- status_text.text("Processing complete!")
92
- progress_bar.empty()
93
- return transcript
94
-
95
- def summarize_text(text, tokenizer, model, device, summarizer_type='bart'):
96
- if summarizer_type == 'bart':
97
- max_input_length = 1024
98
- max_summary_length = 150
99
- chunk_size = 512
100
- else:
101
- max_input_length = 16384
102
- max_summary_length = 512
103
- chunk_size = 8192
104
- inputs = tokenizer(text, return_tensors="pt", truncation=False)
105
- input_ids = inputs["input_ids"].to(device)
106
- num_tokens = input_ids.shape[1]
107
- st.write(f"Number of tokens in input: {num_tokens}")
108
- if num_tokens < 50:
109
- return "Transcript too short to summarize effectively."
110
- try:
111
- summaries = []
112
- if num_tokens <= max_input_length:
113
- truncated_inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length).to(device)
114
- with torch.no_grad():
115
- summary_ids = model.generate(truncated_inputs["input_ids"], num_beams=4, max_length=max_summary_length, min_length=50, early_stopping=True, temperature=0.7)
116
- summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
117
- else:
118
- st.write(f"Transcript exceeds {max_input_length} tokens. Processing in chunks...")
119
- tokens = input_ids[0].tolist()
120
- for i in range(0, num_tokens, chunk_size):
121
- chunk_tokens = tokens[i:i + chunk_size]
122
- chunk_input_ids = torch.tensor([chunk_tokens]).to(device)
123
- with torch.no_grad():
124
- summary_ids = model.generate(chunk_input_ids, num_beams=4, max_length=max_summary_length // 2, min_length=25, early_stopping=True, temperature=0.7)
125
- summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
126
- combined_summary = " ".join(summaries)
127
- combined_inputs = tokenizer(combined_summary, return_tensors="pt", truncation=True, max_length=max_input_length).to(device)
128
- with torch.no_grad():
129
- final_summary_ids = model.generate(combined_inputs["input_ids"], num_beams=4, max_length=max_summary_length, min_length=50, early_stopping=True, temperature=0.7)
130
- summaries = [tokenizer.decode(final_summary_ids[0], skip_special_tokens=True)]
131
- return " ".join(summaries)
132
- except Exception as e:
133
- st.error(f"Summarization error: {str(e)}")
134
- return f"Error: {str(e)}"
135
-
136
- def save_uploaded_file(uploaded_file):
137
- try:
138
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file:
139
- tmp_file.write(uploaded_file.read())
140
- return tmp_file.name
141
- except Exception as e:
142
- st.error(f"Error saving uploaded file: {str(e)}")
143
- return None
144
-
145
- def merge_intervals(intervals):
146
- if not intervals:
147
- return []
148
- intervals.sort(key=lambda x: x[0])
149
- merged = [intervals[0]]
150
- for current in intervals[1:]:
151
- previous = merged[-1]
152
- if previous[1] >= current[0]:
153
- merged[-1] = (previous[0], max(previous[1], current[1]))
154
- else:
155
- merged.append(current)
156
- return merged
157
-
158
- def create_edited_video(video_path, transcript, keep_indices):
159
- try:
160
- intervals_to_keep = [(transcript[i][1], transcript[i][2]) for i in keep_indices]
161
- merged_intervals = merge_intervals(intervals_to_keep)
162
- temp_files = []
163
- resolution = st.session_state.get('resolution', '1280:720') # Default to 720p
164
- crf = st.session_state.get('crf', 23) # Default to medium quality
165
- scale_filter = f"scale={resolution}" if resolution else None
166
- for j, (start, end) in enumerate(merged_intervals):
167
- temp_file = f"temp_{j}.mp4"
168
- output_args = {
169
- 'vcodec': 'libx264',
170
- 'preset': 'medium',
171
- 'crf': crf,
172
- 'r': 30, # 30 FPS
173
- 'acodec': 'aac',
174
- 'ab': '128k',
175
- 'map_metadata': -1
176
- }
177
- if scale_filter:
178
- output_args['vf'] = scale_filter
179
- ffmpeg.input(video_path, ss=start, to=end).output(
180
- temp_file, **output_args
181
- ).run(overwrite_output=True, quiet=True)
182
- temp_files.append(temp_file)
183
- with open("list.txt", "w") as f:
184
- for temp_file in temp_files:
185
- f.write(f"file '{temp_file}'\n")
186
- edited_video_path = "edited_video.mp4"
187
- ffmpeg.input('list.txt', format='concat', safe=0).output(
188
- edited_video_path, **output_args
189
- ).run(overwrite_output=True, quiet=True)
190
- for temp_file in temp_files:
191
- if os.path.exists(temp_file):
192
- os.remove(temp_file)
193
- if os.path.exists("list.txt"):
194
- os.remove("list.txt")
195
- return edited_video_path
196
- except Exception as e:
197
- st.error(f"Error creating edited video: {str(e)}")
198
- return None
199
-
200
- def generate_srt(transcript, include_timeframe=True):
201
- srt_content = ""
202
- for idx, (text, start, end) in enumerate(transcript, 1):
203
- if include_timeframe:
204
- start_time = seconds_to_srt_time(start)
205
- end_time = seconds_to_srt_time(end)
206
- srt_content += f"{idx}\n{start_time} --> {end_time}\n{text}\n\n"
207
- else:
208
- srt_content += f"{text}\n\n"
209
- return srt_content
210
-
211
- def main():
212
- st.title("Video Transcription and Summarization App")
213
- st.markdown("Upload a video to transcribe its audio, generate a summary, and edit subtitles.")
214
-
215
- if 'app_state' not in st.session_state:
216
- st.session_state['app_state'] = 'upload'
217
- if 'video_path' not in st.session_state:
218
- st.session_state['video_path'] = None
219
- if 'primary_transcript' not in st.session_state:
220
- st.session_state['primary_transcript'] = None
221
- if 'english_transcript' not in st.session_state:
222
- st.session_state['english_transcript'] = None
223
- if 'english_summary' not in st.session_state:
224
- st.session_state['english_summary'] = None
225
- if 'language' not in st.session_state:
226
- st.session_state['language'] = None
227
- if 'language_code' not in st.session_state:
228
- st.session_state['language_code'] = None
229
- if 'translate_to_english' not in st.session_state:
230
- st.session_state['translate_to_english'] = False
231
- if 'summarizer_type' not in st.session_state:
232
- st.session_state['summarizer_type'] = None
233
- if 'summary_generated' not in st.session_state:
234
- st.session_state['summary_generated'] = False
235
- if 'current_time' not in st.session_state:
236
- st.session_state['current_time'] = 0
237
- if 'edited_video_path' not in st.session_state:
238
- st.session_state['edited_video_path'] = None
239
- if 'search_query' not in st.session_state:
240
- st.session_state['search_query'] = ""
241
- if 'show_timeframe' not in st.session_state:
242
- st.session_state['show_timeframe'] = True
243
- if 'resolution' not in st.session_state:
244
- st.session_state['resolution'] = '1280:720' # Default to 720p
245
- if 'crf' not in st.session_state:
246
- st.session_state['crf'] = 23 # Default to medium quality
247
-
248
- st.write(f"Current app state: {st.session_state['app_state']}")
249
-
250
- if st.session_state['app_state'] == 'upload':
251
- with st.form(key="upload_form"):
252
- uploaded_file = st.file_uploader("Upload a video", type=["mp4"])
253
- if st.form_submit_button("Upload") and uploaded_file:
254
- video_path = save_uploaded_file(uploaded_file)
255
- if video_path:
256
- st.session_state['video_path'] = video_path
257
- st.session_state['app_state'] = 'processing'
258
- st.write(f"Uploaded file: {uploaded_file.name}")
259
- st.rerun()
260
-
261
- if st.session_state['app_state'] == 'processing':
262
- with st.form(key="processing_form"):
263
- language = st.selectbox("Select language", ["English", "Urdu"], key="language_select")
264
- language_code = "en" if language == "English" else "ur"
265
- st.session_state['language'] = language
266
- st.session_state['language_code'] = language_code
267
- chunk_duration = st.number_input("Duration per chunk (seconds):", min_value=1.0, step=0.1, value=10.0)
268
- resolution = st.selectbox("Output resolution", ["Original", "1080p", "720p"], key="resolution_select")
269
- quality = st.selectbox("Video quality", ["High", "Medium", "Low"], key="quality_select")
270
- resolution_map = {"Original": None, "1080p": "1920:1080", "720p": "1280:720"}
271
- crf_map = {"High": 18, "Medium": 23, "Low": 28}
272
- st.session_state['resolution'] = resolution_map[resolution]
273
- st.session_state['crf'] = crf_map[quality]
274
- if language_code == "ur":
275
- translate_to_english = st.checkbox("Generate English translation", key="translate_checkbox")
276
- st.session_state['translate_to_english'] = translate_to_english
277
- else:
278
- st.session_state['translate_to_english'] = False
279
- if st.form_submit_button("Process"):
280
- with st.spinner("Processing video..."):
281
- start_time = time.time()
282
- try:
283
- st.write("Extracting audio...")
284
- audio_path = "processed_audio.wav"
285
- ffmpeg.input(st.session_state['video_path']).output(audio_path, ac=1, ar=16000).run(overwrite_output=True, quiet=True)
286
- audio, sr = librosa.load(audio_path, sr=16000)
287
- audio = np.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0)
288
- audio_duration = len(audio) / sr
289
- st.write(f"Audio duration: {audio_duration:.2f} seconds")
290
- if audio_duration < 5:
291
- st.error("Audio too short (< 5s). Upload a longer video.")
292
- return
293
- summarizer_type = 'bart' if audio_duration <= 300 else 'led'
294
- st.write(f"Using summarizer: {summarizer_type}")
295
- st.session_state['summarizer_type'] = summarizer_type
296
- st.write("Loading models...")
297
- processor, model, sum_tokenizer, sum_model, device = load_model(language_code, summarizer_type)
298
- st.write("Splitting audio into chunks...")
299
- chunks = split_audio_into_chunks(audio, sr, chunk_duration)
300
- st.write(f"Number of chunks: {len(chunks)}")
301
- st.write("Transcribing audio...")
302
- primary_transcript = process_chunks(chunks, sr, processor, model, device, language_code, chunk_duration, task="transcribe", transcript_file="temp_primary_transcript.json")
303
- english_transcript = None
304
- if st.session_state['translate_to_english'] and language_code == "ur":
305
- st.write("Translating to English...")
306
- processor, model, _, _, device = load_model('en', summarizer_type)
307
- english_transcript = process_chunks(chunks, sr, processor, model, device, 'ur', chunk_duration, task="translate", transcript_file="temp_english_transcript.json")
308
- st.session_state.update({
309
- 'primary_transcript': primary_transcript,
310
- 'english_transcript': english_transcript,
311
- 'summary_generated': False,
312
- 'app_state': 'results'
313
- })
314
- st.write("Processing completed successfully!")
315
- st.rerun()
316
- except Exception as e:
317
- st.error(f"Processing failed: {str(e)}")
318
- finally:
319
- if os.path.exists(audio_path):
320
- os.remove(audio_path)
321
- for temp_file in ["temp_primary_transcript.json", "temp_english_transcript.json"]:
322
- if os.path.exists(temp_file):
323
- os.remove(temp_file)
324
-
325
- if st.session_state['app_state'] == 'results':
326
- st.video(st.session_state['video_path'], start_time=st.session_state['current_time'])
327
- st.session_state['show_timeframe'] = st.checkbox("Show timeframe in transcript", value=st.session_state['show_timeframe'])
328
- st.markdown("### Search Subtitles")
329
- search_query = st.text_input("Search subtitles...", value=st.session_state['search_query'] or "", key="search_input")
330
- st.session_state['search_query'] = search_query.lower()
331
- st.markdown(f"### {st.session_state['language']} Transcript")
332
- for text, start, end in st.session_state['primary_transcript']:
333
- display_text = text.lower()
334
- if not search_query or search_query in display_text:
335
- label = f"[{format_time(start)} - {format_time(end)}] {text}" if st.session_state['show_timeframe'] else text
336
- if st.button(label, key=f"primary_{start}"):
337
- st.session_state['current_time'] = start
338
- st.rerun()
339
- if st.session_state['english_transcript']:
340
- st.markdown("### English Translation")
341
- for text, start, end in st.session_state['english_transcript']:
342
- display_text = text.lower()
343
- if not search_query or search_query in display_text:
344
- label = f"[{format_time(start)} - {format_time(end)}] {text}" if st.session_state['show_timeframe'] else text
345
- if st.button(label, key=f"english_{start}"):
346
- st.session_state['current_time'] = start
347
- st.rerun()
348
- if (st.session_state['language_code'] == 'en' or st.session_state['translate_to_english']) and not st.session_state['summary_generated']:
349
- if st.button("Generate Summary"):
350
- with st.spinner("Generating summary..."):
351
- try:
352
- _, _, sum_tokenizer, sum_model, device = load_model(st.session_state['language_code'], st.session_state['summarizer_type'])
353
- full_text = " ".join([text for text, _, _ in (st.session_state['english_transcript'] or st.session_state['primary_transcript'])])
354
- english_summary = summarize_text(full_text, sum_tokenizer, sum_model, device, st.session_state['summarizer_type'])
355
- st.session_state['english_summary'] = english_summary
356
- st.session_state['summary_generated'] = True
357
- except Exception as e:
358
- st.error(f"Summary generation failed: {str(e)}")
359
- if st.session_state['english_summary'] and st.session_state['summary_generated']:
360
- st.markdown("### Summary")
361
- st.write(st.session_state['english_summary'])
362
- st.markdown("### Download Subtitles")
363
- include_timeframe = st.checkbox("Include timeframe in subtitles", value=True)
364
- transcript_to_download = st.session_state['primary_transcript'] or st.session_state['english_transcript']
365
- if transcript_to_download:
366
- srt_content = generate_srt(transcript_to_download, include_timeframe)
367
- st.download_button(label="Download Subtitles (SRT)", data=srt_content, file_name="subtitles.srt", mime="text/plain")
368
- st.markdown("### Edit Subtitles")
369
- transcript_to_edit = st.session_state['primary_transcript'] or st.session_state['english_transcript']
370
- if transcript_to_edit and st.button("Delete Subtitles"):
371
- st.session_state['app_state'] = 'editing'
372
- st.rerun()
373
-
374
- if st.session_state['app_state'] == 'editing':
375
- st.markdown("### Delete Subtitles")
376
- transcript_to_edit = st.session_state['primary_transcript'] or st.session_state['english_transcript']
377
- for i, (text, start, end) in enumerate(transcript_to_edit):
378
- st.write(f"{i}: [{format_time(start)} - {format_time(end)}] {text}")
379
- indices_input = st.text_input("Enter the indices of subtitles to delete (comma-separated, e.g., 0,1,3):")
380
- if st.button("Confirm Deletion"):
381
- try:
382
- delete_indices = [int(idx.strip()) for idx in indices_input.split(',') if idx.strip()]
383
- delete_indices = [idx for idx in delete_indices if 0 <= idx < len(transcript_to_edit)]
384
- keep_indices = [i for i in range(len(transcript_to_edit)) if i not in delete_indices]
385
- if not keep_indices:
386
- st.error("All subtitles are deleted. No video to generate.")
387
- else:
388
- edited_video_path = create_edited_video(st.session_state['video_path'], transcript_to_edit, keep_indices)
389
- if edited_video_path:
390
- st.session_state['edited_video_path'] = edited_video_path
391
- st.session_state['app_state'] = 'results'
392
- st.rerun()
393
- except ValueError:
394
- st.error("Invalid input. Please enter comma-separated integers.")
395
- except Exception as e:
396
- st.error(f"Error during video editing: {str(e)}")
397
- if st.button("Cancel Deletion"):
398
- st.session_state['app_state'] = 'results'
399
- st.rerun()
400
-
401
- if st.session_state['app_state'] == 'results' and st.session_state['edited_video_path']:
402
- st.markdown("### Edited Video")
403
- st.video(st.session_state['edited_video_path'])
404
- with open(st.session_state['edited_video_path'], "rb") as file:
405
- st.download_button(label="Download Edited Video", data=file, file_name="edited_video.mp4", mime="video/mp4")
406
-
407
- if st.session_state.get('video_path') and st.button("Reset"):
408
- if st.session_state['video_path'] and os.path.exists(st.session_state['video_path']):
409
- os.remove(st.session_state['video_path'])
410
- if st.session_state['edited_video_path'] and os.path.exists(st.session_state['edited_video_path']):
411
- os.remove(st.session_state['edited_video_path'])
412
- st.session_state.clear()
413
- st.rerun()
414
-
415
- if __name__ == "__main__":
416
  main()
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import torch
5
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import librosa
7
+ import numpy as np
8
+ import ffmpeg
9
+ import time
10
+ import json
11
+ import psutil
12
+
13
+ def format_time(seconds):
14
+ minutes = int(seconds // 60)
15
+ secs = int(seconds % 60)
16
+ return f"{minutes}:{secs:02d}"
17
+
18
+ def seconds_to_srt_time(seconds):
19
+ hours = int(seconds // 3600)
20
+ minutes = int((seconds % 3600) // 60)
21
+ secs = int(seconds % 60)
22
+ millis = int((seconds - int(seconds)) * 1000)
23
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
24
+
25
+ @st.cache_resource
26
+ def load_model(language='en', summarizer_type='bart'):
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ if language == 'ur':
29
+ processor = AutoProcessor.from_pretrained("GogetaBlueMUI/whisper-medium-ur-fleurs")
30
+ model = AutoModelForSpeechSeq2Seq.from_pretrained("GogetaBlueMUI/whisper-medium-ur-fleurs").to(device)
31
+ else:
32
+ processor = AutoProcessor.from_pretrained("openai/whisper-small")
33
+ model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to(device)
34
+ if device.type == "cuda":
35
+ model = model.half()
36
+ if summarizer_type == 'bart':
37
+ sum_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
38
+ sum_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)
39
+ else:
40
+ sum_tokenizer = AutoTokenizer.from_pretrained("pszemraj/led-large-book-summary")
41
+ sum_model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/led-large-book-summary").to(device)
42
+ return processor, model, sum_tokenizer, sum_model, device
43
+
44
+ def split_audio_into_chunks(audio, sr, chunk_duration):
45
+ chunk_samples = int(chunk_duration * sr)
46
+ chunks = [audio[start:start + chunk_samples] for start in range(0, len(audio), chunk_samples)]
47
+ return chunks
48
+
49
+ def transcribe_audio(audio, sr, processor, model, device, start_time, language, task="transcribe"):
50
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
51
+ input_features = inputs.input_features.to(device)
52
+ if model.dtype == torch.float16:
53
+ input_features = input_features.half()
54
+ generate_kwargs = {
55
+ "task": task,
56
+ "language": "urdu" if language == "ur" else language,
57
+ "max_new_tokens": 128,
58
+ "return_timestamps": True
59
+ }
60
+ try:
61
+ with torch.no_grad():
62
+ outputs = model.generate(input_features, **generate_kwargs)
63
+ text = processor.decode(outputs[0], skip_special_tokens=True)
64
+ return [(text, start_time, start_time + len(audio) / sr)]
65
+ except Exception as e:
66
+ st.error(f"Transcription error: {str(e)}")
67
+ return [(f"Error: {str(e)}", start_time, start_time + len(audio) / sr)]
68
+
69
+ def process_chunks(chunks, sr, processor, model, device, language, chunk_duration, task="transcribe", transcript_file="temp_transcript.json"):
70
+ transcript = []
71
+ chunk_start = 0
72
+ total_chunks = len(chunks)
73
+ progress_bar = st.progress(0)
74
+ status_text = st.empty()
75
+ if os.path.exists(transcript_file):
76
+ os.remove(transcript_file)
77
+ for i, chunk in enumerate(chunks):
78
+ status_text.text(f"Processing chunk {i+1}/{total_chunks}...")
79
+ try:
80
+ memory = psutil.virtual_memory()
81
+ st.write(f"Memory usage: {memory.percent}% (Chunk {i+1}/{total_chunks})")
82
+ chunk_transcript = transcribe_audio(chunk, sr, processor, model, device, chunk_start, language, task)
83
+ transcript.extend(chunk_transcript)
84
+ with open(transcript_file, "w", encoding="utf-8") as f:
85
+ json.dump(transcript, f, ensure_ascii=False)
86
+ chunk_start += chunk_duration
87
+ progress_bar.progress((i + 1) / total_chunks)
88
+ except Exception as e:
89
+ st.error(f"Error processing chunk {i+1}: {str(e)}")
90
+ break
91
+ status_text.text("Processing complete!")
92
+ progress_bar.empty()
93
+ return transcript
94
+
95
+ def summarize_text(text, tokenizer, model, device, summarizer_type='bart'):
96
+ if summarizer_type == 'bart':
97
+ max_input_length = 1024
98
+ max_summary_length = 150
99
+ chunk_size = 512
100
+ else:
101
+ max_input_length = 16384
102
+ max_summary_length = 512
103
+ chunk_size = 8192
104
+ inputs = tokenizer(text, return_tensors="pt", truncation=False)
105
+ input_ids = inputs["input_ids"].to(device)
106
+ num_tokens = input_ids.shape[1]
107
+ st.write(f"Number of tokens in input: {num_tokens}")
108
+ if num_tokens < 50:
109
+ return "Transcript too short to summarize effectively."
110
+ try:
111
+ summaries = []
112
+ if num_tokens <= max_input_length:
113
+ truncated_inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length).to(device)
114
+ with torch.no_grad():
115
+ summary_ids = model.generate(truncated_inputs["input_ids"], num_beams=4, max_length=max_summary_length, min_length=50, early_stopping=True, temperature=0.7)
116
+ summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
117
+ else:
118
+ st.write(f"Transcript exceeds {max_input_length} tokens. Processing in chunks...")
119
+ tokens = input_ids[0].tolist()
120
+ for i in range(0, num_tokens, chunk_size):
121
+ chunk_tokens = tokens[i:i + chunk_size]
122
+ chunk_input_ids = torch.tensor([chunk_tokens]).to(device)
123
+ with torch.no_grad():
124
+ summary_ids = model.generate(chunk_input_ids, num_beams=4, max_length=max_summary_length // 2, min_length=25, early_stopping=True, temperature=0.7)
125
+ summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))
126
+ combined_summary = " ".join(summaries)
127
+ combined_inputs = tokenizer(combined_summary, return_tensors="pt", truncation=True, max_length=max_input_length).to(device)
128
+ with torch.no_grad():
129
+ final_summary_ids = model.generate(combined_inputs["input_ids"], num_beams=4, max_length=max_summary_length, min_length=50, early_stopping=True, temperature=0.7)
130
+ summaries = [tokenizer.decode(final_summary_ids[0], skip_special_tokens=True)]
131
+ return " ".join(summaries)
132
+ except Exception as e:
133
+ st.error(f"Summarization error: {str(e)}")
134
+ return f"Error: {str(e)}"
135
+
136
+ def save_uploaded_file(uploaded_file):
137
+ try:
138
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file:
139
+ tmp_file.write(uploaded_file.read())
140
+ return tmp_file.name
141
+ except Exception as e:
142
+ st.error(f"Error saving uploaded file: {str(e)}")
143
+ return None
144
+
145
+ def merge_intervals(intervals):
146
+ if not intervals:
147
+ return []
148
+ intervals.sort(key=lambda x: x[0])
149
+ merged = [intervals[0]]
150
+ for current in intervals[1:]:
151
+ previous = merged[-1]
152
+ if previous[1] >= current[0]:
153
+ merged[-1] = (previous[0], max(previous[1], current[1]))
154
+ else:
155
+ merged.append(current)
156
+ return merged
157
+
158
+ def create_edited_video(video_path, transcript, keep_indices):
159
+ try:
160
+ intervals_to_keep = [(transcript[i][1], transcript[i][2]) for i in keep_indices]
161
+ merged_intervals = merge_intervals(intervals_to_keep)
162
+ temp_files = []
163
+ for j, (start, end) in enumerate(merged_intervals):
164
+ temp_file = f"temp_{j}.mp4"
165
+ ffmpeg.input(video_path, ss=start, to=end).output(temp_file, c='copy').run(overwrite_output=True, quiet=True)
166
+ temp_files.append(temp_file)
167
+ with open("list.txt", "w") as f:
168
+ for temp_file in temp_files:
169
+ f.write(f"file '{temp_file}'\n")
170
+ edited_video_path = "edited_video.mp4"
171
+ ffmpeg.input('list.txt', format='concat', safe=0).output(edited_video_path, c='copy').run(overwrite_output=True, quiet=True)
172
+ for temp_file in temp_files:
173
+ if os.path.exists(temp_file):
174
+ os.remove(temp_file)
175
+ if os.path.exists("list.txt"):
176
+ os.remove("list.txt")
177
+ return edited_video_path
178
+ except Exception as e:
179
+ st.error(f"Error creating edited video: {str(e)}")
180
+ return None
181
+
182
+ def generate_srt(transcript, include_timeframe=True):
183
+ srt_content = ""
184
+ for idx, (text, start, end) in enumerate(transcript, 1):
185
+ if include_timeframe:
186
+ start_time = seconds_to_srt_time(start)
187
+ end_time = seconds_to_srt_time(end)
188
+ srt_content += f"{idx}\n{start_time} --> {end_time}\n{text}\n\n"
189
+ else:
190
+ srt_content += f"{text}\n\n"
191
+ return srt_content
192
+
193
+ def main():
194
+ st.title("Video Transcription and Summarization App")
195
+ st.markdown("Upload a video to transcribe its audio, generate a summary, and edit subtitles.")
196
+
197
+ # Inject CSS to set video height
198
+ st.markdown("""
199
+ <style>
200
+ video {
201
+ width: 350px !important;
202
+ height: 500px !important;
203
+ object-fit: contain;
204
+ }
205
+ </style>
206
+ """, unsafe_allow_html=True)
207
+
208
+ if 'app_state' not in st.session_state:
209
+ st.session_state['app_state'] = 'upload'
210
+ if 'video_path' not in st.session_state:
211
+ st.session_state['video_path'] = None
212
+ if 'primary_transcript' not in st.session_state:
213
+ st.session_state['primary_transcript'] = None
214
+ if 'english_transcript' not in st.session_state:
215
+ st.session_state['english_transcript'] = None
216
+ if 'english_summary' not in st.session_state:
217
+ st.session_state['english_summary'] = None
218
+ if 'language' not in st.session_state:
219
+ st.session_state['language'] = None
220
+ if 'language_code' not in st.session_state:
221
+ st.session_state['language_code'] = None
222
+ if 'translate_to_english' not in st.session_state:
223
+ st.session_state['translate_to_english'] = False
224
+ if 'summarizer_type' not in st.session_state:
225
+ st.session_state['summarizer_type'] = None
226
+ if 'summary_generated' not in st.session_state:
227
+ st.session_state['summary_generated'] = False
228
+ if 'current_time' not in st.session_state:
229
+ st.session_state['current_time'] = 0
230
+ if 'edited_video_path' not in st.session_state:
231
+ st.session_state['edited_video_path'] = None
232
+ if 'search_query' not in st.session_state:
233
+ st.session_state['search_query'] = ""
234
+ if 'show_timeframe' not in st.session_state:
235
+ st.session_state['show_timeframe'] = True
236
+
237
+ st.write(f"Current app state: {st.session_state['app_state']}")
238
+
239
+ if st.session_state['app_state'] == 'upload':
240
+ with st.form(key="upload_form"):
241
+ uploaded_file = st.file_uploader("Upload a video", type=["mp4"])
242
+ if st.form_submit_button("Upload") and uploaded_file:
243
+ video_path = save_uploaded_file(uploaded_file)
244
+ if video_path:
245
+ st.session_state['video_path'] = video_path
246
+ st.session_state['app_state'] = 'processing'
247
+ st.write(f"Uploaded file: {uploaded_file.name}")
248
+ st.rerun()
249
+
250
+ if st.session_state['app_state'] == 'processing':
251
+ with st.form(key="processing_form"):
252
+ language = st.selectbox("Select language", ["English", "Urdu"], key="language_select")
253
+ language_code = "en" if language == "English" else "ur"
254
+ st.session_state['language'] = language
255
+ st.session_state['language_code'] = language_code
256
+ chunk_duration = st.number_input("Duration per chunk (seconds):", min_value=1.0, step=0.1, value=10.0)
257
+ if language_code == "ur":
258
+ translate_to_english = st.checkbox("Generate English translation", key="translate_checkbox")
259
+ st.session_state['translate_to_english'] = translate_to_english
260
+ else:
261
+ st.session_state['translate_to_english'] = False
262
+ if st.form_submit_button("Process"):
263
+ with st.spinner("Processing video..."):
264
+ start_time = time.time()
265
+ try:
266
+ st.write("Extracting audio...")
267
+ audio_path = "processed_audio.wav"
268
+ ffmpeg.input(st.session_state['video_path']).output(audio_path, ac=1, ar=16000).run(overwrite_output=True, quiet=True)
269
+ audio, sr = librosa.load(audio_path, sr=16000)
270
+ audio = np.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0)
271
+ audio_duration = len(audio) / sr
272
+ st.write(f"Audio duration: {audio_duration:.2f} seconds")
273
+ if audio_duration < 5:
274
+ st.error("Audio too short (< 5s). Upload a longer video.")
275
+ return
276
+ summarizer_type = 'bart' if audio_duration <= 300 else 'led'
277
+ st.write(f"Using summarizer: {summarizer_type}")
278
+ st.session_state['summarizer_type'] = summarizer_type
279
+ st.write("Loading models...")
280
+ processor, model, sum_tokenizer, sum_model, device = load_model(language_code, summarizer_type)
281
+ st.write("Splitting audio into chunks...")
282
+ chunks = split_audio_into_chunks(audio, sr, chunk_duration)
283
+ st.write(f"Number of chunks: {len(chunks)}")
284
+ st.write("Transcribing audio...")
285
+ primary_transcript = process_chunks(chunks, sr, processor, model, device, language_code, chunk_duration, task="transcribe", transcript_file="temp_primary_transcript.json")
286
+ english_transcript = None
287
+ if st.session_state['translate_to_english'] and language_code == "ur":
288
+ st.write("Translating to English...")
289
+ processor, model, _, _, device = load_model('en', summarizer_type)
290
+ english_transcript = process_chunks(chunks, sr, processor, model, device, 'ur', chunk_duration, task="translate", transcript_file="temp_english_transcript.json")
291
+ st.session_state.update({
292
+ 'primary_transcript': primary_transcript,
293
+ 'english_transcript': english_transcript,
294
+ 'summary_generated': False,
295
+ 'app_state': 'results'
296
+ })
297
+ st.write("Processing completed successfully!")
298
+ st.rerun()
299
+ except Exception as e:
300
+ st.error(f"Processing failed: {str(e)}")
301
+ finally:
302
+ if os.path.exists(audio_path):
303
+ os.remove(audio_path)
304
+ for temp_file in ["temp_primary_transcript.json", "temp_english_transcript.json"]:
305
+ if os.path.exists(temp_file):
306
+ os.remove(temp_file)
307
+
308
+ if st.session_state['app_state'] == 'results':
309
+ st.video(st.session_state['video_path'], start_time=st.session_state['current_time'])
310
+ st.session_state['show_timeframe'] = st.checkbox("Show timeframe in transcript", value=st.session_state['show_timeframe'])
311
+ st.markdown("### Search Subtitles")
312
+ search_query = st.text_input("Search subtitles...", value=st.session_state['search_query'] or "", key="search_input")
313
+ st.session_state['search_query'] = search_query.lower()
314
+ st.markdown(f"### {st.session_state['language']} Transcript")
315
+ for text, start, end in st.session_state['primary_transcript']:
316
+ display_text = text.lower()
317
+ if not search_query or search_query in display_text:
318
+ label = f"[{format_time(start)} - {format_time(end)}] {text}" if st.session_state['show_timeframe'] else text
319
+ if st.button(label, key=f"primary_{start}"):
320
+ st.session_state['current_time'] = start
321
+ st.rerun()
322
+ if st.session_state['english_transcript']:
323
+ st.markdown("### English Translation")
324
+ for text, start, end in st.session_state['english_transcript']:
325
+ display_text = text.lower()
326
+ if not search_query or search_query in display_text:
327
+ label = f"[{format_time(start)} - {format_time(end)}] {text}" if st.session_state['show_timeframe'] else text
328
+ if st.button(label, key=f"english_{start}"):
329
+ st.session_state['current_time'] = start
330
+ st.rerun()
331
+ if (st.session_state['language_code'] == 'en' or st.session_state['translate_to_english']) and not st.session_state['summary_generated']:
332
+ if st.button("Generate Summary"):
333
+ with st.spinner("Generating summary..."):
334
+ try:
335
+ _, _, sum_tokenizer, sum_model, device = load_model(st.session_state['language_code'], st.session_state['summarizer_type'])
336
+ full_text = " ".join([text for text, _, _ in (st.session_state['english_transcript'] or st.session_state['primary_transcript'])])
337
+ english_summary = summarize_text(full_text, sum_tokenizer, sum_model, device, st.session_state['summarizer_type'])
338
+ st.session_state['english_summary'] = english_summary
339
+ st.session_state['summary_generated'] = True
340
+ except Exception as e:
341
+ st.error(f"Summary generation failed: {str(e)}")
342
+ if st.session_state['english_summary'] and st.session_state['summary_generated']:
343
+ st.markdown("### Summary")
344
+ st.write(st.session_state['english_summary'])
345
+ st.markdown("### Download Subtitles")
346
+ include_timeframe = st.checkbox("Include timeframe in subtitles", value=True)
347
+ transcript_to_download = st.session_state['primary_transcript'] or st.session_state['english_transcript']
348
+ if transcript_to_download:
349
+ srt_content = generate_srt(transcript_to_download, include_timeframe)
350
+ st.download_button(label="Download Subtitles (SRT)", data=srt_content, file_name="subtitles.srt", mime="text/plain")
351
+ st.markdown("### Edit Subtitles")
352
+ transcript_to_edit = st.session_state['primary_transcript'] or st.session_state['english_transcript']
353
+ if transcript_to_edit and st.button("Delete Subtitles"):
354
+ st.session_state['app_state'] = 'editing'
355
+ st.rerun()
356
+
357
+ if st.session_state['app_state'] == 'editing':
358
+ st.markdown("### Delete Subtitles")
359
+ transcript_to_edit = st.session_state['primary_transcript'] or st.session_state['english_transcript']
360
+ for i, (text, start, end) in enumerate(transcript_to_edit):
361
+ st.write(f"{i}: [{format_time(start)} - {format_time(end)}] {text}")
362
+ indices_input = st.text_input("Enter the indices of subtitles to delete (comma-separated, e.g., 0,1,3):")
363
+ if st.button("Confirm Deletion"):
364
+ try:
365
+ delete_indices = [int(idx.strip()) for idx in indices_input.split(',') if idx.strip()]
366
+ delete_indices = [idx for idx in delete_indices if 0 <= idx < len(transcript_to_edit)]
367
+ keep_indices = [i for i in range(len(transcript_to_edit)) if i not in delete_indices]
368
+ if not keep_indices:
369
+ st.error("All subtitles are deleted. No video to generate.")
370
+ else:
371
+ edited_video_path = create_edited_video(st.session_state['video_path'], transcript_to_edit, keep_indices)
372
+ if edited_video_path:
373
+ st.session_state['edited_video_path'] = edited_video_path
374
+ st.session_state['app_state'] = 'results'
375
+ st.rerun()
376
+ except ValueError:
377
+ st.error("Invalid input. Please enter comma-separated integers.")
378
+ except Exception as e:
379
+ st.error(f"Error during video editing: {str(e)}")
380
+ if st.button("Cancel Deletion"):
381
+ st.session_state['app_state'] = 'results'
382
+ st.rerun()
383
+
384
+ if st.session_state['app_state'] == 'results' and st.session_state['edited_video_path']:
385
+ st.markdown("### Edited Video")
386
+ st.video(st.session_state['edited_video_path'])
387
+ with open(st.session_state['edited_video_path'], "rb") as file:
388
+ st.download_button(label="Download Edited Video", data=file, file_name="edited_video.mp4", mime="video/mp4")
389
+
390
+ if st.session_state.get('video_path') and st.button("Reset"):
391
+ if st.session_state['video_path'] and os.path.exists(st.session_state['video_path']):
392
+ os.remove(st.session_state['video_path'])
393
+ if st.session_state['edited_video_path'] and os.path.exists(st.session_state['edited_video_path']):
394
+ os.remove(st.session_state['edited_video_path'])
395
+ st.session_state.clear()
396
+ st.rerun()
397
+
398
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  main()