|
import asyncio |
|
from dataclasses import dataclass |
|
from typing import Optional, List, Tuple |
|
from concurrent.futures import ThreadPoolExecutor |
|
import torch |
|
import numpy as np |
|
from transformers import PreTrainedModel |
|
|
|
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt |
|
from vllm.multimodal import MultiModalDataDict |
|
from vllm.utils import Counter |
|
|
|
from TTS.TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder |
|
from gpt_config import XTTSGPTConfig |
|
from xtts2_config import XTTSConfig |
|
from tokenizer import XTTSTokenizerFast |
|
|
|
|
|
@dataclass |
|
class XTTSRequest: |
|
"""Container for XTTS inference request data""" |
|
request_id: str |
|
text: str |
|
language: str |
|
gpt_cond_latent: torch.Tensor |
|
speaker_embedding: torch.Tensor |
|
temperature: float = 0.75 |
|
top_p: float = 0.85 |
|
top_k: int = 50 |
|
repetition_penalty: float = 10.0 |
|
length_penalty: float = 1.0 |
|
do_sample: bool = True |
|
|
|
|
|
@dataclass |
|
class XTTSOutput: |
|
"""Container for XTTS inference output""" |
|
request_id: str |
|
wav: np.ndarray |
|
gpt_latents: np.ndarray |
|
speaker_embedding: torch.Tensor |
|
|
|
|
|
class Xtts(PreTrainedModel): |
|
"""Async XTTS model implementation using VLLM's AsyncEngine.""" |
|
|
|
def __init__(self, hifi_config: XTTSConfig, gpt_config: XTTSGPTConfig, tensor_parallel_size: int = 1, **kwargs): |
|
self.hifi_config = hifi_config |
|
self.gpt_config = gpt_config |
|
self.tp = tensor_parallel_size |
|
self.tokenizer = XTTSTokenizerFast.from_pretrained("AstraMindAI/xtts2-gpt") |
|
self.request_counter = Counter() |
|
self.executor = ThreadPoolExecutor(max_workers=4) |
|
self.init_models() |
|
self.register_buffer("mel_stats", torch.ones(80)) |
|
|
|
@staticmethod |
|
def get_memory_percentage(memory: int) -> float: |
|
"""Get memory percentage.""" |
|
return memory / torch.cuda.get_device_properties(0).total_memory |
|
|
|
async def init_models(self): |
|
"""Initialize models with AsyncVLLMEngine.""" |
|
|
|
engine_args = AsyncEngineArgs( |
|
model=self.gpt_config.model_dir, |
|
tensor_parallel_size=self.tp, |
|
dtype="auto ", |
|
max_model_len=self.gpt_config.gpt_max_text_tokens + self.gpt_config.gpt_max_audio_tokens, |
|
gpu_memory_utilization=self.get_memory_percentage(2), |
|
trust_remote_code=True, |
|
skip_tokenizer_init=True, |
|
max_num_batched_tokens=4096, |
|
max_num_seqs=256, |
|
) |
|
|
|
self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args) |
|
self.llm_engine = AsyncLLMEngine |
|
|
|
self.hifigan_decoder = HifiDecoder( |
|
input_sample_rate=self.hifi_config.input_sample_rate, |
|
output_sample_rate=self.hifi_config.output_sample_rate, |
|
output_hop_length=self.hifi_config.output_hop_length, |
|
ar_mel_length_compression=self.hifi_config.gpt_code_stride_len, |
|
decoder_input_dim=self.hifi_config.decoder_input_dim, |
|
d_vector_dim=self.hifi_config.d_vector_dim, |
|
cond_d_vector_in_each_upsampling_layer=self.hifi_config.cond_d_vector_in_each_upsampling_layer, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: str, |
|
torch_dtype: torch.dtype = torch.float16, |
|
device_map: Optional[str] = "auto", |
|
tensor_parallel_size: int = 1, |
|
**kwargs, |
|
) -> "Xtts": |
|
"""Load pretrained XTTS model from HuggingFace Hub. |
|
|
|
Args: |
|
pretrained_model_name_or_path (str): Path to pretrained weights or HF Hub model id |
|
torch_dtype (torch.dtype, optional): Type to load the model as. Defaults to float16. |
|
device_map (str, optional): Device mapping strategy. Defaults to "auto". |
|
**kwargs: Additional arguments passed to the model. |
|
|
|
Returns: |
|
Xtts: Loaded model instance |
|
""" |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
import os |
|
|
|
|
|
if not os.path.exists(pretrained_model_name_or_path): |
|
config_file = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="../xtts2_gpt/config.json" |
|
) |
|
with open(config_file, 'r') as f: |
|
config = json.load(f) |
|
|
|
gpt_config_file = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="gpt_config.py" |
|
) |
|
with open(gpt_config_file, 'r') as f: |
|
gpt_config = json.loads(f.read()) |
|
|
|
hifigan_config_file = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="xtts2_config.py" |
|
) |
|
with open(hifigan_config_file, 'r') as f: |
|
hifigan_config = json.loads(f.read()) |
|
else: |
|
|
|
with open(os.path.join(pretrained_model_name_or_path, "config.json"), 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
|
|
gpt_config = XTTSGPTConfig(**config) |
|
hifi_config = XTTSConfig(**config) |
|
|
|
|
|
model = cls( |
|
hifi_config=hifi_config, |
|
gpt_config=gpt_config, |
|
tensor_parallel_size=tensor_parallel_size, |
|
**kwargs |
|
) |
|
|
|
|
|
if not os.path.exists(pretrained_model_name_or_path): |
|
gpt_weights = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="../xtts2_gpt/xttsv2-gpt.safetensors" |
|
) |
|
hifigan_weights = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename="xttsv2-hifigan-mel.safetensors" |
|
) |
|
else: |
|
gpt_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-gpt.safetensors") |
|
hifigan_weights = os.path.join(pretrained_model_name_or_path, "xttsv2-hifigan-mel.safetensors") |
|
|
|
|
|
import safetensors.torch |
|
state_dict = safetensors.torch.load_file(gpt_weights) |
|
model.gpt.load_state_dict(state_dict) |
|
|
|
|
|
hifigan_state = safetensors.torch.load_file(hifigan_weights) |
|
model.hifigan_decoder.load_state_dict(hifigan_state) |
|
|
|
|
|
model.config = config |
|
|
|
|
|
model = model.to(torch_dtype) |
|
|
|
|
|
if device_map: |
|
from accelerate import dispatch_model |
|
model = dispatch_model(model, device_map=device_map) |
|
|
|
return model |
|
|
|
def prepare_inputs(self, text: str, language: str, gpt_cond_latent: torch.Tensor) -> Tuple[List[int], torch.Tensor]: |
|
"""Prepare input text with conditioning tokens.""" |
|
|
|
|
|
text_tokens = self.tokenizer.encode(text, lang=language) |
|
return text_tokens, gpt_cond_latent |
|
|
|
|
|
|
|
async def generate_speech_async(self, request: XTTSRequest) -> XTTSOutput: |
|
"""Generate speech for a single request asynchronously.""" |
|
|
|
tokens, gpt_cond_latent = self.prepare_inputs( |
|
request.text, |
|
request.language, |
|
request.gpt_cond_latent |
|
) |
|
|
|
|
|
sampling_params = SamplingParams( |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repetition_penalty, |
|
max_tokens=self.gpt_config.gpt_max_audio_tokens, |
|
stop=['</s>', '<|endoftext|>'] |
|
) |
|
engine_inputs = TokensPrompt( prompt_token_ids = tokens ) |
|
if gpt_cond_latent is not None: |
|
engine_inputs["multi_modal_data"] = MultiModalDataDict({"audio":gpt_cond_latent}) |
|
|
|
output_generator = self.llm_engine.generate( |
|
inputs=engine_inputs, |
|
sampling_params=sampling_params, |
|
request_id=request.request_id |
|
) |
|
|
|
async for outputs in output_generator: |
|
|
|
generated_tokens = outputs.outputs[0].token_ids |
|
|
|
|
|
hidden_states = await self._tokens_to_hidden_states(generated_tokens) |
|
|
|
|
|
wav = await asyncio.get_event_loop().run_in_executor( |
|
self.executor, |
|
lambda: self.hifigan_decoder( |
|
hidden_states, |
|
g=request.speaker_embedding |
|
).cpu().numpy().squeeze() |
|
) |
|
|
|
return XTTSOutput( |
|
request_id=request.request_id, |
|
wav=wav, |
|
gpt_latents=hidden_states.cpu().numpy(), |
|
speaker_embedding=request.speaker_embedding |
|
) |
|
|
|
|
|
async def _tokens_to_hidden_states(self, tokens: List[int]) -> torch.Tensor: |
|
"""Convert generated tokens to hidden states.""" |
|
|
|
|
|
|
|
token_tensor = torch.tensor(tokens, device=self.device) |
|
|
|
hidden_states = await self.llm_engine.encode(token_tensor) |
|
return hidden_states |
|
|