task-oriented-dialog-agent / speech_conversation_app.py
reecursion's picture
Create speech_conversation_app.py
df90a53 verified
import os
import time
import numpy as np
import torch
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForSpeechSeq2Seq
from datasets import load_dataset
import soundfile as sf
# Global variables to track latency
latency_ASR = 0.0
latency_LLM = 0.0
latency_TTS = 0.0
# Global variables to store conversation state
conversation_history = []
audio_output = None
# ASR Models
ASR_OPTIONS = {
"Whisper Small": "openai/whisper-small",
"Wav2Vec2": "facebook/wav2vec2-base-960h"
}
# LLM Models
LLM_OPTIONS = {
"Llama-2 7B Chat": "meta-llama/Llama-2-7b-chat-hf",
"Flan-T5 Small": "google/flan-t5-small"
}
# TTS Models
TTS_OPTIONS = {
"VITS": "espnet/kan-bayashi_ljspeech_vits",
"FastSpeech2": "espnet/kan-bayashi_ljspeech_fastspeech2"
}
# Load models
asr_models = {}
llm_models = {}
tts_models = {}
def load_asr_model(model_name):
"""Load ASR model from Hugging Face"""
global asr_models
if model_name not in asr_models:
print(f"Loading ASR model: {model_name}")
model_id = ASR_OPTIONS[model_name]
if "whisper" in model_id:
asr_models[model_name] = pipeline("automatic-speech-recognition", model=model_id)
else: # wav2vec2
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
asr_models[model_name] = {"processor": processor, "model": model}
return asr_models[model_name]
def load_llm_model(model_name):
"""Load LLM model from Hugging Face"""
global llm_models
if model_name not in llm_models:
print(f"Loading LLM model: {model_name}")
model_id = LLM_OPTIONS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
llm_models[model_name] = {
"model": model,
"tokenizer": tokenizer
}
return llm_models[model_name]
def load_tts_model(model_name):
"""Load TTS model using ESPnet"""
global tts_models
if model_name not in tts_models:
print(f"Loading TTS model: {model_name}")
try:
# Import ESPnet TTS modules
from espnet2.bin.tts_inference import Text2Speech
model_id = TTS_OPTIONS[model_name]
tts = Text2Speech.from_pretrained(model_id)
tts_models[model_name] = tts
except ImportError:
print("ESPnet not installed. Using mock TTS for demonstration.")
tts_models[model_name] = "mock_tts"
return tts_models[model_name]
def transcribe_audio(audio_data, sr, asr_model_name):
"""Transcribe audio using selected ASR model"""
global latency_ASR
start_time = time.time()
model = load_asr_model(asr_model_name)
if "whisper" in ASR_OPTIONS[asr_model_name]:
result = model({"array": audio_data, "sampling_rate": sr})
transcript = result["text"]
else: # wav2vec2
inputs = model["processor"](audio_data, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
outputs = model["model"].generate(**inputs)
transcript = model["processor"].batch_decode(outputs, skip_special_tokens=True)[0]
latency_ASR = time.time() - start_time
return transcript
def generate_response(transcript, llm_model_name, system_prompt):
"""Generate response using selected LLM model"""
global latency_LLM, conversation_history
start_time = time.time()
model_info = load_llm_model(llm_model_name)
model = model_info["model"]
tokenizer = model_info["tokenizer"]
# Format the prompt based on the model
if "llama" in LLM_OPTIONS[llm_model_name].lower():
# Format for Llama models
if not conversation_history:
conversation_history.append({"role": "system", "content": system_prompt})
conversation_history.append({"role": "user", "content": transcript})
prompt = tokenizer.apply_chat_template(
conversation_history,
tokenize=False,
add_generation_prompt=True
)
else:
# Format for T5 models
prompt = f"{system_prompt}\nUser: {transcript}\nAssistant:"
# Generate text
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
)
# Decode the response
if "llama" in LLM_OPTIONS[llm_model_name].lower():
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the assistant's response
response = response.split("Assistant: ")[-1].strip()
else:
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Add to conversation history
conversation_history.append({"role": "assistant", "content": response})
latency_LLM = time.time() - start_time
return response
def synthesize_speech(text, tts_model_name):
"""Synthesize speech using selected TTS model"""
global latency_TTS
start_time = time.time()
tts = load_tts_model(tts_model_name)
if tts == "mock_tts":
# Mock TTS response for demonstration
# In a real implementation, this would use the ESPnet model
# Load a sample audio file for demonstration
try:
sample_rate = 16000
# Generate a simple sine wave as demo audio
duration = 2 # seconds
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) # 220 Hz sine wave
except Exception as e:
print(f"Error generating mock audio: {e}")
audio_data = np.zeros(16000) # 1 second of silence
sample_rate = 16000
else:
# Use actual ESPnet TTS model
with torch.no_grad():
wav = tts(text)["wav"]
audio_data = wav.numpy()
sample_rate = tts.fs
latency_TTS = time.time() - start_time
return (sample_rate, audio_data)
def process_speech(
audio_input,
asr_option,
llm_option,
tts_option,
system_prompt
):
"""Process speech: ASR -> LLM -> TTS pipeline"""
global audio_output
# Check if audio input is available
if audio_input is None:
return None, "", "", None
# Get audio data
sr, audio_data = audio_input
# ASR: Speech to text
transcript = transcribe_audio(audio_data, sr, asr_option)
# LLM: Generate response
response = generate_response(transcript, llm_option, system_prompt)
# TTS: Text to speech
audio_output = synthesize_speech(response, tts_option)
# Return results
return audio_input, transcript, response, audio_output
def display_latency():
"""Display latency information"""
return f"""
ASR Latency: {latency_ASR:.2f} seconds
LLM Latency: {latency_LLM:.2f} seconds
TTS Latency: {latency_TTS:.2f} seconds
Total Latency: {latency_ASR + latency_LLM + latency_TTS:.2f} seconds
"""
def reset_conversation():
"""Reset the conversation history"""
global conversation_history, audio_output
conversation_history = []
audio_output = None
return None, "", "", None, ""
# Create Gradio interface
with gr.Blocks(title="Conversational Speech System") as demo:
gr.Markdown(
"""
# Conversational Speech System with ASR, LLM, and TTS
This demo showcases a complete speech-to-speech conversation system using:
- **ASR** (Automatic Speech Recognition) to convert your speech to text
- **LLM** (Large Language Model) to generate responses
- **TTS** (Text-to-Speech) to convert the responses to speech
Speak into your microphone and the system will respond with synthesized speech.
"""
)
with gr.Row():
with gr.Column(scale=1):
# Input components
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
label="Speak here",
)
system_prompt = gr.Textbox(
label="System Prompt (instructions for the LLM)",
value="You are a helpful and friendly AI assistant. Keep your responses concise and under 3 sentences."
)
asr_dropdown = gr.Dropdown(
choices=list(ASR_OPTIONS.keys()),
value=list(ASR_OPTIONS.keys())[0],
label="Select ASR Model"
)
llm_dropdown = gr.Dropdown(
choices=list(LLM_OPTIONS.keys()),
value=list(LLM_OPTIONS.keys())[0],
label="Select LLM Model"
)
tts_dropdown = gr.Dropdown(
choices=list(TTS_OPTIONS.keys()),
value=list(TTS_OPTIONS.keys())[0],
label="Select TTS Model"
)
reset_btn = gr.Button("Reset Conversation")
with gr.Column(scale=1):
# Output components
user_transcript = gr.Textbox(label="Your Speech (ASR Output)")
system_response = gr.Textbox(label="AI Response (LLM Output)")
audio_output_component = gr.Audio(label="AI Voice Response", autoplay=True)
latency_info = gr.Textbox(label="Performance Metrics")
# Set up event handlers
audio_input.change(
process_speech,
inputs=[audio_input, asr_dropdown, llm_dropdown, tts_dropdown, system_prompt],
outputs=[audio_input, user_transcript, system_response, audio_output_component]
).then(
display_latency,
inputs=[],
outputs=[latency_info]
)
reset_btn.click(
reset_conversation,
inputs=[],
outputs=[audio_input, user_transcript, system_response, audio_output_component, latency_info]
)
# Launch the app
if __name__ == "__main__":
demo.launch()