bcci commited on
Commit
f023a07
Β·
verified Β·
1 Parent(s): b9f1e8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -224
app.py CHANGED
@@ -10,247 +10,285 @@ from fastapi.responses import StreamingResponse, Response, HTMLResponse
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
- from kokoro import StreamKPipeline, KModel, KPipeline # Import StreamKPipeline and KModel
14
 
15
  app = FastAPI(
16
- title="Kokoro Streaming TTS FastAPI",
17
- middleware=[
18
- Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
19
- ]
20
  )
21
 
22
- # ------------------------------------------------------------------------------
23
- # Global Pipeline Instance
24
- # ------------------------------------------------------------------------------
25
- # Create one pipeline instance for the entire app.
26
- model = KModel() # Initialize KModel
27
- stream_pipeline = StreamKPipeline(lang_code="a", model=model, chunk_size=2) # Initialize StreamKPipeline, passing the model
28
- pipeline = KPipeline(lang_code="a", model=model) # Initialize KPipeline, passing the model
29
 
 
30
 
31
-
32
- # ------------------------------------------------------------------------------
33
- # Helper Functions
34
- # ------------------------------------------------------------------------------
35
 
36
  def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
37
- """
38
- Generate a WAV header for streaming.
39
- Since we don't know the final audio size, we set the data chunk size to a large dummy value.
40
- This header is sent only once at the start of the stream.
41
- """
42
- bits_per_sample = sample_width * 8
43
- byte_rate = sample_rate * num_channels * sample_width
44
- block_align = num_channels * sample_width
45
- # total file size = 36 + data_size (header is 44 bytes total)
46
- total_size = 36 + data_size
47
- header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
48
- fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
49
- data_chunk_header = struct.pack('<4sI', b'data', data_size)
50
- return header + fmt_chunk + data_chunk_header
51
-
52
-
53
- def audio_chunk_to_pcm_bytes(audio_chunk: torch.Tensor) -> bytes:
54
- """
55
- Convert a torch.FloatTensor audio chunk (values in [-1, 1]) to raw 16-bit PCM bytes.
56
- """
57
- # Ensure tensor is on CPU and flatten if necessary.
58
- audio_np = audio_chunk.cpu().numpy()
59
- if audio_np.ndim > 1:
60
- audio_np = audio_np.flatten()
61
- # Scale to int16 range.
62
- audio_int16 = np.int16(audio_np * 32767)
63
- return audio_int16.tobytes()
64
-
65
-
66
- def audio_chunk_to_opus_bytes(audio_chunk: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
67
- """
68
- Convert a torch.FloatTensor audio chunk to Opus encoded bytes.
69
- Requires the 'opuslib' package: pip install opuslib
70
- """
71
- try:
72
- import opuslib
73
- except ImportError:
74
- raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
75
-
76
- audio_np = audio_chunk.cpu().numpy()
77
- if audio_np.ndim > 1:
78
- audio_np = audio_np.flatten()
79
- # Scale to int16 range. Important for opus.
80
- audio_int16 = np.int16(audio_np * 32767)
81
-
82
- encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
83
-
84
- # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
85
- frame_size = int(sample_rate * 0.020) # 20ms frame size
86
-
87
- encoded_data = b''
88
- for i in range(0, len(audio_int16), frame_size):
89
- frame = audio_int16[i:i + frame_size]
90
- if len(frame) < frame_size:
91
- # Pad the last frame with zeros if needed.
92
- frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
93
- encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
94
- encoded_data += encoded_frame
95
-
96
- return encoded_data
97
-
98
-
99
- # ------------------------------------------------------------------------------
100
- # Streaming TTS Endpoint
101
- # ------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  @app.get("/tts/streaming", summary="Streaming TTS")
104
  def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
105
- """
106
- Streaming TTS endpoint that returns a continuous audio stream.
107
- Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
108
-
109
- The endpoint first yields a WAV header (with a dummy length) for WAV,
110
- then yields encoded audio data chunks as they are generated.
111
- """
112
- sample_rate = 24000
113
- num_channels = 1
114
- sample_width = 2 # 16-bit PCM
115
-
116
- def audio_chunk_generator():
117
- if format.lower() == "wav":
118
- # Yield the WAV header first for PCM WAV format.
119
- header = generate_wav_header(sample_rate, num_channels, sample_width)
120
- yield header
121
-
122
- # Stream audio chunks from the pipeline.
123
- for audio_chunk in stream_pipeline(text=text, voice=voice, speed=speed):
124
- if audio_chunk is not None and audio_chunk.numel() > 0:
125
- if format.lower() == "wav":
126
- yield audio_chunk_to_pcm_bytes(audio_chunk)
127
- elif format.lower() == "opus":
128
- yield audio_chunk_to_opus_bytes(audio_chunk, sample_rate=sample_rate)
 
 
 
 
 
 
 
129
  else:
130
- raise ValueError(f"Unsupported audio format: {format}")
131
-
 
 
132
 
133
- return StreamingResponse(
134
- audio_chunk_generator(),
135
- media_type="audio/wav",
136
- headers={"Cache-Control": "no-cache"},
137
- )
138
 
139
-
140
- # ------------------------------------------------------------------------------
141
- # Full TTS Endpoint (unchanged from your original code)
142
- # ------------------------------------------------------------------------------
 
 
 
 
143
 
144
  @app.get("/tts/full", summary="Full TTS")
145
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
146
- """
147
- Full TTS endpoint that synthesizes the entire text, concatenates the audio,
148
- and returns a complete WAV or Opus file.
149
- """
150
- # Use newline-based splitting via the pipeline's split_pattern.
151
- results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
152
- audio_segments = []
153
- for result in results:
154
- if result.audio is not None:
155
- audio_np = result.audio.cpu().numpy()
156
- if audio_np.ndim > 1:
157
- audio_np = audio_np.flatten()
158
- audio_segments.append(audio_np)
159
-
160
- if not audio_segments:
161
- raise HTTPException(status_code=500, detail="No audio generated.")
162
-
163
- # Concatenate all audio segments.
164
- full_audio = np.concatenate(audio_segments)
165
-
166
- # Write the concatenated audio to an in-memory WAV or Opus file.
167
- sample_rate = 24000
168
- num_channels = 1
169
- sample_width = 2 # 16-bit PCM -> 2 bytes per sample
170
- if format.lower() == "wav":
171
- wav_io = io.BytesIO()
172
- with wave.open(wav_io, "wb") as wav_file:
173
- wav_file.setnchannels(num_channels)
174
- wav_file.setsampwidth(sample_width)
175
- wav_file.setframerate(sample_rate)
176
- full_audio_int16 = np.int16(full_audio * 32767)
177
- wav_file.writeframes(full_audio_int16.tobytes())
178
- wav_io.seek(0)
179
- return Response(content=wav_io.read(), media_type="audio/wav")
180
- elif format.lower() == "opus":
181
- opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate)
182
- return Response(content=opus_data, media_type="audio/opus")
183
- else:
184
- raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
185
-
186
-
187
- # ------------------------------------------------------------------------------
188
- # HTML Demo Page Endpoint (unchanged from your original code, but updated to call new streaming endpoint)
189
- # ------------------------------------------------------------------------------
190
 
