tts-kokoro / app /server.py
nam pham
fix: change local variable
6394a29
import os
import io
import soundfile as sf
import litserve as ls
from fastapi.responses import Response
from kokoro import KPipeline
from audio_utils import combine_audio_files
class KokoroAPI(ls.LitAPI):
"""
KokoroAPI is a subclass of ls.LitAPI that provides an interface to the Kokoro model for text-to-speech task.
Methods:
- setup(device): Called once at startup for the task-specific setup.
- decode_request(request): Convert the request payload to model input.
- predict(inputs): Uses the model to generate audio from the input text.
- encode_response(output): Convert the model output to a response payload.
"""
def __init__(self):
super().__init__()
self.pipeline = None
self.current_lang = None
def setup(self, device):
self.device = device
def decode_request(self, request):
"""
Convert the request payload to model input.
"""
# Extract the inputs from request payload
language_code = request.get("language_code", "a")
text = request.get("text", "")
voice = request.get("voice", "af_heart")
# Initialize or update pipeline if language changes
if self.current_lang != language_code:
self.current_lang = language_code
self.pipeline = KPipeline(lang_code=language_code, device=self.device)
# Return the inputs
return text, voice
def predict(self, inputs):
"""
Run inference and generate audio file using the Kokoro model.
"""
# Get the inputs
text, voice = inputs
try:
# Generate audio files
generator = self.pipeline(text, voice=voice, speed=1, split_pattern=r"\n+")
# Create the output directory if it does not exist
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'output')
os.makedirs(output_dir, exist_ok=True)
# Save each audio file
file_count = 0
for i, (gs, ps, audio) in enumerate(generator):
file_path = f"{output_dir}/{i}.wav"
sf.write(file_path, audio, 24000)
file_count = i + 1 # Keep track of number of files
if file_count == 0:
# Handle case where no audio was generated
return None
# Combine all audio files
final_audio, samplerate = combine_audio_files(output_dir, file_count - 1)
# Save the final audio to a buffer
audio_buffer = io.BytesIO()
sf.write(audio_buffer, final_audio, samplerate, format="WAV")
audio_buffer.seek(0)
audio_data = audio_buffer.getvalue()
audio_buffer.close()
return audio_data
finally:
# Clean up output directory if it exists
if os.path.exists(output_dir):
for file in os.listdir(output_dir):
file_path = os.path.join(output_dir, file)
try:
os.remove(file_path)
except:
pass
try:
os.rmdir(output_dir)
except:
pass
def encode_response(self, output):
"""
Convert the model output to a response payload.
"""
# Package the generated audio data into a response
return Response(content=output, headers={"Content-Type": "audio/wav"})
if __name__ == "__main__":
# Create an instance of the KokoroAPI class and run the server
api = KokoroAPI()
server = ls.LitServer(api, track_requests=True)
server.run(port=7860)