|
from concurrent import futures |
|
import torch |
|
from models import build_model |
|
import numpy as np |
|
import re |
|
import wave |
|
from kokoro import generate |
|
from openai import OpenAI |
|
from collections import deque |
|
import grpc |
|
import text_to_speech_pb2 |
|
import text_to_speech_pb2_grpc |
|
import io |
|
from dotenv import load_dotenv |
|
import os |
|
from chat_database import save_chat_entry, get_chat_history |
|
|
|
load_dotenv() |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
MODEL = build_model('kokoro-v0_19.pth', device) |
|
|
|
|
|
VOICE_NAME = [ |
|
'af', |
|
'af_bella', 'af_sarah', 'am_adam', 'am_michael', |
|
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', |
|
'af_nicole', 'af_sky', |
|
][0] |
|
VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device) |
|
|
|
|
|
client = OpenAI( |
|
api_key= os.getenv("OPENAI_API_KEY") |
|
) |
|
|
|
def chunk_text(text, max_chars=2040): |
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
chunks = [] |
|
current_chunk = [] |
|
current_length = 0 |
|
for sentence in sentences: |
|
sentence_length = len(sentence) |
|
if current_length + sentence_length <= max_chars: |
|
current_chunk.append(sentence) |
|
current_length += sentence_length |
|
else: |
|
if current_chunk: |
|
chunks.append(' '.join(current_chunk)) |
|
current_chunk = [sentence] |
|
current_length = sentence_length |
|
if current_chunk: |
|
chunks.append(' '.join(current_chunk)) |
|
return chunks |
|
|
|
def generate_audio_from_chunks(text, model, voicepack, voice_name): |
|
chunks = chunk_text(text) |
|
combined_audio = np.array([]) |
|
for chunk in chunks: |
|
try: |
|
audio, _ = generate(model, chunk, voicepack, lang=voice_name[0]) |
|
combined_audio = np.concatenate([combined_audio, audio]) if combined_audio.size > 0 else audio |
|
except Exception: |
|
pass |
|
return combined_audio |
|
|
|
def save_audio_to_file(audio_data, file_number, sample_rate=24000): |
|
filename = f"output-{file_number}.wav" |
|
with wave.open(filename, 'wb') as wav_file: |
|
wav_file.setnchannels(1) |
|
wav_file.setsampwidth(2) |
|
wav_file.setframerate(sample_rate) |
|
audio_int16 = (audio_data * 32767).astype(np.int16) |
|
wav_file.writeframes(audio_int16.tobytes()) |
|
return filename |
|
|
|
def getResponse(text , session_id): |
|
try: |
|
chat_history = get_chat_history(session_id) |
|
response = client.chat.completions.create( |
|
model='gpt-3.5-turbo', |
|
messages=chat_history, |
|
stream=True |
|
) |
|
return response |
|
except Exception as e: |
|
print("Error in getResponse : " , e) |
|
|
|
def get_audio_bytes(audio_data, sample_rate=24000): |
|
wav_bytes = io.BytesIO() |
|
with wave.open(wav_bytes, 'wb') as wav_file: |
|
wav_file.setnchannels(1) |
|
wav_file.setsampwidth(2) |
|
wav_file.setframerate(sample_rate) |
|
audio_int16 = (audio_data * 32767).astype(np.int16) |
|
wav_file.writeframes(audio_int16.tobytes()) |
|
wav_bytes.seek(0) |
|
return wav_bytes.read() |
|
|
|
def dummy_bytes(): |
|
buffer = io.BytesIO() |
|
dummy_data = b"This is a test of dummy byte data." |
|
buffer.write(dummy_data) |
|
buffer.seek(0) |
|
byte_value = buffer.getvalue() |
|
return byte_value |
|
|
|
|
|
class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer): |
|
def ProcessText(self, request_iterator, context): |
|
try: |
|
print("Received new request") |
|
parameters = { |
|
"processing_active": False, |
|
"queue": deque(), |
|
"file_number": 0, |
|
"session_id": "", |
|
"interrupt_seq" : 0 |
|
} |
|
for request in request_iterator: |
|
field = request.WhichOneof('request_data') |
|
if field == 'metadata': |
|
parameters["session_id"] = request.metadata.session_id |
|
continue |
|
elif field == 'text': |
|
text = request.text |
|
if not text: |
|
continue |
|
save_chat_entry(parameters["session_id"] , "user" , text) |
|
parameters["queue"].clear() |
|
yield text_to_speech_pb2.ProcessTextResponse( |
|
buffer = dummy_bytes(), |
|
session_id=parameters["session_id"], |
|
sequence_id = "-2", |
|
transcript=text, |
|
) |
|
final_response = "" |
|
response = getResponse(text , parameters["session_id"]) |
|
for chunk in response: |
|
msg = chunk.choices[0].delta.content |
|
if msg: |
|
final_response += msg |
|
if final_response.endswith(('.', '!', '?')): |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append((final_response, parameters["file_number"])) |
|
final_response = "" |
|
if not parameters["processing_active"]: |
|
yield from self.process_queue(parameters) |
|
|
|
if final_response: |
|
parameters["file_number"] += 1 |
|
parameters["queue"].append((final_response, parameters["file_number"])) |
|
if not parameters["processing_active"]: |
|
yield from self.process_queue(parameters) |
|
|
|
elif field == 'status': |
|
transcript = request.status.transcript |
|
played_seq = request.status.played_seq |
|
interrupt_seq = request.status.interrupt_seq |
|
parameters["interrupt_seq"] = interrupt_seq |
|
save_chat_entry(parameters["session_id"] , "assistant" , transcript) |
|
continue |
|
else: |
|
continue |
|
except Exception as e: |
|
print("Error in ProcessText:", e) |
|
|
|
def process_queue(self , parameters): |
|
try: |
|
while True: |
|
if not parameters["queue"]: |
|
parameters["processing_active"] = False |
|
break |
|
parameters["processing_active"] = True |
|
sentence, file_number = parameters["queue"].popleft() |
|
if file_number <= int(parameters["interrupt_seq"]): |
|
continue |
|
combined_audio = generate_audio_from_chunks(sentence, MODEL, VOICEPACK, VOICE_NAME) |
|
audio_bytes = get_audio_bytes(combined_audio) |
|
|
|
yield text_to_speech_pb2.ProcessTextResponse( |
|
buffer=audio_bytes, |
|
session_id=parameters["session_id"], |
|
sequence_id=str(file_number), |
|
transcript=sentence, |
|
) |
|
except Exception as e: |
|
parameters["processing_active"] = False |
|
print("Error in process_queue:", e) |
|
|
|
|
|
def serve(): |
|
print("Starting gRPC server...") |
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) |
|
text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server(TextToSpeechServicer(), server) |
|
server.add_insecure_port('[::]:8081') |
|
server.start() |
|
print("gRPC server is running on port 8081") |
|
server.wait_for_termination() |
|
|
|
|
|
if __name__ == "__main__": |
|
serve() |
|
|