Spaces:
Running
Running
import io | |
import re | |
import wave | |
import struct | |
import os | |
import time | |
import json | |
import numpy as np | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.gzip import GZipMiddleware | |
from misaki import en, espeak | |
from onnxruntime import InferenceSession | |
from huggingface_hub import snapshot_download | |
from scipy.io.wavfile import write as write_wav | |
# ------------------------------------------------------------------------------ | |
# Load configuration and set up vocabulary | |
# ------------------------------------------------------------------------------ | |
config_file_path = 'config.json' # Update with your actual path | |
with open(config_file_path, 'r') as f: | |
config = json.load(f) | |
phoneme_vocab = config['vocab'] | |
# ------------------------------------------------------------------------------ | |
# Download the model and voice files from Hugging Face Hub | |
# ------------------------------------------------------------------------------ | |
model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX" | |
model_name = "onnx/model_q4.onnx" # "onnx/model.onnx" | |
voice_file_pattern = "*.bin" | |
local_dir = "." | |
snapshot_download( | |
repo_id=model_repo, | |
allow_patterns=[model_name, voice_file_pattern], | |
local_dir=local_dir | |
) | |
# ------------------------------------------------------------------------------ | |
# Load the ONNX model | |
# ------------------------------------------------------------------------------ | |
model_path = os.path.join(local_dir, model_name) | |
sess = InferenceSession(model_path) | |
# ------------------------------------------------------------------------------ | |
# Create the FastAPI app with GZip compression | |
# ------------------------------------------------------------------------------ | |
app = FastAPI( | |
title="Kokoro TTS FastAPI", | |
middleware=[Middleware(GZipMiddleware, compresslevel=9)] | |
) | |
# ------------------------------------------------------------------------------ | |
# Helper Functions | |
# ------------------------------------------------------------------------------ | |
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: | |
""" | |
Generate a WAV header for streaming. Since we do not know the final audio size, | |
a large dummy value is used for the data chunk size. | |
""" | |
bits_per_sample = sample_width * 8 | |
byte_rate = sample_rate * num_channels * sample_width | |
block_align = num_channels * sample_width | |
total_size = 36 + data_size # 36 + data_size (header is 44 bytes total) | |
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') | |
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
data_chunk_header = struct.pack('<4sI', b'data', data_size) | |
return header + fmt_chunk + data_chunk_header | |
stream_header = generate_wav_header(24000, 1, 2) | |
def custom_split_text(text: str) -> list: | |
""" | |
Custom splitting strategy: | |
- Start with a chunk size of 2 words. | |
- For each chunk, if a period (".") is found in any word (except the very last word), | |
then split at that word (including it). | |
- Otherwise, use the current chunk size. | |
- Increase the chunk size by 2 for each subsequent chunk. | |
- If there are fewer than the desired number of words remaining, include all of them. | |
""" | |
words = text.split() | |
chunks = [] | |
chunk_size = 2 | |
start = 0 | |
while start < len(words): | |
candidate_end = start + chunk_size | |
if candidate_end > len(words): | |
candidate_end = len(words) | |
chunk_words = words[start:candidate_end] | |
split_index = None | |
# for i in range(len(chunk_words) - 1): | |
# if '.' in chunk_words[i]: | |
# split_index = i | |
# break | |
# if split_index is not None: | |
# candidate_end = start + split_index + 1 | |
# chunk_words = words[start:candidate_end] | |
chunks.append(" ".join(chunk_words)) | |
start = candidate_end | |
if chunk_size < 100: | |
chunk_size += 2 | |
return chunks | |
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: | |
""" | |
Convert a torch.FloatTensor (values in [-1, 1]) to raw 16-bit PCM bytes. | |
""" | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
audio_int16 = np.int16(audio_np * 32767) | |
return audio_int16.tobytes() | |
def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes: | |
""" | |
Convert a torch.FloatTensor to Opus-encoded bytes. | |
Requires the 'opuslib' package: pip install opuslib | |
""" | |
try: | |
import opuslib | |
except ImportError: | |
raise ImportError("opuslib is not installed. Please install it with: pip install opuslib") | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
audio_int16 = np.int16(audio_np * 32767) | |
encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) | |
frame_size = int(sample_rate * 0.020) # 20 ms frame | |
encoded_data = b'' | |
for i in range(0, len(audio_int16), frame_size): | |
frame = audio_int16[i:i + frame_size] | |
if len(frame) < frame_size: | |
frame = np.pad(frame, (0, frame_size - len(frame)), 'constant') | |
encoded_frame = encoder.encode(frame.tobytes(), frame_size) | |
encoded_data += encoded_frame | |
return encoded_data | |
fbs = espeak.EspeakFallback(british=True) | |
g2p = en.G2P(trf=False, british=False, fallback=fbs) | |
def tokenizer(text: str): | |
""" | |
Converts text to a list of phoneme tokens using the global vocabulary. | |
""" | |
phonemes_string, tokens = g2p(text) | |
phonemes = [ph for ph in phonemes_string] | |
print(text + " " + phonemes_string) | |
tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab] | |
return tokens | |
# ------------------------------------------------------------------------------ | |
# Endpoints | |
# ------------------------------------------------------------------------------ | |
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): | |
""" | |
Streaming TTS endpoint. | |
This endpoint splits the input text into chunks (using the doubling strategy), | |
then for each chunk: | |
- For the first chunk, a 0 is prepended. | |
- For subsequent chunks, the first token is set to the last token from the previous chunk. | |
- For the final chunk, a 0 is appended. | |
The audio for each chunk is generated immediately and streamed to the client. | |
""" | |
chunks = custom_split_text(text) | |
# Load the voice/style file (must be present in voices/{voice}.bin) | |
voice_path = os.path.join(local_dir, f"voices/{voice}.bin") | |
if not os.path.exists(voice_path): | |
raise HTTPException(status_code=404, detail="Voice file not found") | |
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) | |
def audio_generator(): | |
# If outputting a WAV stream, yield a WAV header once. | |
if format.lower() == "wav": | |
yield stream_header | |
prev_last_token = None | |
for i, chunk in enumerate(chunks): | |
# Convert the chunk text to tokens. | |
chunk_tokens = tokenizer(chunk) | |
# For the first chunk, prepend 0; for later chunks, start with the previous chunk's last token. | |
# if i == 0: | |
# tokens_to_send = [0] + chunk_tokens + [0] | |
# else: | |
# tokens_to_send = [0] + chunk_tokens + [0] | |
# token_to_send = [0] + chunk_tokens | |
# Save the last token of this chunk for the next iteration. | |
prev_last_token = chunk_tokens[-1:] | |
# Prepare the model input (a batch of one sequence). | |
tokens_to_send = [0] + chunk_tokens + [0] | |
final_token = [tokens_to_send] | |
print(final_token) | |
# Use the number of tokens to select the appropriate style vector. | |
style_index = len(chunk_tokens) + 2 | |
if style_index >= len(voices): | |
style_index = len(voices) - 1 # Fallback if index is out-of-bounds. | |
ref_s = voices[style_index] | |
# Prepare the speed parameter. | |
speed_param = np.ones(1, dtype=np.float32) * speed | |
# Run the model (ONNX inference) for this chunk. | |
try: | |
start_time = time.time() | |
audio_output = sess.run(None, { | |
"input_ids": final_token, | |
"style": ref_s, | |
"speed": speed_param, | |
})[0] | |
print(f"Chunk {i} inference time: {time.time() - start_time:.3f}s") | |
except Exception as e: | |
print(f"Error processing chunk {i}: {e}") | |
# In case of error, generate a short silent chunk. | |
audio_output = np.zeros((24000,), dtype=np.float32) | |
# Convert the model output (assumed to be float32 in [-1, 1]) to int16 PCM. | |
audio_int16 = (audio_output * 32767).astype(np.int16).flatten()[6000:-3000] | |
print(audio_int16) | |
# Convert to a torch tensor (back into float range) for our helper functions. | |
# audio_tensor = torch.from_numpy(audio_int16.astype(np.float32) / 32767) | |
# Yield the encoded audio chunk. | |
yield audio_int16.tobytes() | |
media_type = "audio/wav" | |
return StreamingResponse( | |
audio_generator(), | |
media_type=media_type, | |
headers={"Cache-Control": "no-cache"}, | |
) | |
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): | |
""" | |
Full TTS endpoint that synthesizes the entire text and returns a complete WAV or Opus file. | |
""" | |
voice_path = os.path.join(local_dir, f"voices/{voice}.bin") | |
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) | |
tokens = tokenizer(text) | |
ref_s = voices[len(tokens)] | |
final_token = [[0, *tokens, 0]] | |
start_time = time.time() | |
audio = sess.run(None, { | |
"input_ids": final_token, | |
"style": ref_s, | |
"speed": np.ones(1, dtype=np.float32) * speed, | |
})[0] | |
print(f"Full TTS inference time: {time.time()-start_time:.3f}s") | |
# Convert to int16 PCM. | |
audio = (audio * 32767).astype(np.int16).flatten() | |
if format.lower() == "wav": | |
wav_io = io.BytesIO() | |
write_wav(wav_io, 24000, audio) | |
wav_io.seek(0) | |
return Response(content=wav_io.read(), media_type="audio/wav") | |
elif format.lower() == "opus": | |
opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio.astype(np.float32)/32767), sample_rate=24000) | |
return Response(content=opus_data, media_type="audio/opus") | |
else: | |
raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}") | |
def index(): | |
""" | |
HTML demo page for Kokoro TTS. | |
""" | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Kokoro TTS Demo</title> | |
</head> | |
<body> | |
<h1>Kokoro TTS Demo</h1> | |
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> | |
<label for="voice">Voice:</label> | |
<input type="text" id="voice" value="af_heart"><br> | |
<label for="speed">Speed:</label> | |
<input type="number" step="0.1" id="speed" value="1.0"><br> | |
<label for="format">Format:</label> | |
<select id="format"> | |
<option value="wav">WAV</option> | |
<option value="opus" selected>Opus</option> | |
</select><br><br> | |
<button onclick="playStreaming()">Play Streaming TTS</button> | |
<button onclick="playFull()">Play Full TTS</button> | |
<br><br> | |
<audio id="audio" controls autoplay></audio> | |
<script> | |
function playStreaming() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const format = document.getElementById('format').value; | |
const audio = document.getElementById('audio'); | |
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; | |
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; | |
audio.play(); | |
} | |
function playFull() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const format = document.getElementById('format').value; | |
const audio = document.getElementById('audio'); | |
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; | |
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; | |
audio.play(); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
# ------------------------------------------------------------------------------ | |
# Run the app with: uvicorn app:app --reload | |
# ------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |