Spaces:
Sleeping
Sleeping
import io | |
import re | |
import wave | |
import struct | |
import numpy as np | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from kokoro import KPipeline | |
app = FastAPI(title="Kokoro TTS FastAPI") | |
# ------------------------------------------------------------------------------ | |
# Global Pipeline Instance | |
# ------------------------------------------------------------------------------ | |
pipeline = KPipeline(lang_code="a") | |
# ------------------------------------------------------------------------------ | |
# Helper Functions | |
# ------------------------------------------------------------------------------ | |
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: | |
""" | |
Generate a WAV header for streaming. | |
Since we don't know the final audio size, we set the data chunk size to a large dummy value. | |
""" | |
bits_per_sample = sample_width * 8 | |
byte_rate = sample_rate * num_channels * sample_width | |
block_align = num_channels * sample_width | |
total_size = 36 + data_size # header (44 bytes) minus 8 + dummy data size | |
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') | |
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
data_chunk_header = struct.pack('<4sI', b'data', data_size) | |
return header + fmt_chunk + data_chunk_header | |
def custom_split_text(text: str) -> list: | |
""" | |
Custom splitting: | |
- Start with a chunk size of 2 words. | |
- For each chunk, if a period (".") is found in any word (except if it’s the very last word), | |
then split the chunk at that word (i.e. include words up to and including that word). | |
- Otherwise, use the current chunk size. | |
- For subsequent chunks, increase the chunk size by 2 (i.e. 2, 4, 6, …). | |
- If there are fewer than the desired number of words for a full chunk, add all remaining words. | |
""" | |
words = text.split() | |
chunks = [] | |
chunk_size = 2 | |
start = 0 | |
while start < len(words): | |
candidate_end = start + chunk_size | |
if candidate_end > len(words): | |
candidate_end = len(words) | |
chunk_words = words[start:candidate_end] | |
# Look for a period in the chunk (from right to left) | |
split_index = None | |
for i in reversed(range(len(chunk_words))): | |
if '.' in chunk_words[i]: | |
split_index = i | |
break | |
if split_index is not None and split_index != len(chunk_words) - 1: | |
# If a period is found and it’s not the last word in the chunk, | |
# adjust the chunk so it ends at that word. | |
candidate_end = start + split_index + 1 | |
chunk_words = words[start:candidate_end] | |
chunks.append(" ".join(chunk_words)) | |
start = candidate_end | |
chunk_size += 2 # Increase by 2 (added, not multiplied) | |
return chunks | |
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: | |
""" | |
Convert a torch.FloatTensor (with values assumed in [-1, 1]) to raw 16-bit PCM bytes. | |
""" | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
audio_int16 = np.int16(audio_np * 32767) | |
return audio_int16.tobytes() | |
# ------------------------------------------------------------------------------ | |
# Endpoints | |
# ------------------------------------------------------------------------------ | |
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0): | |
""" | |
Streaming TTS endpoint that returns a continuous WAV stream. | |
The endpoint first yields a WAV header (with a dummy length) then yields raw PCM data | |
for each text chunk as soon as it is generated. | |
""" | |
chunks = custom_split_text(text) | |
sample_rate = 24000 | |
num_channels = 1 | |
sample_width = 2 # 16-bit PCM | |
def audio_generator(): | |
# Yield the WAV header first. | |
header = generate_wav_header(sample_rate, num_channels, sample_width) | |
yield header | |
# Process and yield each chunk's PCM data. | |
for i, chunk in enumerate(chunks): | |
print(f"Processing chunk {i}: {chunk}") | |
try: | |
results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None)) | |
for result in results: | |
if result.audio is not None: | |
print(f"Chunk {i}: Audio generated") | |
yield audio_tensor_to_pcm_bytes(result.audio) | |
else: | |
print(f"Chunk {i}: No audio generated") | |
except Exception as e: | |
print(f"Error processing chunk {i}: {e}") | |
return StreamingResponse( | |
audio_generator(), | |
media_type="audio/wav", | |
headers={"Cache-Control": "no-cache"}, | |
) | |
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0): | |
""" | |
Full TTS endpoint that synthesizes the entire text, concatenates the audio, | |
and returns a complete WAV file. | |
""" | |
results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")) | |
audio_segments = [] | |
for result in results: | |
if result.audio is not None: | |
audio_np = result.audio.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
audio_segments.append(audio_np) | |
if not audio_segments: | |
raise HTTPException(status_code=500, detail="No audio generated.") | |
full_audio = np.concatenate(audio_segments) | |
sample_rate = 24000 | |
num_channels = 1 | |
sample_width = 2 # 16-bit PCM | |
wav_io = io.BytesIO() | |
with wave.open(wav_io, "wb") as wav_file: | |
wav_file.setnchannels(num_channels) | |
wav_file.setsampwidth(sample_width) | |
wav_file.setframerate(sample_rate) | |
full_audio_int16 = np.int16(full_audio * 32767) | |
wav_file.writeframes(full_audio_int16.tobytes()) | |
wav_io.seek(0) | |
return Response(content=wav_io.read(), media_type="audio/wav") | |
def index(): | |
""" | |
HTML demo page for Kokoro TTS. | |
Two playback methods are provided: | |
- "Play Full TTS" uses a standard <audio> element. | |
- "Play Streaming TTS" uses the Web Audio API (via a ScriptProcessorNode) to stream | |
the raw PCM data as it arrives. This method first reads the WAV header (44 bytes) | |
then continuously pulls in PCM data, converts it to Float32, and plays it. | |
""" | |
return r""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Kokoro TTS Demo</title> | |
</head> | |
<body> | |
<h1>Kokoro TTS Demo</h1> | |
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> | |
<label for="voice">Voice:</label> | |
<input type="text" id="voice" value="af_heart"><br> | |
<label for="speed">Speed:</label> | |
<input type="number" step="0.1" id="speed" value="1.0"><br><br> | |
<button onclick="startStreaming()">Play Streaming TTS (Web Audio API)</button> | |
<button onclick="playFull()">Play Full TTS (Standard Audio)</button> | |
<br><br> | |
<audio id="fullAudio" controls></audio> | |
<script> | |
// Function to play full TTS by simply setting the <audio> element's source. | |
function playFull() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const audio = document.getElementById('fullAudio'); | |
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`; | |
audio.play(); | |
} | |
// Function to stream audio using the Web Audio API. | |
async function startStreaming() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const response = await fetch(`/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`); | |
if (!response.body) { | |
alert("Streaming not supported in this browser."); | |
return; | |
} | |
const reader = response.body.getReader(); | |
const audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
// Create a ScriptProcessorNode (buffer size of 4096 samples) | |
const scriptNode = audioContext.createScriptProcessor(4096, 1, 1); | |
let bufferQueue = []; | |
let currentBuffer = new Float32Array(0); | |
let headerRead = false; | |
let headerBytes = new Uint8Array(0); | |
// Helper: Convert Int16 PCM (little-endian) to Float32. | |
function int16ToFloat32(buffer) { | |
const len = buffer.length; | |
const floatBuffer = new Float32Array(len); | |
for (let i = 0; i < len; i++) { | |
floatBuffer[i] = buffer[i] / 32767; | |
} | |
return floatBuffer; | |
} | |
scriptNode.onaudioprocess = function(e) { | |
const output = e.outputBuffer.getChannelData(0); | |
let offset = 0; | |
while (offset < output.length) { | |
if (currentBuffer.length === 0) { | |
if (bufferQueue.length > 0) { | |
currentBuffer = bufferQueue.shift(); | |
} else { | |
// If no data is available, output silence. | |
for (let i = offset; i < output.length; i++) { | |
output[i] = 0; | |
} | |
break; | |
} | |
} | |
const needed = output.length - offset; | |
const available = currentBuffer.length; | |
const toCopy = Math.min(needed, available); | |
output.set(currentBuffer.slice(0, toCopy), offset); | |
offset += toCopy; | |
if (toCopy < currentBuffer.length) { | |
currentBuffer = currentBuffer.slice(toCopy); | |
} else { | |
currentBuffer = new Float32Array(0); | |
} | |
} | |
}; | |
scriptNode.connect(audioContext.destination); | |
// Read the response stream. | |
while (true) { | |
const { done, value } = await reader.read(); | |
if (done) break; | |
let chunk = value; | |
// First, accumulate the 44-byte WAV header. | |
if (!headerRead) { | |
let combined = new Uint8Array(headerBytes.length + chunk.length); | |
combined.set(headerBytes); | |
combined.set(chunk, headerBytes.length); | |
if (combined.length >= 44) { | |
headerBytes = combined.slice(0, 44); | |
headerRead = true; | |
// Remove the header bytes from the chunk. | |
chunk = combined.slice(44); | |
} else { | |
headerBytes = combined; | |
continue; | |
} | |
} | |
// Make sure the chunk length is even (2 bytes per sample). | |
if (chunk.length % 2 !== 0) { | |
chunk = chunk.slice(0, chunk.length - 1); | |
} | |
const int16Buffer = new Int16Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 2); | |
const floatBuffer = int16ToFloat32(int16Buffer); | |
bufferQueue.push(floatBuffer); | |
} | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
# ------------------------------------------------------------------------------ | |
# Run with: uvicorn app:app --reload | |
# ------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |