liuganghuggingface commited on
Commit
6090d6f
·
verified ·
1 Parent(s): cd16fe3

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +352 -338
graph_decoder/diffusion_model.py CHANGED
@@ -19,348 +19,362 @@ class GraphDiT(nn.Module):
19
  model_dtype,
20
  ):
21
  super().__init__()
 
 
 
 
 
22
  pass
 
23
 
24
- # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
25
-
26
- # input_dims = data_info.input_dims
27
- # output_dims = data_info.output_dims
28
- # nodes_dist = data_info.nodes_dist
29
- # active_index = data_info.active_index
30
-
31
- # self.model_config = dm_cfg
32
- # self.data_info = data_info
33
- # self.T = dm_cfg.diffusion_steps
34
- # self.Xdim = input_dims["X"]
35
- # self.Edim = input_dims["E"]
36
- # self.ydim = input_dims["y"]
37
- # self.Xdim_output = output_dims["X"]
38
- # self.Edim_output = output_dims["E"]
39
- # self.ydim_output = output_dims["y"]
40
- # self.node_dist = nodes_dist
41
- # self.active_index = active_index
42
- # self.max_n_nodes = data_info.max_n_nodes
43
- # self.atom_decoder = data_info.atom_decoder
44
- # self.hidden_size = dm_cfg.hidden_size
45
- # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
46
-
47
- # self.denoiser = Transformer(
48
- # max_n_nodes=self.max_n_nodes,
49
- # hidden_size=dm_cfg.hidden_size,
50
- # depth=dm_cfg.depth,
51
- # num_heads=dm_cfg.num_heads,
52
- # mlp_ratio=dm_cfg.mlp_ratio,
53
- # drop_condition=dm_cfg.drop_condition,
54
- # Xdim=self.Xdim,
55
- # Edim=self.Edim,
56
- # ydim=self.ydim,
57
- # )
58
-
59
- # self.model_dtype = model_dtype
60
- # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
61
- # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
62
- # )
63
- # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
64
- # data_info.node_types.to(self.model_dtype)
65
- # )
66
- # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
67
- # data_info.edge_types.to(self.model_dtype)
68
- # )
69
- # x_marginals = x_marginals / x_marginals.sum()
70
- # e_marginals = e_marginals / e_marginals.sum()
71
-
72
- # xe_conditions = data_info.transition_E.to(self.model_dtype)
73
- # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
74
-
75
- # xe_conditions = xe_conditions.sum(dim=1)
76
- # ex_conditions = xe_conditions.t()
77
- # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
78
- # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
79
-
80
- # self.transition_model = utils.MarginalTransition(
81
- # x_marginals=x_marginals,
82
- # e_marginals=e_marginals,
83
- # xe_conditions=xe_conditions,
84
- # ex_conditions=ex_conditions,
85
- # y_classes=self.ydim_output,
86
- # n_nodes=self.max_n_nodes,
87
- # )
88
- # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
89
-
90
- # def init_model(self, model_dir):
91
- # model_file = os.path.join(model_dir, 'model.pt')
92
- # if os.path.exists(model_file):
93
- # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
94
- # else:
95
- # raise FileNotFoundError(f"Model file not found: {model_file}")
96
-
97
- # def disable_grads(self):
98
- # self.denoiser.disable_grads()
 
 
 
 
 
 
 
 
99
 
100
- # def forward(
101
- # self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
102
- # ):
103
- # raise ValueError('Not Implement')
104
-
105
- # def _forward(self, noisy_data, unconditioned=False):
106
- # noisy_x, noisy_e, properties = (
107
- # noisy_data["X_t"].to(self.model_dtype),
108
- # noisy_data["E_t"].to(self.model_dtype),
109
- # noisy_data["y_t"].to(self.model_dtype).clone(),
110
- # )
111
- # node_mask, timestep = (
112
- # noisy_data["node_mask"],
113
- # noisy_data["t"],
114
- # )
115
 
116
- # pred = self.denoiser(
117
- # noisy_x,
118
- # noisy_e,
119
- # node_mask,
120
- # properties,
121
- # timestep,
122
- # unconditioned=unconditioned,
123
- # )
124
- # return pred
125
-
126
- # def apply_noise(self, X, E, y, node_mask):
127
- # """Sample noise and apply it to the data."""
128
-
129
- # # Sample a timestep t.
130
- # # When evaluating, the loss for t=0 is computed separately
131
- # lowest_t = 0 if self.training else 1
132
- # t_int = torch.randint(
133
- # lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
134
- # ).to(
135
- # self.model_dtype
136
- # ) # (bs, 1)
137
- # s_int = t_int - 1
138
-
139
- # t_float = t_int / self.T
140
- # s_float = s_int / self.T
141
-
142
- # # beta_t and alpha_s_bar are used for denoising/loss computation
143
- # beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
144
- # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
145
- # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
146
-
147
- # Qtb = self.transition_model.get_Qt_bar(
148
- # alpha_t_bar, X.device
149
- # ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
150
-
151
- # bs, n, d = X.shape
152
- # X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
153
- # prob_all = X_all @ Qtb.X
154
- # probX = prob_all[:, :, : self.Xdim_output]
155
- # probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
156
-
157
- # sampled_t = utils.sample_discrete_features(
158
- # probX=probX, probE=probE, node_mask=node_mask
159
- # )
160
-
161
- # X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
162
- # E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
163
- # assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
164
-
165
- # y_t = y
166
- # z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
167
-
168
- # noisy_data = {
169
- # "t_int": t_int,
170
- # "t": t_float,
171
- # "beta_t": beta_t,
172
- # "alpha_s_bar": alpha_s_bar,
173
- # "alpha_t_bar": alpha_t_bar,
174
- # "X_t": z_t.X,
175
- # "E_t": z_t.E,
176
- # "y_t": z_t.y,
177
- # "node_mask": node_mask,
178
- # }
179
- # return noisy_data
180
-
181
- # @torch.no_grad()
182
- # def generate(
183
- # self,
184
- # properties,
185
- # guide_scale=1.,
186
- # num_nodes=None,
187
- # number_chain_steps=50,
188
- # ):
189
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
- # properties = [float('nan') if x is None else x for x in properties]
191
- # properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
192
- # batch_size = properties.size(0)
193
- # assert batch_size == 1
194
- # if num_nodes is None:
195
- # num_nodes = self.node_dist.sample_n(batch_size, device)
196
- # else:
197
- # num_nodes = torch.LongTensor([num_nodes]).to(device)
198
-
199
- # arange = (
200
- # torch.arange(self.max_n_nodes, device=device)
201
- # .unsqueeze(0)
202
- # .expand(batch_size, -1)
203
- # )
204
- # node_mask = arange < num_nodes.unsqueeze(1)
205
-
206
- # z_T = utils.sample_discrete_feature_noise(
207
- # limit_dist=self.limit_dist, node_mask=node_mask
208
- # )
209
- # X, E = z_T.X, z_T.E
210
-
211
- # assert (E == torch.transpose(E, 1, 2)).all()
212
-
213
- # if number_chain_steps > 0:
214
- # chain_X_size = torch.Size((number_chain_steps, X.size(1)))
215
- # chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
216
- # chain_X = torch.zeros(chain_X_size)
217
- # chain_E = torch.zeros(chain_E_size)
218
-
219
- # # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
220
- # y = properties
221
- # for s_int in reversed(range(0, self.T)):
222
- # s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
223
- # t_array = s_array + 1
224
- # s_norm = s_array / self.T
225
- # t_norm = t_array / self.T
226
-
227
- # # Sample z_s
228
- # sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
229
- # s_norm, t_norm, X, E, y, node_mask, guide_scale, device
230
- # )
231
- # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
232
 
233
- # if number_chain_steps > 0:
234
- # # Save the first keep_chain graphs
235
- # write_index = (s_int * number_chain_steps) // self.T
236
- # chain_X[write_index] = discrete_sampled_s.X[:1]
237
- # chain_E[write_index] = discrete_sampled_s.E[:1]
238
-
239
- # # Sample
240
- # sampled_s = sampled_s.mask(node_mask, collapse=True)
241
- # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
242
-
243
- # molecule_list = []
244
- # n = num_nodes[0]
245
- # atom_types = X[0, :n].cpu()
246
- # edge_types = E[0, :n, :n].cpu()
247
- # molecule_list.append([atom_types, edge_types])
248
- # smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
249
-
250
- # # Visualize Chains
251
- # if number_chain_steps > 0:
252
- # final_X_chain = X[:1]
253
- # final_E_chain = E[:1]
254
-
255
- # chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
256
- # chain_E[0] = final_E_chain
257
-
258
- # chain_X = utils.reverse_tensor(chain_X)
259
- # chain_E = utils.reverse_tensor(chain_E)
260
-
261
- # # Repeat last frame to see final sample better
262
- # chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
263
- # chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
264
- # mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
265
- # else:
266
- # mol_img_list = []
267
-
268
- # return smiles, mol_img_list
269
-
270
- # def check_valid(self, smiles):
271
- # return check_valid(smiles)
272
 
273
- # def sample_p_zs_given_zt(
274
- # self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
275
- # ):
276
- # """Samples from zs ~ p(zs | zt). Only used during sampling.
277
- # if last_step, return the graph prediction as well"""
278
- # bs, n, _ = X_t.shape
279
- # beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
280
- # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
281
- # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
282
-
283
- # # Neural net predictions
284
- # noisy_data = {
285
- # "X_t": X_t,
286
- # "E_t": E_t,
287
- # "y_t": properties,
288
- # "t": t,
289
- # "node_mask": node_mask,
290
- # }
291
-
292
- # def get_prob(noisy_data, unconditioned=False):
293
- # pred = self._forward(noisy_data, unconditioned=unconditioned)
294
-
295
- # # Normalize predictions
296
- # pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
297
- # pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
298
-
299
- # # Retrieve transitions matrix
300
- # Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
301
- # Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
302
- # Qt = self.transition_model.get_Qt(beta_t, device)
303
-
304
- # Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
305
- # predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
306
-
307
- # unnormalized_probX_all = utils.reverse_diffusion(
308
- # predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
309
- # )
310
-
311
- # unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
312
- # unnormalized_prob_E = unnormalized_probX_all[
313
- # :, :, self.Xdim_output :
314
- # ].reshape(bs, n * n, -1)
315
-
316
- # unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
317
- # unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
318
-
319
- # prob_X = unnormalized_prob_X / torch.sum(
320
- # unnormalized_prob_X, dim=-1, keepdim=True
321
- # ) # bs, n, d_t-1
322
- # prob_E = unnormalized_prob_E / torch.sum(
323
- # unnormalized_prob_E, dim=-1, keepdim=True
324
- # ) # bs, n, d_t-1
325
- # prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
326
-
327
- # return prob_X, prob_E
328
-
329
- # prob_X, prob_E = get_prob(noisy_data)
330
-
331
- # ### Guidance
332
- # if guide_scale != 1:
333
- # uncon_prob_X, uncon_prob_E = get_prob(
334
- # noisy_data, unconditioned=True
335
- # )
336
- # prob_X = (
337
- # uncon_prob_X
338
- # * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
339
- # )
340
- # prob_E = (
341
- # uncon_prob_E
342
- # * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
343
- # )
344
- # prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
345
- # prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
346
-
347
- # # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
348
- # # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
349
-
350
- # sampled_s = utils.sample_discrete_features(
351
- # prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
352
- # )
353
-
354
- # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
355
- # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
356
-
357
- # assert (E_s == torch.transpose(E_s, 1, 2)).all()
358
- # assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
359
-
360
- # out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
361
- # out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
362
-
363
- # return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
364
- # node_mask, collapse=True
365
- # ).type_as(properties)
366
 
 
19
  model_dtype,
20
  ):
21
  super().__init__()
22
+
23
+ def init_model(self, model_dir):
24
+ pass
25
+
26
+ def disable_grads(self):
27
  pass
28
+
29
 
30
+ # class GraphDiT(nn.Module):
31
+ # def __init__(
32
+ # self,
33
+ # model_config_path,
34
+ # data_info_path,
35
+ # model_dtype,
36
+ # ):
37
+ # super().__init__()
38
+ # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
39
+
40
+ # input_dims = data_info.input_dims
41
+ # output_dims = data_info.output_dims
42
+ # nodes_dist = data_info.nodes_dist
43
+ # active_index = data_info.active_index
44
+
45
+ # self.model_config = dm_cfg
46
+ # self.data_info = data_info
47
+ # self.T = dm_cfg.diffusion_steps
48
+ # self.Xdim = input_dims["X"]
49
+ # self.Edim = input_dims["E"]
50
+ # self.ydim = input_dims["y"]
51
+ # self.Xdim_output = output_dims["X"]
52
+ # self.Edim_output = output_dims["E"]
53
+ # self.ydim_output = output_dims["y"]
54
+ # self.node_dist = nodes_dist
55
+ # self.active_index = active_index
56
+ # self.max_n_nodes = data_info.max_n_nodes
57
+ # self.atom_decoder = data_info.atom_decoder
58
+ # self.hidden_size = dm_cfg.hidden_size
59
+ # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
60
+
61
+ # self.denoiser = Transformer(
62
+ # max_n_nodes=self.max_n_nodes,
63
+ # hidden_size=dm_cfg.hidden_size,
64
+ # depth=dm_cfg.depth,
65
+ # num_heads=dm_cfg.num_heads,
66
+ # mlp_ratio=dm_cfg.mlp_ratio,
67
+ # drop_condition=dm_cfg.drop_condition,
68
+ # Xdim=self.Xdim,
69
+ # Edim=self.Edim,
70
+ # ydim=self.ydim,
71
+ # )
72
+
73
+ # self.model_dtype = model_dtype
74
+ # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
75
+ # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
76
+ # )
77
+ # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
78
+ # data_info.node_types.to(self.model_dtype)
79
+ # )
80
+ # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
81
+ # data_info.edge_types.to(self.model_dtype)
82
+ # )
83
+ # x_marginals = x_marginals / x_marginals.sum()
84
+ # e_marginals = e_marginals / e_marginals.sum()
85
+
86
+ # xe_conditions = data_info.transition_E.to(self.model_dtype)
87
+ # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
88
+
89
+ # xe_conditions = xe_conditions.sum(dim=1)
90
+ # ex_conditions = xe_conditions.t()
91
+ # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
92
+ # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
93
+
94
+ # self.transition_model = utils.MarginalTransition(
95
+ # x_marginals=x_marginals,
96
+ # e_marginals=e_marginals,
97
+ # xe_conditions=xe_conditions,
98
+ # ex_conditions=ex_conditions,
99
+ # y_classes=self.ydim_output,
100
+ # n_nodes=self.max_n_nodes,
101
+ # )
102
+ # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
103
+
104
+ # def init_model(self, model_dir):
105
+ # model_file = os.path.join(model_dir, 'model.pt')
106
+ # if os.path.exists(model_file):
107
+ # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
108
+ # else:
109
+ # raise FileNotFoundError(f"Model file not found: {model_file}")
110
+
111
+ # def disable_grads(self):
112
+ # self.denoiser.disable_grads()
113
 
114
+ # def forward(
115
+ # self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
116
+ # ):
117
+ # raise ValueError('Not Implement')
118
+
119
+ # def _forward(self, noisy_data, unconditioned=False):
120
+ # noisy_x, noisy_e, properties = (
121
+ # noisy_data["X_t"].to(self.model_dtype),
122
+ # noisy_data["E_t"].to(self.model_dtype),
123
+ # noisy_data["y_t"].to(self.model_dtype).clone(),
124
+ # )
125
+ # node_mask, timestep = (
126
+ # noisy_data["node_mask"],
127
+ # noisy_data["t"],
128
+ # )
129
 
130
+ # pred = self.denoiser(
131
+ # noisy_x,
132
+ # noisy_e,
133
+ # node_mask,
134
+ # properties,
135
+ # timestep,
136
+ # unconditioned=unconditioned,
137
+ # )
138
+ # return pred
139
+
140
+ # def apply_noise(self, X, E, y, node_mask):
141
+ # """Sample noise and apply it to the data."""
142
+
143
+ # # Sample a timestep t.
144
+ # # When evaluating, the loss for t=0 is computed separately
145
+ # lowest_t = 0 if self.training else 1
146
+ # t_int = torch.randint(
147
+ # lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
148
+ # ).to(
149
+ # self.model_dtype
150
+ # ) # (bs, 1)
151
+ # s_int = t_int - 1
152
+
153
+ # t_float = t_int / self.T
154
+ # s_float = s_int / self.T
155
+
156
+ # # beta_t and alpha_s_bar are used for denoising/loss computation
157
+ # beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
158
+ # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
159
+ # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
160
+
161
+ # Qtb = self.transition_model.get_Qt_bar(
162
+ # alpha_t_bar, X.device
163
+ # ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
164
+
165
+ # bs, n, d = X.shape
166
+ # X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
167
+ # prob_all = X_all @ Qtb.X
168
+ # probX = prob_all[:, :, : self.Xdim_output]
169
+ # probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
170
+
171
+ # sampled_t = utils.sample_discrete_features(
172
+ # probX=probX, probE=probE, node_mask=node_mask
173
+ # )
174
+
175
+ # X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
176
+ # E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
177
+ # assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
178
+
179
+ # y_t = y
180
+ # z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
181
+
182
+ # noisy_data = {
183
+ # "t_int": t_int,
184
+ # "t": t_float,
185
+ # "beta_t": beta_t,
186
+ # "alpha_s_bar": alpha_s_bar,
187
+ # "alpha_t_bar": alpha_t_bar,
188
+ # "X_t": z_t.X,
189
+ # "E_t": z_t.E,
190
+ # "y_t": z_t.y,
191
+ # "node_mask": node_mask,
192
+ # }
193
+ # return noisy_data
194
+
195
+ # @torch.no_grad()
196
+ # def generate(
197
+ # self,
198
+ # properties,
199
+ # guide_scale=1.,
200
+ # num_nodes=None,
201
+ # number_chain_steps=50,
202
+ # ):
203
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ # properties = [float('nan') if x is None else x for x in properties]
205
+ # properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
206
+ # batch_size = properties.size(0)
207
+ # assert batch_size == 1
208
+ # if num_nodes is None:
209
+ # num_nodes = self.node_dist.sample_n(batch_size, device)
210
+ # else:
211
+ # num_nodes = torch.LongTensor([num_nodes]).to(device)
212
+
213
+ # arange = (
214
+ # torch.arange(self.max_n_nodes, device=device)
215
+ # .unsqueeze(0)
216
+ # .expand(batch_size, -1)
217
+ # )
218
+ # node_mask = arange < num_nodes.unsqueeze(1)
219
+
220
+ # z_T = utils.sample_discrete_feature_noise(
221
+ # limit_dist=self.limit_dist, node_mask=node_mask
222
+ # )
223
+ # X, E = z_T.X, z_T.E
224
+
225
+ # assert (E == torch.transpose(E, 1, 2)).all()
226
+
227
+ # if number_chain_steps > 0:
228
+ # chain_X_size = torch.Size((number_chain_steps, X.size(1)))
229
+ # chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
230
+ # chain_X = torch.zeros(chain_X_size)
231
+ # chain_E = torch.zeros(chain_E_size)
232
+
233
+ # # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
234
+ # y = properties
235
+ # for s_int in reversed(range(0, self.T)):
236
+ # s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
237
+ # t_array = s_array + 1
238
+ # s_norm = s_array / self.T
239
+ # t_norm = t_array / self.T
240
+
241
+ # # Sample z_s
242
+ # sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
243
+ # s_norm, t_norm, X, E, y, node_mask, guide_scale, device
244
+ # )
245
+ # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
246
 
247
+ # if number_chain_steps > 0:
248
+ # # Save the first keep_chain graphs
249
+ # write_index = (s_int * number_chain_steps) // self.T
250
+ # chain_X[write_index] = discrete_sampled_s.X[:1]
251
+ # chain_E[write_index] = discrete_sampled_s.E[:1]
252
+
253
+ # # Sample
254
+ # sampled_s = sampled_s.mask(node_mask, collapse=True)
255
+ # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
256
+
257
+ # molecule_list = []
258
+ # n = num_nodes[0]
259
+ # atom_types = X[0, :n].cpu()
260
+ # edge_types = E[0, :n, :n].cpu()
261
+ # molecule_list.append([atom_types, edge_types])
262
+ # smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
263
+
264
+ # # Visualize Chains
265
+ # if number_chain_steps > 0:
266
+ # final_X_chain = X[:1]
267
+ # final_E_chain = E[:1]
268
+
269
+ # chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
270
+ # chain_E[0] = final_E_chain
271
+
272
+ # chain_X = utils.reverse_tensor(chain_X)
273
+ # chain_E = utils.reverse_tensor(chain_E)
274
+
275
+ # # Repeat last frame to see final sample better
276
+ # chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
277
+ # chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
278
+ # mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
279
+ # else:
280
+ # mol_img_list = []
281
+
282
+ # return smiles, mol_img_list
283
+
284
+ # def check_valid(self, smiles):
285
+ # return check_valid(smiles)
286
 
287
+ # def sample_p_zs_given_zt(
288
+ # self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
289
+ # ):
290
+ # """Samples from zs ~ p(zs | zt). Only used during sampling.
291
+ # if last_step, return the graph prediction as well"""
292
+ # bs, n, _ = X_t.shape
293
+ # beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
294
+ # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
295
+ # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
296
+
297
+ # # Neural net predictions
298
+ # noisy_data = {
299
+ # "X_t": X_t,
300
+ # "E_t": E_t,
301
+ # "y_t": properties,
302
+ # "t": t,
303
+ # "node_mask": node_mask,
304
+ # }
305
+
306
+ # def get_prob(noisy_data, unconditioned=False):
307
+ # pred = self._forward(noisy_data, unconditioned=unconditioned)
308
+
309
+ # # Normalize predictions
310
+ # pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
311
+ # pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
312
+
313
+ # # Retrieve transitions matrix
314
+ # Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
315
+ # Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
316
+ # Qt = self.transition_model.get_Qt(beta_t, device)
317
+
318
+ # Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
319
+ # predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
320
+
321
+ # unnormalized_probX_all = utils.reverse_diffusion(
322
+ # predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
323
+ # )
324
+
325
+ # unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
326
+ # unnormalized_prob_E = unnormalized_probX_all[
327
+ # :, :, self.Xdim_output :
328
+ # ].reshape(bs, n * n, -1)
329
+
330
+ # unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
331
+ # unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
332
+
333
+ # prob_X = unnormalized_prob_X / torch.sum(
334
+ # unnormalized_prob_X, dim=-1, keepdim=True
335
+ # ) # bs, n, d_t-1
336
+ # prob_E = unnormalized_prob_E / torch.sum(
337
+ # unnormalized_prob_E, dim=-1, keepdim=True
338
+ # ) # bs, n, d_t-1
339
+ # prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
340
+
341
+ # return prob_X, prob_E
342
+
343
+ # prob_X, prob_E = get_prob(noisy_data)
344
+
345
+ # ### Guidance
346
+ # if guide_scale != 1:
347
+ # uncon_prob_X, uncon_prob_E = get_prob(
348
+ # noisy_data, unconditioned=True
349
+ # )
350
+ # prob_X = (
351
+ # uncon_prob_X
352
+ # * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
353
+ # )
354
+ # prob_E = (
355
+ # uncon_prob_E
356
+ # * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
357
+ # )
358
+ # prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
359
+ # prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
360
+
361
+ # # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
362
+ # # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
363
+
364
+ # sampled_s = utils.sample_discrete_features(
365
+ # prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
366
+ # )
367
+
368
+ # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
369
+ # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
370
+
371
+ # assert (E_s == torch.transpose(E_s, 1, 2)).all()
372
+ # assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
373
+
374
+ # out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
375
+ # out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
376
+
377
+ # return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
378
+ # node_mask, collapse=True
379
+ # ).type_as(properties)
380