import os import io import soundfile as sf import litserve as ls from fastapi.responses import Response from kokoro import KPipeline from audio_utils import combine_audio_files class KokoroAPI(ls.LitAPI): """ KokoroAPI is a subclass of ls.LitAPI that provides an interface to the Kokoro model for text-to-speech task. Methods: - setup(device): Called once at startup for the task-specific setup. - decode_request(request): Convert the request payload to model input. - predict(inputs): Uses the model to generate audio from the input text. - encode_response(output): Convert the model output to a response payload. """ def __init__(self): super().__init__() self.pipeline = None self.current_lang = None def setup(self, device): self.device = device def decode_request(self, request): """ Convert the request payload to model input. """ # Extract the inputs from request payload language_code = request.get("language_code", "a") text = request.get("text", "") voice = request.get("voice", "af_heart") # Initialize or update pipeline if language changes if self.current_lang != language_code: self.current_lang = language_code self.pipeline = KPipeline(lang_code=language_code, device=self.device) # Return the inputs return text, voice def predict(self, inputs): """ Run inference and generate audio file using the Kokoro model. """ # Get the inputs text, voice = inputs try: # Generate audio files generator = self.pipeline(text, voice=voice, speed=1, split_pattern=r"\n+") # Create the output directory if it does not exist output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'output') os.makedirs(output_dir, exist_ok=True) # Save each audio file file_count = 0 for i, (gs, ps, audio) in enumerate(generator): file_path = f"{output_dir}/{i}.wav" sf.write(file_path, audio, 24000) file_count = i + 1 # Keep track of number of files if file_count == 0: # Handle case where no audio was generated return None # Combine all audio files final_audio, samplerate = combine_audio_files(output_dir, file_count - 1) # Save the final audio to a buffer audio_buffer = io.BytesIO() sf.write(audio_buffer, final_audio, samplerate, format="WAV") audio_buffer.seek(0) audio_data = audio_buffer.getvalue() audio_buffer.close() return audio_data finally: # Clean up output directory if it exists if os.path.exists(output_dir): for file in os.listdir(output_dir): file_path = os.path.join(output_dir, file) try: os.remove(file_path) except: pass try: os.rmdir(output_dir) except: pass def encode_response(self, output): """ Convert the model output to a response payload. """ # Package the generated audio data into a response return Response(content=output, headers={"Content-Type": "audio/wav"}) if __name__ == "__main__": # Create an instance of the KokoroAPI class and run the server api = KokoroAPI() server = ls.LitServer(api, track_requests=True) server.run(port=7860)