Spaces:
Running
Running
import io | |
import re | |
import wave | |
import struct | |
import numpy as np | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.gzip import GZipMiddleware | |
from misaki import en | |
import os | |
import numpy as np | |
from onnxruntime import InferenceSession | |
from huggingface_hub import snapshot_download | |
import json | |
# Load the configuration file | |
config_file_path = 'config.json' # Update this with the path to your config file | |
with open(config_file_path, 'r') as f: | |
config = json.load(f) | |
# Extract the phoneme vocabulary | |
phoneme_vocab = config['vocab'] | |
# Step 3: Download the model and voice file from Hugging Face Hub | |
model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX" | |
model_name = "onnx/model_q8f16.onnx" | |
voice_file = "voices" | |
local_dir = "." | |
# Download the model and voice file | |
snapshot_download( | |
repo_id=model_repo, | |
local_dir=local_dir, | |
allow_patterns=[model_name, voice_file], | |
) | |
# Step 4: Load the model | |
model_path = os.path.join(local_dir, model_name) | |
sess = InferenceSession(model_path) | |
app = FastAPI( | |
title="Kokoro TTS FastAPI", | |
middleware=[ | |
Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression | |
] | |
) | |
# ------------------------------------------------------------------------------ | |
# Global Pipeline Instance | |
# ------------------------------------------------------------------------------ | |
# Create one pipeline instance for the entire app. | |
# ------------------------------------------------------------------------------ | |
# Helper Functions | |
# ------------------------------------------------------------------------------ | |
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: | |
""" | |
Generate a WAV header for streaming. | |
Since we don't know the final audio size, we set the data chunk size to a large dummy value. | |
This header is sent only once at the start of the stream. | |
""" | |
bits_per_sample = sample_width * 8 | |
byte_rate = sample_rate * num_channels * sample_width | |
block_align = num_channels * sample_width | |
# total file size = 36 + data_size (header is 44 bytes total) | |
total_size = 36 + data_size | |
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') | |
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
data_chunk_header = struct.pack('<4sI', b'data', data_size) | |
return header + fmt_chunk + data_chunk_header | |
def custom_split_text(text: str) -> list: | |
""" | |
Custom splitting: | |
- Start with a chunk size of 2 words. | |
- For each chunk, if a period (".") is found in any word (except if it’s the very last word), | |
then split the chunk at that word (include words up to that word). | |
- Otherwise, use the current chunk size. | |
- For subsequent chunks, increase the chunk size by 2. | |
- If there are fewer than the desired number of words for a full chunk, add all remaining words. | |
""" | |
words = text.split() | |
chunks = [] | |
chunk_size = 2 | |
start = 0 | |
while start < len(words): | |
candidate_end = start + chunk_size | |
if candidate_end > len(words): | |
candidate_end = len(words) | |
chunk_words = words[start:candidate_end] | |
# Look for a period in any word except the last one. | |
split_index = None | |
for i in range(len(chunk_words) - 1): | |
if '.' in chunk_words[i]: | |
split_index = i | |
break | |
if split_index is not None: | |
candidate_end = start + split_index + 1 | |
chunk_words = words[start:candidate_end] | |
chunks.append(" ".join(chunk_words)) | |
start = candidate_end | |
chunk_size += 2 # Increase the chunk size by 2 for the next iteration. | |
return chunks | |
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: | |
""" | |
Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes. | |
""" | |
# Ensure tensor is on CPU and flatten if necessary. | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
# Scale to int16 range. | |
audio_int16 = np.int16(audio_np * 32767) | |
return audio_int16.tobytes() | |
def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes: | |
""" | |
Convert a torch.FloatTensor to Opus encoded bytes. | |
Requires the 'opuslib' package: pip install opuslib | |
""" | |
try: | |
import opuslib | |
except ImportError: | |
raise ImportError("opuslib is not installed. Please install it with: pip install opuslib") | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
# Scale to int16 range. Important for opus. | |
audio_int16 = np.int16(audio_np * 32767) | |
encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) # 1 channel for mono. | |
# Calculate the number of frames to encode. Opus frames are 2.5, 5, 10, or 20 ms long. | |
frame_size = int(sample_rate * 0.020) # 20ms frame size | |
encoded_data = b'' | |
for i in range(0, len(audio_int16), frame_size): | |
frame = audio_int16[i:i + frame_size] | |
if len(frame) < frame_size: | |
# Pad the last frame with zeros if needed. | |
frame = np.pad(frame, (0, frame_size - len(frame)), 'constant') | |
encoded_frame = encoder.encode(frame.tobytes(), frame_size) # Encode the frame. | |
encoded_data += encoded_frame | |
return encoded_data | |
g2p = en.G2P(trf=False, british=False, fallback=None) # no transformer, American English | |
def tokenizer(text): | |
phonemes_string, _ = g2p(text) | |
phonemes = [] | |
for i in phonemes_string: | |
phonemes.append(i) | |
tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab] | |
return tokens | |
# ------------------------------------------------------------------------------ | |
# Endpoints | |
# ------------------------------------------------------------------------------ | |
# @app.get("/tts/streaming", summary="Streaming TTS") | |
# def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "opus"): | |
# """ | |
# Streaming TTS endpoint that returns a continuous audio stream. | |
# Supports WAV (PCM) and Opus formats. Opus offers significantly better compression. | |
# The endpoint first yields a WAV header (with a dummy length) for WAV, | |
# then yields encoded audio data for each text chunk as soon as it is generated. | |
# """ | |
# # Split the input text using the custom doubling strategy. | |
# chunks = custom_split_text(text) | |
# sample_rate = 24000 | |
# num_channels = 1 | |
# sample_width = 2 # 16-bit PCM | |
# def audio_generator(): | |
# if format.lower() == "wav": | |
# # Yield the WAV header first. | |
# header = generate_wav_header(sample_rate, num_channels, sample_width) | |
# yield header | |
# # Process and yield each chunk's audio data. | |
# for i, chunk in enumerate(chunks): | |
# print(f"Processing chunk {i}: {chunk}") # Debugging | |
# try: | |
# results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None)) | |
# for result in results: | |
# if result.audio is not None: | |
# if format.lower() == "wav": | |
# yield audio_tensor_to_pcm_bytes(result.audio) | |
# elif format.lower() == "opus": | |
# yield audio_tensor_to_opus_bytes(result.audio, sample_rate=sample_rate) | |
# else: | |
# raise ValueError(f"Unsupported audio format: {format}") | |
# else: | |
# print(f"Chunk {i}: No audio generated") | |
# except Exception as e: | |
# print(f"Error processing chunk {i}: {e}") | |
# yield b'' # important so that streaming continues. Consider returning an error sound. | |
# media_type = "audio/wav" if format.lower() == "wav" else "audio/opus" | |
# return StreamingResponse( | |
# audio_generator(), | |
# media_type=media_type, | |
# headers={"Cache-Control": "no-cache"}, | |
# ) | |
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): | |
""" | |
Full TTS endpoint that synthesizes the entire text, concatenates the audio, | |
and returns a complete WAV or Opus file. | |
""" | |
voice_path = os.path.join(local_dir, f"voices/{voice}.bin") | |
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) | |
tokens = tokenizer(text) | |
final_token = [[0, *tokens]] | |
full_audio = sess.run(None, dict( | |
input_ids=tokens, | |
style=ref_s, | |
speed=np.ones(1, dtype=np.float32), | |
))[0] | |
# Write the concatenated audio to an in-memory WAV or Opus file. | |
sample_rate = 24000 | |
num_channels = 1 | |
sample_width = 2 # 16-bit PCM -> 2 bytes per sample | |
if format.lower() == "wav": | |
wav_io = io.BytesIO() | |
with wave.open(wav_io, "wb") as wav_file: | |
wav_file.setnchannels(num_channels) | |
wav_file.setsampwidth(sample_width) | |
wav_file.setframerate(sample_rate) | |
full_audio_int16 = np.int16(full_audio * 32767) | |
wav_file.writeframes(full_audio_int16.tobytes()) | |
wav_io.seek(0) | |
return Response(content=wav_io.read(), media_type="audio/wav") | |
elif format.lower() == "opus": | |
opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(full_audio), sample_rate=sample_rate) | |
return Response(content=opus_data, media_type="audio/opus") | |
else: | |
raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}") | |
def index(): | |
""" | |
HTML demo page for Kokoro TTS. | |
This page provides a simple UI to enter text, choose a voice and speed, | |
and play synthesized audio from both the streaming and full endpoints. | |
""" | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Kokoro TTS Demo</title> | |
</head> | |
<body> | |
<h1>Kokoro TTS Demo</h1> | |
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> | |
<label for="voice">Voice:</label> | |
<input type="text" id="voice" value="af_heart"><br> | |
<label for="speed">Speed:</label> | |
<input type="number" step="0.1" id="speed" value="1.0"><br> | |
<label for="format">Format:</label> | |
<select id="format"> | |
<option value="wav">WAV</option> | |
<option value="opus" selected>Opus</option> | |
</select><br><br> | |
<button onclick="playStreaming()">Play Streaming TTS</button> | |
<button onclick="playFull()">Play Full TTS</button> | |
<br><br> | |
<audio id="audio" controls autoplay></audio> | |
<script> | |
function playStreaming() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const format = document.getElementById('format').value; | |
const audio = document.getElementById('audio'); | |
// Set the audio element's source to the streaming endpoint. | |
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; | |
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; | |
audio.play(); | |
} | |
function playFull() { | |
const text = document.getElementById('text').value; | |
const voice = document.getElementById('voice').value; | |
const speed = document.getElementById('speed').value; | |
const format = document.getElementById('format').value; | |
const audio = document.getElementById('audio'); | |
// Set the audio element's source to the full TTS endpoint. | |
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; | |
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; | |
audio.play(); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
# ------------------------------------------------------------------------------ | |
# Run with: uvicorn app:app --reload | |
# ------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |