Spaces:
Sleeping
Sleeping
import io | |
import re | |
import wave | |
import struct | |
import numpy as np | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from kokoro import KPipeline | |
app = FastAPI(title="Kokoro TTS FastAPI") | |
# ------------------------------------------------------------------------------ | |
# Global Pipeline Instance | |
# ------------------------------------------------------------------------------ | |
# Create one pipeline instance for the entire app. | |
pipeline = KPipeline(lang_code="a") | |
# ------------------------------------------------------------------------------ | |
# Helper Functions | |
# ------------------------------------------------------------------------------ | |
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: | |
""" | |
Generate a WAV header for streaming. | |
Since we don't know the final audio size, we set the data chunk size to a large dummy value. | |
This header is sent only once at the start of the stream. | |
""" | |
bits_per_sample = sample_width * 8 | |
byte_rate = sample_rate * num_channels * sample_width | |
block_align = num_channels * sample_width | |
# total file size = 36 + data_size (header is 44 bytes total) | |
total_size = 36 + data_size | |
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') | |
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
data_chunk_header = struct.pack('<4sI', b'data', data_size) | |
return header + fmt_chunk + data_chunk_header | |
def custom_split_text(text: str) -> list: | |
""" | |
Custom splitting: split text into chunks where each chunk doubles in size. | |
""" | |
words = text.split() | |
chunks = [] | |
chunk_size = 1 | |
start = 0 | |
while start < len(words): | |
end = start + chunk_size | |
chunk = " ".join(words[start:end]) | |
chunks.append(chunk) | |
start = end | |
chunk_size *= 2 # double the chunk size for the next iteration | |
return chunks | |
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: | |
""" | |
Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes. | |
""" | |
# Ensure tensor is on CPU and flatten if necessary. | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
# Scale to int16 range. | |
audio_int16 = np.int16(audio_np * 32767) | |
return audio_int16.tobytes() | |
# ------------------------------------------------------------------------------ | |
# Endpoints | |
# ------------------------------------------------------------------------------ | |
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0): | |
""" | |
Streaming TTS endpoint that returns a continuous WAV stream. | |
The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data | |
for each text chunk as soon as it is generated. | |
""" | |
# Split the input text using the custom doubling strategy. | |
chunks = custom_split_text(text) | |
sample_rate = 24000 | |
num_channels = 1 | |
sample_width = 2 # 16-bit PCM | |
def audio_generator(): | |
# Yield the WAV header first. | |
header = generate_wav_header(sample_rate, num_channels, sample_width) | |
yield header | |
# Process and yield each chunk's PCM data. | |
for i, chunk in enumerate(chunks): | |
print(f"Processing chunk {i}: {chunk}") # Debugging | |
try: | |
results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None)) | |
for result in results: | |
if result.audio is not None: | |
print(f"Chunk {i}: Audio generated") # Debugging | |
pcm_bytes = audio_tensor_to_pcm_bytes(result.audio) | |
for i in range(0, len(pcm_bytes), 100): | |
yield pcm_bytes[i:i + chunk_size] | |
else: | |
print(f"Chunk {i}: No audio generated") | |
except Exception as e: | |
print(f"Error processing chunk {i}: {e}") | |
return StreamingResponse( | |
audio_generator(), | |
media_type="audio/wav", | |
headers={"Cache-Control": "no-cache"}, | |
) | |
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0): | |
""" | |
Full TTS endpoint that synthesizes the entire text, concatenates the audio, | |
and returns a complete WAV file. | |
""" | |
# Use newline-based splitting via the pipeline's split_pattern. | |
results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")) | |
audio_segments = [] | |
for result in results: | |
if result.audio is not None: | |
audio_np = result.audio.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
audio_segments.append(audio_np) | |
if not audio_segments: | |
raise HTTPException(status_code=500, detail="No audio generated.") | |
# Concatenate all audio segments. | |
full_audio = np.concatenate(audio_segments) | |
# Write the concatenated audio to an in-memory WAV file. | |
sample_rate = 24000 | |
num_channels = 1 | |
sample_width = 2 # 16-bit PCM -> 2 bytes per sample | |
wav_io = io.BytesIO() | |
with wave.open(wav_io, "wb") as wav_file: | |
wav_file.setnchannels(num_channels) | |
wav_file.setsampwidth(sample_width) | |
wav_file.setframerate(sample_rate) | |
full_audio_int16 = np.int16(full_audio * 32767) | |
wav_file.writeframes(full_audio_int16.tobytes()) | |
wav_io.seek(0) | |
return Response(content=wav_io.read(), media_type="audio/wav") | |
def index(): | |
""" | |
HTML demo page for Kokoro TTS. | |
This page provides a simple UI to enter text, choose a voice and speed, | |
and play synthesized audio from both the streaming and full endpoints. | |
""" | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Kokoro TTS Demo</title> | |
</head> | |
<body> | |
<h1>Kokoro TTS Demo</h1> | |
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> | |
<label for="voice">Voice:</label> | |
<input type="text" id="voice" value="af_heart"><br> | |
<label for="speed">Speed:</label> | |
<input type="number" step="0.1" id="speed" value="1.0"><br><br> | |
<button onclick="playStreaming()">Play Streaming TTS</button> | |
<button onclick="playFull()">Play Full TTS</button> | |
<br><br> | |
<audio id="audio" controls autoplay></audio> | |
<script> | |
function playStreaming() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const audio = document.getElementById('audio'); | |
// Set the audio element's source to the streaming endpoint. | |
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`; | |
audio.play(); | |
} | |
function playFull() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const audio = document.getElementById('audio'); | |
// Set the audio element's source to the full TTS endpoint. | |
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`; | |
audio.play(); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
# ------------------------------------------------------------------------------ | |
# Run with: uvicorn app:app --reload | |
# ------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |