Update graph_decoder/diffusion_model.py
Browse files- graph_decoder/diffusion_model.py +13 -17
graph_decoder/diffusion_model.py
CHANGED
@@ -43,18 +43,17 @@ class GraphDiT(nn.Module):
|
|
43 |
self.hidden_size = dm_cfg.hidden_size
|
44 |
self.mol_visualizer = MolecularVisualization(self.atom_decoder)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
self.denoiser = None
|
58 |
|
59 |
self.model_dtype = model_dtype
|
60 |
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
|
@@ -90,14 +89,12 @@ class GraphDiT(nn.Module):
|
|
90 |
def init_model(self, model_dir):
|
91 |
model_file = os.path.join(model_dir, 'model.pt')
|
92 |
if os.path.exists(model_file):
|
93 |
-
|
94 |
-
pass
|
95 |
else:
|
96 |
raise FileNotFoundError(f"Model file not found: {model_file}")
|
97 |
|
98 |
def disable_grads(self):
|
99 |
-
|
100 |
-
# self.denoiser.disable_grads()
|
101 |
|
102 |
def forward(
|
103 |
self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
|
@@ -193,7 +190,6 @@ class GraphDiT(nn.Module):
|
|
193 |
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
|
194 |
batch_size = properties.size(0)
|
195 |
assert batch_size == 1
|
196 |
-
# print('self.denoiser.dtype', self.model_dtype)
|
197 |
if num_nodes is None:
|
198 |
num_nodes = self.node_dist.sample_n(batch_size, device)
|
199 |
else:
|
|
|
43 |
self.hidden_size = dm_cfg.hidden_size
|
44 |
self.mol_visualizer = MolecularVisualization(self.atom_decoder)
|
45 |
|
46 |
+
self.denoiser = Transformer(
|
47 |
+
max_n_nodes=self.max_n_nodes,
|
48 |
+
hidden_size=dm_cfg.hidden_size,
|
49 |
+
depth=dm_cfg.depth,
|
50 |
+
num_heads=dm_cfg.num_heads,
|
51 |
+
mlp_ratio=dm_cfg.mlp_ratio,
|
52 |
+
drop_condition=dm_cfg.drop_condition,
|
53 |
+
Xdim=self.Xdim,
|
54 |
+
Edim=self.Edim,
|
55 |
+
ydim=self.ydim,
|
56 |
+
)
|
|
|
57 |
|
58 |
self.model_dtype = model_dtype
|
59 |
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
|
|
|
89 |
def init_model(self, model_dir):
|
90 |
model_file = os.path.join(model_dir, 'model.pt')
|
91 |
if os.path.exists(model_file):
|
92 |
+
self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
|
|
|
93 |
else:
|
94 |
raise FileNotFoundError(f"Model file not found: {model_file}")
|
95 |
|
96 |
def disable_grads(self):
|
97 |
+
self.denoiser.disable_grads()
|
|
|
98 |
|
99 |
def forward(
|
100 |
self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
|
|
|
190 |
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
|
191 |
batch_size = properties.size(0)
|
192 |
assert batch_size == 1
|
|
|
193 |
if num_nodes is None:
|
194 |
num_nodes = self.node_dist.sample_n(batch_size, device)
|
195 |
else:
|