Spaces:
Runtime error
Runtime error
import io | |
import re | |
import wave | |
import struct | |
import time | |
import numpy as np | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.gzip import GZipMiddleware | |
from kokoro import KPipeline, StreamKPipeline | |
from kokoro.model import KModel | |
app = FastAPI( | |
title="Kokoro TTS FastAPI", | |
middleware=[ | |
Middleware(GZipMiddleware, compresslevel=9) # Add GZip compression | |
] | |
) | |
# ------------------------------------------------------------------------------ | |
# Global Pipeline Instance | |
# ------------------------------------------------------------------------------ | |
# Create one pipeline instance for the entire app. | |
model = KModel() # Or however you initialize/load your model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
#pipeline = KPipeline(lang_code="a",model=model) | |
voice = "af_heart" | |
speed = 1.0 | |
pipeline = StreamKPipeline(lang_code="a", model=model, voice=voice, device=device, speed=speed) | |
# ------------------------------------------------------------------------------ | |
# Helper Functions | |
# ------------------------------------------------------------------------------ | |
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: | |
""" | |
Generate a WAV header for streaming. | |
Since we don't know the final audio size, we set the data chunk size to a large dummy value. | |
This header is sent only once at the start of the stream. | |
""" | |
bits_per_sample = sample_width * 8 | |
byte_rate = sample_rate * num_channels * sample_width | |
block_align = num_channels * sample_width | |
# total file size = 36 + data_size (header is 44 bytes total) | |
total_size = 36 + data_size | |
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') | |
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) | |
data_chunk_header = struct.pack('<4sI', b'data', data_size) | |
return header + fmt_chunk + data_chunk_header | |
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: | |
""" | |
Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes. | |
""" | |
# Ensure tensor is on CPU and flatten if necessary. | |
audio_np = audio_tensor.cpu().numpy() | |
if audio_np.ndim > 1: | |
audio_np = audio_np.flatten() | |
# Scale to int16 range. | |
audio_int16 = np.int16(audio_np * 32767) | |
return audio_int16.tobytes() | |
# ------------------------------------------------------------------------------ | |
# Endpoints | |
# ------------------------------------------------------------------------------ | |
def tts_streaming(text: str): | |
""" | |
Streaming TTS endpoint that returns a continuous audio stream. | |
The endpoint yields a WAV header (with a dummy length) for WAV, | |
then yields encoded audio data for each phoneme as soon as it is generated. | |
""" | |
sample_rate = 24000 | |
num_channels = 1 | |
sample_width = 2 # 16-bit PCM | |
def audio_generator(): | |
# Yield the WAV header first. | |
header = generate_wav_header(sample_rate, num_channels, sample_width) | |
yield header | |
# Process and yield each audio chunk. | |
try: | |
for result in pipeline(text): # Use StreamKPipeline | |
if result.audio is not None: | |
yield audio_tensor_to_pcm_bytes(result.audio) | |
else: | |
print("No audio generated for phoneme") | |
except Exception as e: | |
print(f"Error processing: {e}") | |
yield b'' # Important so that streaming continues. | |
media_type = "audio/wav" | |
return StreamingResponse( | |
audio_generator(), | |
media_type=media_type, | |
headers={"Cache-Control": "no-cache"}, | |
) | |
#Remove full tts | |
def index(): | |
""" | |
HTML demo page for Kokoro TTS. | |
This page provides a simple UI to enter text and play synthesized audio from the streaming endpoint. | |
""" | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Kokoro TTS Demo</title> | |
</head> | |
<body> | |
<h1>Kokoro TTS Demo</h1> | |
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br><br> | |
<button onclick="playStreaming()">Play Streaming TTS</button> | |
<br><br> | |
<audio id="audio" controls autoplay></audio> | |
<script> | |
function playStreaming() { | |
const text = document.getElementById('text').value; | |
const audio = document.getElementById('audio'); | |
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}`; | |
audio.type = 'audio/wav'; | |
audio.play(); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
# ------------------------------------------------------------------------------ | |
# Run with: uvicorn app:app --reload | |
# ------------------------------------------------------------------------------ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |