File size: 9,929 Bytes
8b6d69d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import asyncio
from dataclasses import dataclass
from typing import Optional, List, Tuple
from concurrent.futures import ThreadPoolExecutor
import torch
import numpy as np
from transformers import PreTrainedModel

from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.utils import Counter

from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from gpt_config import XTTSGPTConfig
from xtts2_config import XTTSConfig
from tokenizer import XTTSTokenizerFast


@dataclass
class XTTSRequest:
    """Container for XTTS inference request data"""
    request_id: str
    text: str
    language: str
    gpt_cond_latent: torch.Tensor
    speaker_embedding: torch.Tensor
    temperature: float = 0.75
    top_p: float = 0.85
    top_k: int = 50
    repetition_penalty: float = 10.0
    length_penalty: float = 1.0
    do_sample: bool = True


@dataclass
class XTTSOutput:
    """Container for XTTS inference output"""
    request_id: str
    wav: np.ndarray
    gpt_latents: np.ndarray
    speaker_embedding: torch.Tensor


class Xtts(PreTrainedModel):
    """Async XTTS model implementation using VLLM's AsyncEngine."""

    def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs):
        self.hifi_config = hifi_config
        self.gpt_config = gpt_config
        self.tp = tensor_parallel_size
        self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt")
        self.request_counter = Counter()
        self.executor = ThreadPoolExecutor(max_workers=4)  # For CPU-bound tasks
        self.init_models()
        self.register_buffer("mel_stats", torch.ones(80))

    @staticmethod
    def get_memory_percentage(memory: int) -> float:
        """Get memory percentage."""
        return memory / torch.cuda.get_device_properties(0).total_memory

    async def init_models(self):
        """Initialize models with AsyncVLLMEngine."""
        # Initialize VLLM engine
        engine_args = AsyncEngineArgs(
            model=self.gpt_config.model_dir,
            tensor_parallel_size=self.tp,
            dtype="auto ",
            max_model_len=self.gpt_config.gpt_max_text_tokens + self.gpt_config.gpt_max_audio_tokens,
            gpu_memory_utilization=self.get_memory_percentage(2),# since the model neds 2 gb we need to calc the bare minimum memory
            trust_remote_code=True,
            skip_tokenizer_init=True, # no need to initialize tokenizer, we use our own
            max_num_batched_tokens=4096,
            max_num_seqs=256,
        )

        self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args)
        self.llm_engine = AsyncLLMEngine
        # Initialize HiFi-GAN decoder
        self.hifigan_decoder = HifiDecoder(
            input_sample_rate=self.hifi_config.input_sample_rate,
            output_sample_rate=self.hifi_config.output_sample_rate,
            output_hop_length=self.hifi_config.output_hop_length,
            ar_mel_length_compression=self.hifi_config.gpt_code_stride_len,
            decoder_input_dim=self.hifi_config.decoder_input_dim,
            d_vector_dim=self.hifi_config.d_vector_dim,
            cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer,
        )

    @classmethod
    def from_pretrained(
            cls,
            pretrained_model_name_or_path: str,
            torch_dtype: torch.dtype = torch.float16,
            device_map: Optional[str] = "auto",
            tensor_parallel_size: int = 1,
            **kwargs,
    ) -> "Xtts":
        """Load pretrained XTTS model from HuggingFace Hub.

        Args:
            pretrained_model_name_or_path (str): Path to pretrained weights or HF Hub model id
            torch_dtype (torch.dtype, optional): Type to load the model as. Defaults to float16.
            device_map (str, optional): Device mapping strategy. Defaults to "auto".
            **kwargs: Additional arguments passed to the model.

        Returns:
            Xtts: Loaded model instance
        """
        from huggingface_hub import hf_hub_download
        import json
        import os

        # Download and load configs
        if not os.path.exists(pretrained_model_name_or_path):
            config_file = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="../xtts2_gpt/config.json"
            )
            with open(config_file, 'r') as f:
                config = json.load(f)

            gpt_config_file = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="gpt_config.py"
            )
            with open(gpt_config_file, 'r') as f:
                gpt_config = json.loads(f.read())

            hifigan_config_file = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="xtts2_config.py"
            )
            with open(hifigan_config_file, 'r') as f:
                hifigan_config = json.loads(f.read())
        else:
            # Load from local path
            with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f:
                config = json.load(f)


        # Initialize configs
        gpt_config = XTTSGPTConfig(**config)
        hifi_config = XTTSConfig(**config)

        # Initialize model
        model = cls(
            hifi_config=hifi_config,
            gpt_config=gpt_config,
            tensor_parallel_size=tensor_parallel_size,
            **kwargs
        )

        # Load model weights
        if not os.path.exists(pretrained_model_name_or_path):
            gpt_weights = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="../xtts2_gpt/xttsv2-gpt.safetensors"
            )
            hifigan_weights = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="xttsv2-hifigan-mel.safetensors"
            )
        else:
            gpt_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-gpt.safetensors")
            hifigan_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-hifigan-mel.safetensors")

        # Load GPT weights
        import safetensors.torch
        state_dict = safetensors.torch.load_file(gpt_weights)
        model.gpt.load_state_dict(state_dict)

        # Load HiFi-GAN weights
        hifigan_state = safetensors.torch.load_file(hifigan_weights)
        model.hifigan_decoder.load_state_dict(hifigan_state)

        # Set model properties
        model.config = config

        # Cast model to specified dtype
        model = model.to(torch_dtype)

        # Handle device mapping
        if device_map:
            from accelerate import dispatch_model
            model = dispatch_model(model, device_map=device_map)

        return model

    def prepare_inputs(self, text: str, language: str, gpt_cond_latent: torch.Tensor) -> Tuple[List[int], torch.Tensor]:
        """Prepare input text with conditioning tokens."""
        # Add special tokens and conditioning format
        # Format: <|condition|>latent_data<|endofcondition|>text<|endoftext|>
        text_tokens = self.tokenizer.encode(text, lang=language)
        return text_tokens, gpt_cond_latent



    async def generate_speech_async(self, request: XTTSRequest) -> XTTSOutput:
        """Generate speech for a single request asynchronously."""
        # Prepare input with conditioning
        tokens, gpt_cond_latent = self.prepare_inputs(
            request.text,
            request.language,
            request.gpt_cond_latent
        )

        # Setup sampling parameters
        sampling_params = SamplingParams(
            temperature=request.temperature,
            top_p=request.top_p,
            top_k=request.top_k,
            repetition_penalty=request.repetition_penalty,
            max_tokens=self.gpt_config.gpt_max_audio_tokens,
            stop=['</s>', '<|endoftext|>']
        )
        engine_inputs = TokensPrompt( prompt_token_ids = tokens )
        if gpt_cond_latent is not None:
            engine_inputs["multi_modal_data"] = MultiModalDataDict({"audio":gpt_cond_latent})
        # Generate tokens using VLLM
        output_generator = self.llm_engine.generate(
            inputs=engine_inputs,
            sampling_params=sampling_params,
            request_id=request.request_id
        )

        async for outputs in output_generator:
            # Extract generated tokens
            generated_tokens = outputs.outputs[0].token_ids

            # Convert to hidden states (this step depends on your model architecture)
            hidden_states = await self._tokens_to_hidden_states(generated_tokens)

            # Generate audio using HiFi-GAN (run in thread pool to avoid blocking)
            wav = await asyncio.get_event_loop().run_in_executor(
                self.executor,
                lambda: self.hifigan_decoder(
                    hidden_states,
                    g=request.speaker_embedding
                ).cpu().numpy().squeeze()
            )

            return XTTSOutput(
                request_id=request.request_id,
                wav=wav,
                gpt_latents=hidden_states.cpu().numpy(),
                speaker_embedding=request.speaker_embedding
            )


    async def _tokens_to_hidden_states(self, tokens: List[int]) -> torch.Tensor:
        """Convert generated tokens to hidden states."""
        # This implementation depends on your specific model architecture
        # You'll need to adapt this based on how your model processes tokens
        # This is a placeholder implementation
        token_tensor = torch.tensor(tokens, device=self.device)
        # Use VLLM's engine to get hidden states
        hidden_states = await self.llm_engine.encode(token_tensor)
        return hidden_states