Spaces:
Sleeping
Sleeping
import os | |
import io | |
import soundfile as sf | |
import litserve as ls | |
from fastapi.responses import Response | |
from kokoro import KPipeline | |
from audio_utils import combine_audio_files | |
class KokoroAPI(ls.LitAPI): | |
""" | |
KokoroAPI is a subclass of ls.LitAPI that provides an interface to the Kokoro model for text-to-speech task. | |
Methods: | |
- setup(device): Called once at startup for the task-specific setup. | |
- decode_request(request): Convert the request payload to model input. | |
- predict(inputs): Uses the model to generate audio from the input text. | |
- encode_response(output): Convert the model output to a response payload. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.pipeline = None | |
self.current_lang = None | |
def setup(self, device): | |
self.device = device | |
def decode_request(self, request): | |
""" | |
Convert the request payload to model input. | |
""" | |
# Extract the inputs from request payload | |
language_code = request.get("language_code", "a") | |
text = request.get("text", "") | |
voice = request.get("voice", "af_heart") | |
# Initialize or update pipeline if language changes | |
if self.current_lang != language_code: | |
self.current_lang = language_code | |
self.pipeline = KPipeline(lang_code=language_code, device=self.device) | |
# Return the inputs | |
return text, voice | |
def predict(self, inputs): | |
""" | |
Run inference and generate audio file using the Kokoro model. | |
""" | |
# Get the inputs | |
text, voice = inputs | |
try: | |
# Generate audio files | |
generator = self.pipeline(text, voice=voice, speed=1, split_pattern=r"\n+") | |
# Create the output directory if it does not exist | |
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'output') | |
os.makedirs(output_dir, exist_ok=True) | |
# Save each audio file | |
file_count = 0 | |
for i, (gs, ps, audio) in enumerate(generator): | |
file_path = f"{output_dir}/{i}.wav" | |
sf.write(file_path, audio, 24000) | |
file_count = i + 1 # Keep track of number of files | |
if file_count == 0: | |
# Handle case where no audio was generated | |
return None | |
# Combine all audio files | |
final_audio, samplerate = combine_audio_files(output_dir, file_count - 1) | |
# Save the final audio to a buffer | |
audio_buffer = io.BytesIO() | |
sf.write(audio_buffer, final_audio, samplerate, format="WAV") | |
audio_buffer.seek(0) | |
audio_data = audio_buffer.getvalue() | |
audio_buffer.close() | |
return audio_data | |
finally: | |
# Clean up output directory if it exists | |
if os.path.exists(output_dir): | |
for file in os.listdir(output_dir): | |
file_path = os.path.join(output_dir, file) | |
try: | |
os.remove(file_path) | |
except: | |
pass | |
try: | |
os.rmdir(output_dir) | |
except: | |
pass | |
def encode_response(self, output): | |
""" | |
Convert the model output to a response payload. | |
""" | |
# Package the generated audio data into a response | |
return Response(content=output, headers={"Content-Type": "audio/wav"}) | |
if __name__ == "__main__": | |
# Create an instance of the KokoroAPI class and run the server | |
api = KokoroAPI() | |
server = ls.LitServer(api, track_requests=True) | |
server.run(port=7860) | |