File size: 6,400 Bytes
5fc76ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import gc
import queue
from typing import Generator
import numpy as np
import torch
from loguru import logger
from fish_speech.inference_engine.reference_loader import ReferenceLoader
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
from fish_speech.inference_engine.vq_manager import VQManager
from fish_speech.models.text2semantic.inference import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
)
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps, set_seed
from fish_speech.utils.schema import ServeTTSRequest
class TTSInferenceEngine(ReferenceLoader, VQManager):
def __init__(
self,
llama_queue: queue.Queue,
decoder_model: FireflyArchitecture,
precision: torch.dtype,
compile: bool,
) -> None:
super().__init__()
self.llama_queue = llama_queue
self.decoder_model = decoder_model
self.precision = precision
self.compile = compile
@torch.inference_mode()
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
"""
Main inference function:
- Loads the reference audio and text.
- Calls the LLAMA model for inference.
- Decodes the VQ tokens to audio.
"""
ref_id: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
# Load the reference audio and text based on id or hash
if ref_id is not None:
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
elif req.references:
prompt_tokens, prompt_texts = self.load_by_hash(
req.references, req.use_memory_cache
)
# Set the random seed if provided
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# Get the symbolic tokens from the LLAMA model
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
# Get the sample rate from the decoder model
sample_rate = self.decoder_model.spec_transform.sample_rate
# If streaming, send the header
if req.streaming:
yield InferenceResult(
code="header",
audio=(
sample_rate,
np.array(wav_chunk_header(sample_rate=sample_rate)),
),
error=None,
)
segments = []
while True:
# Get the response from the LLAMA model
wrapped_result: WrappedGenerateResponse = response_queue.get()
if wrapped_result.status == "error":
yield InferenceResult(
code="error",
audio=None,
error=(
wrapped_result.response
if isinstance(wrapped_result.response, Exception)
else Exception("Unknown error")
),
)
break
# Check the response type
if not isinstance(wrapped_result.response, GenerateResponse):
raise TypeError(
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
)
result: GenerateResponse = wrapped_result.response
if result.action != "next":
segment = self.get_audio_segment(result)
if req.streaming: # Used only by the API server
yield InferenceResult(
code="segment",
audio=(sample_rate, segment),
error=None,
)
segments.append(segment)
else:
break
# Clean up the memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Edge case: no audio generated
if len(segments) == 0:
yield InferenceResult(
code="error",
audio=None,
error=RuntimeError("No audio generated, please check the input text."),
)
else:
# Streaming or not, return the final audio
audio = np.concatenate(segments, axis=0)
yield InferenceResult(
code="final",
audio=(sample_rate, audio),
error=None,
)
return None
def send_Llama_request(
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
) -> queue.Queue:
"""
Send a request to the LLAMA model to generate the symbolic tokens.
"""
# Prepare the request
request = dict(
device=self.decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=self.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
# Create a queue to get the response
response_queue = queue.Queue()
# Send the request to the LLAMA model
self.llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
return response_queue
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
"""
Decode the VQ tokens to audio.
"""
# Don't use autocast on MPS devices
with autocast_exclude_mps(
device_type=self.decoder_model.device.type, dtype=self.precision
):
# Decode the symbolic tokens to audio
segment = self.decode_vq_tokens(codes=result.codes)
# Convert the audio to numpy
return segment.float().cpu().numpy()
|