File size: 5,285 Bytes
7fd0353
 
 
 
0870f8f
7fd0353
 
 
 
 
e7655ad
 
7fd0353
0870f8f
 
7fd0353
e7655ad
bc93f92
 
 
 
e7655ad
7fd0353
bc93f92
 
 
 
0870f8f
 
 
 
 
 
b9f1e8b
0870f8f
7fd0353
bc93f92
 
 
7fd0353
 
bc93f92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f023a07
 
bc93f92
 
 
 
 
 
 
 
 
 
f023a07
bc93f92
 
 
 
 
 
80ce7b7
0870f8f
bc93f92
0870f8f
bc93f92
0870f8f
 
bc93f92
 
 
 
 
 
80ce7b7
 
 
f023a07
0870f8f
f023a07
0870f8f
80ce7b7
0870f8f
 
 
 
 
 
 
 
80ce7b7
 
bc93f92
 
 
 
 
 
ada1283
0870f8f
7fd0353
 
bc93f92
 
80ce7b7
0870f8f
bc93f92
 
80ce7b7
 
 
 
 
 
 
0870f8f
80ce7b7
 
 
 
 
 
 
0870f8f
80ce7b7
 
 
 
 
 
 
 
f023a07
bc93f92
 
 
 
 
f023a07
bc93f92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import io
import re
import wave
import struct
import time

import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, Response, HTMLResponse
from fastapi.middleware import Middleware
from fastapi.middleware.gzip import GZipMiddleware

from kokoro import KPipeline, StreamKPipeline
from kokoro.model import KModel

app = FastAPI(
    title="Kokoro TTS FastAPI",
    middleware=[
        Middleware(GZipMiddleware, compresslevel=9)  # Add GZip compression
    ]
)

# ------------------------------------------------------------------------------
# Global Pipeline Instance
# ------------------------------------------------------------------------------
# Create one pipeline instance for the entire app.
model = KModel()  # Or however you initialize/load your model
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
#pipeline = KPipeline(lang_code="a",model=model)
voice = "af_heart"
speed = 1.0

pipeline = StreamKPipeline(lang_code="a", model=model, voice=voice, device=device, speed=speed)

# ------------------------------------------------------------------------------
# Helper Functions
# ------------------------------------------------------------------------------

def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
    """
    Generate a WAV header for streaming.
    Since we don't know the final audio size, we set the data chunk size to a large dummy value.
    This header is sent only once at the start of the stream.
    """
    bits_per_sample = sample_width * 8
    byte_rate = sample_rate * num_channels * sample_width
    block_align = num_channels * sample_width
    # total file size = 36 + data_size (header is 44 bytes total)
    total_size = 36 + data_size
    header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
    fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
    data_chunk_header = struct.pack('<4sI', b'data', data_size)
    return header + fmt_chunk + data_chunk_header


def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
    """
    Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes.
    """
    # Ensure tensor is on CPU and flatten if necessary.
    audio_np = audio_tensor.cpu().numpy()
    if audio_np.ndim > 1:
        audio_np = audio_np.flatten()
    # Scale to int16 range.
    audio_int16 = np.int16(audio_np * 32767)
    return audio_int16.tobytes()



# ------------------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------------------

@app.get("/tts/streaming", summary="Streaming TTS")
def tts_streaming(text: str):
    """
    Streaming TTS endpoint that returns a continuous audio stream.

    The endpoint yields a WAV header (with a dummy length) for WAV,
    then yields encoded audio data for each phoneme as soon as it is generated.
    """
    sample_rate = 24000
    num_channels = 1
    sample_width = 2  # 16-bit PCM

    def audio_generator():
        # Yield the WAV header first.
        header = generate_wav_header(sample_rate, num_channels, sample_width)
        yield header

        # Process and yield each audio chunk.
        try:
            for result in pipeline(text):  # Use StreamKPipeline

                if result.audio is not None:
                    yield audio_tensor_to_pcm_bytes(result.audio)

                else:
                    print("No audio generated for phoneme")
        except Exception as e:
            print(f"Error processing: {e}")
            yield b''  # Important so that streaming continues.

    media_type = "audio/wav"

    return StreamingResponse(
        audio_generator(),
        media_type=media_type,
        headers={"Cache-Control": "no-cache"},
    )

#Remove full tts
@app.get("/", response_class=HTMLResponse)
def index():
    """
    HTML demo page for Kokoro TTS.

    This page provides a simple UI to enter text and play synthesized audio from the streaming endpoint.
    """
    return """
    <!DOCTYPE html>
    <html>
    <head>
        <title>Kokoro TTS Demo</title>
    </head>
    <body>
        <h1>Kokoro TTS Demo</h1>
        <textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br><br>
        <button onclick="playStreaming()">Play Streaming TTS</button>
        <br><br>
        <audio id="audio" controls autoplay></audio>
        <script>
            function playStreaming() {
                const text = document.getElementById('text').value;
                const audio = document.getElementById('audio');
                audio.src = `/tts/streaming?text=${encodeURIComponent(text)}`;
                audio.type = 'audio/wav';
                audio.play();
            }
        </script>
    </body>
    </html>
    """


# ------------------------------------------------------------------------------
# Run with: uvicorn app:app --reload
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn

    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)