bcci commited on
Commit
bc93f92
Β·
verified Β·
1 Parent(s): f023a07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -204
app.py CHANGED
@@ -10,229 +10,191 @@ from fastapi.responses import StreamingResponse, Response, HTMLResponse
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
- from kokoro import KPipeline
 
14
 
15
  app = FastAPI(
16
- title="Kokoro TTS FastAPI",
17
- middleware=[
18
- Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
19
- ]
20
  )
21
 
22
- ------------------------------------------------------------------------------
23
- Global Pipeline Instance
24
- ------------------------------------------------------------------------------
25
- Create one pipeline instance for the entire app.
26
 
27
- pipeline = KPipeline(lang_code="a")
28
 
29
- ------------------------------------------------------------------------------
30
- Helper Functions
31
- ------------------------------------------------------------------------------
32
 
33
  def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
34
- """
35
- Generate a WAV header for streaming.
36
- Since we don't know the final audio size, we set the data chunk size to a large dummy value.
37
- This header is sent only once at the start of the stream.
38
- """
39
- bits_per_sample = sample_width * 8
40
- byte_rate = sample_rate * num_channels * sample_width
41
- block_align = num_channels * sample_width
42
- # total file size = 36 + data_size (header is 44 bytes total)
43
- total_size = 36 + data_size
44
- header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
45
- fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
46
- data_chunk_header = struct.pack('<4sI', b'data', data_size)
47
- return header + fmt_chunk + data_chunk_header
48
-
49
- def custom_split_text(text: str) -> list:
50
- """
51
- Custom splitting:
52
- - Start with a chunk size of 2 words.
53
- - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
54
- then split the chunk at that word (include words up to that word).
55
- - Otherwise, use the current chunk size.
56
- - For subsequent chunks, increase the chunk size by 2.
57
- - If there are fewer than the desired number of words for a full chunk, add all remaining words.
58
- """
59
- words = text.split()
60
- chunks = []
61
- chunk_size = 2
62
- start = 0
63
- while start < len(words):
64
- candidate_end = start + chunk_size
65
- if candidate_end > len(words):
66
- candidate_end = len(words)
67
- chunk_words = words[start:candidate_end]
68
- # Look for a period in any word except the last one.
69
- split_index = None
70
- for i in range(len(chunk_words) - 1):
71
- if '.' in chunk_words[i]:
72
- split_index = i
73
- break
74
- if split_index is not None:
75
- candidate_end = start + split_index + 1
76
- chunk_words = words[start:candidate_end]
77
- chunks.append(" ".join(chunk_words))
78
- start = candidate_end
79
- chunk_size += 2 # Increase the chunk size by 2 for the next iteration.
80
- return chunks
81
 
82
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
83
- """
84
- Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
85
- """
86
- # Ensure tensor is on CPU and flatten if necessary.
87
- audio_np = audio_tensor.cpu().numpy()
88
- if audio_np.ndim > 1:
89
- audio_np = audio_np.flatten()
90
- # Scale to int16 range.
91
- audio_int16 = np.int16(audio_np * 32767)
92
- return audio_int16.tobytes()
93
 
94
  def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
95
- """
96
- Convert a torch.FloatTensor to Opus encoded bytes.
97
- Requires the 'opuslib' package: pip install opuslib
98
- """
99
- try:
100
- import opuslib
101
- except ImportError:
102
- raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
103
-
104
- audio_np = audio_tensor.cpu().numpy()
105
- if audio_np.ndim > 1:
106
- audio_np = audio_np.flatten()
107
- # Scale to int16 range. Important for opus.
108
- audio_int16 = np.int16(audio_np * 32767)
109
-
110
- encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
111
-
112
- # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
113
- frame_size = int(sample_rate * 0.020) # 20ms frame size
114
-
115
- encoded_data = b''
116
- for i in range(0, len(audio_int16), frame_size):
117
- frame = audio_int16[i:i + frame_size]
118
- if len(frame) < frame_size:
119
- # Pad the last frame with zeros if needed.
120
- frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
121
- encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
122
- encoded_data += encoded_frame
123
-
124
- return encoded_data
125
- content_copy
126
- download
127
- Use code with caution.
128
- ------------------------------------------------------------------------------
129
- Endpoints
130
- ------------------------------------------------------------------------------
131
-
132
- @app.get("/tts/streaming", summary="Streaming TTS")
133
  def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
