Jannat24 commited on
Commit
3f2989b
·
verified ·
1 Parent(s): 2a681ac

Delete modules/finetunedvqgan.py

Browse files
Files changed (1) hide show
  1. modules/finetunedvqgan.py +0 -34
modules/finetunedvqgan.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- from torch.utils.checkpoint import checkpoint
3
- from taming.models.vqgan import VQModel
4
- from omegaconf import OmegaConf
5
- from taming.models.vqgan import GumbelVQ
6
-
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
- class Generator:
10
- def __init__(self, config_path, checkpoint_path, device=device):
11
- self.config_path = config_path
12
- self.checkpoint_path = checkpoint_path
13
- self.device = device
14
-
15
- def load_models(self):
16
- # Load configuration
17
- config = OmegaConf.load(self.config_path)
18
- # Extract parameters specific to GumbelVQ
19
- vq_params = config.model.params
20
- # Initialize the GumbelVQ models
21
- model_vaq = GumbelVQ(
22
- ddconfig=vq_params.ddconfig,
23
- lossconfig=vq_params.lossconfig,
24
- n_embed=vq_params.n_embed,
25
- embed_dim=vq_params.embed_dim,
26
- kl_weight=vq_params.kl_weight,
27
- temperature_scheduler_config=vq_params.temperature_scheduler_config,
28
- ).to(self.device)
29
- # Load model checkpoints
30
- checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
31
- # Load the state dictionary into the models
32
- model_vaq.load_state_dict(checkpoint, strict=True)
33
-
34
- return model_vaq.eval()