191
  @app.get("/", response_class=HTMLResponse)
192
  def index():
193
- """
194
- HTML demo page for Kokoro TTS.
195
-
196
- This page provides a simple UI to enter text, choose a voice and speed,
197
- and play synthesized audio from both the streaming and full endpoints.
198
- """
199
- return """
200
- <!DOCTYPE html>
201
- <html>
202
- <head>
203
- <title>Kokoro Streaming TTS Demo</title>
204
- </head>
205
- <body>
206
- <h1>Kokoro Streaming TTS Demo</h1>
207
- <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
208
- <label for="voice">Voice:</label>
209
- <input type="text" id="voice" value="af_heart"><br>
210
- <label for="speed">Speed:</label>
211
- <input type="number" step="0.1" id="speed" value="1.0"><br>
212
- <label for="format">Format:</label>
213
- <select id="format">
214
- <option value="wav">WAV</option>
215
- <option value="opus" selected>Opus</option>
216
- </select><br><br>
217
- <button onclick="playStreaming()">Play Streaming TTS</button>
218
- <button onclick="playFull()">Play Full TTS</button>
219
- <br><br>
220
- <audio id="audio" controls autoplay></audio>
221
- <script>
222
- function playStreaming() {
223
- const text = document.getElementById('text').value;
224
- const voice = document.getElementById('voice').value;
225
- const speed = document.getElementById('speed').value;
226
- const format = document.getElementById('format').value;
227
- const audio = document.getElementById('audio');
228
- // Set the audio element's source to the streaming endpoint.
229
- audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
230
- audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
231
- audio.play();
232
- }
233
- function playFull() {
234
- const text = document.getElementById('text').value;
235
- const voice = document.getElementById('voice').value;
236
- const speed = document.getElementById('speed').value;
237
- const format = document.getElementById('format').value;
238
- const audio = document.getElementById('audio');
239
- // Set the audio element's source to the full TTS endpoint.
240
- audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
241
- audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
242
- audio.play();
243
- }
244
- </script>
245
- </body>
246
- </html>
247
- """
248
-
249
-
250
- # ------------------------------------------------------------------------------
251
- # Run with: uvicorn app:app --reload
252
- # ------------------------------------------------------------------------------
253
- if __name__ == "__main__":
254
- import uvicorn
255
-
256
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
+ from kokoro import KPipeline
14
 