134
- """
135
- Streaming TTS endpoint that returns a continuous audio stream.
136
- Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- The endpoint first yields a WAV header (with a dummy length) for WAV,
139
- then yields encoded audio data for each text chunk as soon as it is generated.
140
- """
141
- # Split the input text using the custom doubling strategy.
142
- chunks = custom_split_text(text)
143
- sample_rate = 24000
144
- num_channels = 1
145
- sample_width = 2 # 16-bit PCM
146
-
147
- def audio_generator():
148
- if format.lower() == "wav":
149
- # Yield the WAV header first.
150
- header = generate_wav_header(sample_rate, num_channels, sample_width)
151
- yield header
152
- # Process and yield each chunk's audio data.
153
- for i, chunk in enumerate(chunks):
154
- print(f"Processing chunk {i}: {chunk}") # Debugging
155
  try:
156
- results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
157
- for result in results:
158
- if result.audio is not None:
159
  if format.lower() == "wav":
160
- yield audio_tensor_to_pcm_bytes(result.audio)
161
  elif format.lower() == "opus":
162
- yield audio_tensor_to_opus_bytes(result.audio, sample_rate=sample_rate)
163
  else:
164
  raise ValueError(f"Unsupported audio format: {format}")
165
- else:
166
- print(f"Chunk {i}: No audio generated")
167
  except Exception as e:
168
- print(f"Error processing chunk {i}: {e}")
169
- yield b'' # important so that streaming continues. Consider returning an error sound.
 
 
 
 
 
 
 
170
 
171
- media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
172
 
173
- return StreamingResponse(
174
- audio_generator(),
175
- media_type=media_type,
176
- headers={"Cache-Control": "no-cache"},
177
- )
178
- content_copy
179
- download
180
- Use code with caution.
181
 
182
  @app.get("/tts/full", summary="Full TTS")
183
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
184
- """
185
- Full TTS endpoint that synthesizes the entire text, concatenates the audio,
186
- and returns a complete WAV or Opus file.
187
- """
188
- # Use newline-based splitting via the pipeline's split_pattern.
189
- results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
190
- audio_segments = []
191
- for result in results:
192
- if result.audio is not None:
193
- audio_np = result.audio.cpu().numpy()
194
- if audio_np.ndim > 1:
195
- audio_np = audio_np.flatten()
196
- audio_segments.append(audio_np)
197
-
198
- if not audio_segments:
199
- raise HTTPException(status_code=500, detail="No audio generated.")
200
-
201
- # Concatenate all audio segments.
202
- full_audio = np.concatenate(audio_segments)
203
-
204
- # Write the concatenated audio to an in-memory WAV or Opus file.
205
- sample_rate = 24000
206
- num_channels = 1
207
- sample_width = 2 # 16-bit PCM -> 2 bytes per sample
208
- if format.lower() == "wav":
209
- wav_io = io.BytesIO()
210
- with wave.open(wav_io, "wb") as wav_file:
211
- wav_file.setnchannels(num_channels)
212
- wav_file.setsampwidth(sample_width)
213
- wav_file.setframerate(sample_rate)
214
- full_audio_int16 = np.int16(full_audio * 32767)
215
- wav_file.writeframes(full_audio_int16.tobytes())
216
- wav_io.seek(0)
217
- return Response(content=wav_io.read(), media_type="audio/wav")
218
- elif format.lower() == "opus":
219
- opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate)
220
- return Response(content=opus_data, media_type="audio/opus")
221
- else:
222
- raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
223
- content_copy
224
- download
225
- Use code with caution.
 
 
 
 
 
226
 
227
  @app.get("/", response_class=HTMLResponse)
228
  def index():
229
- """
230
- HTML demo page for Kokoro TTS.
231
-
232
- This page provides a simple UI to enter text, choose a voice and speed,
233
- and play synthesized audio from both the streaming and full endpoints.
234
- """
235
- return """
236
  <!DOCTYPE html>
237
  <html>
238
  <head>
@@ -281,14 +243,12 @@ return """
281
  </body>
