Godota's picture
ref
de3ffd2
raw
history blame
10.6 kB
import logging
logging.getLogger("httpx").setLevel(logging.WARNING)
import gradio as gr
from transformers import pipeline
import torch
import re
from typing import List, Dict
import spaces
from speechbrain.inference import EncoderClassifier
import torchaudio
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import tempfile
import os
import subprocess
# Initialize Whisper with flash attention
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device="cuda:0",
model_kwargs={"attn_implementation": "flash_attention_2"},
)
# Initialize Whisper with flash attention
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device="cuda:0",
model_kwargs={"attn_implementation": "flash_attention_2"},
)
# Speaker model initialization
speaker_model = None
def init_speaker_model():
global speaker_model
if speaker_model is None:
speaker_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb",
run_opts={"device": "cuda:0"}
)
def format_time(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
if hours > 0:
return f"{hours:02}:{minutes:02}:{seconds:02}"
else:
return f"{minutes:02}:{seconds:02}"
def extract_audio_from_video(video_path):
"""Extract audio from video file using ffmpeg."""
try:
# Create temporary file for audio
temp_audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_audio_path = temp_audio.name
temp_audio.close()
# FFmpeg command to extract audio
command = [
'ffmpeg',
'-i', video_path,
'-vn',
'-acodec', 'pcm_s16le',
'-ar', '16000',
'-ac', '1',
'-y',
temp_audio_path
]
# Run ffmpeg
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if process.returncode != 0:
raise Exception(f"FFmpeg error: {stderr.decode()}")
return temp_audio_path
except Exception as e:
if os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
raise gr.Error(f"Error extracting audio from video: {str(e)}")
return temp_audio_path
except Exception as e:
if os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
raise gr.Error(f"Error extracting audio from video: {str(e)}")
def cleanup_temp_file(file_path):
"""Clean up temporary file if it exists."""
try:
if os.path.exists(file_path):
os.unlink(file_path)
except Exception as e:
print(f"Error cleaning up temporary file: {str(e)}")
@spaces.GPU
def transcribe_with_language(input_file, language, detect_speakers=False):
temp_audio_path = None
try:
gr.Info("Processing input file...")
# Check if input is video
if input_file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
gr.Info("Extracting audio from video...")
temp_audio_path = extract_audio_from_video(input_file)
audio_path = temp_audio_path
else:
audio_path = input_file
gr.Info("Starting transcription task")
generate_kwargs = {"task": "transcribe"}
if language != "auto":
generate_kwargs["language"] = language
outputs = pipe(
audio_path,
chunk_length_s=30,
batch_size=128,
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
gr.Info("Finished transcription task")
# If speakers detection is not needed
if not detect_speakers:
transcription_with_timestamps = []
for segment in outputs["chunks"]:
start_time = segment["timestamp"][0] if segment["timestamp"][0] is not None else 0.0
end_time = segment["timestamp"][1] if segment["timestamp"][1] is not None else 0.0
text = segment["text"].strip()
formatted_start = format_time(start_time)
formatted_end = format_time(end_time)
transcription_with_timestamps.append(f"[{formatted_start} -> {formatted_end}] {text}")
return "\n".join(transcription_with_timestamps)
# Speaker detection
init_speaker_model()
gr.Info("Detecting speakers...")
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
# Speaker analysis
embeddings = []
window_size = 3 * 16000
step_size = int(2 * 16000)
with torch.no_grad():
for start in range(0, waveform.shape[1] - window_size + 1, step_size):
segment = waveform[:, start:start + window_size]
if segment.shape[1] == window_size:
embedding = speaker_model.encode_batch(segment)
embeddings.append(embedding.squeeze().cpu().numpy())
# Clustering
if len(embeddings) > 1:
embeddings = np.array(embeddings)
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0.85,
metric='cosine',
linkage='average'
).fit(embeddings)
labels = clustering.labels_
else:
labels = [0]
# Format output
transcription = []
current_speaker = None
for segment in outputs["chunks"]:
start_time = segment["timestamp"][0] if segment["timestamp"][0] is not None else 0.0
end_time = segment["timestamp"][1] if segment["timestamp"][1] is not None else 0.0
text = segment["text"].strip()
segment_idx = min(int(start_time / 2), len(labels) - 1)
speaker = f"SPEAKER_{labels[segment_idx]}"
if speaker != current_speaker:
if transcription:
transcription.append("")
current_speaker = speaker
formatted_start = format_time(start_time)
formatted_end = format_time(end_time)
transcription.append(f"[{formatted_start} -> {formatted_end}] {speaker}: {text}")
return "\n".join(transcription)
finally:
if temp_audio_path:
cleanup_temp_file(temp_audio_path)
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="blue",
)) as demo:
gr.Markdown("""
# 🎙️ Надшвидке розпізнавання мовлення
Завантажте аудіо або відео файл для транскрибації.
""")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="📁 Завантажити аудіо/відео файл",
file_types=["audio", "video"],
type="filepath",
interactive=True
)
language = gr.Dropdown(
choices=[
"auto",
"ukrainian",
"english",
"russian",
"polish",
"german",
"french",
"spanish",
],
value="auto",
label="🌐 Мова аудіо",
info="Виберіть мову аудіо для кращого розпізнавання"
)
detect_speakers = gr.Checkbox(
label="🎭 Визначати спікерів",
value=False,
info="Увімкнути розпізнавання різних спікерів"
)
btn = gr.Button("🎯 Розпочати обробку", variant="primary", size="lg")
gr.Markdown("""
### ℹ️ Інформація
- Підтримує більшість аудіо та відео форматів
- Може обробляти файли тривалістю до декількох годин
- Визначення спікерів збільшить час обробки
- Автоматичне визначення мови працює досить точно
- Для кращого результату рекомендується вибрати конкретну мову
""")
with gr.Column(scale=1):
output = gr.Textbox(
label="📝 Результат розпізнавання",
interactive=False,
lines=25,
show_copy_button=False
)
copy_btn = gr.Button(
"📋 Копіювати текст",
variant="secondary",
interactive=False # Изначально неактивна
)
# Add progress indicator
progress = gr.Progress(track_tqdm=True)
def update_copy_button(text):
# В новых версиях Gradio используем прямое обновление свойств
return gr.update(interactive=bool(text and text.strip()))
def copy_text(text):
"""Helper function for copying text"""
gr.Info("Текст скопійовано!") # Добавляем уведомление
return text
btn.click(
fn=transcribe_with_language,
inputs=[file_input, language, detect_speakers],
outputs=output,
).then( # Возвращаем .then вместо .success
fn=update_copy_button,
inputs=[output],
outputs=copy_btn
)
# Handle copy button
copy_btn.click(
fn=copy_text,
inputs=[output],
outputs=output,
js="""async (text) => {
if (text && text.trim()) {
await navigator.clipboard.writeText(text);
}
return text;
}"""
)
if __name__ == "__main__":
demo.queue().launch()