File size: 13,797 Bytes
7fd0353
 
 
 
25bd1c6
 
 
7fd0353
 
 
 
 
e7655ad
 
7fd0353
aacdd5b
f711016
 
 
c6817ce
f711016
25bd1c6
 
 
 
f711016
 
 
 
25bd1c6
 
 
df749c9
a9b6d73
2b9bb53
f711016
 
 
2b9bb53
0efe118
f711016
 
25bd1c6
 
 
f711016
 
7fd0353
25bd1c6
 
 
e7655ad
 
25bd1c6
e7655ad
7fd0353
 
 
 
 
 
 
25bd1c6
 
7fd0353
 
 
 
25bd1c6
7fd0353
 
 
 
 
e76bbe1
7fd0353
 
 
25bd1c6
65e1914
25bd1c6
 
65e1914
25bd1c6
 
7fd0353
 
 
65e1914
7fd0353
 
65e1914
 
 
 
 
4ed6221
 
 
 
 
 
 
65e1914
 
4ed6221
 
7fd0353
 
 
 
 
25bd1c6
7fd0353
 
 
 
 
eb95b12
7fd0353
 
e7655ad
 
25bd1c6
e7655ad
 
 
 
 
 
 
 
 
 
 
 
25bd1c6
 
e7655ad
 
 
 
 
25bd1c6
e7655ad
 
 
a807e9e
 
25bd1c6
 
 
 
 
be87c08
25bd1c6
04e4550
f711016
 
 
e7655ad
7fd0353
 
 
 
25bd1c6
e76bbe1
25bd1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e76bbe1
25bd1c6
 
 
 
 
 
 
565857a
 
 
 
a10f390
25bd1c6
 
d5853ad
25bd1c6
 
565857a
 
3e265d0
5c5808a
25bd1c6
 
562211b
25bd1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b77557
25bd1c6
 
435cc84
565857a
25bd1c6
 
9a74dea
25bd1c6
 
01f0881
25bd1c6
01f0881
25bd1c6
 
 
 
 
7fd0353
 
 
e7655ad
7fd0353
25bd1c6
7fd0353
f711016
 
 
 
1bed31e
f5d0fca
 
05eca7a
25bd1c6
 
 
 
 
 
05eca7a
25bd1c6
 
f1429ea
e7655ad
 
25bd1c6
e7655ad
 
 
25bd1c6
e7655ad
 
 
 
7fd0353
 
 
 
 
 
ada1283
009fc5c
 
 
 
 
 
 
 
 
 
 
e7655ad
 
 
 
 
 
009fc5c
 
 
 
 
 
 
 
 
e7655ad
009fc5c
e7655ad
 
009fc5c
 
 
 
 
 
e7655ad
009fc5c
e7655ad
 
009fc5c
 
 
 
 
7fd0353
 
 
 
25bd1c6
7fd0353
 
 
25bd1c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import io
import re
import wave
import struct
import os
import time
import json

import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, Response, HTMLResponse
from fastapi.middleware import Middleware
from fastapi.middleware.gzip import GZipMiddleware

from misaki import en, espeak

from onnxruntime import InferenceSession
from huggingface_hub import snapshot_download
from scipy.io.wavfile import write as write_wav

# ------------------------------------------------------------------------------
# Load configuration and set up vocabulary
# ------------------------------------------------------------------------------
config_file_path = 'config.json'  # Update with your actual path
with open(config_file_path, 'r') as f:
    config = json.load(f)
phoneme_vocab = config['vocab']

# ------------------------------------------------------------------------------
# Download the model and voice files from Hugging Face Hub
# ------------------------------------------------------------------------------
model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX"
model_name = "onnx/model_q4.onnx" # "onnx/model.onnx"
voice_file_pattern = "*.bin"
local_dir = "."
snapshot_download(
    repo_id=model_repo,
    allow_patterns=[model_name, voice_file_pattern],
    local_dir=local_dir
)

# ------------------------------------------------------------------------------
# Load the ONNX model
# ------------------------------------------------------------------------------
model_path = os.path.join(local_dir, model_name)
sess = InferenceSession(model_path)

# ------------------------------------------------------------------------------
# Create the FastAPI app with GZip compression
# ------------------------------------------------------------------------------
app = FastAPI(
    title="Kokoro TTS FastAPI",
    middleware=[Middleware(GZipMiddleware, compresslevel=9)]
)

# ------------------------------------------------------------------------------
# Helper Functions
# ------------------------------------------------------------------------------

def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
    """
    Generate a WAV header for streaming. Since we do not know the final audio size,
    a large dummy value is used for the data chunk size.
    """
    bits_per_sample = sample_width * 8
    byte_rate = sample_rate * num_channels * sample_width
    block_align = num_channels * sample_width
    total_size = 36 + data_size  # 36 + data_size (header is 44 bytes total)
    header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
    fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
    data_chunk_header = struct.pack('<4sI', b'data', data_size)
    return header + fmt_chunk + data_chunk_header

stream_header = generate_wav_header(24000, 1, 2)

def custom_split_text(text: str) -> list:
    """
    Custom splitting strategy:
      - Start with a chunk size of 2 words.
      - For each chunk, if a period (".") is found in any word (except the very last word),
        then split at that word (including it).
      - Otherwise, use the current chunk size.
      - Increase the chunk size by 2 for each subsequent chunk.
      - If there are fewer than the desired number of words remaining, include all of them.
    """
    words = text.split()
    chunks = []
    chunk_size = 2
    start = 0
    while start < len(words):
        candidate_end = start + chunk_size
        if candidate_end > len(words):
            candidate_end = len(words)
        chunk_words = words[start:candidate_end]
        split_index = None
        # for i in range(len(chunk_words) - 1):
        #     if '.' in chunk_words[i]:
        #         split_index = i
        #         break
        # if split_index is not None:
        #     candidate_end = start + split_index + 1
        #     chunk_words = words[start:candidate_end]
        chunks.append(" ".join(chunk_words))
        start = candidate_end
        if chunk_size < 100:
            chunk_size += 2
    return chunks


def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
    """
    Convert a torch.FloatTensor (values in [-1, 1]) to raw 16-bit PCM bytes.
    """
    audio_np = audio_tensor.cpu().numpy()
    if audio_np.ndim > 1:
        audio_np = audio_np.flatten()
    audio_int16 = np.int16(audio_np * 32767)
    return audio_int16.tobytes()


def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
    """
    Convert a torch.FloatTensor to Opus-encoded bytes.
    Requires the 'opuslib' package: pip install opuslib
    """
    try:
        import opuslib
    except ImportError:
        raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")

    audio_np = audio_tensor.cpu().numpy()
    if audio_np.ndim > 1:
        audio_np = audio_np.flatten()
    audio_int16 = np.int16(audio_np * 32767)

    encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP)
    frame_size = int(sample_rate * 0.020)  # 20 ms frame
    encoded_data = b''
    for i in range(0, len(audio_int16), frame_size):
        frame = audio_int16[i:i + frame_size]
        if len(frame) < frame_size:
            frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
        encoded_frame = encoder.encode(frame.tobytes(), frame_size)
        encoded_data += encoded_frame
    return encoded_data

fbs = espeak.EspeakFallback(british=True)
g2p = en.G2P(trf=False, british=False, fallback=fbs)

def tokenizer(text: str):
    """
    Converts text to a list of phoneme tokens using the global vocabulary.
    """
    phonemes_string, tokens = g2p(text)
    phonemes = [ph for ph in phonemes_string]
    print(text + "     " + phonemes_string)
    tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab]
    return tokens


# ------------------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------------------

