import gc import queue from typing import Generator import numpy as np import torch from loguru import logger from fish_speech.inference_engine.reference_loader import ReferenceLoader from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header from fish_speech.inference_engine.vq_manager import VQManager from fish_speech.models.text2semantic.inference import ( GenerateRequest, GenerateResponse, WrappedGenerateResponse, ) from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture from fish_speech.text.chn_text_norm.text import Text as ChnNormedText from fish_speech.utils import autocast_exclude_mps, set_seed from fish_speech.utils.schema import ServeTTSRequest class TTSInferenceEngine(ReferenceLoader, VQManager): def __init__( self, llama_queue: queue.Queue, decoder_model: FireflyArchitecture, precision: torch.dtype, compile: bool, ) -> None: super().__init__() self.llama_queue = llama_queue self.decoder_model = decoder_model self.precision = precision self.compile = compile @torch.inference_mode() def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]: """ Main inference function: - Loads the reference audio and text. - Calls the LLAMA model for inference. - Decodes the VQ tokens to audio. """ ref_id: str | None = req.reference_id prompt_tokens, prompt_texts = [], [] # Load the reference audio and text based on id or hash if ref_id is not None: prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache) elif req.references: prompt_tokens, prompt_texts = self.load_by_hash( req.references, req.use_memory_cache ) # Set the random seed if provided if req.seed is not None: set_seed(req.seed) logger.warning(f"set seed: {req.seed}") # Get the symbolic tokens from the LLAMA model response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts) # Get the sample rate from the decoder model sample_rate = self.decoder_model.spec_transform.sample_rate # If streaming, send the header if req.streaming: yield InferenceResult( code="header", audio=( sample_rate, np.array(wav_chunk_header(sample_rate=sample_rate)), ), error=None, ) segments = [] while True: # Get the response from the LLAMA model wrapped_result: WrappedGenerateResponse = response_queue.get() if wrapped_result.status == "error": yield InferenceResult( code="error", audio=None, error=( wrapped_result.response if isinstance(wrapped_result.response, Exception) else Exception("Unknown error") ), ) break # Check the response type if not isinstance(wrapped_result.response, GenerateResponse): raise TypeError( "Expected GenerateResponse, got {type(wrapped_result.response).__name__}" ) result: GenerateResponse = wrapped_result.response if result.action != "next": segment = self.get_audio_segment(result) if req.streaming: # Used only by the API server yield InferenceResult( code="segment", audio=(sample_rate, segment), error=None, ) segments.append(segment) else: break # Clean up the memory if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Edge case: no audio generated if len(segments) == 0: yield InferenceResult( code="error", audio=None, error=RuntimeError("No audio generated, please check the input text."), ) else: # Streaming or not, return the final audio audio = np.concatenate(segments, axis=0) yield InferenceResult( code="final", audio=(sample_rate, audio), error=None, ) return None def send_Llama_request( self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list ) -> queue.Queue: """ Send a request to the LLAMA model to generate the symbolic tokens. """ # Prepare the request request = dict( device=self.decoder_model.device, max_new_tokens=req.max_new_tokens, text=( req.text if not req.normalize else ChnNormedText(raw_text=req.text).normalize() ), top_p=req.top_p, repetition_penalty=req.repetition_penalty, temperature=req.temperature, compile=self.compile, iterative_prompt=req.chunk_length > 0, chunk_length=req.chunk_length, max_length=4096, prompt_tokens=prompt_tokens, prompt_text=prompt_texts, ) # Create a queue to get the response response_queue = queue.Queue() # Send the request to the LLAMA model self.llama_queue.put( GenerateRequest( request=request, response_queue=response_queue, ) ) return response_queue def get_audio_segment(self, result: GenerateResponse) -> np.ndarray: """ Decode the VQ tokens to audio. """ # Don't use autocast on MPS devices with autocast_exclude_mps( device_type=self.decoder_model.device.type, dtype=self.precision ): # Decode the symbolic tokens to audio segment = self.decode_vq_tokens(codes=result.codes) # Convert the audio to numpy return segment.float().cpu().numpy()