|
import gc |
|
import queue |
|
from typing import Generator |
|
|
|
import numpy as np |
|
import torch |
|
from loguru import logger |
|
|
|
from fish_speech.inference_engine.reference_loader import ReferenceLoader |
|
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header |
|
from fish_speech.inference_engine.vq_manager import VQManager |
|
from fish_speech.models.text2semantic.inference import ( |
|
GenerateRequest, |
|
GenerateResponse, |
|
WrappedGenerateResponse, |
|
) |
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture |
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText |
|
from fish_speech.utils import autocast_exclude_mps, set_seed |
|
from fish_speech.utils.schema import ServeTTSRequest |
|
|
|
|
|
class TTSInferenceEngine(ReferenceLoader, VQManager): |
|
|
|
def __init__( |
|
self, |
|
llama_queue: queue.Queue, |
|
decoder_model: FireflyArchitecture, |
|
precision: torch.dtype, |
|
compile: bool, |
|
) -> None: |
|
|
|
super().__init__() |
|
|
|
self.llama_queue = llama_queue |
|
self.decoder_model = decoder_model |
|
self.precision = precision |
|
self.compile = compile |
|
|
|
@torch.inference_mode() |
|
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]: |
|
""" |
|
Main inference function: |
|
- Loads the reference audio and text. |
|
- Calls the LLAMA model for inference. |
|
- Decodes the VQ tokens to audio. |
|
""" |
|
|
|
ref_id: str | None = req.reference_id |
|
prompt_tokens, prompt_texts = [], [] |
|
|
|
if ref_id is not None: |
|
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache) |
|
|
|
elif req.references: |
|
prompt_tokens, prompt_texts = self.load_by_hash( |
|
req.references, req.use_memory_cache |
|
) |
|
|
|
|
|
if req.seed is not None: |
|
set_seed(req.seed) |
|
logger.warning(f"set seed: {req.seed}") |
|
|
|
|
|
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts) |
|
|
|
|
|
sample_rate = self.decoder_model.spec_transform.sample_rate |
|
|
|
|
|
if req.streaming: |
|
yield InferenceResult( |
|
code="header", |
|
audio=( |
|
sample_rate, |
|
np.array(wav_chunk_header(sample_rate=sample_rate)), |
|
), |
|
error=None, |
|
) |
|
|
|
segments = [] |
|
|
|
while True: |
|
|
|
wrapped_result: WrappedGenerateResponse = response_queue.get() |
|
if wrapped_result.status == "error": |
|
yield InferenceResult( |
|
code="error", |
|
audio=None, |
|
error=( |
|
wrapped_result.response |
|
if isinstance(wrapped_result.response, Exception) |
|
else Exception("Unknown error") |
|
), |
|
) |
|
break |
|
|
|
|
|
if not isinstance(wrapped_result.response, GenerateResponse): |
|
raise TypeError( |
|
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}" |
|
) |
|
|
|
result: GenerateResponse = wrapped_result.response |
|
if result.action != "next": |
|
segment = self.get_audio_segment(result) |
|
|
|
if req.streaming: |
|
yield InferenceResult( |
|
code="segment", |
|
audio=(sample_rate, segment), |
|
error=None, |
|
) |
|
segments.append(segment) |
|
else: |
|
break |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
if len(segments) == 0: |
|
yield InferenceResult( |
|
code="error", |
|
audio=None, |
|
error=RuntimeError("No audio generated, please check the input text."), |
|
) |
|
else: |
|
|
|
audio = np.concatenate(segments, axis=0) |
|
yield InferenceResult( |
|
code="final", |
|
audio=(sample_rate, audio), |
|
error=None, |
|
) |
|
|
|
return None |
|
|
|
def send_Llama_request( |
|
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list |
|
) -> queue.Queue: |
|
""" |
|
Send a request to the LLAMA model to generate the symbolic tokens. |
|
""" |
|
|
|
|
|
request = dict( |
|
device=self.decoder_model.device, |
|
max_new_tokens=req.max_new_tokens, |
|
text=( |
|
req.text |
|
if not req.normalize |
|
else ChnNormedText(raw_text=req.text).normalize() |
|
), |
|
top_p=req.top_p, |
|
repetition_penalty=req.repetition_penalty, |
|
temperature=req.temperature, |
|
compile=self.compile, |
|
iterative_prompt=req.chunk_length > 0, |
|
chunk_length=req.chunk_length, |
|
max_length=4096, |
|
prompt_tokens=prompt_tokens, |
|
prompt_text=prompt_texts, |
|
) |
|
|
|
|
|
response_queue = queue.Queue() |
|
|
|
|
|
self.llama_queue.put( |
|
GenerateRequest( |
|
request=request, |
|
response_queue=response_queue, |
|
) |
|
) |
|
|
|
return response_queue |
|
|
|
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray: |
|
""" |
|
Decode the VQ tokens to audio. |
|
""" |
|
|
|
|
|
with autocast_exclude_mps( |
|
device_type=self.decoder_model.device.type, dtype=self.precision |
|
): |
|
|
|
segment = self.decode_vq_tokens(codes=result.codes) |
|
|
|
|
|
return segment.float().cpu().numpy() |
|
|