from concurrent import futures import torch from models import build_model import numpy as np import re import wave from kokoro import generate from openai import OpenAI from collections import deque import grpc import text_to_speech_pb2 import text_to_speech_pb2_grpc import io from dotenv import load_dotenv import os from chat_database import save_chat_entry, get_chat_history load_dotenv() # Device configuration device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load the Kokoro model MODEL = build_model('kokoro-v0_19.pth', device) # Specify the voice name and load the voice pack VOICE_NAME = [ 'af', 'af_bella', 'af_sarah', 'am_adam', 'am_michael', 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', 'af_nicole', 'af_sky', ][0] VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device) client = OpenAI( api_key= os.getenv("OPENAI_API_KEY") ) def chunk_text(text, max_chars=2040): sentences = re.split(r'(?<=[.!?])\s+', text) chunks = [] current_chunk = [] current_length = 0 for sentence in sentences: sentence_length = len(sentence) if current_length + sentence_length <= max_chars: current_chunk.append(sentence) current_length += sentence_length else: if current_chunk: chunks.append(' '.join(current_chunk)) current_chunk = [sentence] current_length = sentence_length if current_chunk: chunks.append(' '.join(current_chunk)) return chunks def generate_audio_from_chunks(text, model, voicepack, voice_name): chunks = chunk_text(text) combined_audio = np.array([]) for chunk in chunks: try: audio, _ = generate(model, chunk, voicepack, lang=voice_name[0]) combined_audio = np.concatenate([combined_audio, audio]) if combined_audio.size > 0 else audio except Exception: pass return combined_audio def save_audio_to_file(audio_data, file_number, sample_rate=24000): filename = f"output-{file_number}.wav" with wave.open(filename, 'wb') as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) audio_int16 = (audio_data * 32767).astype(np.int16) wav_file.writeframes(audio_int16.tobytes()) return filename def getResponse(text , session_id): try: chat_history = get_chat_history(session_id) response = client.chat.completions.create( model='gpt-3.5-turbo', messages=chat_history, stream=True ) return response except Exception as e: print("Error in getResponse : " , e) def get_audio_bytes(audio_data, sample_rate=24000): wav_bytes = io.BytesIO() with wave.open(wav_bytes, 'wb') as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) audio_int16 = (audio_data * 32767).astype(np.int16) wav_file.writeframes(audio_int16.tobytes()) wav_bytes.seek(0) return wav_bytes.read() def dummy_bytes(): buffer = io.BytesIO() dummy_data = b"This is a test of dummy byte data." buffer.write(dummy_data) buffer.seek(0) byte_value = buffer.getvalue() return byte_value class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer): def ProcessText(self, request_iterator, context): try: print("Received new request") parameters = { "processing_active": False, "queue": deque(), "file_number": 0, "session_id": "", "interrupt_seq" : 0 } for request in request_iterator: field = request.WhichOneof('request_data') if field == 'metadata': parameters["session_id"] = request.metadata.session_id continue elif field == 'text': text = request.text if not text: continue save_chat_entry(parameters["session_id"] , "user" , text) parameters["queue"].clear() yield text_to_speech_pb2.ProcessTextResponse( buffer = dummy_bytes(), session_id=parameters["session_id"], sequence_id = "-2", transcript=text, ) final_response = "" response = getResponse(text , parameters["session_id"]) for chunk in response: msg = chunk.choices[0].delta.content if msg: final_response += msg if final_response.endswith(('.', '!', '?')): parameters["file_number"] += 1 parameters["queue"].append((final_response, parameters["file_number"])) final_response = "" if not parameters["processing_active"]: yield from self.process_queue(parameters) if final_response: parameters["file_number"] += 1 parameters["queue"].append((final_response, parameters["file_number"])) if not parameters["processing_active"]: yield from self.process_queue(parameters) elif field == 'status': transcript = request.status.transcript played_seq = request.status.played_seq interrupt_seq = request.status.interrupt_seq parameters["interrupt_seq"] = interrupt_seq save_chat_entry(parameters["session_id"] , "assistant" , transcript) continue else: continue except Exception as e: print("Error in ProcessText:", e) def process_queue(self , parameters): try: while True: if not parameters["queue"]: parameters["processing_active"] = False break parameters["processing_active"] = True sentence, file_number = parameters["queue"].popleft() if file_number <= int(parameters["interrupt_seq"]): continue combined_audio = generate_audio_from_chunks(sentence, MODEL, VOICEPACK, VOICE_NAME) audio_bytes = get_audio_bytes(combined_audio) # filename = save_audio_to_file(combined_audio, file_number) yield text_to_speech_pb2.ProcessTextResponse( buffer=audio_bytes, session_id=parameters["session_id"], sequence_id=str(file_number), transcript=sentence, ) except Exception as e: parameters["processing_active"] = False print("Error in process_queue:", e) def serve(): print("Starting gRPC server...") server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server(TextToSpeechServicer(), server) server.add_insecure_port('[::]:8081') server.start() print("gRPC server is running on port 8081") server.wait_for_termination() if __name__ == "__main__": serve()