bcci commited on
Commit
363625a
·
verified ·
1 Parent(s): b60527b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -62
app.py CHANGED
@@ -15,7 +15,6 @@ app = FastAPI(title="Kokoro TTS FastAPI")
15
  # ------------------------------------------------------------------------------
16
  # Global Pipeline Instance
17
  # ------------------------------------------------------------------------------
18
- # Create one pipeline instance for the entire app.
19
  pipeline = KPipeline(lang_code="a")
20
 
21
 
@@ -27,13 +26,11 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
27
  """
28
  Generate a WAV header for streaming.
29
  Since we don't know the final audio size, we set the data chunk size to a large dummy value.
30
- This header is sent only once at the start of the stream.
31
  """
32
  bits_per_sample = sample_width * 8
33
  byte_rate = sample_rate * num_channels * sample_width
34
  block_align = num_channels * sample_width
35
- # total file size = 36 + data_size (header is 44 bytes total)
36
- total_size = 36 + data_size
37
  header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
38
  fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
39
  data_chunk_header = struct.pack('<4sI', b'data', data_size)
@@ -42,30 +39,47 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
42
 
43
  def custom_split_text(text: str) -> list:
44
  """
45
- Custom splitting: split text into chunks where each chunk doubles in size.
 
 
 
 
 
 
46
  """
47
  words = text.split()
48
  chunks = []
49
  chunk_size = 2
50
  start = 0
51
  while start < len(words):
52
- end = start + chunk_size
53
- chunk = " ".join(words[start:end])
54
- chunks.append(chunk)
55
- start = end
56
- chunk_size += 2
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return chunks
58
 
59
 
60
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
61
  """
62
- Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
63
  """
64
- # Ensure tensor is on CPU and flatten if necessary.
65
  audio_np = audio_tensor.cpu().numpy()
66
  if audio_np.ndim > 1:
67
  audio_np = audio_np.flatten()
68
- # Scale to int16 range.
69
  audio_int16 = np.int16(audio_np * 32767)
70
  return audio_int16.tobytes()
71
 
@@ -82,7 +96,6 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
82
  The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data
83
  for each text chunk as soon as it is generated.
84
  """
85
- # Split the input text using the custom doubling strategy.
86
  chunks = custom_split_text(text)
87
  sample_rate = 24000
88
  num_channels = 1
@@ -94,19 +107,17 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
94
  yield header
95
  # Process and yield each chunk's PCM data.
96
  for i, chunk in enumerate(chunks):
97
- print(f"Processing chunk {i}: {chunk}") # Debugging
98
  try:
99
  results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
100
  for result in results:
101
  if result.audio is not None:
102
- print(f"Chunk {i}: Audio generated") # Debugging
103
- pcm_bytes = audio_tensor_to_pcm_bytes(result.audio)
104
- yield pcm_bytes
105
  else:
106
  print(f"Chunk {i}: No audio generated")
107
  except Exception as e:
108
  print(f"Error processing chunk {i}: {e}")
109
-
110
  return StreamingResponse(
111
  audio_generator(),
112
  media_type="audio/wav",
@@ -120,7 +131,6 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
120
  Full TTS endpoint that synthesizes the entire text, concatenates the audio,
121
  and returns a complete WAV file.
122
  """
123
- # Use newline-based splitting via the pipeline's split_pattern.
124
  results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
125
  audio_segments = []
126
  for result in results:
@@ -133,13 +143,11 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
133
  if not audio_segments:
134
  raise HTTPException(status_code=500, detail="No audio generated.")
135
 
136
- # Concatenate all audio segments.
137
  full_audio = np.concatenate(audio_segments)
138
 
139
- # Write the concatenated audio to an in-memory WAV file.
140
  sample_rate = 24000
141
  num_channels = 1
142
- sample_width = 2 # 16-bit PCM -> 2 bytes per sample
143
  wav_io = io.BytesIO()
144
  with wave.open(wav_io, "wb") as wav_file:
145
  wav_file.setnchannels(num_channels)
@@ -156,48 +164,131 @@ def index():
156
  """
157
  HTML demo page for Kokoro TTS.
158
 
159
- This page provides a simple UI to enter text, choose a voice and speed,
160
- and play synthesized audio from both the streaming and full endpoints.
 
 
 
161
  """
162
  return """
163
- <!DOCTYPE html>
164
- <html>
165
- <head>
166
- <title>Kokoro TTS Demo</title>
167
- </head>
168
- <body>
169
- <h1>Kokoro TTS Demo</h1>
170
- <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
171
- <label for="voice">Voice:</label>
172
- <input type="text" id="voice" value="af_heart"><br>
173
- <label for="speed">Speed:</label>
174
- <input type="number" step="0.1" id="speed" value="1.0"><br><br>
175
- <button onclick="playStreaming()">Play Streaming TTS</button>
176
- <button onclick="playFull()">Play Full TTS</button>
177
- <br><br>
178
- <audio id="audio" controls autoplay></audio>
179
- <script>
180
- function playStreaming() {
181
- const text = document.getElementById('text').value;
182
- const voice = document.getElementById('voice').value;
183
- const speed = document.getElementById('speed').value;
184
- const audio = document.getElementById('audio');
185
- // Set the audio element's source to the streaming endpoint.
186
- audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
187
- audio.play();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  }
189
- function playFull() {
190
- const text = document.getElementById('text').value;
191
- const voice = document.getElementById('voice').value;
192
- const speed = document.getElementById('speed').value;
193
- const audio = document.getElementById('audio');
194
- // Set the audio element's source to the full TTS endpoint.
195
- audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
196
- audio.play();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  }
198
- </script>
199
- </body>
200
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  """
202
 
203
 
@@ -206,5 +297,4 @@ def index():
206
  # ------------------------------------------------------------------------------
207
  if __name__ == "__main__":
208
  import uvicorn
209
-
210
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
15
  # ------------------------------------------------------------------------------
16
  # Global Pipeline Instance
17
  # ------------------------------------------------------------------------------
 
18
  pipeline = KPipeline(lang_code="a")
19
 
20
 
 
26
  """
27
  Generate a WAV header for streaming.
28
  Since we don't know the final audio size, we set the data chunk size to a large dummy value.
 
29
  """
30
  bits_per_sample = sample_width * 8
31
  byte_rate = sample_rate * num_channels * sample_width
32
  block_align = num_channels * sample_width
33
+ total_size = 36 + data_size # header (44 bytes) minus 8 + dummy data size
 
34
  header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
35
  fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
36
  data_chunk_header = struct.pack('<4sI', b'data', data_size)
 
39
 
40
  def custom_split_text(text: str) -> list:
41
  """
42
+ Custom splitting:
43
+ - Start with a chunk size of 2 words.
44
+ - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
45
+ then split the chunk at that word (i.e. include words up to and including that word).
46
+ - Otherwise, use the current chunk size.
47
+ - For subsequent chunks, increase the chunk size by 2 (i.e. 2, 4, 6, …).
48
+ - If there are fewer than the desired number of words for a full chunk, add all remaining words.
49
  """
50
  words = text.split()
51
  chunks = []
52
  chunk_size = 2
53
  start = 0
54
  while start < len(words):
55
+ candidate_end = start + chunk_size
56
+ if candidate_end > len(words):
57
+ candidate_end = len(words)
58
+ chunk_words = words[start:candidate_end]
59
+ # Look for a period in the chunk (from right to left)
60
+ split_index = None
61
+ for i in reversed(range(len(chunk_words))):
62
+ if '.' in chunk_words[i]:
63
+ split_index = i
64
+ break
65
+ if split_index is not None and split_index !== len(chunk_words) - 1:
66
+ # If a period is found and it’s not the last word in the chunk,
67
+ # adjust the chunk so it ends at that word.
68
+ candidate_end = start + split_index + 1
69
+ chunk_words = words[start:candidate_end]
70
+ chunks.append(" ".join(chunk_words))
71
+ start = candidate_end
72
+ chunk_size += 2 # Increase by 2 (added, not multiplied)
73
  return chunks
74
 
75
 
76
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
77
  """
78
+ Convert a torch.FloatTensor (with values assumed in [-1, 1]) to raw 16-bit PCM bytes.
79
  """
 
80
  audio_np = audio_tensor.cpu().numpy()
81
  if audio_np.ndim > 1:
82
  audio_np = audio_np.flatten()
 
83
  audio_int16 = np.int16(audio_np * 32767)
84
  return audio_int16.tobytes()
85
 
 
96
  The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data
97
  for each text chunk as soon as it is generated.
98
  """
 
99
  chunks = custom_split_text(text)
100
  sample_rate = 24000
101
  num_channels = 1
 
107
  yield header
108
  # Process and yield each chunk's PCM data.
109
  for i, chunk in enumerate(chunks):
110
+ print(f"Processing chunk {i}: {chunk}")
111
  try:
112
  results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
113
  for result in results:
114
  if result.audio is not None:
115
+ print(f"Chunk {i}: Audio generated")
116
+ yield audio_tensor_to_pcm_bytes(result.audio)
 
117
  else:
118
  print(f"Chunk {i}: No audio generated")
119
  except Exception as e:
120
  print(f"Error processing chunk {i}: {e}")
 
121
  return StreamingResponse(
122
  audio_generator(),
123
  media_type="audio/wav",
 
131
  Full TTS endpoint that synthesizes the entire text, concatenates the audio,
132
  and returns a complete WAV file.
133
  """
 
134
  results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
135
  audio_segments = []
136
  for result in results:
 
143
  if not audio_segments:
144
  raise HTTPException(status_code=500, detail="No audio generated.")
145
 
 
146
  full_audio = np.concatenate(audio_segments)
147
 
 
148
  sample_rate = 24000
149
  num_channels = 1
150
+ sample_width = 2 # 16-bit PCM
151
  wav_io = io.BytesIO()
152
  with wave.open(wav_io, "wb") as wav_file:
153
  wav_file.setnchannels(num_channels)
 
164
  """
165
  HTML demo page for Kokoro TTS.
166
 
167
+ Two playback methods are provided:
168
+ - "Play Full TTS" uses a standard <audio> element.
169
+ - "Play Streaming TTS" uses the Web Audio API (via a ScriptProcessorNode) to stream
170
+ the raw PCM data as it arrives. This method first reads the WAV header (44 bytes)
171
+ then continuously pulls in PCM data, converts it to Float32, and plays it.
172
  """
173
  return """
174
+ <!DOCTYPE html>
175
+ <html>
176
+ <head>
177
+ <title>Kokoro TTS Demo</title>
178
+ </head>
179
+ <body>
180
+ <h1>Kokoro TTS Demo</h1>
181
+ <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
182
+ <label for="voice">Voice:</label>
183
+ <input type="text" id="voice" value="af_heart"><br>
184
+ <label for="speed">Speed:</label>
185
+ <input type="number" step="0.1" id="speed" value="1.0"><br><br>
186
+ <button onclick="startStreaming()">Play Streaming TTS (Web Audio API)</button>
187
+ <button onclick="playFull()">Play Full TTS (Standard Audio)</button>
188
+ <br><br>
189
+ <audio id="fullAudio" controls></audio>
190
+ <script>
191
+ // Function to play full TTS by simply setting the <audio> element's source.
192
+ function playFull() {
193
+ const text = document.getElementById('text').value;
194
+ const voice = document.getElementById('voice').value;
195
+ const speed = document.getElementById('speed').value;
196
+ const audio = document.getElementById('fullAudio');
197
+ audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
198
+ audio.play();
199
+ }
200
+
201
+ // Function to stream audio using the Web Audio API.
202
+ async function startStreaming() {
203
+ const text = document.getElementById('text').value;
204
+ const voice = document.getElementById('voice').value;
205
+ const speed = document.getElementById('speed').value;
206
+ const response = await fetch(`/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`);
207
+ if (!response.body) {
208
+ alert("Streaming not supported in this browser.");
209
+ return;
210
+ }
211
+
212
+ const reader = response.body.getReader();
213
+ const audioContext = new (window.AudioContext || window.webkitAudioContext)();
214
+ // Create a ScriptProcessorNode (buffer size of 4096 samples)
215
+ const scriptNode = audioContext.createScriptProcessor(4096, 1, 1);
216
+ let bufferQueue = [];
217
+ let currentBuffer = new Float32Array(0);
218
+ let headerRead = false;
219
+ let headerBytes = new Uint8Array(0);
220
+
221
+ // Helper: Convert Int16 PCM (little-endian) to Float32.
222
+ function int16ToFloat32(buffer) {
223
+ const len = buffer.length;
224
+ const floatBuffer = new Float32Array(len);
225
+ for (let i = 0; i < len; i++) {
226
+ floatBuffer[i] = buffer[i] / 32767;
227
  }
228
+ return floatBuffer;
229
+ }
230
+
231
+ scriptNode.onaudioprocess = function(e) {
232
+ const output = e.outputBuffer.getChannelData(0);
233
+ let offset = 0;
234
+ while (offset < output.length) {
235
+ if (currentBuffer.length === 0) {
236
+ if (bufferQueue.length > 0) {
237
+ currentBuffer = bufferQueue.shift();
238
+ } else {
239
+ // If no data is available, output silence.
240
+ for (let i = offset; i < output.length; i++) {
241
+ output[i] = 0;
242
+ }
243
+ break;
244
+ }
245
+ }
246
+ const needed = output.length - offset;
247
+ const available = currentBuffer.length;
248
+ const toCopy = Math.min(needed, available);
249
+ output.set(currentBuffer.slice(0, toCopy), offset);
250
+ offset += toCopy;
251
+ if (toCopy < currentBuffer.length) {
252
+ currentBuffer = currentBuffer.slice(toCopy);
253
+ } else {
254
+ currentBuffer = new Float32Array(0);
255
+ }
256
  }
257
+ };
258
+ scriptNode.connect(audioContext.destination);
259
+
260
+ // Read the response stream.
261
+ while (true) {
262
+ const { done, value } = await reader.read();
263
+ if (done) break;
264
+ let chunk = value;
265
+ // First, accumulate the 44-byte WAV header.
266
+ if (!headerRead) {
267
+ let combined = new Uint8Array(headerBytes.length + chunk.length);
268
+ combined.set(headerBytes);
269
+ combined.set(chunk, headerBytes.length);
270
+ if (combined.length >= 44) {
271
+ headerBytes = combined.slice(0, 44);
272
+ headerRead = true;
273
+ // Remove the header bytes from the chunk.
274
+ chunk = combined.slice(44);
275
+ } else {
276
+ headerBytes = combined;
277
+ continue;
278
+ }
279
+ }
280
+ // Make sure the chunk length is even (2 bytes per sample).
281
+ if (chunk.length % 2 !== 0) {
282
+ chunk = chunk.slice(0, chunk.length - 1);
283
+ }
284
+ const int16Buffer = new Int16Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 2);
285
+ const floatBuffer = int16ToFloat32(int16Buffer);
286
+ bufferQueue.push(floatBuffer);
287
+ }
288
+ }
289
+ </script>
290
+ </body>
291
+ </html>
292
  """
293
 
294
 
 
297
  # ------------------------------------------------------------------------------
298
  if __name__ == "__main__":
299
  import uvicorn
 
300
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)