File size: 7,649 Bytes
d7dfeff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from concurrent import futures
import torch
from models import build_model
import numpy as np
import re
import wave
from kokoro import generate
from openai import OpenAI
from collections import deque
import grpc
import text_to_speech_pb2
import text_to_speech_pb2_grpc
import io
from dotenv import load_dotenv
import os
from chat_database import save_chat_entry, get_chat_history

load_dotenv()

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the Kokoro model
MODEL = build_model('kokoro-v0_19.pth', device)

# Specify the voice name and load the voice pack
VOICE_NAME = [
    'af',
    'af_bella', 'af_sarah', 'am_adam', 'am_michael',
    'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
    'af_nicole', 'af_sky',
][0]
VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)


client = OpenAI(
    api_key= os.getenv("OPENAI_API_KEY")
)

def chunk_text(text, max_chars=2040):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = []
    current_length = 0
    for sentence in sentences:
        sentence_length = len(sentence)
        if current_length + sentence_length <= max_chars:
            current_chunk.append(sentence)
            current_length += sentence_length
        else:
            if current_chunk:
                chunks.append(' '.join(current_chunk))
            current_chunk = [sentence]
            current_length = sentence_length
    if current_chunk:
        chunks.append(' '.join(current_chunk))
    return chunks

def generate_audio_from_chunks(text, model, voicepack, voice_name):
    chunks = chunk_text(text)
    combined_audio = np.array([])
    for chunk in chunks:
        try:
            audio, _ = generate(model, chunk, voicepack, lang=voice_name[0])
            combined_audio = np.concatenate([combined_audio, audio]) if combined_audio.size > 0 else audio
        except Exception:
            pass
    return combined_audio

def save_audio_to_file(audio_data, file_number, sample_rate=24000):
    filename = f"output-{file_number}.wav"
    with wave.open(filename, 'wb') as wav_file:
        wav_file.setnchannels(1)
        wav_file.setsampwidth(2)
        wav_file.setframerate(sample_rate)
        audio_int16 = (audio_data * 32767).astype(np.int16)
        wav_file.writeframes(audio_int16.tobytes())
    return filename

def getResponse(text , session_id):
    try:
        chat_history = get_chat_history(session_id)
        response = client.chat.completions.create(
            model='gpt-3.5-turbo',
            messages=chat_history,
            stream=True
        )
        return response
    except Exception as e:
        print("Error in getResponse : "  , e)

def get_audio_bytes(audio_data, sample_rate=24000):
    wav_bytes = io.BytesIO()
    with wave.open(wav_bytes, 'wb') as wav_file:
        wav_file.setnchannels(1)
        wav_file.setsampwidth(2)
        wav_file.setframerate(sample_rate)
        audio_int16 = (audio_data * 32767).astype(np.int16)
        wav_file.writeframes(audio_int16.tobytes())
    wav_bytes.seek(0)
    return wav_bytes.read()

def dummy_bytes():
    buffer = io.BytesIO()
    dummy_data = b"This is a test of dummy byte data."
    buffer.write(dummy_data)
    buffer.seek(0)
    byte_value = buffer.getvalue()
    return byte_value


class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer):
    def ProcessText(self, request_iterator, context):
        try:
            print("Received new request")
            parameters = {
                "processing_active": False,
                "queue": deque(),
                "file_number": 0,
                "session_id": "",
                "interrupt_seq" : 0
            }
            for request in request_iterator:
                field = request.WhichOneof('request_data')
                if field == 'metadata':
                    parameters["session_id"] = request.metadata.session_id
                    continue
                elif field == 'text':
                    text = request.text
                    if not text:
                        continue
                    save_chat_entry(parameters["session_id"] , "user" , text)
                    parameters["queue"].clear()
                    yield text_to_speech_pb2.ProcessTextResponse(
                        buffer = dummy_bytes(),
                        session_id=parameters["session_id"],
                        sequence_id = "-2",
                        transcript=text,
                    )
                    final_response = ""
                    response = getResponse(text , parameters["session_id"])
                    for chunk in response:
                        msg = chunk.choices[0].delta.content
                        if msg:
                            final_response += msg
                            if final_response.endswith(('.', '!', '?')):
                                parameters["file_number"] += 1
                                parameters["queue"].append((final_response, parameters["file_number"]))
                                final_response = ""
                                if not parameters["processing_active"]:
                                    yield from self.process_queue(parameters)

                    if final_response:
                        parameters["file_number"] += 1
                        parameters["queue"].append((final_response, parameters["file_number"]))
                        if not parameters["processing_active"]:
                            yield from self.process_queue(parameters)
                            
                elif field == 'status':
                    transcript = request.status.transcript
                    played_seq = request.status.played_seq
                    interrupt_seq = request.status.interrupt_seq
                    parameters["interrupt_seq"] = interrupt_seq
                    save_chat_entry(parameters["session_id"] , "assistant" , transcript)
                    continue
                else:
                    continue
        except Exception as e:
            print("Error in ProcessText:", e)

    def process_queue(self , parameters):
        try:
            while True:
                if not parameters["queue"]:
                    parameters["processing_active"] = False
                    break
                parameters["processing_active"] = True
                sentence, file_number = parameters["queue"].popleft()
                if file_number <= int(parameters["interrupt_seq"]):
                    continue
                combined_audio = generate_audio_from_chunks(sentence, MODEL, VOICEPACK, VOICE_NAME)
                audio_bytes = get_audio_bytes(combined_audio)
                # filename = save_audio_to_file(combined_audio, file_number)
                yield text_to_speech_pb2.ProcessTextResponse(
                    buffer=audio_bytes,
                    session_id=parameters["session_id"],
                    sequence_id=str(file_number),
                    transcript=sentence,
                )
        except Exception as e:
            parameters["processing_active"] = False
            print("Error in process_queue:", e)


def serve():
    print("Starting gRPC server...")
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server(TextToSpeechServicer(), server)
    server.add_insecure_port('[::]:8081')
    server.start()
    print("gRPC server is running on port 8081")
    server.wait_for_termination()


if __name__ == "__main__":
    serve()