bcci commited on
Commit
25bd1c6
·
verified ·
1 Parent(s): 05eca7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -143
app.py CHANGED
@@ -2,6 +2,9 @@ import io
2
  import re
3
  import wave
4
  import struct
 
 
 
5
 
6
  import numpy as np
7
  import torch
@@ -12,71 +15,58 @@ from fastapi.middleware.gzip import GZipMiddleware
12
 
13
  from misaki import en
14
 
15
- import os
16
- import numpy as np
17
  from onnxruntime import InferenceSession
18
  from huggingface_hub import snapshot_download
19
-
20
- import json
21
  from scipy.io.wavfile import write as write_wav
22
 
23
- import time
24
-
25
- # Load the configuration file
26
- config_file_path = 'config.json' # Update this with the path to your config file
27
-
28
  with open(config_file_path, 'r') as f:
29
  config = json.load(f)
30
-
31
- # Extract the phoneme vocabulary
32
  phoneme_vocab = config['vocab']
33
 
34
- # Step 3: Download the model and voice file from Hugging Face Hub
 
 
35
  model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX"
36
  model_name = "onnx/model_q8f16.onnx"
37
  voice_file_pattern = "*.bin"
38
  local_dir = "."
39
-
40
- # Download the model and voice file
41
  snapshot_download(
42
  repo_id=model_repo,
43
  allow_patterns=[model_name, voice_file_pattern],
44
  local_dir=local_dir
45
  )
46
 
47
- # Step 4: Load the model
 
 
48
  model_path = os.path.join(local_dir, model_name)
49
  sess = InferenceSession(model_path)
50
 
 
 
 
51
  app = FastAPI(
52
  title="Kokoro TTS FastAPI",
53
- middleware=[
54
- Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression
55
- ]
56
  )
57
 
58
- # ------------------------------------------------------------------------------
59
- # Global Pipeline Instance
60
- # ------------------------------------------------------------------------------
61
- # Create one pipeline instance for the entire app.
62
-
63
-
64
-
65
  # ------------------------------------------------------------------------------
66
  # Helper Functions
67
  # ------------------------------------------------------------------------------
68
 
69
  def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
70
  """
71
- Generate a WAV header for streaming.
72
- Since we don't know the final audio size, we set the data chunk size to a large dummy value.
73
- This header is sent only once at the start of the stream.
74
  """
75
  bits_per_sample = sample_width * 8
76
  byte_rate = sample_rate * num_channels * sample_width
77
  block_align = num_channels * sample_width
78
- # total file size = 36 + data_size (header is 44 bytes total)
79
- total_size = 36 + data_size
80
  header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
81
  fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
82
  data_chunk_header = struct.pack('<4sI', b'data', data_size)
@@ -85,13 +75,13 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
85
 
86
  def custom_split_text(text: str) -> list:
87
  """
88
- Custom splitting:
89
  - Start with a chunk size of 2 words.
90
- - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
91
- then split the chunk at that word (include words up to that word).
92
  - Otherwise, use the current chunk size.
93
- - For subsequent chunks, increase the chunk size by 2.
94
- - If there are fewer than the desired number of words for a full chunk, add all remaining words.
95
  """
96
  words = text.split()
97
  chunks = []
@@ -102,7 +92,6 @@ def custom_split_text(text: str) -> list:
102
  if candidate_end > len(words):
103
  candidate_end = len(words)
104
  chunk_words = words[start:candidate_end]
105
- # Look for a period in any word except the last one.
106
  split_index = None
107
  for i in range(len(chunk_words) - 1):
108
  if '.' in chunk_words[i]:
@@ -113,26 +102,24 @@ def custom_split_text(text: str) -> list:
113
  chunk_words = words[start:candidate_end]
114
  chunks.append(" ".join(chunk_words))
115
  start = candidate_end
116
- chunk_size += 2 # Increase the chunk size by 2 for the next iteration.
117
  return chunks
118
 
119
 
120
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
121
  """
122
- Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
123
  """
124
- # Ensure tensor is on CPU and flatten if necessary.
125
  audio_np = audio_tensor.cpu().numpy()
126
  if audio_np.ndim > 1:
127
  audio_np = audio_np.flatten()
128
- # Scale to int16 range.
129
  audio_int16 = np.int16(audio_np * 32767)
130
  return audio_int16.tobytes()
131
 
132
 
133
  def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
134
  """
135
- Convert a torch.FloatTensor to Opus encoded bytes.
136
  Requires the 'opuslib' package: pip install opuslib
137
  """
138
  try:
@@ -143,154 +130,175 @@ def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24
143
  audio_np = audio_tensor.cpu().numpy()
144
  if audio_np.ndim > 1:
145
  audio_np = audio_np.flatten()
146
- # Scale to int16 range. Important for opus.
147
  audio_int16 = np.int16(audio_np * 32767)
148
 
149
- encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono.
150
-
151
- # Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long.
152
- frame_size = int(sample_rate * 0.020) # 20ms frame size
153
-
154
  encoded_data = b''
155
  for i in range(0, len(audio_int16), frame_size):
156
  frame = audio_int16[i:i + frame_size]
157
  if len(frame) < frame_size:
158
- # Pad the last frame with zeros if needed.
159
  frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
160
- encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame.
161
  encoded_data += encoded_frame
162
-
163
  return encoded_data
164
 
165
- g2p = en.G2P(trf=False, british=False, fallback=None) # no transformer, American English
166
 
167
- def tokenizer(text):
 
 
 
 
 
 
168
  print("Text: " + text)
169
  phonemes_string, _ = g2p(text)
170
- phonemes = []
171
- for i in phonemes_string:
172
- phonemes.append(i)
173
  tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab]
174
- print(tokens)
175
  return tokens
176
-
177
-
178
 
179
 
180
  # ------------------------------------------------------------------------------
181
  # Endpoints
182
  # ------------------------------------------------------------------------------
183
 
184
- # @app.get("/tts/streaming", summary="Streaming TTS")
185
- # def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
186
- # """
187
- # Streaming TTS endpoint that returns a continuous audio stream.
188
- # Supports WAV (PCM) and Opus formats. Opus offers significantly better compression.
189
-
190
- # The endpoint first yields a WAV header (with a dummy length) for WAV,
191
- # then yields encoded audio data for each text chunk as soon as it is generated.
192
- # """
193
- # # Split the input text using the custom doubling strategy.
194
- # chunks = custom_split_text(text)
195
- # sample_rate = 24000
196
- # num_channels = 1
197
- # sample_width = 2 # 16-bit PCM
198
-
199
-
200
- # def audio_generator():
201
- # if format.lower() == "wav":
202
- # # Yield the WAV header first.
203
- # header = generate_wav_header(sample_rate, num_channels, sample_width)
204
- # yield header
205
- # # Process and yield each chunk's audio data.
206
- # for i, chunk in enumerate(chunks):
207
- # print(f"Processing chunk {i}: {chunk}") # Debugging
208
- # try:
209
- # results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
210
- # for result in results:
211
- # if result.audio is not None:
212
- # if format.lower() == "wav":
213
- # yield audio_tensor_to_pcm_bytes(result.audio)
214
- # elif format.lower() == "opus":
215
- # yield audio_tensor_to_opus_bytes(result.audio, sample_rate=sample_rate)
216
- # else:
217
- # raise ValueError(f"Unsupported audio format: {format}")
218
- # else:
219
- # print(f"Chunk {i}: No audio generated")
220
- # except Exception as e:
221
- # print(f"Error processing chunk {i}: {e}")
222
- # yield b'' # important so that streaming continues. Consider returning an error sound.
223
-
224
- # media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
225
-
226
- # return StreamingResponse(
227
- # audio_generator(),
228
- # media_type=media_type,
229
- # headers={"Cache-Control": "no-cache"},
230
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  @app.get("/tts/full", summary="Full TTS")
234
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
235
  """
236
- Full TTS endpoint that synthesizes the entire text, concatenates the audio,
237
- and returns a complete WAV or Opus file.
238
  """
239
  voice_path = os.path.join(local_dir, f"voices/{voice}.bin")
240
  voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)
241
 
242
  tokens = tokenizer(text)
243
-
244
  ref_s = voices[len(tokens)]
245
-
246
  final_token = [[0, *tokens, 0]]
247
 
248
  start_time = time.time()
249
-
250
- audio = sess.run(None, dict(
251
- input_ids=final_token,
252
- style=ref_s,
253
- speed=np.ones(1, dtype=np.float32),
254
- ))[0]
255
-
256
- print(time.time()-start_time)
257
 
258
- # Write the concatenated audio to an in-memory WAV or Opus file.
259
- sample_rate = 24000
260
-
261
- # audio = np.array(audio, dtype=np.float32) # Ensure it's float32 first
262
- audio = (audio * 32767).astype(np.int16) # Scale to int16 range
263
-
264
- # Flatten the array if it's 2D
265
- audio = audio.flatten()
266
 
267
  if format.lower() == "wav":
268
-
269
- # Create an in-memory buffer
270
  wav_io = io.BytesIO()
