import asyncio from dataclasses import dataclass from typing import Optional, List, Tuple from concurrent.futures import ThreadPoolExecutor import torch import numpy as np from transformers import PreTrainedModel from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt from vllm.multimodal import MultiModalDataDict from vllm.utils import Counter from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from gpt_config import XTTSGPTConfig from xtts2_config import XTTSConfig from tokenizer import XTTSTokenizerFast @dataclass class XTTSRequest: """Container for XTTS inference request data""" request_id: str text: str language: str gpt_cond_latent: torch.Tensor speaker_embedding: torch.Tensor temperature: float = 0.75 top_p: float = 0.85 top_k: int = 50 repetition_penalty: float = 10.0 length_penalty: float = 1.0 do_sample: bool = True @dataclass class XTTSOutput: """Container for XTTS inference output""" request_id: str wav: np.ndarray gpt_latents: np.ndarray speaker_embedding: torch.Tensor class Xtts(PreTrainedModel): """Async XTTS model implementation using VLLM's AsyncEngine.""" def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs): self.hifi_config = hifi_config self.gpt_config = gpt_config self.tp = tensor_parallel_size self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt") self.request_counter = Counter() self.executor = ThreadPoolExecutor(max_workers=4) # For CPU-bound tasks self.init_models() self.register_buffer("mel_stats", torch.ones(80)) @staticmethod def get_memory_percentage(memory: int) -> float: """Get memory percentage.""" return memory / torch.cuda.get_device_properties(0).total_memory async def init_models(self): """Initialize models with AsyncVLLMEngine.""" # Initialize VLLM engine engine_args = AsyncEngineArgs( model=self.gpt_config.model_dir, tensor_parallel_size=self.tp, dtype="auto ", max_model_len=self.gpt_config.gpt_max_text_tokens + self.gpt_config.gpt_max_audio_tokens, gpu_memory_utilization=self.get_memory_percentage(2),# since the model neds 2 gb we need to calc the bare minimum memory trust_remote_code=True, skip_tokenizer_init=True, # no need to initialize tokenizer, we use our own max_num_batched_tokens=4096, max_num_seqs=256, ) self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args) self.llm_engine = AsyncLLMEngine # Initialize HiFi-GAN decoder self.hifigan_decoder = HifiDecoder( input_sample_rate=self.hifi_config.input_sample_rate, output_sample_rate=self.hifi_config.output_sample_rate, output_hop_length=self.hifi_config.output_hop_length, ar_mel_length_compression=self.hifi_config.gpt_code_stride_len, decoder_input_dim=self.hifi_config.decoder_input_dim, d_vector_dim=self.hifi_config.d_vector_dim, cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer, ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, torch_dtype: torch.dtype = torch.float16, device_map: Optional[str] = "auto", tensor_parallel_size: int = 1, **kwargs, ) -> "Xtts": """Load pretrained XTTS model from HuggingFace Hub. Args: pretrained_model_name_or_path (str): Path to pretrained weights or HF Hub model id torch_dtype (torch.dtype, optional): Type to load the model as. Defaults to float16. device_map (str, optional): Device mapping strategy. Defaults to "auto". **kwargs: Additional arguments passed to the model. Returns: Xtts: Loaded model instance """ from huggingface_hub import hf_hub_download import json import os # Download and load configs if not os.path.exists(pretrained_model_name_or_path): config_file = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="../xtts2_gpt/config.json" ) with open(config_file, 'r') as f: config = json.load(f) gpt_config_file = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="gpt_config.py" ) with open(gpt_config_file, 'r') as f: gpt_config = json.loads(f.read()) hifigan_config_file = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="xtts2_config.py" ) with open(hifigan_config_file, 'r') as f: hifigan_config = json.loads(f.read()) else: # Load from local path with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f: config = json.load(f) # Initialize configs gpt_config = XTTSGPTConfig(**config) hifi_config = XTTSConfig(**config) # Initialize model model = cls( hifi_config=hifi_config, gpt_config=gpt_config, tensor_parallel_size=tensor_parallel_size, **kwargs ) # Load model weights if not os.path.exists(pretrained_model_name_or_path): gpt_weights = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="../xtts2_gpt/xttsv2-gpt.safetensors" ) hifigan_weights = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="xttsv2-hifigan-mel.safetensors" ) else: gpt_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-gpt.safetensors") hifigan_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-hifigan-mel.safetensors") # Load GPT weights import safetensors.torch state_dict = safetensors.torch.load_file(gpt_weights) model.gpt.load_state_dict(state_dict) # Load HiFi-GAN weights hifigan_state = safetensors.torch.load_file(hifigan_weights) model.hifigan_decoder.load_state_dict(hifigan_state) # Set model properties model.config = config # Cast model to specified dtype model = model.to(torch_dtype) # Handle device mapping if device_map: from accelerate import dispatch_model model = dispatch_model(model, device_map=device_map) return model def prepare_inputs(self, text: str, language: str, gpt_cond_latent: torch.Tensor) -> Tuple[List[int], torch.Tensor]: """Prepare input text with conditioning tokens.""" # Add special tokens and conditioning format # Format: <|condition|>latent_data<|endofcondition|>text<|endoftext|> text_tokens = self.tokenizer.encode(text, lang=language) return text_tokens, gpt_cond_latent async def generate_speech_async(self, request: XTTSRequest) -> XTTSOutput: """Generate speech for a single request asynchronously.""" # Prepare input with conditioning tokens, gpt_cond_latent = self.prepare_inputs( request.text, request.language, request.gpt_cond_latent ) # Setup sampling parameters sampling_params = SamplingParams( temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, max_tokens=self.gpt_config.gpt_max_audio_tokens, stop=['', '<|endoftext|>'] ) engine_inputs = TokensPrompt( prompt_token_ids = tokens ) if gpt_cond_latent is not None: engine_inputs["multi_modal_data"] = MultiModalDataDict({"audio":gpt_cond_latent}) # Generate tokens using VLLM output_generator = self.llm_engine.generate( inputs=engine_inputs, sampling_params=sampling_params, request_id=request.request_id ) async for outputs in output_generator: # Extract generated tokens generated_tokens = outputs.outputs[0].token_ids # Convert to hidden states (this step depends on your model architecture) hidden_states = await self._tokens_to_hidden_states(generated_tokens) # Generate audio using HiFi-GAN (run in thread pool to avoid blocking) wav = await asyncio.get_event_loop().run_in_executor( self.executor, lambda: self.hifigan_decoder( hidden_states, g=request.speaker_embedding ).cpu().numpy().squeeze() ) return XTTSOutput( request_id=request.request_id, wav=wav, gpt_latents=hidden_states.cpu().numpy(), speaker_embedding=request.speaker_embedding ) async def _tokens_to_hidden_states(self, tokens: List[int]) -> torch.Tensor: """Convert generated tokens to hidden states.""" # This implementation depends on your specific model architecture # You'll need to adapt this based on how your model processes tokens # This is a placeholder implementation token_tensor = torch.tensor(tokens, device=self.device) # Use VLLM's engine to get hidden states hidden_states = await self.llm_engine.encode(token_tensor) return hidden_states