|
from pathlib import Path |
|
|
|
import click |
|
import hydra |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
import torchaudio |
|
from hydra import compose, initialize |
|
from hydra.utils import instantiate |
|
from loguru import logger |
|
from omegaconf import OmegaConf |
|
|
|
from tools.file import AUDIO_EXTENSIONS |
|
|
|
|
|
OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
|
def load_model(config_name, checkpoint_path, device="cuda"): |
|
hydra.core.global_hydra.GlobalHydra.instance().clear() |
|
with initialize(version_base="1.3", config_path="../../fish_speech/configs"): |
|
cfg = compose(config_name=config_name) |
|
|
|
model = instantiate(cfg) |
|
state_dict = torch.load( |
|
checkpoint_path, |
|
map_location=device, |
|
) |
|
if "state_dict" in state_dict: |
|
state_dict = state_dict["state_dict"] |
|
|
|
if any("generator" in k for k in state_dict): |
|
state_dict = { |
|
k.replace("generator.", ""): v |
|
for k, v in state_dict.items() |
|
if "generator." in k |
|
} |
|
|
|
result = model.load_state_dict(state_dict, strict=False) |
|
model.eval() |
|
model.to(device) |
|
|
|
logger.info(f"Loaded model: {result}") |
|
return model |
|
|
|
|
|
@torch.no_grad() |
|
@click.command() |
|
@click.option( |
|
"--input-path", |
|
"-i", |
|
default="test.wav", |
|
type=click.Path(exists=True, path_type=Path), |
|
) |
|
@click.option( |
|
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path) |
|
) |
|
@click.option("--config-name", default="firefly_gan_vq") |
|
@click.option( |
|
"--checkpoint-path", |
|
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", |
|
) |
|
@click.option( |
|
"--device", |
|
"-d", |
|
default="cuda", |
|
) |
|
def main(input_path, output_path, config_name, checkpoint_path, device): |
|
model = load_model(config_name, checkpoint_path, device=device) |
|
|
|
if input_path.suffix in AUDIO_EXTENSIONS: |
|
logger.info(f"Processing in-place reconstruction of {input_path}") |
|
|
|
|
|
audio, sr = torchaudio.load(str(input_path)) |
|
if audio.shape[0] > 1: |
|
audio = audio.mean(0, keepdim=True) |
|
audio = torchaudio.functional.resample( |
|
audio, sr, model.spec_transform.sample_rate |
|
) |
|
|
|
audios = audio[None].to(device) |
|
logger.info( |
|
f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds" |
|
) |
|
|
|
|
|
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long) |
|
indices = model.encode(audios, audio_lengths)[0][0] |
|
|
|
logger.info(f"Generated indices of shape {indices.shape}") |
|
|
|
|
|
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy()) |
|
elif input_path.suffix == ".npy": |
|
logger.info(f"Processing precomputed indices from {input_path}") |
|
indices = np.load(input_path) |
|
indices = torch.from_numpy(indices).to(device).long() |
|
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}" |
|
else: |
|
raise ValueError(f"Unknown input type: {input_path}") |
|
|
|
|
|
feature_lengths = torch.tensor([indices.shape[1]], device=device) |
|
fake_audios, _ = model.decode( |
|
indices=indices[None], feature_lengths=feature_lengths |
|
) |
|
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate |
|
|
|
logger.info( |
|
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}" |
|
) |
|
|
|
|
|
fake_audio = fake_audios[0, 0].float().cpu().numpy() |
|
sf.write(output_path, fake_audio, model.spec_transform.sample_rate) |
|
logger.info(f"Saved audio to {output_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|