zhzluke96
update
d2b7e94
raw
history blame contribute delete
789 Bytes
import logging
from functools import cache
import torch
from ..denoiser.denoiser import Denoiser
from ..inference import inference
from .hparams import HParams
logger = logging.getLogger(__name__)
@cache
def load_denoiser(run_dir, device):
if run_dir is None:
return Denoiser(HParams())
hp = HParams.load(run_dir)
denoiser = Denoiser(hp)
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
state_dict = torch.load(path, map_location="cpu")["module"]
denoiser.load_state_dict(state_dict)
denoiser.eval()
denoiser.to(device)
return denoiser
@torch.inference_mode()
def denoise(dwav, sr, run_dir, device):
denoiser = load_denoiser(run_dir, device)
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)