271
-
272
- # Write the audio data to the buffer in WAV format
273
- write_wav(wav_io, sample_rate, audio)
274
-
275
- # Seek to the beginning of the buffer
276
  wav_io.seek(0)
277
-
278
  return Response(content=wav_io.read(), media_type="audio/wav")
279
  elif format.lower() == "opus":
280
- opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio), sample_rate=sample_rate)
281
  return Response(content=opus_data, media_type="audio/opus")
282
  else:
283
  raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
284
 
285
 
286
-
287
  @app.get("/", response_class=HTMLResponse)
288
  def index():
289
  """
290
  HTML demo page for Kokoro TTS.
291
-
292
- This page provides a simple UI to enter text, choose a voice and speed,
293
- and play synthesized audio from both the streaming and full endpoints.
294
  """
295
  return """
296
  <!DOCTYPE html>
@@ -321,7 +329,6 @@ def index():
321
  const speed = document.getElementById('speed').value;
322
  const format = document.getElementById('format').value;
323
  const audio = document.getElementById('audio');
324
- // Set the audio element's source to the streaming endpoint.
325
  audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
326
  audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
327
  audio.play();
@@ -332,7 +339,6 @@ def index():
332
  const speed = document.getElementById('speed').value;
333
  const format = document.getElementById('format').value;
334
  const audio = document.getElementById('audio');
335
- // Set the audio element's source to the full TTS endpoint.
336
  audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
337
  audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
338
  audio.play();
@@ -344,9 +350,8 @@ def index():
344
 
345
 
346
  # ------------------------------------------------------------------------------
347
- # Run with: uvicorn app:app --reload
348
  # ------------------------------------------------------------------------------
349
  if __name__ == "__main__":
350
  import uvicorn
351
-
352
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
2
  import re
3
  import wave
4
  import struct
5
+ import os
6
+ import time
7
+ import json
8
 
9
  import numpy as np
10
  import torch
 
15
 
16
  from misaki import en
17
 
 
 
18
  from onnxruntime import InferenceSession
19
  from huggingface_hub import snapshot_download
 
 
20
  from scipy.io.wavfile import write as write_wav
21
 
22
+ # ------------------------------------------------------------------------------
23
+ # Load configuration and set up vocabulary
24
+ # ------------------------------------------------------------------------------
25
+ config_file_path = 'config.json' # Update with your actual path
 
26
  with open(config_file_path, 'r') as f:
27
  config = json.load(f)
 
 
28
  phoneme_vocab = config['vocab']
29
 
30
+ # ------------------------------------------------------------------------------
31
+ # Download the model and voice files from Hugging Face Hub
32
+ # ------------------------------------------------------------------------------
33
  model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX"
34
  model_name = "onnx/model_q8f16.onnx"
35
  voice_file_pattern = "*.bin"
36
  local_dir = "."
 
 
37
  snapshot_download(
38
  repo_id=model_repo,
39
  allow_patterns=[model_name, voice_file_pattern],
40
  local_dir=local_dir
41
  )
42
 
43
+ # ------------------------------------------------------------------------------
44
+ # Load the ONNX model
45
+ # ------------------------------------------------------------------------------
46
  model_path = os.path.join(local_dir, model_name)
47
  sess = InferenceSession(model_path)
48
 
49
+ # ------------------------------------------------------------------------------
50
+ # Create the FastAPI app with GZip compression
51
+ # ------------------------------------------------------------------------------
52
  app = FastAPI(
53
  title="Kokoro TTS FastAPI",
54
+ middleware=[Middleware(GZipMiddleware, compresslevel=9)]
 
 
55
  )
56
 
 
 
 
 
 
 
 
57
  # ------------------------------------------------------------------------------
58
  # Helper Functions
59
  # ------------------------------------------------------------------------------
60
 
61
  def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
62
  """
63
+ Generate a WAV header for streaming. Since we do not know the final audio size,
64
+ a large dummy value is used for the data chunk size.
 
65
  """
66
  bits_per_sample = sample_width * 8
67
  byte_rate = sample_rate * num_channels * sample_width
68
  block_align = num_channels * sample_width
69
+ total_size = 36 + data_size # 36 + data_size (header is 44 bytes total)
 
70
  header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
71
  fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
72
  data_chunk_header = struct.pack('<4sI', b'data', data_size)
 
75
 
76
  def custom_split_text(text: str) -> list:
77
  """
78
+ Custom splitting strategy:
79
  - Start with a chunk size of 2 words.
80
+ - For each chunk, if a period (".") is found in any word (except the very last word),
81
+ then split at that word (including it).
82
  - Otherwise, use the current chunk size.