@app.get("/tts/streaming", summary="Streaming TTS")
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
    """
    Streaming TTS endpoint.
    
    This endpoint splits the input text into chunks (using the doubling strategy),
    then for each chunk:
      - For the first chunk, a 0 is prepended.
      - For subsequent chunks, the first token is set to the last token from the previous chunk.
      - For the final chunk, a 0 is appended.
    
    The audio for each chunk is generated immediately and streamed to the client.
    """
    chunks = custom_split_text(text)

    # Load the voice/style file (must be present in voices/{voice}.bin)
    voice_path = os.path.join(local_dir, f"voices/{voice}.bin")
    if not os.path.exists(voice_path):
        raise HTTPException(status_code=404, detail="Voice file not found")
    voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)

    def audio_generator():
        # If outputting a WAV stream, yield a WAV header once.
        if format.lower() == "wav":
            yield stream_header

        prev_last_token = None
        for i, chunk in enumerate(chunks):
            # Convert the chunk text to tokens.
            chunk_tokens = tokenizer(chunk)

            # For the first chunk, prepend 0; for later chunks, start with the previous chunk's last token.
            # if i == 0:
            #     tokens_to_send = [0] + chunk_tokens + [0]
            # else:
                # tokens_to_send =  [0] + chunk_tokens + [0]
                # token_to_send = [0] + chunk_tokens

            # Save the last token of this chunk for the next iteration.
            prev_last_token = chunk_tokens[-1:]

            # Prepare the model input (a batch of one sequence).
            tokens_to_send =  [0] + chunk_tokens + [0]

            final_token = [tokens_to_send]
            print(final_token)

            # Use the number of tokens to select the appropriate style vector.
            style_index = len(chunk_tokens) + 2
            if style_index >= len(voices):
                style_index = len(voices) - 1  # Fallback if index is out-of-bounds.
            ref_s = voices[style_index]

            # Prepare the speed parameter.
            speed_param = np.ones(1, dtype=np.float32) * speed

            # Run the model (ONNX inference) for this chunk.
            try:
                start_time = time.time()
                audio_output = sess.run(None, {
                    "input_ids": final_token,
                    "style": ref_s,
                    "speed": speed_param,
                })[0]
                print(f"Chunk {i} inference time: {time.time() - start_time:.3f}s")
            except Exception as e:
                print(f"Error processing chunk {i}: {e}")
                # In case of error, generate a short silent chunk.
                audio_output = np.zeros((24000,), dtype=np.float32)

            # Convert the model output (assumed to be float32 in [-1, 1]) to int16 PCM.
            audio_int16 = (audio_output * 32767).astype(np.int16).flatten()[6000:-3000]
            print(audio_int16)

            # Convert to a torch tensor (back into float range) for our helper functions.
            # audio_tensor = torch.from_numpy(audio_int16.astype(np.float32) / 32767)

            # Yield the encoded audio chunk.
            yield audio_int16.tobytes()

    media_type = "audio/wav"
    return StreamingResponse(
        audio_generator(),
        media_type=media_type,
        headers={"Cache-Control": "no-cache"},
    )


@app.get("/tts/full", summary="Full TTS")
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
    """
    Full TTS endpoint that synthesizes the entire text and returns a complete WAV or Opus file.
    """
    voice_path = os.path.join(local_dir, f"voices/{voice}.bin")    
    voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)

    tokens = tokenizer(text)
    ref_s = voices[len(tokens)]
    final_token = [[0, *tokens, 0]]

    start_time = time.time()
    audio = sess.run(None, {
        "input_ids": final_token,
        "style": ref_s,
        "speed": np.ones(1, dtype=np.float32) * speed,
    })[0]
    print(f"Full TTS inference time: {time.time()-start_time:.3f}s")
    
    # Convert to int16 PCM.
    audio = (audio * 32767).astype(np.int16).flatten()
    
    if format.lower() == "wav":
        wav_io = io.BytesIO()
        write_wav(wav_io, 24000, audio)
        wav_io.seek(0)
        return Response(content=wav_io.read(), media_type="audio/wav")
    elif format.lower() == "opus":
        opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio.astype(np.float32)/32767), sample_rate=24000)
        return Response(content=opus_data, media_type="audio/opus")
    else:
        raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")


@app.get("/", response_class=HTMLResponse)
def index():
    """
    HTML demo page for Kokoro TTS.
    """
    return """
    <!DOCTYPE html>
    <html>
    <head>
        <title>Kokoro TTS Demo</title>
    </head>
    <body>
        <h1>Kokoro TTS Demo</h1>
        <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
        <label for="voice">Voice:</label>
        <input type="text" id="voice" value="af_heart"><br>
        <label for="speed">Speed:</label>
        <input type="number" step="0.1" id="speed" value="1.0"><br>
        <label for="format">Format:</label>
        <select id="format">
            <option value="wav">WAV</option>
            <option value="opus" selected>Opus</option>
        </select><br><br>
        <button onclick="playStreaming()">Play Streaming TTS</button>
        <button onclick="playFull()">Play Full TTS</button>
        <br><br>
        <audio id="audio" controls autoplay></audio>
        <script>
            function playStreaming() {
                const text = document.getElementById('text').value;
                const voice = document.getElementById('voice').value;
                const speed = document.getElementById('speed').value;
                const format = document.getElementById('format').value;
                const audio = document.getElementById('audio');
                audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
                audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
                audio.play();
            }
            function playFull() {
                const text = document.getElementById('text').value;
                const voice = document.getElementById('voice').value;
                const speed = document.getElementById('speed').value;
                const format = document.getElementById('format').value;
                const audio = document.getElementById('audio');
                audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
                audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
                audio.play();
            }
        </script>
    </body>
    </html>
    """


# ------------------------------------------------------------------------------
# Run the app with: uvicorn app:app --reload
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)