liuganghuggingface commited on
Commit
cd16fe3
·
verified ·
1 Parent(s): 1e3c39a

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +339 -338
graph_decoder/diffusion_model.py CHANGED
@@ -19,347 +19,348 @@ class GraphDiT(nn.Module):
19
  model_dtype,
20
  ):
21
  super().__init__()
 
22
 
23
- dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
24
-
25
- input_dims = data_info.input_dims
26
- output_dims = data_info.output_dims
27
- nodes_dist = data_info.nodes_dist
28
- active_index = data_info.active_index
29
-
30
- self.model_config = dm_cfg
31
- self.data_info = data_info
32
- self.T = dm_cfg.diffusion_steps
33
- self.Xdim = input_dims["X"]
34
- self.Edim = input_dims["E"]
35
- self.ydim = input_dims["y"]
36
- self.Xdim_output = output_dims["X"]
37
- self.Edim_output = output_dims["E"]
38
- self.ydim_output = output_dims["y"]
39
- self.node_dist = nodes_dist
40
- self.active_index = active_index
41
- self.max_n_nodes = data_info.max_n_nodes
42
- self.atom_decoder = data_info.atom_decoder
43
- self.hidden_size = dm_cfg.hidden_size
44
- self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
-
46
- self.denoiser = Transformer(
47
- max_n_nodes=self.max_n_nodes,
48
- hidden_size=dm_cfg.hidden_size,
49
- depth=dm_cfg.depth,
50
- num_heads=dm_cfg.num_heads,
51
- mlp_ratio=dm_cfg.mlp_ratio,
52
- drop_condition=dm_cfg.drop_condition,
53
- Xdim=self.Xdim,
54
- Edim=self.Edim,
55
- ydim=self.ydim,
56
- )
57
-
58
- self.model_dtype = model_dtype
59
- self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
60
- dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
61
- )
62
- x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
63
- data_info.node_types.to(self.model_dtype)
64
- )
65
- e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
66
- data_info.edge_types.to(self.model_dtype)
67
- )
68
- x_marginals = x_marginals / x_marginals.sum()
69
- e_marginals = e_marginals / e_marginals.sum()
70
-
71
- xe_conditions = data_info.transition_E.to(self.model_dtype)
72
- xe_conditions = xe_conditions[self.active_index][:, self.active_index]
73
-
74
- xe_conditions = xe_conditions.sum(dim=1)
75
- ex_conditions = xe_conditions.t()
76
- xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
77
- ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
78
-
79
- self.transition_model = utils.MarginalTransition(
80
- x_marginals=x_marginals,
81
- e_marginals=e_marginals,
82
- xe_conditions=xe_conditions,
83
- ex_conditions=ex_conditions,
84
- y_classes=self.ydim_output,
85
- n_nodes=self.max_n_nodes,
86
- )
87
- self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
88
-
89
- def init_model(self, model_dir):
90
- model_file = os.path.join(model_dir, 'model.pt')
91
- if os.path.exists(model_file):
92
- self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
93
- else:
94
- raise FileNotFoundError(f"Model file not found: {model_file}")
95
-
96
- def disable_grads(self):
97
- self.denoiser.disable_grads()
98
 
99
- def forward(
100
- self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
101
- ):
102
- raise ValueError('Not Implement')
103
-
104
- def _forward(self, noisy_data, unconditioned=False):
105
- noisy_x, noisy_e, properties = (
106
- noisy_data["X_t"].to(self.model_dtype),
107
- noisy_data["E_t"].to(self.model_dtype),
108
- noisy_data["y_t"].to(self.model_dtype).clone(),
109
- )
110
- node_mask, timestep = (
111
- noisy_data["node_mask"],
112
- noisy_data["t"],
113
- )
114
 
115
- pred = self.denoiser(
116
- noisy_x,
117
- noisy_e,
118
- node_mask,
119
- properties,
120
- timestep,
121
- unconditioned=unconditioned,
122
- )
123
- return pred
124
-
125
- def apply_noise(self, X, E, y, node_mask):
126
- """Sample noise and apply it to the data."""
127
-
128
- # Sample a timestep t.
129
- # When evaluating, the loss for t=0 is computed separately
130
- lowest_t = 0 if self.training else 1
131
- t_int = torch.randint(
132
- lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
133
- ).to(
134
- self.model_dtype
135
- ) # (bs, 1)
136
- s_int = t_int - 1
137
-
138
- t_float = t_int / self.T
139
- s_float = s_int / self.T
140
-
141
- # beta_t and alpha_s_bar are used for denoising/loss computation
142
- beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
143
- alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
144
- alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
145
-
146
- Qtb = self.transition_model.get_Qt_bar(
147
- alpha_t_bar, X.device
148
- ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
149
-
150
- bs, n, d = X.shape
151
- X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
152
- prob_all = X_all @ Qtb.X
153
- probX = prob_all[:, :, : self.Xdim_output]
154
- probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
155
-
156
- sampled_t = utils.sample_discrete_features(
157
- probX=probX, probE=probE, node_mask=node_mask
158
- )
159
-
160
- X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
161
- E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
162
- assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
163
-
164
- y_t = y
165
- z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
166
-
167
- noisy_data = {
168
- "t_int": t_int,
169
- "t": t_float,
170
- "beta_t": beta_t,
171
- "alpha_s_bar": alpha_s_bar,
172
- "alpha_t_bar": alpha_t_bar,
173
- "X_t": z_t.X,
174
- "E_t": z_t.E,
175
- "y_t": z_t.y,
176
- "node_mask": node_mask,
177
- }
178
- return noisy_data
179
-
180
- @torch.no_grad()
181
- def generate(
182
- self,
183
- properties,
184
- guide_scale=1.,
185
- num_nodes=None,
186
- number_chain_steps=50,
187
- ):
188
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189
- properties = [float('nan') if x is None else x for x in properties]
190
- properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
191
- batch_size = properties.size(0)
192
- assert batch_size == 1
193
- if num_nodes is None:
194
- num_nodes = self.node_dist.sample_n(batch_size, device)
195
- else:
196
- num_nodes = torch.LongTensor([num_nodes]).to(device)
197
-
198
- arange = (
199
- torch.arange(self.max_n_nodes, device=device)
200
- .unsqueeze(0)
201
- .expand(batch_size, -1)
202
- )
203
- node_mask = arange < num_nodes.unsqueeze(1)
204
-
205
- z_T = utils.sample_discrete_feature_noise(
206
- limit_dist=self.limit_dist, node_mask=node_mask
207
- )
208
- X, E = z_T.X, z_T.E
209
-
210
- assert (E == torch.transpose(E, 1, 2)).all()
211
-
212
- if number_chain_steps > 0:
213
- chain_X_size = torch.Size((number_chain_steps, X.size(1)))
214
- chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
215
- chain_X = torch.zeros(chain_X_size)
216
- chain_E = torch.zeros(chain_E_size)
217
-
218
- # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
219
- y = properties
220
- for s_int in reversed(range(0, self.T)):
221
- s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
222
- t_array = s_array + 1
223
- s_norm = s_array / self.T
224
- t_norm = t_array / self.T
225
-
226
- # Sample z_s
227
- sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
228
- s_norm, t_norm, X, E, y, node_mask, guide_scale, device
229
- )
230
- X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
231
 
