samarth2002's picture
files added
5fc76ef
raw
history blame
6.4 kB
import gc
import queue
from typing import Generator
import numpy as np
import torch
from loguru import logger
from fish_speech.inference_engine.reference_loader import ReferenceLoader
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
from fish_speech.inference_engine.vq_manager import VQManager
from fish_speech.models.text2semantic.inference import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
)
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps, set_seed
from fish_speech.utils.schema import ServeTTSRequest
class TTSInferenceEngine(ReferenceLoader, VQManager):
def __init__(
self,
llama_queue: queue.Queue,
decoder_model: FireflyArchitecture,
precision: torch.dtype,
compile: bool,
) -> None:
super().__init__()
self.llama_queue = llama_queue
self.decoder_model = decoder_model
self.precision = precision
self.compile = compile
@torch.inference_mode()
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
"""
Main inference function:
- Loads the reference audio and text.
- Calls the LLAMA model for inference.
- Decodes the VQ tokens to audio.
"""
ref_id: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
# Load the reference audio and text based on id or hash
if ref_id is not None:
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
elif req.references:
prompt_tokens, prompt_texts = self.load_by_hash(
req.references, req.use_memory_cache
)
# Set the random seed if provided
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# Get the symbolic tokens from the LLAMA model
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
# Get the sample rate from the decoder model
sample_rate = self.decoder_model.spec_transform.sample_rate
# If streaming, send the header
if req.streaming:
yield InferenceResult(
code="header",
audio=(
sample_rate,
np.array(wav_chunk_header(sample_rate=sample_rate)),
),
error=None,
)
segments = []
while True:
# Get the response from the LLAMA model
wrapped_result: WrappedGenerateResponse = response_queue.get()
if wrapped_result.status == "error":
yield InferenceResult(
code="error",
audio=None,
error=(
wrapped_result.response
if isinstance(wrapped_result.response, Exception)
else Exception("Unknown error")
),
)
break
# Check the response type
if not isinstance(wrapped_result.response, GenerateResponse):
raise TypeError(
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
)
result: GenerateResponse = wrapped_result.response
if result.action != "next":
segment = self.get_audio_segment(result)
if req.streaming: # Used only by the API server
yield InferenceResult(
code="segment",
audio=(sample_rate, segment),
error=None,
)
segments.append(segment)
else:
break
# Clean up the memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Edge case: no audio generated
if len(segments) == 0:
yield InferenceResult(
code="error",
audio=None,
error=RuntimeError("No audio generated, please check the input text."),
)
else:
# Streaming or not, return the final audio
audio = np.concatenate(segments, axis=0)
yield InferenceResult(
code="final",
audio=(sample_rate, audio),
error=None,
)
return None
def send_Llama_request(
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
) -> queue.Queue:
"""
Send a request to the LLAMA model to generate the symbolic tokens.
"""
# Prepare the request
request = dict(
device=self.decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=self.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
# Create a queue to get the response
response_queue = queue.Queue()
# Send the request to the LLAMA model
self.llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
return response_queue
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
"""
Decode the VQ tokens to audio.
"""
# Don't use autocast on MPS devices
with autocast_exclude_mps(
device_type=self.decoder_model.device.type, dtype=self.precision
):
# Decode the symbolic tokens to audio
segment = self.decode_vq_tokens(codes=result.codes)
# Convert the audio to numpy
return segment.float().cpu().numpy()