|
import torch |
|
from funasr import AutoModel |
|
from loguru import logger |
|
|
|
from fish_speech.inference_engine import TTSInferenceEngine |
|
from fish_speech.models.text2semantic.inference import ( |
|
launch_thread_safe_queue, |
|
launch_thread_safe_queue_agent, |
|
) |
|
from fish_speech.models.vqgan.inference import load_model as load_decoder_model |
|
from fish_speech.utils.schema import ServeTTSRequest |
|
from tools.server.inference import inference_wrapper as inference |
|
|
|
ASR_MODEL_NAME = "iic/SenseVoiceSmall" |
|
|
|
|
|
class ModelManager: |
|
def __init__( |
|
self, |
|
mode: str, |
|
device: str, |
|
half: bool, |
|
compile: bool, |
|
asr_enabled: bool, |
|
llama_checkpoint_path: str, |
|
decoder_checkpoint_path: str, |
|
decoder_config_name: str, |
|
) -> None: |
|
|
|
self.mode = mode |
|
self.device = device |
|
self.half = half |
|
self.compile = compile |
|
|
|
self.precision = torch.half if half else torch.bfloat16 |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
self.device = "mps" |
|
logger.info("mps is available, running on mps.") |
|
elif not torch.cuda.is_available(): |
|
self.device = "cpu" |
|
logger.info("CUDA is not available, running on CPU.") |
|
|
|
|
|
if asr_enabled: |
|
self.load_asr_model(self.device) |
|
|
|
|
|
self.load_llama_model( |
|
llama_checkpoint_path, self.device, self.precision, self.compile, self.mode |
|
) |
|
self.load_decoder_model( |
|
decoder_config_name, decoder_checkpoint_path, self.device |
|
) |
|
self.tts_inference_engine = TTSInferenceEngine( |
|
llama_queue=self.llama_queue, |
|
decoder_model=self.decoder_model, |
|
precision=self.precision, |
|
compile=self.compile, |
|
) |
|
|
|
|
|
if self.mode == "tts": |
|
self.warm_up(self.tts_inference_engine) |
|
|
|
def load_asr_model(self, device, hub="ms") -> None: |
|
self.asr_model = AutoModel( |
|
model=ASR_MODEL_NAME, |
|
device=device, |
|
disable_pbar=True, |
|
hub=hub, |
|
) |
|
logger.info("ASR model loaded.") |
|
|
|
def load_llama_model( |
|
self, checkpoint_path, device, precision, compile, mode |
|
) -> None: |
|
|
|
if mode == "tts": |
|
self.llama_queue = launch_thread_safe_queue( |
|
checkpoint_path=checkpoint_path, |
|
device=device, |
|
precision=precision, |
|
compile=compile, |
|
) |
|
elif mode == "agent": |
|
self.llama_queue, self.tokenizer, self.config = ( |
|
launch_thread_safe_queue_agent( |
|
checkpoint_path=checkpoint_path, |
|
device=device, |
|
precision=precision, |
|
compile=compile, |
|
) |
|
) |
|
else: |
|
raise ValueError(f"Invalid mode: {mode}") |
|
|
|
logger.info("LLAMA model loaded.") |
|
|
|
def load_decoder_model(self, config_name, checkpoint_path, device) -> None: |
|
self.decoder_model = load_decoder_model( |
|
config_name=config_name, |
|
checkpoint_path=checkpoint_path, |
|
device=device, |
|
) |
|
logger.info("Decoder model loaded.") |
|
|
|
def warm_up(self, tts_inference_engine) -> None: |
|
request = ServeTTSRequest( |
|
text="Hello world.", |
|
references=[], |
|
reference_id=None, |
|
max_new_tokens=1024, |
|
chunk_length=200, |
|
top_p=0.7, |
|
repetition_penalty=1.2, |
|
temperature=0.7, |
|
format="wav", |
|
) |
|
list(inference(request, tts_inference_engine)) |
|
logger.info("Models warmed up.") |
|
|