import io import re import wave import struct import numpy as np import torch from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, Response, HTMLResponse from kokoro import KPipeline app = FastAPI(title="Kokoro TTS FastAPI") # ------------------------------------------------------------------------------ # Global Pipeline Instance # ------------------------------------------------------------------------------ # Create one pipeline instance for the entire app. pipeline = KPipeline(lang_code="a") # ------------------------------------------------------------------------------ # Helper Functions # ------------------------------------------------------------------------------ def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: """ Generate a WAV header for streaming. Since we don't know the final audio size, we set the data chunk size to a large dummy value. This header is sent only once at the start of the stream. """ bits_per_sample = sample_width * 8 byte_rate = sample_rate * num_channels * sample_width block_align = num_channels * sample_width # total file size = 36 + data_size (header is 44 bytes total) total_size = 36 + data_size header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) data_chunk_header = struct.pack('<4sI', b'data', data_size) return header + fmt_chunk + data_chunk_header def custom_split_text(text: str) -> list: """ Custom splitting: split text into chunks where each chunk doubles in size. """ words = text.split() chunks = [] chunk_size = 1 start = 0 while start < len(words): end = start + chunk_size chunk = " ".join(words[start:end]) chunks.append(chunk) start = end chunk_size *= 2 # double the chunk size for the next iteration return chunks def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: """ Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes. """ # Ensure tensor is on CPU and flatten if necessary. audio_np = audio_tensor.cpu().numpy() if audio_np.ndim > 1: audio_np = audio_np.flatten() # Scale to int16 range. audio_int16 = np.int16(audio_np * 32767) return audio_int16.tobytes() # ------------------------------------------------------------------------------ # Endpoints # ------------------------------------------------------------------------------ @app.get("/tts/streaming", summary="Streaming TTS") def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0): """ Streaming TTS endpoint that returns a continuous WAV stream. The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data for each text chunk as soon as it is generated. """ # Split the input text using the custom doubling strategy. chunks = custom_split_text(text) sample_rate = 24000 num_channels = 1 sample_width = 2 # 16-bit PCM def audio_generator(): # Yield the WAV header first. header = generate_wav_header(sample_rate, num_channels, sample_width) yield header # Process and yield each chunk's PCM data. for i, chunk in enumerate(chunks): print(f"Processing chunk {i}: {chunk}") # Debugging try: results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None)) for result in results: if result.audio is not None: print(f"Chunk {i}: Audio generated") # Debugging pcm_bytes = audio_tensor_to_pcm_bytes(result.audio) for i in range(0, len(pcm_bytes), 100): yield pcm_bytes[i:i + chunk_size] else: print(f"Chunk {i}: No audio generated") except Exception as e: print(f"Error processing chunk {i}: {e}") return StreamingResponse( audio_generator(), media_type="audio/wav", headers={"Cache-Control": "no-cache"}, ) @app.get("/tts/full", summary="Full TTS") def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0): """ Full TTS endpoint that synthesizes the entire text, concatenates the audio, and returns a complete WAV file. """ # Use newline-based splitting via the pipeline's split_pattern. results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")) audio_segments = [] for result in results: if result.audio is not None: audio_np = result.audio.cpu().numpy() if audio_np.ndim > 1: audio_np = audio_np.flatten() audio_segments.append(audio_np) if not audio_segments: raise HTTPException(status_code=500, detail="No audio generated.") # Concatenate all audio segments. full_audio = np.concatenate(audio_segments) # Write the concatenated audio to an in-memory WAV file. sample_rate = 24000 num_channels = 1 sample_width = 2 # 16-bit PCM -> 2 bytes per sample wav_io = io.BytesIO() with wave.open(wav_io, "wb") as wav_file: wav_file.setnchannels(num_channels) wav_file.setsampwidth(sample_width) wav_file.setframerate(sample_rate) full_audio_int16 = np.int16(full_audio * 32767) wav_file.writeframes(full_audio_int16.tobytes()) wav_io.seek(0) return Response(content=wav_io.read(), media_type="audio/wav") @app.get("/", response_class=HTMLResponse) def index(): """ HTML demo page for Kokoro TTS. This page provides a simple UI to enter text, choose a voice and speed, and play synthesized audio from both the streaming and full endpoints. """ return """ Kokoro TTS Demo

Kokoro TTS Demo







""" # ------------------------------------------------------------------------------ # Run with: uvicorn app:app --reload # ------------------------------------------------------------------------------ if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)