Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tempfile | |
import os | |
import torch | |
import torchaudio | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
from zonos.model import Zonos | |
from zonos.conditioning import make_cond_dict | |
st.set_page_config(page_title="Echo Mind 🧠", page_icon="🧠") | |
st.title("Echo Mind 🧠") | |
st.write("Voice Transcription with Whisper Large V3 Turbo and Zonos Speech Synthesis") | |
def load_whisper_model(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "openai/whisper-large-v3-turbo" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
return pipe | |
def load_zonos_model(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer", device=device) | |
return model | |
def transcribe_audio(audio_path, pipe): | |
# Use the pipeline directly for transcription | |
result = pipe(audio_path) | |
return result["text"] | |
def synthesize_speech(text, audio_path, zonos_model): | |
"""Generate speech using Zonos with the voice from the input audio""" | |
try: | |
# Load the input audio to extract speaker characteristics | |
wav, sampling_rate = torchaudio.load(audio_path) | |
speaker = zonos_model.make_speaker_embedding(wav, sampling_rate) | |
# Prepare conditioning and generate speech | |
cond_dict = make_cond_dict(text=text, speaker=speaker, language="en-us") | |
conditioning = zonos_model.prepare_conditioning(cond_dict) | |
codes = zonos_model.generate(conditioning) | |
wavs = zonos_model.autoencoder.decode(codes).cpu() | |
# Save to a temporary file | |
output_path = tempfile.mktemp(suffix=".wav") | |
torchaudio.save(output_path, wavs[0], zonos_model.autoencoder.sampling_rate) | |
return output_path | |
except Exception as e: | |
st.error(f"Speech synthesis error: {str(e)}") | |
return None | |
# Load the Whisper model | |
with st.spinner("Loading Whisper Large V3 Turbo model..."): | |
pipe = load_whisper_model() | |
st.success("Model loaded!") | |
# Load the Zonos model | |
with st.spinner("Loading Zonos voice synthesis model..."): | |
zonos_model = load_zonos_model() | |
st.success("Voice synthesis model loaded!") | |
audio_bytes = st.audio_input("Record") | |
if audio_bytes: | |
if audio_bytes: | |
st.audio(audio_bytes, format="audio/wav") | |
# Save the recorded audio to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
# Read bytes from the UploadedFile object before writing | |
tmp_file.write(audio_bytes.read()) | |
audio_path = tmp_file.name | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Transcribe Audio"): | |
with st.spinner("Transcribing with Whisper Large V3 Turbo..."): | |
transcription = transcribe_audio(audio_path, pipe) | |
st.session_state.transcription = transcription | |
st.session_state.audio_path = audio_path | |
st.subheader("Transcription") | |
st.write(transcription) | |
with col2: | |
if st.button("Speak Transcription") and 'transcription' in st.session_state: | |
with st.spinner("Synthesizing speech with Zonos..."): | |
output_path = synthesize_speech( | |
st.session_state.transcription, | |
st.session_state.audio_path, | |
zonos_model | |
) | |
if output_path: | |
st.subheader("Synthesized Speech") | |
with open(output_path, "rb") as f: | |
audio_bytes = f.read() | |
st.audio(audio_bytes, format="audio/wav") | |
# Clean up temporary files | |
os.unlink(output_path) | |
# Clean up temporary files when the app is done | |
if 'audio_path' in st.session_state and os.path.exists(st.session_state.audio_path): | |
os.unlink(st.session_state.audio_path) | |