import io import re import wave import struct import os import time import json import numpy as np import torch from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, Response, HTMLResponse from fastapi.middleware import Middleware from fastapi.middleware.gzip import GZipMiddleware from misaki import en, espeak from onnxruntime import InferenceSession from huggingface_hub import snapshot_download from scipy.io.wavfile import write as write_wav # ------------------------------------------------------------------------------ # Load configuration and set up vocabulary # ------------------------------------------------------------------------------ config_file_path = 'config.json' # Update with your actual path with open(config_file_path, 'r') as f: config = json.load(f) phoneme_vocab = config['vocab'] # ------------------------------------------------------------------------------ # Download the model and voice files from Hugging Face Hub # ------------------------------------------------------------------------------ model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX" model_name = "onnx/model_q4.onnx" # "onnx/model.onnx" voice_file_pattern = "*.bin" local_dir = "." snapshot_download( repo_id=model_repo, allow_patterns=[model_name, voice_file_pattern], local_dir=local_dir ) # ------------------------------------------------------------------------------ # Load the ONNX model # ------------------------------------------------------------------------------ model_path = os.path.join(local_dir, model_name) sess = InferenceSession(model_path) # ------------------------------------------------------------------------------ # Create the FastAPI app with GZip compression # ------------------------------------------------------------------------------ app = FastAPI( title="Kokoro TTS FastAPI", middleware=[Middleware(GZipMiddleware, compresslevel=9)] ) # ------------------------------------------------------------------------------ # Helper Functions # ------------------------------------------------------------------------------ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: """ Generate a WAV header for streaming. Since we do not know the final audio size, a large dummy value is used for the data chunk size. """ bits_per_sample = sample_width * 8 byte_rate = sample_rate * num_channels * sample_width block_align = num_channels * sample_width total_size = 36 + data_size # 36 + data_size (header is 44 bytes total) header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) data_chunk_header = struct.pack('<4sI', b'data', data_size) return header + fmt_chunk + data_chunk_header stream_header = generate_wav_header(24000, 1, 2) def custom_split_text(text: str) -> list: """ Custom splitting strategy: - Start with a chunk size of 2 words. - For each chunk, if a period (".") is found in any word (except the very last word), then split at that word (including it). - Otherwise, use the current chunk size. - Increase the chunk size by 2 for each subsequent chunk. - If there are fewer than the desired number of words remaining, include all of them. """ words = text.split() chunks = [] chunk_size = 2 start = 0 while start < len(words): candidate_end = start + chunk_size if candidate_end > len(words): candidate_end = len(words) chunk_words = words[start:candidate_end] split_index = None # for i in range(len(chunk_words) - 1): # if '.' in chunk_words[i]: # split_index = i # break # if split_index is not None: # candidate_end = start + split_index + 1 # chunk_words = words[start:candidate_end] chunks.append(" ".join(chunk_words)) start = candidate_end if chunk_size < 100: chunk_size += 2 return chunks def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: """ Convert a torch.FloatTensor (values in [-1, 1]) to raw 16-bit PCM bytes. """ audio_np = audio_tensor.cpu().numpy() if audio_np.ndim > 1: audio_np = audio_np.flatten() audio_int16 = np.int16(audio_np * 32767) return audio_int16.tobytes() def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes: """ Convert a torch.FloatTensor to Opus-encoded bytes. Requires the 'opuslib' package: pip install opuslib """ try: import opuslib except ImportError: raise ImportError("opuslib is not installed. Please install it with: pip install opuslib") audio_np = audio_tensor.cpu().numpy() if audio_np.ndim > 1: audio_np = audio_np.flatten() audio_int16 = np.int16(audio_np * 32767) encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) frame_size = int(sample_rate * 0.020) # 20 ms frame encoded_data = b'' for i in range(0, len(audio_int16), frame_size): frame = audio_int16[i:i + frame_size] if len(frame) < frame_size: frame = np.pad(frame, (0, frame_size - len(frame)), 'constant') encoded_frame = encoder.encode(frame.tobytes(), frame_size) encoded_data += encoded_frame return encoded_data fbs = espeak.EspeakFallback(british=True) g2p = en.G2P(trf=False, british=False, fallback=fbs) def tokenizer(text: str): """ Converts text to a list of phoneme tokens using the global vocabulary. """ phonemes_string, tokens = g2p(text) phonemes = [ph for ph in phonemes_string] print(text + " " + phonemes_string) tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab] return tokens # ------------------------------------------------------------------------------ # Endpoints # ------------------------------------------------------------------------------ @app.get("/tts/streaming", summary="Streaming TTS") def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): """ Streaming TTS endpoint. This endpoint splits the input text into chunks (using the doubling strategy), then for each chunk: - For the first chunk, a 0 is prepended. - For subsequent chunks, the first token is set to the last token from the previous chunk. - For the final chunk, a 0 is appended. The audio for each chunk is generated immediately and streamed to the client. """ chunks = custom_split_text(text) # Load the voice/style file (must be present in voices/{voice}.bin) voice_path = os.path.join(local_dir, f"voices/{voice}.bin") if not os.path.exists(voice_path): raise HTTPException(status_code=404, detail="Voice file not found") voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) def audio_generator(): # If outputting a WAV stream, yield a WAV header once. if format.lower() == "wav": yield stream_header prev_last_token = None for i, chunk in enumerate(chunks): # Convert the chunk text to tokens. chunk_tokens = tokenizer(chunk) # For the first chunk, prepend 0; for later chunks, start with the previous chunk's last token. # if i == 0: # tokens_to_send = [0] + chunk_tokens + [0] # else: # tokens_to_send = [0] + chunk_tokens + [0] # token_to_send = [0] + chunk_tokens # Save the last token of this chunk for the next iteration. prev_last_token = chunk_tokens[-1:] # Prepare the model input (a batch of one sequence). tokens_to_send = [0] + chunk_tokens + [0] final_token = [tokens_to_send] print(final_token) # Use the number of tokens to select the appropriate style vector. style_index = len(chunk_tokens) + 2 if style_index >= len(voices): style_index = len(voices) - 1 # Fallback if index is out-of-bounds. ref_s = voices[style_index] # Prepare the speed parameter. speed_param = np.ones(1, dtype=np.float32) * speed # Run the model (ONNX inference) for this chunk. try: start_time = time.time() audio_output = sess.run(None, { "input_ids": final_token, "style": ref_s, "speed": speed_param, })[0] print(f"Chunk {i} inference time: {time.time() - start_time:.3f}s") except Exception as e: print(f"Error processing chunk {i}: {e}") # In case of error, generate a short silent chunk. audio_output = np.zeros((24000,), dtype=np.float32) # Convert the model output (assumed to be float32 in [-1, 1]) to int16 PCM. audio_int16 = (audio_output * 32767).astype(np.int16).flatten()[6000:-3000] print(audio_int16) # Convert to a torch tensor (back into float range) for our helper functions. # audio_tensor = torch.from_numpy(audio_int16.astype(np.float32) / 32767) # Yield the encoded audio chunk. yield audio_int16.tobytes() media_type = "audio/wav" return StreamingResponse( audio_generator(), media_type=media_type, headers={"Cache-Control": "no-cache"}, ) @app.get("/tts/full", summary="Full TTS") def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): """ Full TTS endpoint that synthesizes the entire text and returns a complete WAV or Opus file. """ voice_path = os.path.join(local_dir, f"voices/{voice}.bin") voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) tokens = tokenizer(text) ref_s = voices[len(tokens)] final_token = [[0, *tokens, 0]] start_time = time.time() audio = sess.run(None, { "input_ids": final_token, "style": ref_s, "speed": np.ones(1, dtype=np.float32) * speed, })[0] print(f"Full TTS inference time: {time.time()-start_time:.3f}s") # Convert to int16 PCM. audio = (audio * 32767).astype(np.int16).flatten() if format.lower() == "wav": wav_io = io.BytesIO() write_wav(wav_io, 24000, audio) wav_io.seek(0) return Response(content=wav_io.read(), media_type="audio/wav") elif format.lower() == "opus": opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio.astype(np.float32)/32767), sample_rate=24000) return Response(content=opus_data, media_type="audio/opus") else: raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}") @app.get("/", response_class=HTMLResponse) def index(): """ HTML demo page for Kokoro TTS. """ return """