15
  app = FastAPI(
16
+ title="Kokoro TTS FastAPI",
17
+ middleware=[
18
+ Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
19
+ ]
20
  )
21
 
22
+ ------------------------------------------------------------------------------
23
+ Global Pipeline Instance
24
+ ------------------------------------------------------------------------------
25
+ Create one pipeline instance for the entire app.
 
 
 
26
 
27
+ pipeline = KPipeline(lang_code="a")
28
 
29
+ ------------------------------------------------------------------------------
30
+ Helper Functions
31
+ ------------------------------------------------------------------------------
 
32
 
33
  def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
34
+ """
35
+ Generate a WAV header for streaming.
36
+ Since we don't know the final audio size, we set the data chunk size to a large dummy value.
37
+ This header is sent only once at the start of the stream.
38
+ """
39
+ bits_per_sample = sample_width * 8
40
+ byte_rate = sample_rate * num_channels * sample_width
41
+ block_align = num_channels * sample_width
42
+ # total file size = 36 + data_size (header is 44 bytes total)
43
+ total_size = 36 + data_size
44
+ header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
45
+ fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
46
+ data_chunk_header = struct.pack('<4sI', b'data', data_size)
47
+ return header + fmt_chunk + data_chunk_header
48
+
49
+ def custom_split_text(text: str) -> list:
50
+ """
51
+ Custom splitting:
52
+ - Start with a chunk size of 2 words.
53
+ - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
54
+ then split the chunk at that word (include words up to that word).
55
+ - Otherwise, use the current chunk size.
56
+ - For subsequent chunks, increase the chunk size by 2.
57
+ - If there are fewer than the desired number of words for a full chunk, add all remaining words.
58
+ """
59
+ words = text.split()
60
+ chunks = []
61
+ chunk_size = 2
62
+ start = 0
63
+ while start < len(words):
64
+ candidate_end = start + chunk_size
65
+ if candidate_end > len(words):
66
+ candidate_end = len(words)
67
+ chunk_words = words[start:candidate_end]
68
+ # Look for a period in any word except the last one.
69
+ split_index = None
70
+ for i in range(len(chunk_words) - 1):
71
+ if '.' in chunk_words[i]:
72
+ split_index = i
73
+ break
74
+ if split_index is not None:
75
+ candidate_end = start + split_index + 1
76
+ chunk_words = words[start:candidate_end]
77
+ chunks.append(" ".join(chunk_words))
78
+ start = candidate_end
79
+ chunk_size += 2 # Increase the chunk size by 2 for the next iteration.
80
+ return chunks
81
+
82
+ def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
83
+ """
84
+ Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
85
+ """
86
+ # Ensure tensor is on CPU and flatten if necessary.
87
+ audio_np = audio_tensor.cpu().numpy()
88
+ if audio_np.ndim > 1:
89
+ audio_np = audio_np.flatten()
90
+ # Scale to int16 range.
91
+ audio_int16 = np.int16(audio_np * 32767)
92
+ return audio_int16.tobytes()
93
+
94
+ def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
95
+ """
96
+ Convert a torch.FloatTensor to Opus encoded bytes.
97
+ Requires the 'opuslib' package: pip install opuslib
98
+ """
99
+ try:
100
+ import opuslib
101
+ except ImportError:
102
+ raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
103
+
104
+ audio_np = audio_tensor.cpu().numpy()
105
+ if audio_np.ndim > 1:
106
+ audio_np = audio_np.flatten()
107
+ # Scale to int16 range. Important for opus.
108
+ audio_int16 = np.int16(audio_np * 32767)
109
+
110
+ encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
111
+
112
+ # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
113
+ frame_size = int(sample_rate * 0.020) # 20ms frame size
114
+
115
+ encoded_data = b''
116
+ for i in range(0, len(audio_int16), frame_size):
117
+ frame = audio_int16[i:i + frame_size]
118
+ if len(frame) < frame_size:
119
+ # Pad the last frame with zeros if needed.
120
+ frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
121
+ encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
122
+ encoded_data += encoded_frame
123
+
124
+ return encoded_data
125
+ content_copy
126
+ download
127
+ Use code with caution.
128
+ ------------------------------------------------------------------------------
129
+ Endpoints
130
+ ------------------------------------------------------------------------------
131
 
