|
import io |
|
from hashlib import sha256 |
|
from pathlib import Path |
|
from typing import Callable, Literal, Tuple |
|
|
|
import torch |
|
import torchaudio |
|
from loguru import logger |
|
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture |
|
from fish_speech.utils.file import ( |
|
AUDIO_EXTENSIONS, |
|
audio_to_bytes, |
|
list_files, |
|
read_ref_text, |
|
) |
|
from fish_speech.utils.schema import ServeReferenceAudio |
|
|
|
|
|
class ReferenceLoader: |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Component of the TTSInferenceEngine class. |
|
Loads and manages the cache for the reference audio and text. |
|
""" |
|
self.ref_by_id: dict = {} |
|
self.ref_by_hash: dict = {} |
|
|
|
|
|
self.decoder_model: FireflyArchitecture |
|
self.encode_reference: Callable |
|
|
|
|
|
backends = torchaudio.list_audio_backends() |
|
if "ffmpeg" in backends: |
|
self.backend = "ffmpeg" |
|
else: |
|
self.backend = "soundfile" |
|
|
|
def load_by_id( |
|
self, |
|
id: str, |
|
use_cache: Literal["on", "off"], |
|
) -> Tuple: |
|
|
|
|
|
ref_folder = Path("references") / id |
|
ref_folder.mkdir(parents=True, exist_ok=True) |
|
ref_audios = list_files( |
|
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False |
|
) |
|
|
|
if use_cache == "off" or id not in self.ref_by_id: |
|
|
|
prompt_tokens = [ |
|
self.encode_reference( |
|
|
|
reference_audio=audio_to_bytes(str(ref_audio)), |
|
enable_reference_audio=True, |
|
) |
|
for ref_audio in ref_audios |
|
] |
|
prompt_texts = [ |
|
read_ref_text(str(ref_audio.with_suffix(".lab"))) |
|
for ref_audio in ref_audios |
|
] |
|
self.ref_by_id[id] = (prompt_tokens, prompt_texts) |
|
|
|
else: |
|
|
|
logger.info("Use same references") |
|
prompt_tokens, prompt_texts = self.ref_by_id[id] |
|
|
|
return prompt_tokens, prompt_texts |
|
|
|
def load_by_hash( |
|
self, |
|
references: list[ServeReferenceAudio], |
|
use_cache: Literal["on", "off"], |
|
) -> Tuple: |
|
|
|
|
|
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references] |
|
|
|
cache_used = False |
|
prompt_tokens, prompt_texts = [], [] |
|
for i, ref in enumerate(references): |
|
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash: |
|
|
|
prompt_tokens.append( |
|
self.encode_reference( |
|
reference_audio=ref.audio, |
|
enable_reference_audio=True, |
|
) |
|
) |
|
prompt_texts.append(ref.text) |
|
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts) |
|
|
|
else: |
|
|
|
prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]] |
|
cache_used = True |
|
|
|
if cache_used: |
|
logger.info("Use same references") |
|
|
|
return prompt_tokens, prompt_texts |
|
|
|
def load_audio(self, reference_audio, sr): |
|
""" |
|
Load the audio data from a file or bytes. |
|
""" |
|
if len(reference_audio) > 255 or not Path(reference_audio).exists(): |
|
audio_data = reference_audio |
|
reference_audio = io.BytesIO(audio_data) |
|
|
|
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend) |
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
if original_sr != sr: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=original_sr, new_freq=sr |
|
) |
|
waveform = resampler(waveform) |
|
|
|
audio = waveform.squeeze().numpy() |
|
return audio |
|
|