282
  </html>
283
  """
284
- content_copy
285
- download
286
- Use code with caution.
287
- ------------------------------------------------------------------------------
288
- Run with: uvicorn app:app --reload
289
- ------------------------------------------------------------------------------
290
 
291
- if name == "main":
292
- import uvicorn
 
 
 
 
293
 
294
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
10
  from fastapi.middleware import Middleware
11
  from fastapi.middleware.gzip import GZipMiddleware
12
 
13
+ # --- IMPORTANT: Use the AutoregressiveStreamKPipeline ---
14
+ from kokoro.pipeline import AutoregressiveStreamKPipeline # Or wherever your pipeline is.
15
 
16
  app = FastAPI(
17
+ title="Kokoro TTS FastAPI",
18
+ middleware=[
19
+ Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
20
+ ]
21
  )
22
 
23
+ # ------------------------------------------------------------------------------
24
+ # Global Pipeline Instance
25
+ # ------------------------------------------------------------------------------
26
+ # Create one pipeline instance for the entire app.
27
 
28
+ pipeline = AutoregressiveStreamKPipeline(lang_code="a") # Use the autoregressive pipeline
29
 
30
+ # ------------------------------------------------------------------------------
31
+ # Helper Functions
32
+ # ------------------------------------------------------------------------------
33
 
34
  def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
35
+ """
36
+ Generate a WAV header for streaming.
37
+ Since we don't know the final audio size, we set the data chunk size to a large dummy value.
38
+ This header is sent only once at the start of the stream.
39
+ """
40
+ bits_per_sample = sample_width * 8
41
+ byte_rate = sample_rate * num_channels * sample_width
42
+ block_align = num_channels * sample_width
43
+ # total file size = 36 + data_size (header is 44 bytes total)
44
+ total_size = 36 + data_size
45
+ header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
46
+ fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
47
+ data_chunk_header = struct.pack('<4sI', b'data', data_size)
48
+ return header + fmt_chunk + data_chunk_header
49
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
52
+ """
53
+ Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
54
+ """
55
+ # Ensure tensor is on CPU and flatten if necessary.
56
+ audio_np = audio_tensor.cpu().numpy()
57
+ if audio_np.ndim > 1:
58
+ audio_np = audio_np.flatten()
59
+ # Scale to int16 range.
60
+ audio_int16 = np.int16(audio_np * 32767)
61
+ return audio_int16.tobytes()
62
 
63
  def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
64
+ """
65
+ Convert a torch.FloatTensor to Opus encoded bytes.
66
+ Requires the 'opuslib' package: pip install opuslib
67
+ """
68
+ try:
69
+ import opuslib
70
+ except ImportError:
71
+ raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
72
+
73
+ audio_np = audio_tensor.cpu().numpy()
74
+ if audio_np.ndim > 1:
75
+ audio_np = audio_np.flatten()
76
+ # Scale to int16 range. Important for opus.
77
+ audio_int16 = np.int16(audio_np * 32767)
78
+
79
+ encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
80
+
81
+ # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
82
+ frame_size = int(sample_rate * 0.020) # 20ms frame size
83
+
84
+ encoded_data = b''
85
+ for i in range(0, len(audio_int16), frame_size):
86
+ frame = audio_int16[i:i + frame_size]
87
+ if len(frame) < frame_size:
88
+ # Pad the last frame with zeros if needed.
89
+ frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
90
+ encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
91
+ encoded_data += encoded_frame
92
+
93
+ return encoded_data
94
+
95
+ # ------------------------------------------------------------------------------
96
+ # Endpoints
97
+ # ------------------------------------------------------------------------------
98
+
99
+ @app.get("/tts/streaming", summary="Streaming TTS (Autoregressive)")
 
 
100
  def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
101
+ """
102
+ Streaming TTS endpoint that attempts autoregressive, near sample-by-sample output.
103
+
104
+ IMPORTANT: This is EXPERIMENTAL and may have reduced quality compared to
105
+ the full or chunking methods. It's also likely to be slower due to the
106
+ per-phoneme processing overhead.
107
+ """
108
+ sample_rate = 24000
109
+ num_channels = 1
110
+ sample_width = 2 # 16-bit PCM
111
+
112
+ def audio_generator():
113
+ if format.lower() == "wav":
114
+ # Yield the WAV header first.
115
+ header = generate_wav_header(sample_rate, num_channels, sample_width)
116
+ yield header
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  try:
119
+ # Use the AUTOREGRESSIVE pipeline
120
+ for audio_chunk in pipeline(text, voice=voice, speed=speed):
121
+ if audio_chunk.numel() > 0: # Ensure we have audio data
122
  if format.lower() == "wav":
123
+ yield audio_tensor_to_pcm_bytes(audio_chunk)
124
  elif format.lower() == "opus":
125
+ yield audio_tensor_to_opus_bytes(audio_chunk, sample_rate=sample_rate)
126
  else:
127
  raise ValueError(f"Unsupported audio format: {format}")
128
+
 
129
  except Exception as e:
130
+ print(f"Error during streaming: {e}")
131
+ yield b'' # Yield empty bytes to avoid breaking the stream
132
+
133
+ media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
134
+ return StreamingResponse(
135
+ audio_generator(),
136
+ media_type=media_type,
137
+ headers={"Cache-Control": "no-cache"},
138
+ )
139
 
 
140
 
 
 
 
 
 
 
 
 
141
 
142
  @app.get("/tts/full", summary="Full TTS")
143
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
144
+ """
145
+ Full TTS endpoint (no streaming). Synthesizes the entire text and returns
146
+ a complete WAV or Opus file.
147
+ """
148
+ # Use newline-based splitting. This is the *original* KPipeline,
149
+ # which is better for full synthesis. It's important to use
150
+ # the right pipeline for the right task.
151
+ from kokoro.pipeline import KPipeline # Import here to avoid circular import
152
+ full_pipeline = KPipeline(lang_code="a")
153
+
154
+ results = list(full_pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
155
+ audio_segments = []
156
+ for result in results:
157
+ if result.audio is not None:
158
+ audio_np = result.audio.cpu().numpy()
159
+ if audio_np.ndim > 1:
160
+ audio_np = audio_np.flatten()
161
+ audio_segments.append(audio_np)
162
+
163
+ if not audio_segments:
164
+ raise HTTPException(status_code=500, detail="No audio generated.")
165
+
166
+ # Concatenate all audio segments.
167
+ full_audio = np.concatenate(audio_segments)
168
+
169
+ # Write the concatenated audio to an in-memory WAV or Opus file.
170
+ sample_rate = 24000
171
+ num_channels = 1
172
+ sample_width = 2 # 16-bit PCM -> 2 bytes per sample
173
+ if format.lower() == "wav":
174
+ wav_io = io.BytesIO()
175
+ with wave.open(wav_io, "wb") as wav_file:
176
+ wav_file.setnchannels(num_channels)
177
+ wav_file.setsampwidth(sample_width)
178
+ wav_file.setframerate(sample_rate)
179
+ full_audio_int16 = np.int16(full_audio * 32767)
180
+ wav_file.writeframes(full_audio_int16.tobytes())
181
+ wav_io.seek(0)
182
+ return Response(content=wav_io.read(), media_type="audio/wav")
183
+ elif format.lower() == "opus":
184
+ opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate)
185
+ return Response(content=opus_data, media_type="audio/opus")
186
+ else:
187
+ raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
188
+
189
+
190
+
191
 
192
  @app.get("/", response_class=HTMLResponse)
193
  def index():
194
+ """
195
+ HTML demo page for Kokoro TTS.
196
+ """
197
+ return """
 
 
 
198
  <!DOCTYPE html>
199
  <html>
200
  <head>
 
243
  </body>
244
  </html>
245
  """
 
 
 
 
 
 
246
 
247
+ # ------------------------------------------------------------------------------
248
+ # Run with: uvicorn app:app --reload
249
+ # ------------------------------------------------------------------------------
250
+
251
+ if __name__ == "__main__":
252
+ import uvicorn
253
 
254
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)