132
  @app.get("/tts/streaming", summary="Streaming TTS")
133
  def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
134
+ """
135
+ Streaming TTS endpoint that returns a continuous audio stream.
136
+ Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
137
+
138
+ The endpoint first yields a WAV header (with a dummy length) for WAV,
139
+ then yields encoded audio data for each text chunk as soon as it is generated.
140
+ """
141
+ # Split the input text using the custom doubling strategy.
142
+ chunks = custom_split_text(text)
143
+ sample_rate = 24000
144
+ num_channels = 1
145
+ sample_width = 2 # 16-bit PCM
146
+
147
+ def audio_generator():
148
+ if format.lower() == "wav":
149
+ # Yield the WAV header first.
150
+ header = generate_wav_header(sample_rate, num_channels, sample_width)
151
+ yield header
152
+ # Process and yield each chunk's audio data.
153
+ for i, chunk in enumerate(chunks):
154
+ print(f"Processing chunk {i}: {chunk}") # Debugging
155
+ try:
156
+ results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
157
+ for result in results:
158
+ if result.audio is not None:
159
+ if format.lower() == "wav":
160
+ yield audio_tensor_to_pcm_bytes(result.audio)
161
+ elif format.lower() == "opus":
162
+ yield audio_tensor_to_opus_bytes(result.audio, sample_rate=sample_rate)
163
+ else:
164
+ raise ValueError(f"Unsupported audio format: {format}")
165
  else:
166
+ print(f"Chunk {i}: No audio generated")
167
+ except Exception as e:
168
+ print(f"Error processing chunk {i}: {e}")
169
+ yield b'' # important so that streaming continues. Consider returning an error sound.
170
 
171
+ media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
 
 
 
 
172
 
173
+ return StreamingResponse(
174
+ audio_generator(),
175
+ media_type=media_type,
176
+ headers={"Cache-Control": "no-cache"},
177
+ )
178
+ content_copy
179
+ download
180
+ Use code with caution.
181
 
182
  @app.get("/tts/full", summary="Full TTS")
