xttsv2 / xtts2_modeling.py
mlinmg's picture
Upload 6 files
8b6d69d verified
raw
history blame
9.93 kB
import asyncio
from dataclasses import dataclass
from typing import Optional, List, Tuple
from concurrent.futures import ThreadPoolExecutor
import torch
import numpy as np
from transformers import PreTrainedModel
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.utils import Counter
from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from gpt_config import XTTSGPTConfig
from xtts2_config import XTTSConfig
from tokenizer import XTTSTokenizerFast
@dataclass
class XTTSRequest:
"""Container for XTTS inference request data"""
request_id: str
text: str
language: str
gpt_cond_latent: torch.Tensor
speaker_embedding: torch.Tensor
temperature: float = 0.75
top_p: float = 0.85
top_k: int = 50
repetition_penalty: float = 10.0
length_penalty: float = 1.0
do_sample: bool = True
@dataclass
class XTTSOutput:
"""Container for XTTS inference output"""
request_id: str
wav: np.ndarray
gpt_latents: np.ndarray
speaker_embedding: torch.Tensor
class Xtts(PreTrainedModel):
"""Async XTTS model implementation using VLLM's AsyncEngine."""
def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs):
self.hifi_config = hifi_config
self.gpt_config = gpt_config
self.tp = tensor_parallel_size
self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt")
self.request_counter = Counter()
self.executor = ThreadPoolExecutor(max_workers=4) # For CPU-bound tasks
self.init_models()
self.register_buffer("mel_stats", torch.ones(80))
@staticmethod
def get_memory_percentage(memory: int) -> float:
"""Get memory percentage."""
return memory / torch.cuda.get_device_properties(0).total_memory
async def init_models(self):
"""Initialize models with AsyncVLLMEngine."""
# Initialize VLLM engine
engine_args = AsyncEngineArgs(
model=self.gpt_config.model_dir,
tensor_parallel_size=self.tp,
dtype="auto ",
max_model_len=self.gpt_config.gpt_max_text_tokens + self.gpt_config.gpt_max_audio_tokens,
gpu_memory_utilization=self.get_memory_percentage(2),# since the model neds 2 gb we need to calc the bare minimum memory
trust_remote_code=True,
skip_tokenizer_init=True, # no need to initialize tokenizer, we use our own
max_num_batched_tokens=4096,
max_num_seqs=256,
)
self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args)
self.llm_engine = AsyncLLMEngine
# Initialize HiFi-GAN decoder
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.hifi_config.input_sample_rate,
output_sample_rate=self.hifi_config.output_sample_rate,
output_hop_length=self.hifi_config.output_hop_length,
ar_mel_length_compression=self.hifi_config.gpt_code_stride_len,
decoder_input_dim=self.hifi_config.decoder_input_dim,
d_vector_dim=self.hifi_config.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: torch.dtype = torch.float16,
device_map: Optional[str] = "auto",
tensor_parallel_size: int = 1,
**kwargs,
) -> "Xtts":
"""Load pretrained XTTS model from HuggingFace Hub.
Args:
pretrained_model_name_or_path (str): Path to pretrained weights or HF Hub model id
torch_dtype (torch.dtype, optional): Type to load the model as. Defaults to float16.
device_map (str, optional): Device mapping strategy. Defaults to "auto".
**kwargs: Additional arguments passed to the model.
Returns:
Xtts: Loaded model instance
"""
from huggingface_hub import hf_hub_download
import json
import os
# Download and load configs
if not os.path.exists(pretrained_model_name_or_path):
config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="../xtts2_gpt/config.json"
)
with open(config_file, 'r') as f:
config = json.load(f)
gpt_config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="gpt_config.py"
)
with open(gpt_config_file, 'r') as f:
gpt_config = json.loads(f.read())
hifigan_config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="xtts2_config.py"
)
with open(hifigan_config_file, 'r') as f:
hifigan_config = json.loads(f.read())
else:
# Load from local path
with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f:
config = json.load(f)
# Initialize configs
gpt_config = XTTSGPTConfig(**config)
hifi_config = XTTSConfig(**config)
# Initialize model
model = cls(
hifi_config=hifi_config,
gpt_config=gpt_config,
tensor_parallel_size=tensor_parallel_size,
**kwargs
)
# Load model weights
if not os.path.exists(pretrained_model_name_or_path):
gpt_weights = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="../xtts2_gpt/xttsv2-gpt.safetensors"
)
hifigan_weights = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="xttsv2-hifigan-mel.safetensors"
)
else:
gpt_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-gpt.safetensors")
hifigan_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-hifigan-mel.safetensors")
# Load GPT weights
import safetensors.torch
state_dict = safetensors.torch.load_file(gpt_weights)
model.gpt.load_state_dict(state_dict)
# Load HiFi-GAN weights
hifigan_state = safetensors.torch.load_file(hifigan_weights)
model.hifigan_decoder.load_state_dict(hifigan_state)
# Set model properties
model.config = config
# Cast model to specified dtype
model = model.to(torch_dtype)
# Handle device mapping
if device_map:
from accelerate import dispatch_model
model = dispatch_model(model, device_map=device_map)
return model
def prepare_inputs(self, text: str, language: str, gpt_cond_latent: torch.Tensor) -> Tuple[List[int], torch.Tensor]:
"""Prepare input text with conditioning tokens."""
# Add special tokens and conditioning format
# Format: <|condition|>latent_data<|endofcondition|>text<|endoftext|>
text_tokens = self.tokenizer.encode(text, lang=language)
return text_tokens, gpt_cond_latent
async def generate_speech_async(self, request: XTTSRequest) -> XTTSOutput:
"""Generate speech for a single request asynchronously."""
# Prepare input with conditioning
tokens, gpt_cond_latent = self.prepare_inputs(
request.text,
request.language,
request.gpt_cond_latent
)
# Setup sampling parameters
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
max_tokens=self.gpt_config.gpt_max_audio_tokens,
stop=['</s>', '<|endoftext|>']
)
engine_inputs = TokensPrompt( prompt_token_ids = tokens )
if gpt_cond_latent is not None:
engine_inputs["multi_modal_data"] = MultiModalDataDict({"audio":gpt_cond_latent})
# Generate tokens using VLLM
output_generator = self.llm_engine.generate(
inputs=engine_inputs,
sampling_params=sampling_params,
request_id=request.request_id
)
async for outputs in output_generator:
# Extract generated tokens
generated_tokens = outputs.outputs[0].token_ids
# Convert to hidden states (this step depends on your model architecture)
hidden_states = await self._tokens_to_hidden_states(generated_tokens)
# Generate audio using HiFi-GAN (run in thread pool to avoid blocking)
wav = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: self.hifigan_decoder(
hidden_states,
g=request.speaker_embedding
).cpu().numpy().squeeze()
)
return XTTSOutput(
request_id=request.request_id,
wav=wav,
gpt_latents=hidden_states.cpu().numpy(),
speaker_embedding=request.speaker_embedding
)
async def _tokens_to_hidden_states(self, tokens: List[int]) -> torch.Tensor:
"""Convert generated tokens to hidden states."""
# This implementation depends on your specific model architecture
# You'll need to adapt this based on how your model processes tokens
# This is a placeholder implementation
token_tensor = torch.tensor(tokens, device=self.device)
# Use VLLM's engine to get hidden states
hidden_states = await self.llm_engine.encode(token_tensor)
return hidden_states