import fastapi import numpy as np import torch import torchaudio from silero_vad import get_speech_timestamps, load_silero_vad import whisperx import edge_tts import gc import logging import time import os from openai import OpenAI import asyncio from pydub import AudioSegment from io import BytesIO import threading # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Configure FastAPI app = fastapi.FastAPI() # Load Silero VAD model device = 'cuda' if torch.cuda.is_available() else 'cpu' logging.info(f'Using device: {device}') vad_model = load_silero_vad().to(device) logging.info('Loaded Silero VAD model') # Load WhisperX model whisper_model = whisperx.load_model("tiny", device, compute_type="float16") logging.info('Loaded WhisperX model') OPENAI_API_KEY = "" if not OPENAI_API_KEY: logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") raise ValueError("OpenAI API key not found.") logging.info('Initialized OpenAI client') llm_client = OpenAI(api_key=OPENAI_API_KEY) # Corrected import # TTS Voice TTS_VOICE = "en-GB-SoniaNeural" # Function to check voice activity using Silero VAD def check_vad(audio_data, sample_rate): logging.info('Checking voice activity') target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_tensor = resampler(torch.from_numpy(audio_data)) else: audio_tensor = torch.from_numpy(audio_data) audio_tensor = audio_tensor.to(device) speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate) logging.info(f'Found {len(speech_timestamps)} speech timestamps') return len(speech_timestamps) > 0 # Async function to transcribe audio using WhisperX def transcribe(audio_data, sample_rate): logging.info('Transcribing audio') target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_data = resampler(torch.from_numpy(audio_data)).numpy() else: audio_data = audio_data batch_size = 16 # Adjust as needed result = whisper_model.transcribe(audio_data, batch_size=batch_size) text = result["segments"][0]["text"] if len(result["segments"]) > 0 else "" logging.info(f'Transcription result: {text}') del result gc.collect() if device == 'cuda': torch.cuda.empty_cache() return text # Function to convert text to speech using Edge TTS and stream the audio def tts_streaming(text_stream): logging.info('Performing TTS') buffer = "" punctuation = {'.', '!', '?'} for text_chunk in text_stream: if text_chunk is not None: buffer += text_chunk # Check for sentence completion sentences = [] start = 0 for i, char in enumerate(buffer): if char in punctuation: sentences.append(buffer[start:i+1].strip()) start = i+1 buffer = buffer[start:] for sentence in sentences: if sentence: communicate = edge_tts.Communicate(sentence, TTS_VOICE) for chunk in communicate.stream_sync(): if chunk["type"] == "audio": yield chunk["data"] # Process any remaining text if buffer.strip(): communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE) for chunk in communicate.stream_sync(): if chunk["type"] == "audio": yield chunk["data"] # Function to perform language model completion using OpenAI API def llm(text): logging.info('Getting response from OpenAI API') response = llm_client.chat.completions.create( model="gpt-4o", # Updated to a more recent model messages=[ {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."}, {"role": "user", "content": text} ], stream=True, temperature=0.7, top_p=0.9 ) for chunk in response: yield chunk.choices[0].delta.content class Conversation: def __init__(self): self.mode = 'idle' # idle, listening, speaking self.audio_stream = [] self.valid_chunk_queue = [] self.first_valid_chunk = None self.last_valid_chunks = [] self.valid_chunk_transcriptions = '' self.in_transcription = False self.llm_n_tts_task = None self.stop_signal = False self.sample_rate = 0 self.out_audio_stream = [] self.chunk_buffer = 0.5 # seconds def llm_n_tts(self): for text_chunk in llm(self.transcription): if self.stop_signal: break for audio_chunk in tts_streaming([text_chunk]): if self.stop_signal: break self.out_audio_stream.append(np.frombuffer(audio_chunk, dtype=np.int16)) def process_audio_chunk(self, audio_chunk): # Construct audio stream audio_data = AudioSegment.from_file(BytesIO(audio_chunk), format="wav") audio_data = np.array(audio_data.get_array_of_samples()) self.sample_rate = audio_data.frame_rate # Check for voice activity vad = check_vad(audio_data, self.sample_rate) if vad: # Voice activity detected if self.first_valid_chunk is not None: self.valid_chunk_queue.append(self.first_valid_chunk) self.first_valid_chunk = None self.valid_chunk_queue.append(audio_chunk) if len(self.valid_chunk_queue) > 2: # i.e. 3 chunks: 1 non valid chunk + 2 valid chunks # this is to ensure that the speaker is speaking if self.mode == 'idle': self.mode = 'listening' elif self.mode == 'speaking': # Stop llm and tts if self.llm_n_tts_task is not None: self.stop_signal = True self.llm_n_tts_task self.stop_signal = False self.mode = 'listening' else: # No voice activity if self.mode == 'listening': self.last_valid_chunks.append(audio_chunk) if len(self.last_valid_chunks) > 2: # i.e. 2 chunks where the speaker stopped speaking, but we account for natural pauses # so on the 1.5th second of no voice activity, we append the first 2 of the last valid chunks to the valid chunk queue # stop listening and start speaking self.valid_chunk_queue.extend(self.last_valid_chunks[:2]) self.last_valid_chunks = [] while len(self.valid_chunk_queue) > 0: time.sleep(0.1) self.mode = 'speaking' self.llm_n_tts_task = threading.Thread(target=self.llm_n_tts) self.llm_n_tts_task.start() def transcribe_loop(self): while True: if self.mode == 'listening': if len(self.valid_chunk_queue) > 0: accumulated_chunks = np.concatenate(self.valid_chunk_queue) total_duration = len(accumulated_chunks) / self.sample_rate if total_duration >= 3.0 and self.in_transcription == True: # i.e. we have at least 3 seconds of audio so we can start transcribing to reduce latency first_2s_audio = accumulated_chunks[:int(2 * self.sample_rate)] transcribed_text = transcribe(first_2s_audio, self.sample_rate) self.valid_chunk_transcriptions += transcribed_text self.valid_chunk_queue = [accumulated_chunks[int(2 * self.sample_rate):]] if self.mode == any(['idle', 'speaking']): # i.e. the request to stop transcription has been made # so process the remaining audio transcribed_text = transcribe(accumulated_chunks, self.sample_rate) self.valid_chunk_transcriptions += transcribed_text self.valid_chunk_queue = [] else: time.sleep(0.1) def stream_out_audio(self): while True: if len(self.out_audio_stream) > 0: yield AudioSegment(data=self.out_audio_stream.pop(0), sample_width=2, frame_rate=self.sample_rate, channels=1).raw_data @app.websocket("/ws") async def websocket_endpoint(websocket: fastapi.WebSocket): # Accept connection await websocket.accept() # Initialize conversation conversation = Conversation() # Start conversation threads transcribe_thread = threading.Thread(target=conversation.transcribe_loop) transcribe_thread.start() # Process audio chunks chunk_buffer_size = conversation.chunk_buffer while True: try: audio_chunk = await websocket.receive_bytes() conversation.process_audio_chunk(audio_chunk) if conversation.mode == 'speaking': for audio_chunk in conversation.stream_out_audio(): await websocket.send_bytes(audio_chunk) else: await websocket.send_bytes(b'') except Exception as e: logging.error(e) break @app.get("/") async def index(): return fastapi.responses.FileResponse("index.html") if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=8000)