83
+ - Increase the chunk size by 2 for each subsequent chunk.
84
+ - If there are fewer than the desired number of words remaining, include all of them.
85
  """
86
  words = text.split()
87
  chunks = []
 
92
  if candidate_end > len(words):
93
  candidate_end = len(words)
94
  chunk_words = words[start:candidate_end]
 
95
  split_index = None
96
  for i in range(len(chunk_words) - 1):
97
  if '.' in chunk_words[i]:
 
102
  chunk_words = words[start:candidate_end]
103
  chunks.append(" ".join(chunk_words))
104
  start = candidate_end
105
+ chunk_size += 2
106
  return chunks
107
 
108
 
109
  def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
110
  """
111
+ Convert a torch.FloatTensor (values in [-1, 1]) to raw 16-bit PCM bytes.
112
  """
 
113
  audio_np = audio_tensor.cpu().numpy()
114
  if audio_np.ndim > 1:
115
  audio_np = audio_np.flatten()
 
116
  audio_int16 = np.int16(audio_np * 32767)
117
  return audio_int16.tobytes()
118
 
119
 
120
  def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
121
  """
122
+ Convert a torch.FloatTensor to Opus-encoded bytes.
123
  Requires the 'opuslib' package: pip install opuslib
124
  """
125
  try:
 
130
  audio_np = audio_tensor.cpu().numpy()
131
  if audio_np.ndim > 1:
132
  audio_np = audio_np.flatten()
 
133
  audio_int16 = np.int16(audio_np * 32767)
134
 
135
+ encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP)
136
+ frame_size = int(sample_rate * 0.020) # 20 ms frame
 
 
 
137
  encoded_data = b''
138
  for i in range(0, len(audio_int16), frame_size):
139
  frame = audio_int16[i:i + frame_size]
140
  if len(frame) < frame_size:
 
141
  frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
142
+ encoded_frame = encoder.encode(frame.tobytes(), frame_size)
143
  encoded_data += encoded_frame
 
144
  return encoded_data
145
 
 
146
 
147
+ # Initialize G2P for English (American)
148
+ g2p = en.G2P(trf=False, british=False, fallback=None)
149
+
150
+ def tokenizer(text: str):
151
+ """
152
+ Converts text to a list of phoneme tokens using the global vocabulary.
153
+ """
154
  print("Text: " + text)
155
  phonemes_string, _ = g2p(text)
156
+ phonemes = [ph for ph in phonemes_string]
 
 
157
  tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab]
158
+ print("Tokens:", tokens)
159
  return tokens
 
 
160
 
161
 
162
  # ------------------------------------------------------------------------------
163
  # Endpoints
164
  # ------------------------------------------------------------------------------
165
 
166
+ @app.get("/tts/streaming", summary="Streaming TTS")
167
+ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"):
168
+ """
169
+ Streaming TTS endpoint.
170
+
171
+ This endpoint splits the input text into chunks (using the doubling strategy),
172
+ then for each chunk:
173
+ - For the first chunk, a 0 is prepended.
174
+ - For subsequent chunks, the first token is set to the last token from the previous chunk.
175
+ - For the final chunk, a 0 is appended.
176
+
177
+ The audio for each chunk is generated immediately and streamed to the client.
178
+ """
179
+ chunks = custom_split_text(text)
180
+ sample_rate = 24000
181
+ num_channels = 1
182
+ sample_width = 2
183
+
184
+ # Load the voice/style file (must be present in voices/{voice}.bin)
185
+ voice_path = os.path.join(local_dir, f"voices/{voice}.bin")
186
+ if not os.path.exists(voice_path):
187
+ raise HTTPException(status_code=404, detail="Voice file not found")
188
+ voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)
189
+
190
+ def audio_generator():
191
+ # If outputting a WAV stream, yield a WAV header once.
192
+ if format.lower() == "wav":
193
+ header = generate_wav_header(sample_rate, num_channels, sample_width)
194
+ yield header
195
+
196
+ prev_last_token = None
197
+ for i, chunk in enumerate(chunks):
198
+ print(f"Processing chunk {i}: {chunk}")
199
+ # Convert the chunk text to tokens.
200
+ chunk_tokens = tokenizer(chunk)
201
+
202
+ # For the first chunk, prepend 0; for later chunks, start with the previous chunk's last token.
203
+ if i == 0:
204
+ tokens_to_send = [0] + chunk_tokens
205
+ else:
206
+ tokens_to_send = [prev_last_token] + chunk_tokens
207
+
208
+ # If this is the final chunk, append 0.
209
+ if i == len(chunks) - 1:
210
+ tokens_to_send = tokens_to_send + [0]
211
+
212
+ # Save the last token of this chunk for the next iteration.
213
+ prev_last_token = tokens_to_send[-1]
214
+
215
+ # Prepare the model input (a batch of one sequence).
216
+ final_token = [tokens_to_send]
217
+
218
+ # Use the number of tokens to select the appropriate style vector.
219
+ style_index = len(tokens_to_send)
220
+ if style_index >= len(voices):
221
+ style_index = len(voices) - 1 # Fallback if index is out-of-bounds.
222
+ ref_s = voices[style_index]
223
+
224
+ # Prepare the speed parameter.
225
+ speed_param = np.ones(1, dtype=np.float32) * speed
226
+
227
+ # Run the model (ONNX inference) for this chunk.
228
+ try:
229
+ start_time = time.time()
230
+ audio_output = sess.run(None, {
231
+ "input_ids": final_token,
232
+ "style": ref_s,
233
+ "speed": speed_param,
234
+ })[0]
235
+ print(f"Chunk {i} inference time: {time.time() - start_time:.3f}s")
236
+ except Exception as e:
237
+ print(f"Error processing chunk {i}: {e}")
238
+ # In case of error, generate a short silent chunk.
239
+ audio_output = np.zeros((sample_rate,), dtype=np.float32)
240
+
241
+ # Convert the model output (assumed to be float32 in [-1, 1]) to int16 PCM.
242
+ audio_int16 = (audio_output * 32767).astype(np.int16).flatten()
243
+
244
+ # Convert to a torch tensor (back into float range) for our helper functions.
245
+ audio_tensor = torch.from_numpy(audio_int16.astype(np.float32) / 32767)
246
+
247
+ # Yield the encoded audio chunk.
248
+ if format.lower() == "wav":
249
+ yield audio_tensor_to_pcm_bytes(audio_tensor)
250
+ elif format.lower() == "opus":
251
+ yield audio_tensor_to_opus_bytes(audio_tensor, sample_rate=sample_rate)
252
+ else:
253
+ raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
254
+
255
+ media_type = "audio/wav" if format.lower() == "wav" else "audio/opus"
256
+ return StreamingResponse(
257
+ audio_generator(),
258
+ media_type=media_type,
259
+ headers={"Cache-Control": "no-cache"},
260
+ )
261
 
262
 
263
  @app.get("/tts/full", summary="Full TTS")
264
  def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
265
  """
