bcci commited on
Commit
009fc5c
·
verified ·
1 Parent(s): ada1283

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -72
app.py CHANGED
@@ -15,6 +15,7 @@ app = FastAPI(title="Kokoro TTS FastAPI")
15
  # ------------------------------------------------------------------------------
16
  # Global Pipeline Instance
17
  # ------------------------------------------------------------------------------
 
18
  pipeline = KPipeline(lang_code="a")
19
 
20
 
@@ -26,11 +27,13 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
26
  """
27
  Generate a WAV header for streaming.
28
  Since we don't know the final audio size, we set the data chunk size to a large dummy value.
 
29
  """
30
  bits_per_sample = sample_width * 8
31
  byte_rate = sample_rate * num_channels * sample_width
32
  block_align = num_channels * sample_width
33
- total_size = 36 + data_size # header (44 bytes) minus 8 + dummy data size
 
34
  header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
35
  fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
36
  data_chunk_header = struct.pack('<4sI', b'data', data_size)
@@ -39,35 +42,18 @@ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int,
39
 
40
  def custom_split_text(text: str) -> list:
41
  """
42
- Custom splitting:
43
- - Start with a chunk size of 2 words.
44
- - For each chunk, if a period (".") is found in any word (except if it’s the very last word),
45
- then split the chunk at that word (include words up to that word).
46
- - Otherwise, use the current chunk size.
47
- - For subsequent chunks, increase the chunk size by 2.
48
- - If there are fewer than the desired number of words for a full chunk, add all remaining words.
49
  """
50
  words = text.split()
51
  chunks = []
52
- chunk_size = 2
53
  start = 0
54
  while start < len(words):
55
- candidate_end = start + chunk_size
56
- if candidate_end > len(words):
57
- candidate_end = len(words)
58
- chunk_words = words[start:candidate_end]
59
- # Look for a period in any word except the last one.
60
- split_index = None
61
- for i in range(len(chunk_words) - 1):
62
- if '.' in chunk_words[i]:
63
- split_index = i
64
- break
65
- if split_index is not None:
66
- candidate_end = start + split_index + 1
67
- chunk_words = words[start:candidate_end]
68
- chunks.append(" ".join(chunk_words))
69
- start = candidate_end
70
- chunk_size += 2 # Increase the chunk size by 2 for the next iteration.
71
  return chunks
72
 
73
 
@@ -75,9 +61,11 @@ def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
75
  """
76
  Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
77
  """
 
78
  audio_np = audio_tensor.cpu().numpy()
79
  if audio_np.ndim > 1:
80
  audio_np = audio_np.flatten()
 
81
  audio_int16 = np.int16(audio_np * 32767)
82
  return audio_int16.tobytes()
83
 
@@ -91,9 +79,10 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
91
  """
92
  Streaming TTS endpoint that returns a continuous WAV stream.
93
 
94
- This endpoint first yields a WAV header (with a dummy data length) and then yields raw PCM data
95
  for each text chunk as soon as it is generated.
96
  """
 
97
  chunks = custom_split_text(text)
98
  sample_rate = 24000
99
  num_channels = 1
@@ -105,13 +94,15 @@ def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0):
105
  yield header
106
  # Process and yield each chunk's PCM data.
107
  for i, chunk in enumerate(chunks):
108
- print(f"Processing chunk {i}: {chunk}")
109
  try:
110
  results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
111
  for result in results:
112
  if result.audio is not None:
113
- print(f"Chunk {i}: Audio generated")
114
- yield audio_tensor_to_pcm_bytes(result.audio)
 
 
115
  else:
116
  print(f"Chunk {i}: No audio generated")
117
  except Exception as e:
@@ -130,6 +121,7 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
130
  Full TTS endpoint that synthesizes the entire text, concatenates the audio,
131
  and returns a complete WAV file.
