File size: 9,929 Bytes
8b6d69d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
import asyncio
from dataclasses import dataclass
from typing import Optional, List, Tuple
from concurrent.futures import ThreadPoolExecutor
import torch
import numpy as np
from transformers import PreTrainedModel
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.utils import Counter
from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from gpt_config import XTTSGPTConfig
from xtts2_config import XTTSConfig
from tokenizer import XTTSTokenizerFast
@dataclass
class XTTSRequest:
"""Container for XTTS inference request data"""
request_id: str
text: str
language: str
gpt_cond_latent: torch.Tensor
speaker_embedding: torch.Tensor
temperature: float = 0.75
top_p: float = 0.85
top_k: int = 50
repetition_penalty: float = 10.0
length_penalty: float = 1.0
do_sample: bool = True
@dataclass
class XTTSOutput:
"""Container for XTTS inference output"""
request_id: str
wav: np.ndarray
gpt_latents: np.ndarray
speaker_embedding: torch.Tensor
class Xtts(PreTrainedModel):
"""Async XTTS model implementation using VLLM's AsyncEngine."""
def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs):
self.hifi_config = hifi_config
self.gpt_config = gpt_config
self.tp = tensor_parallel_size
self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt")
self.request_counter = Counter()
self.executor = ThreadPoolExecutor(max_workers=4) # For CPU-bound tasks
self.init_models()
self.register_buffer("mel_stats", torch.ones(80))
@staticmethod
def get_memory_percentage(memory: int) -> float:
"""Get memory percentage."""
return memory / torch.cuda.get_device_properties(0).total_memory
async def init_models(self):
"""Initialize models with AsyncVLLMEngine."""
# Initialize VLLM engine
engine_args = AsyncEngineArgs(
model=self.gpt_config.model_dir,
tensor_parallel_size=self.tp,
dtype="auto ",
max_model_len=self.gpt_config.gpt_max_text_tokens + self.gpt_config.gpt_max_audio_tokens,
gpu_memory_utilization=self.get_memory_percentage(2),# since the model neds 2 gb we need to calc the bare minimum memory
trust_remote_code=True,
skip_tokenizer_init=True, # no need to initialize tokenizer, we use our own
max_num_batched_tokens=4096,
max_num_seqs=256,
)
self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args)
self.llm_engine = AsyncLLMEngine
# Initialize HiFi-GAN decoder
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.hifi_config.input_sample_rate,
output_sample_rate=self.hifi_config.output_sample_rate,
output_hop_length=self.hifi_config.output_hop_length,
ar_mel_length_compression=self.hifi_config.gpt_code_stride_len,
decoder_input_dim=self.hifi_config.decoder_input_dim,
d_vector_dim=self.hifi_config.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = torch.float16,
device_map: Optional[str] = "auto",
tensor_parallel_size: int = 1,
**kwargs,
) -> "Xtts":
"""Load pretrained XTTS model from HuggingFace Hub.
Args:
pretrained_model_name_or_path (str): Path to pretrained weights or HF Hub model id
torch_dtype (torch.dtype, optional): Type to load the model as. Defaults to float16.
device_map (str, optional): Device mapping strategy. Defaults to "auto".
**kwargs: Additional arguments passed to the model.
Returns:
Xtts: Loaded model instance
"""
from huggingface_hub import hf_hub_download
import json
import os
# Download and load configs
if not os.path.exists(pretrained_model_name_or_path):
config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="../xtts2_gpt/config.json"
)
with open(config_file, 'r') as f:
config = json.load(f)
gpt_config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="gpt_config.py"
)
with open(gpt_config_file, 'r') as f:
gpt_config = json.loads(f.read())
hifigan_config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="xtts2_config.py"
)
with open(hifigan_config_file, 'r') as f:
hifigan_config = json.loads(f.read())
else:
# Load from local path
with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f:
config = json.load(f)
# Initialize configs
gpt_config = XTTSGPTConfig(**config)
hifi_config = XTTSConfig(**config)
# Initialize model
model = cls(
hifi_config=hifi_config,
gpt_config=gpt_config,
tensor_parallel_size=tensor_parallel_size,
**kwargs
)
# Load model weights
if not os.path.exists(pretrained_model_name_or_path):
gpt_weights = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="../xtts2_gpt/xttsv2-gpt.safetensors"
)
hifigan_weights = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="xttsv2-hifigan-mel.safetensors"
)
else:
gpt_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-gpt.safetensors")
hifigan_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-hifigan-mel.safetensors")
# Load GPT weights
import safetensors.torch
state_dict = safetensors.torch.load_file(gpt_weights)
model.gpt.load_state_dict(state_dict)
# Load HiFi-GAN weights
hifigan_state = safetensors.torch.load_file(hifigan_weights)
model.hifigan_decoder.load_state_dict(hifigan_state)
# Set model properties
model.config = config
# Cast model to specified dtype
model = model.to(torch_dtype)
# Handle device mapping
if device_map:
from accelerate import dispatch_model
model = dispatch_model(model, device_map=device_map)
return model
def prepare_inputs(self, text: str, language: str, gpt_cond_latent: torch.Tensor) -> Tuple[List[int], torch.Tensor]:
"""Prepare input text with conditioning tokens."""
# Add special tokens and conditioning format
# Format: <|condition|>latent_data<|endofcondition|>text<|endoftext|>
text_tokens = self.tokenizer.encode(text, lang=language)
return text_tokens, gpt_cond_latent
async def generate_speech_async(self, request: XTTSRequest) -> XTTSOutput:
"""Generate speech for a single request asynchronously."""
# Prepare input with conditioning
tokens, gpt_cond_latent = self.prepare_inputs(
request.text,
request.language,
request.gpt_cond_latent
)
# Setup sampling parameters
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
max_tokens=self.gpt_config.gpt_max_audio_tokens,
stop=['</s>', '<|endoftext|>']
)
engine_inputs = TokensPrompt( prompt_token_ids = tokens )
if gpt_cond_latent is not None:
engine_inputs["multi_modal_data"] = MultiModalDataDict({"audio":gpt_cond_latent})
# Generate tokens using VLLM
output_generator = self.llm_engine.generate(
inputs=engine_inputs,
sampling_params=sampling_params,
request_id=request.request_id
)
async for outputs in output_generator:
# Extract generated tokens
generated_tokens = outputs.outputs[0].token_ids
# Convert to hidden states (this step depends on your model architecture)
hidden_states = await self._tokens_to_hidden_states(generated_tokens)
# Generate audio using HiFi-GAN (run in thread pool to avoid blocking)
wav = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: self.hifigan_decoder(
hidden_states,
g=request.speaker_embedding
).cpu().numpy().squeeze()
)
return XTTSOutput(
request_id=request.request_id,
wav=wav,
gpt_latents=hidden_states.cpu().numpy(),
speaker_embedding=request.speaker_embedding
)
async def _tokens_to_hidden_states(self, tokens: List[int]) -> torch.Tensor:
"""Convert generated tokens to hidden states."""
# This implementation depends on your specific model architecture
# You'll need to adapt this based on how your model processes tokens
# This is a placeholder implementation
token_tensor = torch.tensor(tokens, device=self.device)
# Use VLLM's engine to get hidden states
hidden_states = await self.llm_engine.encode(token_tensor)
return hidden_states
|