183
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
184
+ """
185
+ Full TTS endpoint that synthesizes the entire text, concatenates the audio,
186
+ and returns a complete WAV or Opus file.
187
+ """
188
+ # Use newline-based splitting via the pipeline's split_pattern.
189
+ results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
190
+ audio_segments = []
191
+ for result in results:
192
+ if result.audio is not None:
193
+ audio_np = result.audio.cpu().numpy()
194
+ if audio_np.ndim > 1:
195
+ audio_np = audio_np.flatten()
196
+ audio_segments.append(audio_np)
197
+
198
+ if not audio_segments:
199
+ raise HTTPException(status_code=500, detail="No audio generated.")
200
+
201
+ # Concatenate all audio segments.
202
+ full_audio = np.concatenate(audio_segments)
203
+
204
+ # Write the concatenated audio to an in-memory WAV or Opus file.
205
+ sample_rate = 24000
206
+ num_channels = 1
207
+ sample_width = 2 # 16-bit PCM -> 2 bytes per sample
208
+ if format.lower() == "wav":
209
+ wav_io = io.BytesIO()
210
+ with wave.open(wav_io, "wb") as wav_file:
211
+ wav_file.setnchannels(num_channels)
212
+ wav_file.setsampwidth(sample_width)
213
+ wav_file.setframerate(sample_rate)
214
+ full_audio_int16 = np.int16(full_audio * 32767)
215
+ wav_file.writeframes(full_audio_int16.tobytes())
216
+ wav_io.seek(0)
217
+ return Response(content=wav_io.read(), media_type="audio/wav")
218
+ elif format.lower() == "opus":
219
+ opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate)
220
+ return Response(content=opus_data, media_type="audio/opus")
221
+ else:
222
+ raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
223
+ content_copy
224
+ download
225
+ Use code with caution.
 
 
226
 
227
  @app.get("/", response_class=HTMLResponse)
228
  def index():
229
+ """
230
+ HTML demo page for Kokoro TTS.
231
+
232
+ This page provides a simple UI to enter text, choose a voice and speed,
233
+ and play synthesized audio from both the streaming and full endpoints.
234
+ """
235
+ return """
236
+ <!DOCTYPE html>
237
+ <html>
238
+ <head>
239
+ <title>Kokoro TTS Demo</title>
240
+ </head>
241
+ <body>
242
+ <h1>Kokoro TTS Demo</h1>
243
+ <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
244
+ <label for="voice">Voice:</label>
245
+ <input type="text" id="voice" value="af_heart"><br>
246
+ <label for="speed">Speed:</label>
247
+ <input type="number" step="0.1" id="speed" value="1.0"><br>
248
+ <label for="format">Format:</label>
249
+ <select id="format">
250
+ <option value="wav">WAV</option>
251
+ <option value="opus" selected>Opus</option>
252
+ </select><br><br>
253
+ <button onclick="playStreaming()">Play Streaming TTS</button>
254
+ <button onclick="playFull()">Play Full TTS</button>
255
+ <br><br>
256
+ <audio id="audio" controls autoplay></audio>
257
+ <script>
258
+ function playStreaming() {
259
+ const text = document.getElementById('text').value;
260
+ const voice = document.getElementById('voice').value;
261
+ const speed = document.getElementById('speed').value;
262
+ const format = document.getElementById('format').value;
263
+ const audio = document.getElementById('audio');
264
+ // Set the audio element's source to the streaming endpoint.
265
+ audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
266
+ audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
267
+ audio.play();
268
+ }
269
+ function playFull() {
270
+ const text = document.getElementById('text').value;
271
+ const voice = document.getElementById('voice').value;
272
+ const speed = document.getElementById('speed').value;
273
+ const format = document.getElementById('format').value;
274
+ const audio = document.getElementById('audio');
275
+ // Set the audio element's source to the full TTS endpoint.
276
+ audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
277
+ audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
278
+ audio.play();
279
+ }
280
+ </script>
281
+ </body>
282
+ </html>
283
+ """
284
+ content_copy
285
+ download
286
+ Use code with caution.
287
+ ------------------------------------------------------------------------------
288
+ Run with: uvicorn app:app --reload
289
+ ------------------------------------------------------------------------------
290
+
291
+ if name == "main":
292
+ import uvicorn
293
+
294
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)