VocRT / app.py
anuragsingh922's picture
Upload folder using huggingface_hub
d7dfeff verified
from concurrent import futures
import torch
from models import build_model
import numpy as np
import re
import wave
from kokoro import generate
from openai import OpenAI
from collections import deque
import grpc
import text_to_speech_pb2
import text_to_speech_pb2_grpc
import io
from dotenv import load_dotenv
import os
from chat_database import save_chat_entry, get_chat_history
load_dotenv()
# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the Kokoro model
MODEL = build_model('kokoro-v0_19.pth', device)
# Specify the voice name and load the voice pack
VOICE_NAME = [
'af',
'af_bella', 'af_sarah', 'am_adam', 'am_michael',
'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
'af_nicole', 'af_sky',
][0]
VOICEPACK = torch.load(f'voices/{VOICE_NAME}.pt', weights_only=True).to(device)
client = OpenAI(
api_key= os.getenv("OPENAI_API_KEY")
)
def chunk_text(text, max_chars=2040):
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if current_length + sentence_length <= max_chars:
current_chunk.append(sentence)
current_length += sentence_length
else:
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [sentence]
current_length = sentence_length
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def generate_audio_from_chunks(text, model, voicepack, voice_name):
chunks = chunk_text(text)
combined_audio = np.array([])
for chunk in chunks:
try:
audio, _ = generate(model, chunk, voicepack, lang=voice_name[0])
combined_audio = np.concatenate([combined_audio, audio]) if combined_audio.size > 0 else audio
except Exception:
pass
return combined_audio
def save_audio_to_file(audio_data, file_number, sample_rate=24000):
filename = f"output-{file_number}.wav"
with wave.open(filename, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
audio_int16 = (audio_data * 32767).astype(np.int16)
wav_file.writeframes(audio_int16.tobytes())
return filename
def getResponse(text , session_id):
try:
chat_history = get_chat_history(session_id)
response = client.chat.completions.create(
model='gpt-3.5-turbo',
messages=chat_history,
stream=True
)
return response
except Exception as e:
print("Error in getResponse : " , e)
def get_audio_bytes(audio_data, sample_rate=24000):
wav_bytes = io.BytesIO()
with wave.open(wav_bytes, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
audio_int16 = (audio_data * 32767).astype(np.int16)
wav_file.writeframes(audio_int16.tobytes())
wav_bytes.seek(0)
return wav_bytes.read()
def dummy_bytes():
buffer = io.BytesIO()
dummy_data = b"This is a test of dummy byte data."
buffer.write(dummy_data)
buffer.seek(0)
byte_value = buffer.getvalue()
return byte_value
class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer):
def ProcessText(self, request_iterator, context):
try:
print("Received new request")
parameters = {
"processing_active": False,
"queue": deque(),
"file_number": 0,
"session_id": "",
"interrupt_seq" : 0
}
for request in request_iterator:
field = request.WhichOneof('request_data')
if field == 'metadata':
parameters["session_id"] = request.metadata.session_id
continue
elif field == 'text':
text = request.text
if not text:
continue
save_chat_entry(parameters["session_id"] , "user" , text)
parameters["queue"].clear()
yield text_to_speech_pb2.ProcessTextResponse(
buffer = dummy_bytes(),
session_id=parameters["session_id"],
sequence_id = "-2",
transcript=text,
)
final_response = ""
response = getResponse(text , parameters["session_id"])
for chunk in response:
msg = chunk.choices[0].delta.content
if msg:
final_response += msg
if final_response.endswith(('.', '!', '?')):
parameters["file_number"] += 1
parameters["queue"].append((final_response, parameters["file_number"]))
final_response = ""
if not parameters["processing_active"]:
yield from self.process_queue(parameters)
if final_response:
parameters["file_number"] += 1
parameters["queue"].append((final_response, parameters["file_number"]))
if not parameters["processing_active"]:
yield from self.process_queue(parameters)
elif field == 'status':
transcript = request.status.transcript
played_seq = request.status.played_seq
interrupt_seq = request.status.interrupt_seq
parameters["interrupt_seq"] = interrupt_seq
save_chat_entry(parameters["session_id"] , "assistant" , transcript)
continue
else:
continue
except Exception as e:
print("Error in ProcessText:", e)
def process_queue(self , parameters):
try:
while True:
if not parameters["queue"]:
parameters["processing_active"] = False
break
parameters["processing_active"] = True
sentence, file_number = parameters["queue"].popleft()
if file_number <= int(parameters["interrupt_seq"]):
continue
combined_audio = generate_audio_from_chunks(sentence, MODEL, VOICEPACK, VOICE_NAME)
audio_bytes = get_audio_bytes(combined_audio)
# filename = save_audio_to_file(combined_audio, file_number)
yield text_to_speech_pb2.ProcessTextResponse(
buffer=audio_bytes,
session_id=parameters["session_id"],
sequence_id=str(file_number),
transcript=sentence,
)
except Exception as e:
parameters["processing_active"] = False
print("Error in process_queue:", e)
def serve():
print("Starting gRPC server...")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server(TextToSpeechServicer(), server)
server.add_insecure_port('[::]:8081')
server.start()
print("gRPC server is running on port 8081")
server.wait_for_termination()
if __name__ == "__main__":
serve()