232
- if number_chain_steps > 0:
233
- # Save the first keep_chain graphs
234
- write_index = (s_int * number_chain_steps) // self.T
235
- chain_X[write_index] = discrete_sampled_s.X[:1]
236
- chain_E[write_index] = discrete_sampled_s.E[:1]
237
-
238
- # Sample
239
- sampled_s = sampled_s.mask(node_mask, collapse=True)
240
- X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
241
-
242
- molecule_list = []
243
- n = num_nodes[0]
244
- atom_types = X[0, :n].cpu()
245
- edge_types = E[0, :n, :n].cpu()
246
- molecule_list.append([atom_types, edge_types])
247
- smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
248
-
249
- # Visualize Chains
250
- if number_chain_steps > 0:
251
- final_X_chain = X[:1]
252
- final_E_chain = E[:1]
253
-
254
- chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
255
- chain_E[0] = final_E_chain
256
-
257
- chain_X = utils.reverse_tensor(chain_X)
258
- chain_E = utils.reverse_tensor(chain_E)
259
-
260
- # Repeat last frame to see final sample better
261
- chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
262
- chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
263
- mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
264
- else:
265
- mol_img_list = []
266
-
267
- return smiles, mol_img_list
268
-
269
- def check_valid(self, smiles):
270
- return check_valid(smiles)
271
 
272
- def sample_p_zs_given_zt(
273
- self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
274
- ):
275
- """Samples from zs ~ p(zs | zt). Only used during sampling.
276
- if last_step, return the graph prediction as well"""
277
- bs, n, _ = X_t.shape
278
- beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
279
- alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
280
- alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
281
-
282
- # Neural net predictions
283
- noisy_data = {
284
- "X_t": X_t,
285
- "E_t": E_t,
286
- "y_t": properties,
287
- "t": t,
288
- "node_mask": node_mask,
289
- }
290
-
291
- def get_prob(noisy_data, unconditioned=False):
292
- pred = self._forward(noisy_data, unconditioned=unconditioned)
293
-
294
- # Normalize predictions
295
- pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
296
- pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
297
-
298
- # Retrieve transitions matrix
299
- Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
300
- Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
301
- Qt = self.transition_model.get_Qt(beta_t, device)
302
-
303
- Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
304
- predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
305
-
306
- unnormalized_probX_all = utils.reverse_diffusion(
307
- predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
308
- )
309
-
310
- unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
311
- unnormalized_prob_E = unnormalized_probX_all[
312
- :, :, self.Xdim_output :
313
- ].reshape(bs, n * n, -1)
314
-
315
- unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
316
- unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
317
-
318
- prob_X = unnormalized_prob_X / torch.sum(
319
- unnormalized_prob_X, dim=-1, keepdim=True
320
- ) # bs, n, d_t-1
321
- prob_E = unnormalized_prob_E / torch.sum(
322
- unnormalized_prob_E, dim=-1, keepdim=True
323
- ) # bs, n, d_t-1
324
- prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
325
-
326
- return prob_X, prob_E
327
-
328
- prob_X, prob_E = get_prob(noisy_data)
329
-
330
- ### Guidance
331
- if guide_scale != 1:
332
- uncon_prob_X, uncon_prob_E = get_prob(
333
- noisy_data, unconditioned=True
334
- )
335
- prob_X = (
336
- uncon_prob_X
337
- * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
338
- )
339
- prob_E = (
340
- uncon_prob_E
341
- * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
342
- )
343
- prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
344
- prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
345
-
346
- # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
347
- # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
348
-
349
- sampled_s = utils.sample_discrete_features(
350
- prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
351
- )
352
-
353
- X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
354
- E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
355
-
356
- assert (E_s == torch.transpose(E_s, 1, 2)).all()
357
- assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
358
-
359
- out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
360
- out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
361
-
362
- return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
363
- node_mask, collapse=True
364
- ).type_as(properties)
365
 
 
19
  model_dtype,
20
  ):
21
  super().__init__()
22
+ pass
23
 
24
+ # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
25
+
26
+ # input_dims = data_info.input_dims
27
+ # output_dims = data_info.output_dims
28
+ # nodes_dist = data_info.nodes_dist
29
+ # active_index = data_info.active_index
30
+
31
+ # self.model_config = dm_cfg
32
+ # self.data_info = data_info
33
+ # self.T = dm_cfg.diffusion_steps
34
+ # self.Xdim = input_dims["X"]
35
+ # self.Edim = input_dims["E"]
36
+ # self.ydim = input_dims["y"]
37
+ # self.Xdim_output = output_dims["X"]
38
+ # self.Edim_output = output_dims["E"]
39
+ # self.ydim_output = output_dims["y"]
40
+ # self.node_dist = nodes_dist
41
+ # self.active_index = active_index
42
+ # self.max_n_nodes = data_info.max_n_nodes
43
+ # self.atom_decoder = data_info.atom_decoder
44
+ # self.hidden_size = dm_cfg.hidden_size
45
+ # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
46
+
47
+ # self.denoiser = Transformer(
48
+ # max_n_nodes=self.max_n_nodes,
49
+ # hidden_size=dm_cfg.hidden_size,
50
+ # depth=dm_cfg.depth,
51
+ # num_heads=dm_cfg.num_heads,
52
+ # mlp_ratio=dm_cfg.mlp_ratio,
53
+ # drop_condition=dm_cfg.drop_condition,
54
+ # Xdim=self.Xdim,
55
+ # Edim=self.Edim,
56
+ # ydim=self.ydim,
57
+ # )
58
+
59
+ # self.model_dtype = model_dtype
60
+ # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
61
+ # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
62
+ # )
63
+ # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
64
+ # data_info.node_types.to(self.model_dtype)
65
+ # )
66
+ # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
67
+ # data_info.edge_types.to(self.model_dtype)
68
+ # )
69
+ # x_marginals = x_marginals / x_marginals.sum()
70
+ # e_marginals = e_marginals / e_marginals.sum()
71
+
72
+ # xe_conditions = data_info.transition_E.to(self.model_dtype)
73
+ # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
74
+
75
+ # xe_conditions = xe_conditions.sum(dim=1)
76
+ # ex_conditions = xe_conditions.t()
77
+ # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
78
+ # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
79
+
80
+ # self.transition_model = utils.MarginalTransition(
81
+ # x_marginals=x_marginals,
82
+ # e_marginals=e_marginals,
83
+ # xe_conditions=xe_conditions,
84
+ # ex_conditions=ex_conditions,
85
+ # y_classes=self.ydim_output,
86
+ # n_nodes=self.max_n_nodes,
87
+ # )
88
+ # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
89
+
90
+ # def init_model(self, model_dir):
91
+ # model_file = os.path.join(model_dir, 'model.pt')
92
+ # if os.path.exists(model_file):
93
+ # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
94
+ # else:
95
+ # raise FileNotFoundError(f"Model file not found: {model_file}")
96
+
97
+ # def disable_grads(self):
98
+ # self.denoiser.disable_grads()
99
 
100
+ # def forward(
101
+ # self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
102
+ # ):
103
+ # raise ValueError('Not Implement')
104
+
105
+ # def _forward(self, noisy_data, unconditioned=False):
106
+ # noisy_x, noisy_e, properties = (
107
+ # noisy_data["X_t"].to(self.model_dtype),
108
+ # noisy_data["E_t"].to(self.model_dtype),
109
+ # noisy_data["y_t"].to(self.model_dtype).clone(),
110
+ # )
111
+ # node_mask, timestep = (
112
+ # noisy_data["node_mask"],
113
+ # noisy_data["t"],
114
+ # )
115
 
116
+ # pred = self.denoiser(
117
+ # noisy_x,
118
+ # noisy_e,
119
+ # node_mask,
120
+ # properties,
121
+ # timestep,
122
+ # unconditioned=unconditioned,
123
+ # )
124
+ # return pred
125
+
126
+ # def apply_noise(self, X, E, y, node_mask):
127
+ # """Sample noise and apply it to the data."""
128
+
129
+ # # Sample a timestep t.
130
+ # # When evaluating, the loss for t=0 is computed separately
131
+ # lowest_t = 0 if self.training else 1
132
+ # t_int = torch.randint(
133
+ # lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
134
+ # ).to(
135
+ # self.model_dtype
136
+ # ) # (bs, 1)
137
+ # s_int = t_int - 1
138
+
139
+ # t_float = t_int / self.T
140
+ # s_float = s_int / self.T
141
+
142
+ # # beta_t and alpha_s_bar are used for denoising/loss computation
143
+ # beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
144
+ # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
145
+ # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
146
+
147
+ # Qtb = self.transition_model.get_Qt_bar(
148
+ # alpha_t_bar, X.device
149
+ # ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
150
+
151
+ # bs, n, d = X.shape
152
+ # X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
153
+ # prob_all = X_all @ Qtb.X
154
+ # probX = prob_all[:, :, : self.Xdim_output]
155
+ # probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
156
+
157
+ # sampled_t = utils.sample_discrete_features(
158
+ # probX=probX, probE=probE, node_mask=node_mask
159
+ # )
160
+
161
+ # X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
162
+ # E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
163
+ # assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
164
+
165
+ # y_t = y
166
+ # z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
167
+
168
+ # noisy_data = {
169
+ # "t_int": t_int,
170
+ # "t": t_float,
171
+ # "beta_t": beta_t,
172
+ # "alpha_s_bar": alpha_s_bar,
173
+ # "alpha_t_bar": alpha_t_bar,
174
+ # "X_t": z_t.X,
175
+ # "E_t": z_t.E,
176
+ # "y_t": z_t.y,
177
+ # "node_mask": node_mask,
178
+ # }
179
+ # return noisy_data
180
+
181
+ # @torch.no_grad()
182
+ # def generate(
183
+ # self,
184
+ # properties,
185
+ # guide_scale=1.,
186
+ # num_nodes=None,
187
+ # number_chain_steps=50,
188
+ # ):
189
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ # properties = [float('nan') if x is None else x for x in properties]
191
+ # properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
192
+ # batch_size = properties.size(0)
193
+ # assert batch_size == 1
194
+ # if num_nodes is None:
195
+ # num_nodes = self.node_dist.sample_n(batch_size, device)
196
+ # else:
197
+ # num_nodes = torch.LongTensor([num_nodes]).to(device)
198
+
199
+ # arange = (
200
+ # torch.arange(self.max_n_nodes, device=device)
201
+ # .unsqueeze(0)
202
+ # .expand(batch_size, -1)
203
+ # )
204
+ # node_mask = arange < num_nodes.unsqueeze(1)
205
+
206
+ # z_T = utils.sample_discrete_feature_noise(
207
+ # limit_dist=self.limit_dist, node_mask=node_mask
208
+ # )
209
+ # X, E = z_T.X, z_T.E
210
+
211
+ # assert (E == torch.transpose(E, 1, 2)).all()
212
+
213
+ # if number_chain_steps > 0:
214
+ # chain_X_size = torch.Size((number_chain_steps, X.size(1)))
215
+ # chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
216
+ # chain_X = torch.zeros(chain_X_size)
217
+ # chain_E = torch.zeros(chain_E_size)
218
+
219
+ # # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
220
+ # y = properties
221
+ # for s_int in reversed(range(0, self.T)):
222
+ # s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
223
+ # t_array = s_array + 1
224
+ # s_norm = s_array / self.T
225
+ # t_norm = t_array / self.T
226
+
227
+ # # Sample z_s
228
+ # sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
229
+ # s_norm, t_norm, X, E, y, node_mask, guide_scale, device
230
+ # )
231
+ # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
232
 
233
+ # if number_chain_steps > 0:
234
+ # # Save the first keep_chain graphs
235
+ # write_index = (s_int * number_chain_steps) // self.T
236
+ # chain_X[write_index] = discrete_sampled_s.X[:1]
237
+ # chain_E[write_index] = discrete_sampled_s.E[:1]
238
+
239
+ # # Sample
240
+ # sampled_s = sampled_s.mask(node_mask, collapse=True)
241
+ # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
242
+
243
+ # molecule_list = []
244
+ # n = num_nodes[0]
245
+ # atom_types = X[0, :n].cpu()
246
+ # edge_types = E[0, :n, :n].cpu()
247
+ # molecule_list.append([atom_types, edge_types])
248
+ # smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
249
+
250
+ # # Visualize Chains
251
+ # if number_chain_steps > 0:
252
+ # final_X_chain = X[:1]
253
+ # final_E_chain = E[:1]
254
+
255
+ # chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
256
+ # chain_E[0] = final_E_chain
257
+
258
+ # chain_X = utils.reverse_tensor(chain_X)
259
+ # chain_E = utils.reverse_tensor(chain_E)
260
+
261
+ # # Repeat last frame to see final sample better
262
+ # chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
263
+ # chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
264
+ # mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
265
+ # else:
266
+ # mol_img_list = []
267
+
268
+ # return smiles, mol_img_list
269
+
270
+ # def check_valid(self, smiles):
271
+ # return check_valid(smiles)
272
 
273
+ # def sample_p_zs_given_zt(
274
+ # self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
275
+ # ):
276
+ # """Samples from zs ~ p(zs | zt). Only used during sampling.
277
+ # if last_step, return the graph prediction as well"""
278
+ # bs, n, _ = X_t.shape
279
+ # beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
280
+ # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
281
+ # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
282
+
283
+ # # Neural net predictions
284
+ # noisy_data = {
285
+ # "X_t": X_t,
286
+ # "E_t": E_t,
287
+ # "y_t": properties,
288
+ # "t": t,
289
+ # "node_mask": node_mask,
290
+ # }
291
+
292
+ # def get_prob(noisy_data, unconditioned=False):
293
+ # pred = self._forward(noisy_data, unconditioned=unconditioned)
294
+
295
+ # # Normalize predictions
296
+ # pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
297
+ # pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
298
+
299
+ # # Retrieve transitions matrix
300
+ # Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
301
+ # Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
302
+ # Qt = self.transition_model.get_Qt(beta_t, device)
303
+
304
+ # Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
305
+ # predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
306
+
307
+ # unnormalized_probX_all = utils.reverse_diffusion(
308
+ # predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
309
+ # )
310
+
311
+ # unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
312
+ # unnormalized_prob_E = unnormalized_probX_all[
313
+ # :, :, self.Xdim_output :
314
+ # ].reshape(bs, n * n, -1)
315
+
316
+ # unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
317
+ # unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
318
+
319
+ # prob_X = unnormalized_prob_X / torch.sum(
320
+ # unnormalized_prob_X, dim=-1, keepdim=True
321
+ # ) # bs, n, d_t-1
322
+ # prob_E = unnormalized_prob_E / torch.sum(
323
+ # unnormalized_prob_E, dim=-1, keepdim=True
324
+ # ) # bs, n, d_t-1
325
+ # prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
326
+
327
+ # return prob_X, prob_E
328
+
329
+ # prob_X, prob_E = get_prob(noisy_data)
330
+
331
+ # ### Guidance
332
+ # if guide_scale != 1:
333
+ # uncon_prob_X, uncon_prob_E = get_prob(
334
+ # noisy_data, unconditioned=True
335
+ # )
336
+ # prob_X = (
337
+ # uncon_prob_X
338
+ # * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
339
+ # )
340
+ # prob_E = (
341
+ # uncon_prob_E
342
+ # * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
343
+ # )
344
+ # prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
345
+ # prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
346
+
347
+ # # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
348
+ # # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
349
+
350
+ # sampled_s = utils.sample_discrete_features(
351
+ # prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
352
+ # )
353
+
354
+ # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
355
+ # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
356
+
357
+ # assert (E_s == torch.transpose(E_s, 1, 2)).all()
358
+ # assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
359
+
360
+ # out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
361
+ # out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
362
+
363
+ # return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
364
+ # node_mask, collapse=True
365
+ # ).type_as(properties)
366