bcci commited on
Commit
0870f8f
Β·
verified Β·
1 Parent(s): 80ce7b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -111
app.py CHANGED
@@ -2,6 +2,7 @@ import io
2
  import re
3
  import wave
4
  import struct
 
5
 
6
  import numpy as np
7
  import torch
@@ -10,7 +11,8 @@ from fastapi.responses import StreamingResponse, Response, HTMLResponse
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
- from kokoro import StreamKPipeline, KPipeline # Import StreamKPipeline and KPipeline
 
14
 
15
  app = FastAPI(
16
  title="Kokoro TTS FastAPI",
@@ -23,9 +25,14 @@ app = FastAPI(
23
  # Global Pipeline Instance
24
  # ------------------------------------------------------------------------------
25
  # Create one pipeline instance for the entire app.
26
- stream_pipeline = StreamKPipeline(lang_code="a") # Use StreamKPipeline for streaming
27
- full_pipeline = KPipeline(lang_code="a") # Keep KPipeline for full TTS
 
 
 
 
28
 
 
29
 
30
  # ------------------------------------------------------------------------------
31
  # Helper Functions
@@ -48,40 +55,6 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
48
  return header + fmt_chunk + data_chunk_header
49
 
50
 
51
- def custom_split_text(text: str) -> list:
52
- """
53
- Custom splitting:
54
- - Start with a chunk size of 2 words.
55
- - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
56
- then split the chunk at that word (include words up to that word).
57
- - Otherwise, use the current chunk size.
58
- - For subsequent chunks, increase the chunk size by 2.
59
- - If there are fewer than the desired number of words for a full chunk, add all remaining words.
60
- """
61
- words = text.split()
62
- chunks = []
63
- chunk_size = 2
64
- start = 0
65
- while start < len(words):
66
- candidate_end = start + chunk_size
67
- if candidate_end > len(words):
68
- candidate_end = len(words)
69
- chunk_words = words[start:candidate_end]
70
- # Look for a period in any word except the last one.
71
- split_index = None
72
- for i in range(len(chunk_words) - 1):
73
- if '.' in chunk_words[i]:
74
- split_index = i
75
- break
76
- if split_index is not None:
77
- candidate_end = start + split_index + 1
78
- chunk_words = words[start:candidate_end]
79
- chunks.append(" ".join(chunk_words))
80
- start = candidate_end
81
- chunk_size += 2 # Increase the chunk size by 2 for the next iteration.
82
- return chunks
83
-
84
-
85
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
86
  """
87
  Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
@@ -101,12 +74,12 @@ def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
101
  # ------------------------------------------------------------------------------
102
 
103
  @app.get("/tts/streaming", summary="Streaming TTS")
104
- def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
105
  """
106
- Streaming TTS endpoint that returns a continuous audio stream in WAV format (PCM).
107
 
108
- The endpoint yields a WAV header (with a dummy length) only once at the start of the stream,
109
- then yields PCM audio data chunks as they are generated in real-time.
110
  """
111
  sample_rate = 24000
112
  num_channels = 1
@@ -117,16 +90,18 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
117
  header = generate_wav_header(sample_rate, num_channels, sample_width)
118
  yield header
119
 
120
- # Stream audio chunks from StreamKPipeline
121
  try:
122
- for stream_result in stream_pipeline(text, voice=voice, speed=speed, split_pattern=r'([.!?…])\s+'): # Split at sentence ends
123
- if stream_result.audio_chunk is not None:
124
- pcm_bytes = audio_tensor_to_pcm_bytes(stream_result.audio_chunk)
125
- yield pcm_bytes
126
- except Exception as e:
127
- print(f"Streaming error: {e}")
128
- yield b'' # Keep stream alive on error
129
 
 
 
 
 
 
 
 
 
130
 
131
  media_type = "audio/wav"
132
 
@@ -136,52 +111,13 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
136
  headers={"Cache-Control": "no-cache"},
137
  )
138
 
139
-
140
- @app.get("/tts/full", summary="Full TTS")
141
- def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
142
- """
143
- Full TTS endpoint that synthesizes the entire text using KPipeline,
144
- concatenates the audio, and returns a complete WAV file.
145
- """
146
- # Use newline-based splitting via the pipeline's split_pattern.
147
- results = list(full_pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
148
- audio_segments = []
149
- for result in results:
150
- if result.audio is not None:
151
- audio_np = result.audio.cpu().numpy()
152
- if audio_np.ndim > 1:
153
- audio_np = audio_np.flatten()
154
- audio_segments.append(audio_np)
155
-
156
- if not audio_segments:
157
- raise HTTPException(status_code=500, detail="No audio generated.")
158
-
159
- # Concatenate all audio segments.
160
- full_audio = np.concatenate(audio_segments)
161
-
162
- # Write the concatenated audio to an in-memory WAV file.
163
- sample_rate = 24000
164
- num_channels = 1
165
- sample_width = 2 # 16-bit PCM -> 2 bytes per sample
166
- wav_io = io.BytesIO()
167
- with wave.open(wav_io, "wb") as wav_file:
168
- wav_file.setnchannels(num_channels)
169
- wav_file.setsampwidth(sample_width)
170
- wav_file.setframerate(sample_rate)
171
- full_audio_int16 = np.int16(full_audio * 32767)
172
- wav_file.writeframes(full_audio_int16.tobytes())
173
- wav_io.seek(0)
174
- return Response(content=wav_io.read(), media_type="audio/wav")
175
-
176
-
177
-
178
  @app.get("/", response_class=HTMLResponse)
179
  def index():
180
  """
181
  HTML demo page for Kokoro TTS.
182
 
183
- This page provides a simple UI to enter text, choose a voice and speed,
184
- and play synthesized audio from both the streaming and full endpoints.
185
  """
186
  return """
187
  <!DOCTYPE html>
@@ -191,34 +127,15 @@ def index():
191
  </head>
192
  <body>
193
  <h1>Kokoro TTS Demo</h1>
194
- <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
195
- <label for="voice">Voice:</label>
196
- <input type="text" id="voice" value="af_heart"><br>
197
- <label for="speed">Speed:</label>
198
- <input type="number" step="0.1" id="speed" value="1.0"><br>
199
- <br><br>
200
  <button onclick="playStreaming()">Play Streaming TTS</button>
201
- <button onclick="playFull()">Play Full TTS (Download WAV)</button>
202
  <br><br>
203
  <audio id="audio" controls autoplay></audio>
204
  <script>
205
  function playStreaming() {
206
  const text = document.getElementById('text').value;
207
- const voice = document.getElementById('voice').value;
208
- const speed = document.getElementById('speed').value;
209
- const audio = document.getElementById('audio');
210
- // Set the audio element's source to the streaming endpoint.
211
- audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
212
- audio.type = 'audio/wav';
213
- audio.play();
214
- }
215
- function playFull() {
216
- const text = document.getElementById('text').value;
217
- const voice = document.getElementById('voice').value;
218
- const speed = document.getElementById('speed').value;
219
  const audio = document.getElementById('audio');
220
- // Set the audio element's source to the full TTS endpoint.
221
- audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
222
  audio.type = 'audio/wav';
223
  audio.play();
224
  }
 
2
  import re
3
  import wave
4
  import struct
5
+ import time
6
 
7
  import numpy as np
8
  import torch
 
11
  from fastapi.middleware import Middleware
12
  from fastapi.middleware.gzip import GZipMiddleware
13
 
14
+ from kokoro import KPipeline, StreamKPipeline
15
+ from kokoro.model import KModel
16
 
17
  app = FastAPI(
18
  title="Kokoro TTS FastAPI",
 
25
  # Global Pipeline Instance
26
  # ------------------------------------------------------------------------------
27
  # Create one pipeline instance for the entire app.
28
+ model = KModel() # Or however you initialize/load your model
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model.to(device)
31
+ #pipeline = KPipeline(lang_code="a",model=model)
32
+ voice = "af_heart"
33
+ speed = 1.0
34
 
35
+ pipeline = StreamKPipeline(lang_code="a", model=model, voice=voice, device=device, speed=speed)
36
 
37
  # ------------------------------------------------------------------------------
38
  # Helper Functions
 
55
  return header + fmt_chunk + data_chunk_header
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
59
  """
60
  Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
 
74
  # ------------------------------------------------------------------------------
75
 
76
  @app.get("/tts/streaming", summary="Streaming TTS")
77
+ def tts_streaming(text: str):
78
  """
79
+ Streaming TTS endpoint that returns a continuous audio stream.
80
 
81
+ The endpoint yields a WAV header (with a dummy length) for WAV,
82
+ then yields encoded audio data for each phoneme as soon as it is generated.
83
  """
84
  sample_rate = 24000
85
  num_channels = 1
 
90
  header = generate_wav_header(sample_rate, num_channels, sample_width)
91
  yield header
92
 
93
+ # Process and yield each audio chunk.
94
  try:
95
+ for result in pipeline(text): # Use StreamKPipeline
 
 
 
 
 
 
96
 
97
+ if result.audio is not None:
98
+ yield audio_tensor_to_pcm_bytes(result.audio)
99
+
100
+ else:
101
+ print("No audio generated for phoneme")
102
+ except Exception as e:
103
+ print(f"Error processing: {e}")
104
+ yield b'' # Important so that streaming continues.
105
 
106
  media_type = "audio/wav"
107
 
 
111
  headers={"Cache-Control": "no-cache"},
112
  )
113
 
114
+ #Remove full tts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  @app.get("/", response_class=HTMLResponse)
116
  def index():
117
  """
118
  HTML demo page for Kokoro TTS.
119
 
120
+ This page provides a simple UI to enter text and play synthesized audio from the streaming endpoint.
 
121
  """
122
  return """
123
  <!DOCTYPE html>
 
127
  </head>
128
  <body>
129
  <h1>Kokoro TTS Demo</h1>
130
+ <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br><br>
 
 
 
 
 
131
  <button onclick="playStreaming()">Play Streaming TTS</button>
 
132
  <br><br>
133
  <audio id="audio" controls autoplay></audio>
134
  <script>
135
  function playStreaming() {
136
  const text = document.getElementById('text').value;
 
 
 
 
 
 
 
 
 
 
 
 
137
  const audio = document.getElementById('audio');
138
+ audio.src = `/tts/streaming?text=${encodeURIComponent(text)}`;
 
139
  audio.type = 'audio/wav';
140
  audio.play();
141
  }