Jannat24's picture
2025_march16
c9224f7 verified
raw
history blame contribute delete
997 Bytes
import torch
from torch.utils.checkpoint import checkpoint
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
from taming.models.vqgan import GumbelVQ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generator:
def __init__(self, config_path, device=device):
self.config_path = config_path
self.device = device
def load_models(self):
# Load configuration
config = OmegaConf.load(self.config_path)
# Extract parameters specific to GumbelVQ
vq_params = config.model.params
# Initialize the GumbelVQ models
model_vaq = GumbelVQ(
ddconfig=vq_params.ddconfig,
lossconfig=vq_params.lossconfig,
n_embed=vq_params.n_embed,
embed_dim=vq_params.embed_dim,
kl_weight=vq_params.kl_weight,
temperature_scheduler_config=vq_params.temperature_scheduler_config,
).to(self.device)
return model_vaq