132
  """
 
133
  results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
134
  audio_segments = []
135
  for result in results:
@@ -142,8 +134,10 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
142
  if not audio_segments:
143
  raise HTTPException(status_code=500, detail="No audio generated.")
144
 
 
145
  full_audio = np.concatenate(audio_segments)
146
 
 
147
  sample_rate = 24000
148
  num_channels = 1
149
  sample_width = 2 # 16-bit PCM -> 2 bytes per sample
@@ -162,51 +156,49 @@ def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0):
162
  def index():
163
  """
164
  HTML demo page for Kokoro TTS.
165
-
166
- Two playback methods are provided:
167
- - "Play Streaming TTS" sets the <audio> element's src to the streaming endpoint.
168
- - "Play Full TTS" sets the <audio> element's src to the full synthesis endpoint.
169
- The browser’s native playback handles streaming (progressive download) of the WAV data.
170
  """
171
  return """
172
- <!DOCTYPE html>
173
- <html>
174
- <head>
175
- <title>Kokoro TTS Demo</title>
176
- </head>
177
- <body>
178
- <h1>Kokoro TTS Demo</h1>
179
- <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
180
- <label for="voice">Voice:</label>
181
- <input type="text" id="voice" value="af_heart"><br>
182
- <label for="speed">Speed:</label>
183
- <input type="number" step="0.1" id="speed" value="1.0"><br><br>
184
- <button onclick="playStreaming()">Play Streaming TTS</button>
185
- <button onclick="playFull()">Play Full TTS</button>
186
- <br><br>
187
- <audio id="audioPlayer" controls autoplay></audio>
188
- <script>
189
- function playStreaming() {
190
- const text = document.getElementById('text').value;
191
- const voice = document.getElementById('voice').value;
192
- const speed = document.getElementById('speed').value;
193
- const audio = document.getElementById('audioPlayer');
194
- // Simply point the audio element to the streaming endpoint.
195
- audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
196
- audio.play();
197
- }
198
- function playFull() {
199
- const text = document.getElementById('text').value;
200
- const voice = document.getElementById('voice').value;
201
- const speed = document.getElementById('speed').value;
202
- const audio = document.getElementById('audioPlayer');
203
- // Simply point the audio element to the full synthesis endpoint.
204
- audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
205
- audio.play();
206
- }
207
- </script>
208
- </body>
209
- </html>
210
  """
211
 
212
 
@@ -215,4 +207,5 @@ def index():
215
  # ------------------------------------------------------------------------------
216
  if __name__ == "__main__":
217
  import uvicorn
 
218
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
15
  # ------------------------------------------------------------------------------
16
  # Global Pipeline Instance
17
  # ------------------------------------------------------------------------------
18
+ # Create one pipeline instance for the entire app.
19
  pipeline = KPipeline(lang_code="a")
20
 
21
 
 
27
  """
28
  Generate a WAV header for streaming.
29
  Since we don't know the final audio size, we set the data chunk size to a large dummy value.
30
+ This header is sent only once at the start of the stream.
31
  """
32
  bits_per_sample = sample_width * 8
33
  byte_rate = sample_rate * num_channels * sample_width
34
  block_align = num_channels * sample_width
35
+ # total file size = 36 + data_size (header is 44 bytes total)
36
+ total_size = 36 + data_size
37
  header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
38
  fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
39
  data_chunk_header = struct.pack('<4sI', b'data', data_size)
 
42
 
43
  def custom_split_text(text: str) -> list:
44
  """
45
+ Custom splitting: split text into chunks where each chunk doubles in size.
 
 
 
 
 
 
46
  """
47
  words = text.split()
48
  chunks = []
49
+ chunk_size = 1
50
  start = 0
51
  while start < len(words):
52
+ end = start + chunk_size
53
+ chunk = " ".join(words[start:end])
54
+ chunks.append(chunk)
55
+ start = end
56
+ chunk_size *= 2 # double the chunk size for the next iteration
 
 
 
 
 
 
 
 
 
 
 
57
  return chunks
58
 
59
 
 
61
  """
62
  Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
