Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.checkpoint import checkpoint | |
from taming.models.vqgan import VQModel | |
from omegaconf import OmegaConf | |
from taming.models.vqgan import GumbelVQ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Generator: | |
def __init__(self, config_path, device=device): | |
self.config_path = config_path | |
self.device = device | |
def load_models(self): | |
# Load configuration | |
config = OmegaConf.load(self.config_path) | |
# Extract parameters specific to GumbelVQ | |
vq_params = config.model.params | |
# Initialize the GumbelVQ models | |
model_vaq = GumbelVQ( | |
ddconfig=vq_params.ddconfig, | |
lossconfig=vq_params.lossconfig, | |
n_embed=vq_params.n_embed, | |
embed_dim=vq_params.embed_dim, | |
kl_weight=vq_params.kl_weight, | |
temperature_scheduler_config=vq_params.temperature_scheduler_config, | |
).to(self.device) | |
return model_vaq | |