bcci's picture
Update app.py
a9b6d73 verified
import io
import re
import wave
import struct
import os
import time
import json
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, Response, HTMLResponse
from fastapi.middleware import Middleware
from fastapi.middleware.gzip import GZipMiddleware
from misaki import en, espeak
from onnxruntime import InferenceSession
from huggingface_hub import snapshot_download
from scipy.io.wavfile import write as write_wav
# ------------------------------------------------------------------------------
# Load configuration and set up vocabulary
# ------------------------------------------------------------------------------
config_file_path = 'config.json' # Update with your actual path
with open(config_file_path, 'r') as f:
config = json.load(f)
phoneme_vocab = config['vocab']
# ------------------------------------------------------------------------------
# Download the model and voice files from Hugging Face Hub
# ------------------------------------------------------------------------------
model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX"
model_name = "onnx/model_q4.onnx" # "onnx/model.onnx"
voice_file_pattern = "*.bin"
local_dir = "."
snapshot_download(
repo_id=model_repo,
allow_patterns=[model_name, voice_file_pattern],
local_dir=local_dir
)
# ------------------------------------------------------------------------------
# Load the ONNX model
# ------------------------------------------------------------------------------
model_path = os.path.join(local_dir, model_name)
sess = InferenceSession(model_path)
# ------------------------------------------------------------------------------
# Create the FastAPI app with GZip compression
# ------------------------------------------------------------------------------
app = FastAPI(
title="Kokoro TTS FastAPI",
middleware=[Middleware(GZipMiddleware, compresslevel=9)]
)
# ------------------------------------------------------------------------------
# Helper Functions
# ------------------------------------------------------------------------------
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes:
"""
Generate a WAV header for streaming. Since we do not know the final audio size,
a large dummy value is used for the data chunk size.
"""
bits_per_sample = sample_width * 8
byte_rate = sample_rate * num_channels * sample_width
block_align = num_channels * sample_width
total_size = 36 + data_size # 36 + data_size (header is 44 bytes total)
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE')
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample)
data_chunk_header = struct.pack('<4sI', b'data', data_size)
return header + fmt_chunk + data_chunk_header
stream_header = generate_wav_header(24000, 1, 2)
def custom_split_text(text: str) -> list:
"""
Custom splitting strategy:
- Start with a chunk size of 2 words.
- For each chunk, if a period (".") is found in any word (except the very last word),
then split at that word (including it).
- Otherwise, use the current chunk size.
- Increase the chunk size by 2 for each subsequent chunk.
- If there are fewer than the desired number of words remaining, include all of them.
"""
words = text.split()
chunks = []
chunk_size = 2
start = 0
while start < len(words):
candidate_end = start + chunk_size
if candidate_end > len(words):
candidate_end = len(words)
chunk_words = words[start:candidate_end]
split_index = None
# for i in range(len(chunk_words) - 1):
# if '.' in chunk_words[i]:
# split_index = i
# break
# if split_index is not None:
# candidate_end = start + split_index + 1
# chunk_words = words[start:candidate_end]
chunks.append(" ".join(chunk_words))
start = candidate_end
if chunk_size < 100:
chunk_size += 2
return chunks
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes:
"""
Convert a torch.FloatTensor (values in [-1, 1]) to raw 16-bit PCM bytes.
"""
audio_np = audio_tensor.cpu().numpy()
if audio_np.ndim > 1:
audio_np = audio_np.flatten()
audio_int16 = np.int16(audio_np * 32767)
return audio_int16.tobytes()
def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes:
"""
Convert a torch.FloatTensor to Opus-encoded bytes.
Requires the 'opuslib' package: pip install opuslib
"""
try:
import opuslib
except ImportError:
raise ImportError("opuslib is not installed. Please install it with: pip install opuslib")
audio_np = audio_tensor.cpu().numpy()
if audio_np.ndim > 1:
audio_np = audio_np.flatten()
audio_int16 = np.int16(audio_np * 32767)
encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP)
frame_size = int(sample_rate * 0.020) # 20 ms frame
encoded_data = b''
for i in range(0, len(audio_int16), frame_size):
frame = audio_int16[i:i + frame_size]
if len(frame) < frame_size:
frame = np.pad(frame, (0, frame_size - len(frame)), 'constant')
encoded_frame = encoder.encode(frame.tobytes(), frame_size)
encoded_data += encoded_frame
return encoded_data
fbs = espeak.EspeakFallback(british=True)
g2p = en.G2P(trf=False, british=False, fallback=fbs)
def tokenizer(text: str):
"""
Converts text to a list of phoneme tokens using the global vocabulary.
"""
phonemes_string, tokens = g2p(text)
phonemes = [ph for ph in phonemes_string]
print(text + " " + phonemes_string)
tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab]
return tokens
# ------------------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------------------
@app.get("/tts/streaming", summary="Streaming TTS")
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
"""
Streaming TTS endpoint.
This endpoint splits the input text into chunks (using the doubling strategy),
then for each chunk:
- For the first chunk, a 0 is prepended.
- For subsequent chunks, the first token is set to the last token from the previous chunk.
- For the final chunk, a 0 is appended.
The audio for each chunk is generated immediately and streamed to the client.
"""
chunks = custom_split_text(text)
# Load the voice/style file (must be present in voices/{voice}.bin)
voice_path = os.path.join(local_dir, f"voices/{voice}.bin")
if not os.path.exists(voice_path):
raise HTTPException(status_code=404, detail="Voice file not found")
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)
def audio_generator():
# If outputting a WAV stream, yield a WAV header once.
if format.lower() == "wav":
yield stream_header
prev_last_token = None
for i, chunk in enumerate(chunks):
# Convert the chunk text to tokens.
chunk_tokens = tokenizer(chunk)
# For the first chunk, prepend 0; for later chunks, start with the previous chunk's last token.
# if i == 0:
# tokens_to_send = [0] + chunk_tokens + [0]
# else:
# tokens_to_send = [0] + chunk_tokens + [0]
# token_to_send = [0] + chunk_tokens
# Save the last token of this chunk for the next iteration.
prev_last_token = chunk_tokens[-1:]
# Prepare the model input (a batch of one sequence).
tokens_to_send = [0] + chunk_tokens + [0]
final_token = [tokens_to_send]
print(final_token)
# Use the number of tokens to select the appropriate style vector.
style_index = len(chunk_tokens) + 2
if style_index >= len(voices):
style_index = len(voices) - 1 # Fallback if index is out-of-bounds.
ref_s = voices[style_index]
# Prepare the speed parameter.
speed_param = np.ones(1, dtype=np.float32) * speed
# Run the model (ONNX inference) for this chunk.
try:
start_time = time.time()
audio_output = sess.run(None, {
"input_ids": final_token,
"style": ref_s,
"speed": speed_param,
})[0]
print(f"Chunk {i} inference time: {time.time() - start_time:.3f}s")
except Exception as e:
print(f"Error processing chunk {i}: {e}")
# In case of error, generate a short silent chunk.
audio_output = np.zeros((24000,), dtype=np.float32)
# Convert the model output (assumed to be float32 in [-1, 1]) to int16 PCM.
audio_int16 = (audio_output * 32767).astype(np.int16).flatten()[6000:-3000]
print(audio_int16)
# Convert to a torch tensor (back into float range) for our helper functions.
# audio_tensor = torch.from_numpy(audio_int16.astype(np.float32) / 32767)
# Yield the encoded audio chunk.
yield audio_int16.tobytes()
media_type = "audio/wav"
return StreamingResponse(
audio_generator(),
media_type=media_type,
headers={"Cache-Control": "no-cache"},
)
@app.get("/tts/full", summary="Full TTS")
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"):
"""
Full TTS endpoint that synthesizes the entire text and returns a complete WAV or Opus file.
"""
voice_path = os.path.join(local_dir, f"voices/{voice}.bin")
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256)
tokens = tokenizer(text)
ref_s = voices[len(tokens)]
final_token = [[0, *tokens, 0]]
start_time = time.time()
audio = sess.run(None, {
"input_ids": final_token,
"style": ref_s,
"speed": np.ones(1, dtype=np.float32) * speed,
})[0]
print(f"Full TTS inference time: {time.time()-start_time:.3f}s")
# Convert to int16 PCM.
audio = (audio * 32767).astype(np.int16).flatten()
if format.lower() == "wav":
wav_io = io.BytesIO()
write_wav(wav_io, 24000, audio)
wav_io.seek(0)
return Response(content=wav_io.read(), media_type="audio/wav")
elif format.lower() == "opus":
opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio.astype(np.float32)/32767), sample_rate=24000)
return Response(content=opus_data, media_type="audio/opus")
else:
raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}")
@app.get("/", response_class=HTMLResponse)
def index():
"""
HTML demo page for Kokoro TTS.
"""
return """
<!DOCTYPE html>
<html>
<head>
<title>Kokoro TTS Demo</title>
</head>
<body>
<h1>Kokoro TTS Demo</h1>
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br>
<label for="voice">Voice:</label>
<input type="text" id="voice" value="af_heart"><br>
<label for="speed">Speed:</label>
<input type="number" step="0.1" id="speed" value="1.0"><br>
<label for="format">Format:</label>
<select id="format">
<option value="wav">WAV</option>
<option value="opus" selected>Opus</option>
</select><br><br>
<button onclick="playStreaming()">Play Streaming TTS</button>
<button onclick="playFull()">Play Full TTS</button>
<br><br>
<audio id="audio" controls autoplay></audio>
<script>
function playStreaming() {
const text = document.getElementById('text').value;
const voice = document.getElementById('voice').value;
const speed = document.getElementById('speed').value;
const format = document.getElementById('format').value;
const audio = document.getElementById('audio');
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
audio.play();
}
function playFull() {
const text = document.getElementById('text').value;
const voice = document.getElementById('voice').value;
const speed = document.getElementById('speed').value;
const format = document.getElementById('format').value;
const audio = document.getElementById('audio');
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`;
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus';
audio.play();
}
</script>
</body>
</html>
"""
# ------------------------------------------------------------------------------
# Run the app with: uvicorn app:app --reload
# ------------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)