63
  """
64
+ # Ensure tensor is on CPU and flatten if necessary.
65
  audio_np = audio_tensor.cpu().numpy()
66
  if audio_np.ndim > 1:
67
  audio_np = audio_np.flatten()
68
+ # Scale to int16 range.
69
  audio_int16 = np.int16(audio_np * 32767)
70
  return audio_int16.tobytes()
71
 
 
79
  """
80
  Streaming TTS endpoint that returns a continuous WAV stream.
81
 
82
+ The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data
83
  for each text chunk as soon as it is generated.
84
  """
85
+ # Split the input text using the custom doubling strategy.
86
  chunks = custom_split_text(text)
87
  sample_rate = 24000
88
  num_channels = 1
 
94
  yield header
95
  # Process and yield each chunk's PCM data.
96
  for i, chunk in enumerate(chunks):
97
+ print(f"Processing chunk {i}: {chunk}") # Debugging
98
  try:
99
  results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None))
100
  for result in results:
101
  if result.audio is not None:
102
+ print(f"Chunk {i}: Audio generated") # Debugging
103
+ pcm_bytes = audio_tensor_to_pcm_bytes(result.audio)
104
+ for i in range(0, len(pcm_bytes), 100):
105
+ yield pcm_bytes[i:i + chunk_size]
106
  else:
107
  print(f"Chunk {i}: No audio generated")
108
  except Exception as e:
 
121
  Full TTS endpoint that synthesizes the entire text, concatenates the audio,
122
  and returns a complete WAV file.
123
  """
124
+ # Use newline-based splitting via the pipeline's split_pattern.
125
  results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+"))
126
  audio_segments = []
127
  for result in results:
 
134
  if not audio_segments:
135
  raise HTTPException(status_code=500, detail="No audio generated.")
136
 
137
+ # Concatenate all audio segments.
138
  full_audio = np.concatenate(audio_segments)
139
 
140
+ # Write the concatenated audio to an in-memory WAV file.
141
  sample_rate = 24000
142
  num_channels = 1
143
  sample_width = 2 # 16-bit PCM -> 2 bytes per sample
 
156
  def index():
157
  """
158
  HTML demo page for Kokoro TTS.
159
+
160
+ This page provides a simple UI to enter text, choose a voice and speed,
161
+ and play synthesized audio from both the streaming and full endpoints.
 
 
162
  """
163
  return """
164
+ <!DOCTYPE html>
165
+ <html>
166
+ <head>
167
+ <title>Kokoro TTS Demo</title>
168
+ </head>
169
+ <body>
170
+ <h1>Kokoro TTS Demo</h1>
171
+ <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
172
+ <label for="voice">Voice:</label>
173
+ <input type="text" id="voice" value="af_heart"><br>
174
+ <label for="speed">Speed:</label>
175
+ <input type="number" step="0.1" id="speed" value="1.0"><br><br>
176
+ <button onclick="playStreaming()">Play Streaming TTS</button>
177
+ <button onclick="playFull()">Play Full TTS</button>
178
+ <br><br>
179
+ <audio id="audio" controls autoplay></audio>
180
+ <script>
181
+ function playStreaming() {
182
+ const text = document.getElementById('text').value;
183
+ const voice = document.getElementById('voice').value;
184
+ const speed = document.getElementById('speed').value;
185
+ const audio = document.getElementById('audio');
186
+ // Set the audio element's source to the streaming endpoint.
187
+ audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
188
+ audio.play();
189
+ }
190
+ function playFull() {
191
+ const text = document.getElementById('text').value;
192
+ const voice = document.getElementById('voice').value;
193
+ const speed = document.getElementById('speed').value;
194
+ const audio = document.getElementById('audio');
195
+ // Set the audio element's source to the full TTS endpoint.
196
+ audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`;
197
+ audio.play();
198
+ }
199
+ </script>
200
+ </body>
201
+ </html>
202
  """
203
 
204
 
 
207
  # ------------------------------------------------------------------------------
208
  if __name__ == "__main__":
209
  import uvicorn
210
+
211
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)