bcci commited on
Commit
e7655ad
·
verified ·
1 Parent(s): 65e1914

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -29
app.py CHANGED
@@ -7,10 +7,17 @@ import numpy as np
7
  import torch
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.responses import StreamingResponse, Response, HTMLResponse
 
 
10
 
11
  from kokoro import KPipeline
12
 
13
- app = FastAPI(title="Kokoro TTS FastAPI")
 
 
 
 
 
14
 
15
  # ------------------------------------------------------------------------------
16
  # Global Pipeline Instance
@@ -87,17 +94,51 @@ def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
87
  return audio_int16.tobytes()
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # ------------------------------------------------------------------------------
91
  # Endpoints
92
  # ------------------------------------------------------------------------------
93
 
94
  @app.get("/tts/streaming", summary="Streaming TTS")
95
- def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
96
  """
97
- Streaming TTS endpoint that returns a continuous WAV stream.
98
-
99
- The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data
100
- for each text chunk as soon as it is generated.
 
101
  """
102
  # Split the input text using the custom doubling strategy.
103
  chunks = custom_split_text(text)
@@ -106,34 +147,43 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
106
  sample_width = 2 # 16-bit PCM
107
 
108
  def audio_generator():
109
- # Yield the WAV header first.
110
- header = generate_wav_header(sample_rate, num_channels, sample_width)
111
- yield header
112
- # Process and yield each chunk's PCM data.
 
113
  for i, chunk in enumerate(chunks):
114
  print(f"Processing chunk {i}: {chunk}") # Debugging
115
  try:
116
  results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
117
  for result in results:
118
  if result.audio is not None:
119
- yield audio_tensor_to_pcm_bytes(result.audio)
 
 
 
 
 
120
  else:
121
  print(f"Chunk {i}: No audio generated")
122
  except Exception as e:
123
  print(f"Error processing chunk {i}: {e}")
 
 
 
124
 
125
  return StreamingResponse(
126
  audio_generator(),
127
- media_type="audio/wav",
128
  headers={"Cache-Control": "no-cache"},
129
  )
130
 
131
 
132
  @app.get("/tts/full", summary="Full TTS")
133
- def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
134
  """
135
  Full TTS endpoint that synthesizes the entire text, concatenates the audio,
136
- and returns a complete WAV file.
137
  """
138
  # Use newline-based splitting via the pipeline's split_pattern.
139
  results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
@@ -151,26 +201,33 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
151
  # Concatenate all audio segments.
152
  full_audio = np.concatenate(audio_segments)
153
 
154
- # Write the concatenated audio to an in-memory WAV file.
155
  sample_rate = 24000
156
  num_channels = 1
157
  sample_width = 2 # 16-bit PCM -> 2 bytes per sample
158
- wav_io = io.BytesIO()
159
- with wave.open(wav_io, "wb") as wav_file:
160
- wav_file.setnchannels(num_channels)
161
- wav_file.setsampwidth(sample_width)
162
- wav_file.setframerate(sample_rate)
163
- full_audio_int16 = np.int16(full_audio * 32767)
164
- wav_file.writeframes(full_audio_int16.tobytes())
165
- wav_io.seek(0)
166
- return Response(content=wav_io.read(), media_type="audio/wav")
 
 
 
 
 
 
 
167
 
168
 
169
  @app.get("/", response_class=HTMLResponse)
170
  def index():
171
  """
172
  HTML demo page for Kokoro TTS.
173
-
174
  This page provides a simple UI to enter text, choose a voice and speed,
175
  and play synthesized audio from both the streaming and full endpoints.
176
  """
@@ -186,7 +243,12 @@ def index():
186
  <label for="voice">Voice:</label>
187
  <input type="text" id="voice" value="af_heart"><br>
188
  <label for="speed">Speed:</label>
189
- <input type="number" step="0.1" id="speed" value="1.0"><br><br>
 
 
 
 
 
190
  <button onclick="playStreaming()">Play Streaming TTS</button>
191
  <button onclick="playFull()">Play Full TTS</button>
192
  <br><br>
@@ -196,18 +258,22 @@ def index():
196
  const text = document.getElementById('text').value;
197
  const voice = document.getElementById('voice').value;
198
  const speed = document.getElementById('speed').value;
 
199
  const audio = document.getElementById('audio');
200
  // Set the audio element's source to the streaming endpoint.
201
- audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
 
202
  audio.play();
203
  }
204
  function playFull() {
205
  const text = document.getElementById('text').value;
206
  const voice = document.getElementById('voice').value;
207
  const speed = document.getElementById('speed').value;
 
208
  const audio = document.getElementById('audio');
209
  // Set the audio element's source to the full TTS endpoint.
210
- audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
 
211
  audio.play();
212
  }
213
  </script>
@@ -222,4 +288,4 @@ def index():
222
  if __name__ == "__main__":
223
  import uvicorn
224
 
225
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
7
  import torch
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.responses import StreamingResponse, Response, HTMLResponse
10
+ from fastapi.middleware import Middleware
11
+ from fastapi.middleware.gzip import GZipMiddleware
12
 
13
  from kokoro import KPipeline
14
 
15
+ app = FastAPI(
16
+ title="Kokoro TTS FastAPI",
17
+ middleware=[
18
+ Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
19
+ ]
20
+ )
21
 
22
  # ------------------------------------------------------------------------------
23
  # Global Pipeline Instance
 
94
  return audio_int16.tobytes()
95
 
96
 
97
+ def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
98
+ """
99
+ Convert a torch.FloatTensor to Opus encoded bytes.
100
+ Requires the 'opuslib' package: pip install opuslib
101
+ """
102
+ try:
103
+ import opuslib
104
+ except ImportError:
105
+ raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
106
+
107
+ audio_np = audio_tensor.cpu().numpy()
108
+ if audio_np.ndim > 1:
109
+ audio_np = audio_np.flatten()
110
+ # Scale to int16 range. Important for opus.
111
+ audio_int16 = np.int16(audio_np * 32767)
112
+
113
+ encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
114
+
115
+ # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
116
+ frame_size = int(sample_rate * 0.020) # 20ms frame size
117
+
118
+ encoded_data = b''
119
+ for i in range(0, len(audio_int16), frame_size):
120
+ frame = audio_int16[i:i + frame_size]
121
+ if len(frame) < frame_size:
122
+ # Pad the last frame with zeros if needed.
123
+ frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
124
+ encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
125
+ encoded_data += encoded_frame
126
+
127
+ return encoded_data
128
+
129
+
130
  # ------------------------------------------------------------------------------
131
  # Endpoints
132
  # ------------------------------------------------------------------------------
133
 
134
  @app.get("/tts/streaming", summary="Streaming TTS")
135
+ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
136
  """
137
+ Streaming TTS endpoint that returns a continuous audio stream.
138
+ Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
139
+
140
+ The endpoint first yields a WAV header (with a dummy length) for WAV,
141
+ then yields encoded audio data for each text chunk as soon as it is generated.
142
  """
143
  # Split the input text using the custom doubling strategy.
144
  chunks = custom_split_text(text)
 
147
  sample_width = 2 # 16-bit PCM
148
 
149
  def audio_generator():
150
+ if format.lower() == "wav":
151
+ # Yield the WAV header first.
152
+ header = generate_wav_header(sample_rate, num_channels, sample_width)
153
+ yield header
154
+ # Process and yield each chunk's audio data.
155
  for i, chunk in enumerate(chunks):
156
  print(f"Processing chunk {i}: {chunk}") # Debugging
157
  try:
158
  results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
159
  for result in results:
160
  if result.audio is not None:
161
+ if format.lower() == "wav":
162
+ yield audio_tensor_to_pcm_bytes(result.audio)
163
+ elif format.lower() == "opus":
164
+ yield audio_tensor_to_opus_bytes(result.audio, sample_rate=sample_rate)
165
+ else:
166
+ raise ValueError(f"Unsupported audio format: {format}")
167
  else:
168
  print(f"Chunk {i}: No audio generated")
169
  except Exception as e:
170
  print(f"Error processing chunk {i}: {e}")
171
+ yield b'' # important so that streaming continues. Consider returning an error sound.
172
+
173
+ media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
174
 
175
  return StreamingResponse(
176
  audio_generator(),
177
+ media_type=media_type,
178
  headers={"Cache-Control": "no-cache"},
179
  )
180
 
181
 
182
  @app.get("/tts/full", summary="Full TTS")
183
+ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
184
  """
185
  Full TTS endpoint that synthesizes the entire text, concatenates the audio,
186
+ and returns a complete WAV or Opus file.
187
  """
188
  # Use newline-based splitting via the pipeline's split_pattern.
189
  results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
 
201
  # Concatenate all audio segments.
202
  full_audio = np.concatenate(audio_segments)
203
 
204
+ # Write the concatenated audio to an in-memory WAV or Opus file.
205
  sample_rate = 24000
206
  num_channels = 1
207
  sample_width = 2 # 16-bit PCM -> 2 bytes per sample
208
+ if format.lower() == "wav":
209
+ wav_io = io.BytesIO()
210
+ with wave.open(wav_io, "wb") as wav_file:
211
+ wav_file.setnchannels(num_channels)
212
+ wav_file.setsampwidth(sample_width)
213
+ wav_file.setframerate(sample_rate)
214
+ full_audio_int16 = np.int16(full_audio * 32767)
215
+ wav_file.writeframes(full_audio_int16.tobytes())
216
+ wav_io.seek(0)
217
+ return Response(content=wav_io.read(), media_type="audio/wav")
218
+ elif format.lower() == "opus":
219
+ opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate)
220
+ return Response(content=opus_data, media_type="audio/opus")
221
+ else:
222
+ raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
223
+
224
 
225
 
226
  @app.get("/", response_class=HTMLResponse)
227
  def index():
228
  """
229
  HTML demo page for Kokoro TTS.
230
+
231
  This page provides a simple UI to enter text, choose a voice and speed,
232
  and play synthesized audio from both the streaming and full endpoints.
233
  """
 
243
  <label for="voice">Voice:</label>
244
  <input type="text" id="voice" value="af_heart"><br>
245
  <label for="speed">Speed:</label>
246
+ <input type="number" step="0.1" id="speed" value="1.0"><br>
247
+ <label for="format">Format:</label>
248
+ <select id="format">
249
+ <option value="wav">WAV</option>
250
+ <option value="opus" selected>Opus</option>
251
+ </select><br><br>
252
  <button onclick="playStreaming()">Play Streaming TTS</button>
253
  <button onclick="playFull()">Play Full TTS</button>
254
  <br><br>
 
258
  const text = document.getElementById('text').value;
259
  const voice = document.getElementById('voice').value;
260
  const speed = document.getElementById('speed').value;
261
+ const format = document.getElementById('format').value;
262
  const audio = document.getElementById('audio');
263
  // Set the audio element's source to the streaming endpoint.
264
+ audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
265
+ audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
266
  audio.play();
267
  }
268
  function playFull() {
269
  const text = document.getElementById('text').value;
270
  const voice = document.getElementById('voice').value;
271
  const speed = document.getElementById('speed').value;
272
+ const format = document.getElementById('format').value;
273
  const audio = document.getElementById('audio');
274
  // Set the audio element's source to the full TTS endpoint.
275
+ audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
276
+ audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
277
  audio.play();
278
  }
279
  </script>
 
288
  if __name__ == "__main__":
289
  import uvicorn
290
 
291
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)