266
+ Full TTS endpoint that synthesizes the entire text and returns a complete WAV or Opus file.
 
267
  """
268
  voice_path = os.path.join(local_dir, f"voices/{voice}.bin")
269
  voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)
270
 
271
  tokens = tokenizer(text)
 
272
  ref_s = voices[len(tokens)]
 
273
  final_token = [[0, *tokens, 0]]
274
 
275
  start_time = time.time()
276
+ audio = sess.run(None, {
277
+ "input_ids": final_token,
278
+ "style": ref_s,
279
+ "speed": np.ones(1, dtype=np.float32) * speed,
280
+ })[0]
281
+ print(f"Full TTS inference time: {time.time()-start_time:.3f}s")
 
 
282
 
283
+ # Convert to int16 PCM.
284
+ audio = (audio * 32767).astype(np.int16).flatten()
 
 
 
 
 
 
285
 
286
  if format.lower() == "wav":
 
 
287
  wav_io = io.BytesIO()
288
+ write_wav(wav_io, 24000, audio)
 
 
 
 
289
  wav_io.seek(0)
 
290
  return Response(content=wav_io.read(), media_type="audio/wav")
291
  elif format.lower() == "opus":
292
+ opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio.astype(np.float32)/32767), sample_rate=24000)
293
  return Response(content=opus_data, media_type="audio/opus")
294
  else:
295
  raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
296
 
297
 
 
298
  @app.get("/", response_class=HTMLResponse)
299
  def index():
300
  """
301
  HTML demo page for Kokoro TTS.
 
 
 
302
  """
303
  return """
304
  <!DOCTYPE html>
 
329
  const speed = document.getElementById('speed').value;
330
  const format = document.getElementById('format').value;
331
  const audio = document.getElementById('audio');
 
332
  audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
333
  audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
334
  audio.play();
 
339
  const speed = document.getElementById('speed').value;
340
  const format = document.getElementById('format').value;
341
  const audio = document.getElementById('audio');
 
342
  audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
343
  audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
344
  audio.play();
 
350
 
351
 
352
  # ------------------------------------------------------------------------------
353
+ # Run the app with: uvicorn app:app --reload
354
  # ------------------------------------------------------------------------------
355
  if __name__ == "__main__":
356
  import uvicorn
357
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)