bcci commited on
Commit
80ce7b7
Β·
verified Β·
1 Parent(s): 9c48211

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -154
app.py CHANGED
@@ -10,8 +10,7 @@ from fastapi.responses import StreamingResponse, Response, HTMLResponse
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
- # --- IMPORTANT: Use the AutoregressiveStreamKPipeline ---
14
- from kokoro.pipeline import AutoregressiveStreamKPipeline # Or wherever your pipeline is.
15
 
16
  app = FastAPI(
17
  title="Kokoro TTS FastAPI",
@@ -24,8 +23,9 @@ app = FastAPI(
24
  # Global Pipeline Instance
25
  # ------------------------------------------------------------------------------
26
  # Create one pipeline instance for the entire app.
 
 
27
 
28
- pipeline = AutoregressiveStreamKPipeline(lang_code="a") # Use the autoregressive pipeline
29
 
30
  # ------------------------------------------------------------------------------
31
  # Helper Functions
@@ -48,6 +48,40 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
48
  return header + fmt_chunk + data_chunk_header
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
52
  """
53
  Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
@@ -60,78 +94,42 @@ def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
60
  audio_int16 = np.int16(audio_np * 32767)
61
  return audio_int16.tobytes()
62
 
63
- def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
64
- """
65
- Convert a torch.FloatTensor to Opus encoded bytes.
66
- Requires the 'opuslib' package: pip install opuslib
67
- """
68
- try:
69
- import opuslib
70
- except ImportError:
71
- raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
72
-
73
- audio_np = audio_tensor.cpu().numpy()
74
- if audio_np.ndim > 1:
75
- audio_np = audio_np.flatten()
76
- # Scale to int16 range. Important for opus.
77
- audio_int16 = np.int16(audio_np * 32767)
78
-
79
- encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
80
-
81
- # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
82
- frame_size = int(sample_rate * 0.020) # 20ms frame size
83
-
84
- encoded_data = b''
85
- for i in range(0, len(audio_int16), frame_size):
86
- frame = audio_int16[i:i + frame_size]
87
- if len(frame) < frame_size:
88
- # Pad the last frame with zeros if needed.
89
- frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
90
- encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
91
- encoded_data += encoded_frame
92
 
93
- return encoded_data
94
 
95
  # ------------------------------------------------------------------------------
96
  # Endpoints
97
  # ------------------------------------------------------------------------------
98
 
99
- @app.get("/tts/streaming", summary="Streaming TTS (Autoregressive)")
100
- def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
101
  """
102
- Streaming TTS endpoint that attempts autoregressive, near sample-by-sample output.
103
 
104
- IMPORTANT: This is EXPERIMENTAL and may have reduced quality compared to
105
- the full or chunking methods. It's also likely to be slower due to the
106
- per-phoneme processing overhead.
107
  """
108
  sample_rate = 24000
109
  num_channels = 1
110
  sample_width = 2 # 16-bit PCM
111
 
112
  def audio_generator():
113
- if format.lower() == "wav":
114
- # Yield the WAV header first.
115
- header = generate_wav_header(sample_rate, num_channels, sample_width)
116
- yield header
117
 
 
118
  try:
119
- # Use the AUTOREGRESSIVE pipeline
120
- for audio_chunk in pipeline(text, voice=voice, speed=speed):
121
- print(audio_chunk)
122
- if audio_chunk.numel() > 0: # Ensure we have audio data
123
- if format.lower() == "wav":
124
- yield audio_tensor_to_pcm_bytes(audio_chunk)
125
- elif format.lower() == "opus":
126
- yield audio_tensor_to_opus_bytes(audio_chunk, sample_rate=sample_rate)
127
- else:
128
- raise ValueError(f"Unsupported audio format: {format}")
129
-
130
  except Exception as e:
131
- print(f"Error during streaming: {e}")
132
- yield b'' # Yield empty bytes to avoid breaking the stream
 
 
 
133
 
134
- media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
135
  return StreamingResponse(
136
  audio_generator(),
137
  media_type=media_type,
@@ -139,54 +137,41 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format
139
  )
140
 
141
 
142
-
143
- # @app.get("/tts/full", summary="Full TTS")
144
- # def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
145
- # """
146
- # Full TTS endpoint (no streaming). Synthesizes the entire text and returns
147
- # a complete WAV or Opus file.
148
- # """
149
- # # Use newline-based splitting. This is the *original* KPipeline,
150
- # # which is better for full synthesis. It's important to use
151
- # # the right pipeline for the right task.
152
- # from kokoro.pipeline import KPipeline # Import here to avoid circular import
153
- # full_pipeline = KPipeline(lang_code="a")
154
-
155
- # results = list(full_pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
156
- # audio_segments = []
157
- # for result in results:
158
- # if result.audio is not None:
159
- # audio_np = result.audio.cpu().numpy()
160
- # if audio_np.ndim > 1:
161
- # audio_np = audio_np.flatten()
162
- # audio_segments.append(audio_np)
163
-
164
- # if not audio_segments:
165
- # raise HTTPException(status_code=500, detail="No audio generated.")
166
-
167
- # # Concatenate all audio segments.
168
- # full_audio = np.concatenate(audio_segments)
169
-
170
- # # Write the concatenated audio to an in-memory WAV or Opus file.
171
- # sample_rate = 24000
172
- # num_channels = 1
173
- # sample_width = 2 # 16-bit PCM -> 2 bytes per sample
174
- # if format.lower() == "wav":
175
- # wav_io = io.BytesIO()
176
- # with wave.open(wav_io, "wb") as wav_file:
177
- # wav_file.setnchannels(num_channels)
178
- # wav_file.setsampwidth(sample_width)
179
- # wav_file.setframerate(sample_rate)
180
- # full_audio_int16 = np.int16(full_audio * 32767)
181
- # wav_file.writeframes(full_audio_int16.tobytes())
182
- # wav_io.seek(0)
183
- # return Response(content=wav_io.read(), media_type="audio/wav")
184
- # elif format.lower() == "opus":
185
- # opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate)
186
- # return Response(content=opus_data, media_type="audio/opus")
187
- # else:
188
- # raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
189
-
190
 
191
 
192
 
@@ -194,61 +179,58 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format
194
  def index():
195
  """
196
  HTML demo page for Kokoro TTS.
 
 
 
197
  """
198
  return """
199
- <!DOCTYPE html>
200
- <html>
201
- <head>
202
- <title>Kokoro TTS Demo</title>
203
- </head>
204
- <body>
205
- <h1>Kokoro TTS Demo</h1>
206
- <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
207
- <label for="voice">Voice:</label>
208
- <input type="text" id="voice" value="af_heart"><br>
209
- <label for="speed">Speed:</label>
210
- <input type="number" step="0.1" id="speed" value="1.0"><br>
211
- <label for="format">Format:</label>
212
- <select id="format">
213
- <option value="wav">WAV</option>
214
- <option value="opus" selected>Opus</option>
215
- </select><br><br>
216
- <button onclick="playStreaming()">Play Streaming TTS</button>
217
- <button onclick="playFull()">Play Full TTS</button>
218
- <br><br>
219
- <audio id="audio" controls autoplay></audio>
220
- <script>
221
- function playStreaming() {
222
- const text = document.getElementById('text').value;
223
- const voice = document.getElementById('voice').value;
224
- const speed = document.getElementById('speed').value;
225
- const format = document.getElementById('format').value;
226
- const audio = document.getElementById('audio');
227
- // Set the audio element's source to the streaming endpoint.
228
- audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
229
- audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
230
- audio.play();
231
- }
232
- function playFull() {
233
- const text = document.getElementById('text').value;
234
- const voice = document.getElementById('voice').value;
235
- const speed = document.getElementById('speed').value;
236
- const format = document.getElementById('format').value;
237
- const audio = document.getElementById('audio');
238
- // Set the audio element's source to the full TTS endpoint.
239
- audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
240
- audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
241
- audio.play();
242
- }
243
- </script>
244
- </body>
245
- </html>
246
- """
247
 
248
  # ------------------------------------------------------------------------------
249
  # Run with: uvicorn app:app --reload
250
  # ------------------------------------------------------------------------------
251
-
252
  if __name__ == "__main__":
253
  import uvicorn
254
 
 
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
+ from kokoro import StreamKPipeline, KPipeline # Import StreamKPipeline and KPipeline
 
14
 
15
  app = FastAPI(
16
  title="Kokoro TTS FastAPI",
 
23
  # Global Pipeline Instance
24
  # ------------------------------------------------------------------------------
25
  # Create one pipeline instance for the entire app.
26
+ stream_pipeline = StreamKPipeline(lang_code="a") # Use StreamKPipeline for streaming
27
+ full_pipeline = KPipeline(lang_code="a") # Keep KPipeline for full TTS
28
 
 
29
 
30
  # ------------------------------------------------------------------------------
31
  # Helper Functions
 
48
  return header + fmt_chunk + data_chunk_header
49
 
50
 
51
+ def custom_split_text(text: str) -> list:
52
+ """
53
+ Custom splitting:
54
+ - Start with a chunk size of 2 words.
55
+ - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
56
+ then split the chunk at that word (include words up to that word).
57
+ - Otherwise, use the current chunk size.
58
+ - For subsequent chunks, increase the chunk size by 2.
59
+ - If there are fewer than the desired number of words for a full chunk, add all remaining words.
60
+ """
61
+ words = text.split()
62
+ chunks = []
63
+ chunk_size = 2
64
+ start = 0
65
+ while start < len(words):
66
+ candidate_end = start + chunk_size
67
+ if candidate_end > len(words):
68
+ candidate_end = len(words)
69
+ chunk_words = words[start:candidate_end]
70
+ # Look for a period in any word except the last one.
71
+ split_index = None
72
+ for i in range(len(chunk_words) - 1):
73
+ if '.' in chunk_words[i]:
74
+ split_index = i
75
+ break
76
+ if split_index is not None:
77
+ candidate_end = start + split_index + 1
78
+ chunk_words = words[start:candidate_end]
79
+ chunks.append(" ".join(chunk_words))
80
+ start = candidate_end
81
+ chunk_size += 2 # Increase the chunk size by 2 for the next iteration.
82
+ return chunks
83
+
84
+
85
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
86
  """
87
  Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
 
94
  audio_int16 = np.int16(audio_np * 32767)
95
  return audio_int16.tobytes()
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
98
 
99
  # ------------------------------------------------------------------------------
100
  # Endpoints
101
  # ------------------------------------------------------------------------------
102
 
103
+ @app.get("/tts/streaming", summary="Streaming TTS")
104
+ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
105
  """
106
+ Streaming TTS endpoint that returns a continuous audio stream in WAV format (PCM).
107
 
108
+ The endpoint yields a WAV header (with a dummy length) only once at the start of the stream,
109
+ then yields PCM audio data chunks as they are generated in real-time.
 
110
  """
111
  sample_rate = 24000
112
  num_channels = 1
113
  sample_width = 2 # 16-bit PCM
114
 
115
  def audio_generator():
116
+ # Yield the WAV header first.
117
+ header = generate_wav_header(sample_rate, num_channels, sample_width)
118
+ yield header
 
119
 
120
+ # Stream audio chunks from StreamKPipeline
121
  try:
122
+ for stream_result in stream_pipeline(text, voice=voice, speed=speed, split_pattern=r'([.!?…])\s+'): # Split at sentence ends
123
+ if stream_result.audio_chunk is not None:
124
+ pcm_bytes = audio_tensor_to_pcm_bytes(stream_result.audio_chunk)
125
+ yield pcm_bytes
 
 
 
 
 
 
 
126
  except Exception as e:
127
+ print(f"Streaming error: {e}")
128
+ yield b'' # Keep stream alive on error
129
+
130
+
131
+ media_type = "audio/wav"
132
 
 
133
  return StreamingResponse(
134
  audio_generator(),
135
  media_type=media_type,
 
137
  )
138
 
139
 
140
+ @app.get("/tts/full", summary="Full TTS")
141
+ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
142
+ """
143
+ Full TTS endpoint that synthesizes the entire text using KPipeline,
144
+ concatenates the audio, and returns a complete WAV file.
145
+ """
146
+ # Use newline-based splitting via the pipeline's split_pattern.
147
+ results = list(full_pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
148
+ audio_segments = []
149
+ for result in results:
150
+ if result.audio is not None:
151
+ audio_np = result.audio.cpu().numpy()
152
+ if audio_np.ndim > 1:
153
+ audio_np = audio_np.flatten()
154
+ audio_segments.append(audio_np)
155
+
156
+ if not audio_segments:
157
+ raise HTTPException(status_code=500, detail="No audio generated.")
158
+
159
+ # Concatenate all audio segments.
160
+ full_audio = np.concatenate(audio_segments)
161
+
162
+ # Write the concatenated audio to an in-memory WAV file.
163
+ sample_rate = 24000
164
+ num_channels = 1
165
+ sample_width = 2 # 16-bit PCM -> 2 bytes per sample
166
+ wav_io = io.BytesIO()
167
+ with wave.open(wav_io, "wb") as wav_file:
168
+ wav_file.setnchannels(num_channels)
169
+ wav_file.setsampwidth(sample_width)
170
+ wav_file.setframerate(sample_rate)
171
+ full_audio_int16 = np.int16(full_audio * 32767)
172
+ wav_file.writeframes(full_audio_int16.tobytes())
173
+ wav_io.seek(0)
174
+ return Response(content=wav_io.read(), media_type="audio/wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
 
 
179
  def index():
180
  """
181
  HTML demo page for Kokoro TTS.
182
+
183
+ This page provides a simple UI to enter text, choose a voice and speed,
184
+ and play synthesized audio from both the streaming and full endpoints.
185
  """
186
  return """
187
+ <!DOCTYPE html>
188
+ <html>
189
+ <head>
190
+ <title>Kokoro TTS Demo</title>
191
+ </head>
192
+ <body>
193
+ <h1>Kokoro TTS Demo</h1>
194
+ <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
195
+ <label for="voice">Voice:</label>
196
+ <input type="text" id="voice" value="af_heart"><br>
197
+ <label for="speed">Speed:</label>
198
+ <input type="number" step="0.1" id="speed" value="1.0"><br>
199
+ <br><br>
200
+ <button onclick="playStreaming()">Play Streaming TTS</button>
201
+ <button onclick="playFull()">Play Full TTS (Download WAV)</button>
202
+ <br><br>
203
+ <audio id="audio" controls autoplay></audio>
204
+ <script>
205
+ function playStreaming() {
206
+ const text = document.getElementById('text').value;
207
+ const voice = document.getElementById('voice').value;
208
+ const speed = document.getElementById('speed').value;
209
+ const audio = document.getElementById('audio');
210
+ // Set the audio element's source to the streaming endpoint.
211
+ audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
212
+ audio.type = 'audio/wav';
213
+ audio.play();
214
+ }
215
+ function playFull() {
216
+ const text = document.getElementById('text').value;
217
+ const voice = document.getElementById('voice').value;
218
+ const speed = document.getElementById('speed').value;
219
+ const audio = document.getElementById('audio');
220
+ // Set the audio element's source to the full TTS endpoint.
221
+ audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
222
+ audio.type = 'audio/wav';
223
+ audio.play();
224
+ }
225
+ </script>
226
+ </body>
227
+ </html>
228
+ """
229
+
 
 
 
 
 
230
 
231
  # ------------------------------------------------------------------------------
232
  # Run with: uvicorn app:app --reload
233
  # ------------------------------------------------------------------------------
 
234
  if __name__ == "__main__":
235
  import uvicorn
236