Spaces:
Sleeping
Sleeping
from ttts.vqvae.xtts_dvae import DiscreteVAE | |
from ttts.diffusion.model import DiffusionTts | |
from ttts.gpt.model import UnifiedVoice | |
from ttts.classifier.model import AudioMiniEncoderWithClassifierHead | |
from omegaconf import OmegaConf | |
from ttts.diffusion.aa_model import AA_diffusion | |
import json | |
import torch | |
import os | |
def load_model(model_name, model_path, config_path, device): | |
config_path = os.path.expanduser(config_path) | |
model_path = os.path.expanduser(model_path) | |
if config_path.endswith('.json'): | |
config = json.load(open(config_path)) | |
else: | |
config = OmegaConf.load(config_path) | |
if model_name=='vqvae': | |
model = DiscreteVAE(**config['vqvae']) | |
sd = torch.load(model_path, map_location=device)['model'] | |
model.load_state_dict(sd, strict=True) | |
model = model.to(device) | |
elif model_name=='gpt': | |
model = UnifiedVoice(**config['gpt']) | |
gpt = torch.load(model_path, map_location=device)['model'] | |
model.load_state_dict(gpt, strict=True) | |
model = model.to(device) | |
elif model_name=='diffusion': | |
# model = DiffusionTts(**config['diffusion']) | |
model = AA_diffusion(config) | |
diffusion = torch.load(model_path, map_location=device)['model'] | |
model.load_state_dict(diffusion, strict=True) | |
model = model.to(device) | |
elif model_name == 'classifier': | |
model = AudioMiniEncoderWithClassifierHead(**config['classifier']) | |
classifier = torch.load(model_path, map_location=device)['model'] | |
model.load_state_dict(classifier, strict=True) | |
model = model.to(device) | |
# elif model_name=='clvp': | |
return model.eval() |