Spaces:
tsqn
/
Runtime error

Rex Cheng
initial commit
dbac20f
raw
history blame contribute delete
986 Bytes
from pathlib import Path
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from mmaudio.ext.bigvgan.models import BigVGANVocoder
_bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml'
class BigVGAN(nn.Module):
def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path):
super().__init__()
vocoder_cfg = OmegaConf.load(config_path)
self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator']
self.vocoder.load_state_dict(vocoder_ckpt)
self.weight_norm_removed = False
self.remove_weight_norm()
@torch.inference_mode()
def forward(self, x):
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
return self.vocoder(x)
def remove_weight_norm(self):
self.vocoder.remove_weight_norm()
self.weight_norm_removed = True
return self