File size: 3,633 Bytes
87337b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import asyncio
from dataclasses import dataclass

from websocket import WebSocketConnectionClosedException

from ten.async_ten_env import AsyncTenEnv
from ten_ai_base.config import BaseConfig
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback


@dataclass
class CosyTTSConfig(BaseConfig):
    api_key: str = ""
    voice: str = "longxiaochun"
    model: str = "cosyvoice-v1"
    sample_rate: int = 16000


class AsyncIteratorCallback(ResultCallback):
    def __init__(self, ten_env: AsyncTenEnv, queue: asyncio.Queue) -> None:
        self.closed = False
        self.ten_env = ten_env
        self.loop = asyncio.get_event_loop()
        self.queue = queue

    def close(self):
        self.closed = True

    def on_open(self):
        self.ten_env.log_info("websocket is open.")

    def on_complete(self):
        self.ten_env.log_info("speech synthesis task complete successfully.")

    def on_error(self, message: str):
        self.ten_env.log_error(f"speech synthesis task failed, {message}")

    def on_close(self):
        self.ten_env.log_info("websocket is closed.")
        self.close()

    def on_event(self, message: str) -> None:
        self.ten_env.log_debug(f"received event: {message}")

    def on_data(self, data: bytes) -> None:
        if self.closed:
            self.ten_env.log_warn(
                f"received data: {len(data)} bytes but connection was closed"
            )
            return
        self.ten_env.log_debug(f"received data: {len(data)} bytes")
        asyncio.run_coroutine_threadsafe(self.queue.put(data), self.loop)


class CosyTTS:
    def __init__(self, config: CosyTTSConfig) -> None:
        self.config = config
        self.synthesizer = None  # Initially no synthesizer
        self.queue = asyncio.Queue()
        dashscope.api_key = config.api_key

    def _create_synthesizer(
        self, ten_env: AsyncTenEnv, callback: AsyncIteratorCallback
    ):
        if self.synthesizer:
            self.synthesizer = None

        ten_env.log_info("Creating new synthesizer")
        self.synthesizer = SpeechSynthesizer(
            model=self.config.model,
            voice=self.config.voice,
            format=AudioFormat.PCM_16000HZ_MONO_16BIT,
            callback=callback,
        )

    async def get_audio_bytes(self) -> bytes:
        return await self.queue.get()

    def text_to_speech_stream(
        self, ten_env: AsyncTenEnv, text: str, end_of_segment: bool
    ) -> None:
        try:
            callback = AsyncIteratorCallback(ten_env, self.queue)

            if not self.synthesizer or end_of_segment:
                self._create_synthesizer(ten_env, callback)

            self.synthesizer.streaming_call(text)

            if end_of_segment:
                ten_env.log_info("Streaming complete")
                self.synthesizer.streaming_complete()
                self.synthesizer = None
        except WebSocketConnectionClosedException as e:
            ten_env.log_error(f"WebSocket connection closed, {e}")
            self.synthesizer = None
        except Exception as e:
            ten_env.log_error(f"Error streaming text, {e}")
            self.synthesizer = None

    def cancel(self, ten_env: AsyncTenEnv) -> None:
        if self.synthesizer:
            try:
                self.synthesizer.streaming_cancel()
            except WebSocketConnectionClosedException as e:
                ten_env.log_error(f"WebSocket connection closed, {e}")
            except Exception as e:
                ten_env.log_error(f"Error cancelling streaming, {e}")
            self.synthesizer = None