import gradio as gr import torch import librosa from transformers import Wav2Vec2Processor, AutoModelForCTC import zipfile import os import firebase_admin from firebase_admin import credentials, firestore from datetime import datetime import json import tempfile # Initialize Firebase firebase_config = json.loads(os.environ.get('firebase_creds')) cred = credentials.Certificate(firebase_config) # Your Firebase JSON key file firebase_admin.initialize_app(cred) db = firestore.client() # Load the ASR model and processor MODEL_NAME = "eleferrand/xlsr53_Amis" processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) model = AutoModelForCTC.from_pretrained(MODEL_NAME) def transcribe(audio_file): try: audio, rate = librosa.load(audio_file, sr=16000) input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription.replace("[UNK]", "") except Exception as e: return f"处理文件错误: {e}" def transcribe_both(audio_file): start_time = datetime.now() transcription = transcribe(audio_file) processing_time = (datetime.now() - start_time).total_seconds() return transcription, transcription, processing_time def store_correction(original_transcription, corrected_transcription, audio_file, processing_time, age, native_speaker): try: audio_metadata = {} if audio_file and os.path.exists(audio_file): audio, sr = librosa.load(audio_file, sr=16000) duration = librosa.get_duration(y=audio, sr=sr) file_size = os.path.getsize(audio_file) audio_metadata = {'duration': duration, 'file_size': file_size} combined_data = { 'original_text': original_transcription, 'corrected_text': corrected_transcription, 'timestamp': datetime.now().isoformat(), 'processing_time': processing_time, 'audio_metadata': audio_metadata, 'audio_url': None, 'model_name': MODEL_NAME, 'user_info': { 'native_amis_speaker': native_speaker, 'age': age } } db.collection('transcriptions').add(combined_data) return "校正保存成功! (Correction saved successfully!)" except Exception as e: return f"保存失败: {e} (Error saving correction: {e})" def prepare_download(audio_file, original_transcription, corrected_transcription): if audio_file is None: return None tmp_zip = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") tmp_zip.close() with zipfile.ZipFile(tmp_zip.name, "w") as zf: if os.path.exists(audio_file): zf.write(audio_file, arcname="audio.wav") orig_txt = "original_transcription.txt" with open(orig_txt, "w", encoding="utf-8") as f: f.write(original_transcription) zf.write(orig_txt, arcname="original_transcription.txt") os.remove(orig_txt) corr_txt = "corrected_transcription.txt" with open(corr_txt, "w", encoding="utf-8") as f: f.write(corrected_transcription) zf.write(corr_txt, arcname="corrected_transcription.txt") os.remove(corr_txt) return tmp_zip.name # 界面设计 with gr.Blocks(css=""" .container { max-width: 800px; margin: auto; padding: 20px; font-family: Arial, sans-serif; } .header { text-align: center; margin-bottom: 30px; } .section { margin-bottom: 30px; padding: 15px; border: 1px solid #ddd; border-radius: 8px; background-color: #808080; } .section h3 { margin-top: 0; margin-bottom: 15px; text-align: center; } .button-row { display: flex; justify-content: center; gap: 10px; flex-wrap: wrap; } @media (max-width: 600px) { .gradio-row { flex-direction: column; } } """) as demo: with gr.Column(elem_classes="container"): gr.Markdown("