liuganghuggingface commited on
Commit
58c5fe2
·
verified ·
1 Parent(s): c7a6b58

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +355 -355
graph_decoder/diffusion_model.py CHANGED
@@ -6,27 +6,11 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
- # from . import diffusion_utils as utils
10
- # from .molecule_utils import graph_to_smiles, check_valid
11
- # from .transformer import Transformer
12
- # from .visualize_utils import MolecularVisualization
13
 
14
- class GraphDiT(nn.Module):
15
- def __init__(
16
- self,
17
- model_config_path,
18
- data_info_path,
19
- model_dtype,
20
- ):
21
- super().__init__()
22
-
23
- def init_model(self, model_dir):
24
- pass
25
-
26
- def disable_grads(self):
27
- pass
28
-
29
-
30
  # class GraphDiT(nn.Module):
31
  # def __init__(
32
  # self,
@@ -35,346 +19,362 @@ class GraphDiT(nn.Module):
35
  # model_dtype,
36
  # ):
37
  # super().__init__()
38
- # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
39
-
40
- # input_dims = data_info.input_dims
41
- # output_dims = data_info.output_dims
42
- # nodes_dist = data_info.nodes_dist
43
- # active_index = data_info.active_index
44
-
45
- # self.model_config = dm_cfg
46
- # self.data_info = data_info
47
- # self.T = dm_cfg.diffusion_steps
48
- # self.Xdim = input_dims["X"]
49
- # self.Edim = input_dims["E"]
50
- # self.ydim = input_dims["y"]
51
- # self.Xdim_output = output_dims["X"]
52
- # self.Edim_output = output_dims["E"]
53
- # self.ydim_output = output_dims["y"]
54
- # self.node_dist = nodes_dist
55
- # self.active_index = active_index
56
- # self.max_n_nodes = data_info.max_n_nodes
57
- # self.atom_decoder = data_info.atom_decoder
58
- # self.hidden_size = dm_cfg.hidden_size
59
- # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
60
-
61
- # self.denoiser = Transformer(
62
- # max_n_nodes=self.max_n_nodes,
63
- # hidden_size=dm_cfg.hidden_size,
64
- # depth=dm_cfg.depth,
65
- # num_heads=dm_cfg.num_heads,
66
- # mlp_ratio=dm_cfg.mlp_ratio,
67
- # drop_condition=dm_cfg.drop_condition,
68
- # Xdim=self.Xdim,
69
- # Edim=self.Edim,
70
- # ydim=self.ydim,
71
- # )
72
-
73
- # self.model_dtype = model_dtype
74
- # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
75
- # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
76
- # )
77
- # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
78
- # data_info.node_types.to(self.model_dtype)
79
- # )
80
- # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
81
- # data_info.edge_types.to(self.model_dtype)
82
- # )
83
- # x_marginals = x_marginals / x_marginals.sum()
84
- # e_marginals = e_marginals / e_marginals.sum()
85
-
86
- # xe_conditions = data_info.transition_E.to(self.model_dtype)
87
- # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
88
-
89
- # xe_conditions = xe_conditions.sum(dim=1)
90
- # ex_conditions = xe_conditions.t()
91
- # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
92
- # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
93
-
94
- # self.transition_model = utils.MarginalTransition(
95
- # x_marginals=x_marginals,
96
- # e_marginals=e_marginals,
97
- # xe_conditions=xe_conditions,
98
- # ex_conditions=ex_conditions,
99
- # y_classes=self.ydim_output,
100
- # n_nodes=self.max_n_nodes,
101
- # )
102
- # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
103
 
104
  # def init_model(self, model_dir):
105
- # model_file = os.path.join(model_dir, 'model.pt')
106
- # if os.path.exists(model_file):
107
- # self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
108
- # else:
109
- # raise FileNotFoundError(f"Model file not found: {model_file}")
110
-
111
  # def disable_grads(self):
112
- # self.denoiser.disable_grads()
113
 
114
- # def forward(
115
- # self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
116
- # ):
117
- # raise ValueError('Not Implement')
118
-
119
- # def _forward(self, noisy_data, unconditioned=False):
120
- # noisy_x, noisy_e, properties = (
121
- # noisy_data["X_t"].to(self.model_dtype),
122
- # noisy_data["E_t"].to(self.model_dtype),
123
- # noisy_data["y_t"].to(self.model_dtype).clone(),
124
- # )
125
- # node_mask, timestep = (
126
- # noisy_data["node_mask"],
127
- # noisy_data["t"],
128
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- # pred = self.denoiser(
131
- # noisy_x,
132
- # noisy_e,
133
- # node_mask,
134
- # properties,
135
- # timestep,
136
- # unconditioned=unconditioned,
137
- # )
138
- # return pred
139
-
140
- # def apply_noise(self, X, E, y, node_mask):
141
- # """Sample noise and apply it to the data."""
142
-
143
- # # Sample a timestep t.
144
- # # When evaluating, the loss for t=0 is computed separately
145
- # lowest_t = 0 if self.training else 1
146
- # t_int = torch.randint(
147
- # lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
148
- # ).to(
149
- # self.model_dtype
150
- # ) # (bs, 1)
151
- # s_int = t_int - 1
152
-
153
- # t_float = t_int / self.T
154
- # s_float = s_int / self.T
155
-
156
- # # beta_t and alpha_s_bar are used for denoising/loss computation
157
- # beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
158
- # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
159
- # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
160
-
161
- # Qtb = self.transition_model.get_Qt_bar(
162
- # alpha_t_bar, X.device
163
- # ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
164
-
165
- # bs, n, d = X.shape
166
- # X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
167
- # prob_all = X_all @ Qtb.X
168
- # probX = prob_all[:, :, : self.Xdim_output]
169
- # probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
170
-
171
- # sampled_t = utils.sample_discrete_features(
172
- # probX=probX, probE=probE, node_mask=node_mask
173
- # )
174
-
175
- # X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
176
- # E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
177
- # assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
178
-
179
- # y_t = y
180
- # z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
181
-
182
- # noisy_data = {
183
- # "t_int": t_int,
184
- # "t": t_float,
185
- # "beta_t": beta_t,
186
- # "alpha_s_bar": alpha_s_bar,
187
- # "alpha_t_bar": alpha_t_bar,
188
- # "X_t": z_t.X,
189
- # "E_t": z_t.E,
190
- # "y_t": z_t.y,
191
- # "node_mask": node_mask,
192
- # }
193
- # return noisy_data
194
-
195
- # @torch.no_grad()
196
- # def generate(
197
- # self,
198
- # properties,
199
- # guide_scale=1.,
200
- # num_nodes=None,
201
- # number_chain_steps=50,
202
- # ):
203
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
- # properties = [float('nan') if x is None else x for x in properties]
205
- # properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
206
- # batch_size = properties.size(0)
207
- # assert batch_size == 1
208
- # if num_nodes is None:
209
- # num_nodes = self.node_dist.sample_n(batch_size, device)
210
- # else:
211
- # num_nodes = torch.LongTensor([num_nodes]).to(device)
212
-
213
- # arange = (
214
- # torch.arange(self.max_n_nodes, device=device)
215
- # .unsqueeze(0)
216
- # .expand(batch_size, -1)
217
- # )
218
- # node_mask = arange < num_nodes.unsqueeze(1)
219
-
220
- # z_T = utils.sample_discrete_feature_noise(
221
- # limit_dist=self.limit_dist, node_mask=node_mask
222
- # )
223
- # X, E = z_T.X, z_T.E
224
-
225
- # assert (E == torch.transpose(E, 1, 2)).all()
226
-
227
- # if number_chain_steps > 0:
228
- # chain_X_size = torch.Size((number_chain_steps, X.size(1)))
229
- # chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
230
- # chain_X = torch.zeros(chain_X_size)
231
- # chain_E = torch.zeros(chain_E_size)
232
-
233
- # # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
234
- # y = properties
235
- # for s_int in reversed(range(0, self.T)):
236
- # s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
237
- # t_array = s_array + 1
238
- # s_norm = s_array / self.T
239
- # t_norm = t_array / self.T
240
-
241
- # # Sample z_s
242
- # sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
243
- # s_norm, t_norm, X, E, y, node_mask, guide_scale, device
244
- # )
245
- # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
246
 
247
- # if number_chain_steps > 0:
248
- # # Save the first keep_chain graphs
249
- # write_index = (s_int * number_chain_steps) // self.T
250
- # chain_X[write_index] = discrete_sampled_s.X[:1]
251
- # chain_E[write_index] = discrete_sampled_s.E[:1]
252
-
253
- # # Sample
254
- # sampled_s = sampled_s.mask(node_mask, collapse=True)
255
- # X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
256
-
257
- # molecule_list = []
258
- # n = num_nodes[0]
259
- # atom_types = X[0, :n].cpu()
260
- # edge_types = E[0, :n, :n].cpu()
261
- # molecule_list.append([atom_types, edge_types])
262
- # smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
263
-
264
- # # Visualize Chains
265
- # if number_chain_steps > 0:
266
- # final_X_chain = X[:1]
267
- # final_E_chain = E[:1]
268
-
269
- # chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
270
- # chain_E[0] = final_E_chain
271
-
272
- # chain_X = utils.reverse_tensor(chain_X)
273
- # chain_E = utils.reverse_tensor(chain_E)
274
-
275
- # # Repeat last frame to see final sample better
276
- # chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
277
- # chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
278
- # mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
279
- # else:
280
- # mol_img_list = []
281
-
282
- # return smiles, mol_img_list
283
-
284
- # def check_valid(self, smiles):
285
- # return check_valid(smiles)
286
 
287
- # def sample_p_zs_given_zt(
288
- # self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
289
- # ):
290
- # """Samples from zs ~ p(zs | zt). Only used during sampling.
291
- # if last_step, return the graph prediction as well"""
292
- # bs, n, _ = X_t.shape
293
- # beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
294
- # alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
295
- # alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
296
-
297
- # # Neural net predictions
298
- # noisy_data = {
299
- # "X_t": X_t,
300
- # "E_t": E_t,
301
- # "y_t": properties,
302
- # "t": t,
303
- # "node_mask": node_mask,
304
- # }
305
-
306
- # def get_prob(noisy_data, unconditioned=False):
307
- # pred = self._forward(noisy_data, unconditioned=unconditioned)
308
-
309
- # # Normalize predictions
310
- # pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
311
- # pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
312
-
313
- # # Retrieve transitions matrix
314
- # Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
315
- # Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
316
- # Qt = self.transition_model.get_Qt(beta_t, device)
317
-
318
- # Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
319
- # predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
320
-
321
- # unnormalized_probX_all = utils.reverse_diffusion(
322
- # predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
323
- # )
324
-
325
- # unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
326
- # unnormalized_prob_E = unnormalized_probX_all[
327
- # :, :, self.Xdim_output :
328
- # ].reshape(bs, n * n, -1)
329
-
330
- # unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
331
- # unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
332
-
333
- # prob_X = unnormalized_prob_X / torch.sum(
334
- # unnormalized_prob_X, dim=-1, keepdim=True
335
- # ) # bs, n, d_t-1
336
- # prob_E = unnormalized_prob_E / torch.sum(
337
- # unnormalized_prob_E, dim=-1, keepdim=True
338
- # ) # bs, n, d_t-1
339
- # prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
340
-
341
- # return prob_X, prob_E
342
-
343
- # prob_X, prob_E = get_prob(noisy_data)
344
-
345
- # ### Guidance
346
- # if guide_scale != 1:
347
- # uncon_prob_X, uncon_prob_E = get_prob(
348
- # noisy_data, unconditioned=True
349
- # )
350
- # prob_X = (
351
- # uncon_prob_X
352
- # * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
353
- # )
354
- # prob_E = (
355
- # uncon_prob_E
356
- # * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
357
- # )
358
- # prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
359
- # prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
360
-
361
- # # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
362
- # # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
363
-
364
- # sampled_s = utils.sample_discrete_features(
365
- # prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
366
- # )
367
-
368
- # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
369
- # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
370
-
371
- # assert (E_s == torch.transpose(E_s, 1, 2)).all()
372
- # assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
373
-
374
- # out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
375
- # out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
376
-
377
- # return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
378
- # node_mask, collapse=True
379
- # ).type_as(properties)
380
 
 
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
+ from . import diffusion_utils as utils
10
+ from .molecule_utils import graph_to_smiles, check_valid
11
+ from .transformer import Transformer
12
+ from .visualize_utils import MolecularVisualization
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # class GraphDiT(nn.Module):
15
  # def __init__(
16
  # self,
 
19
  # model_dtype,
20
  # ):
21
  # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # def init_model(self, model_dir):
24
+ # pass
25
+
 
 
 
 
26
  # def disable_grads(self):
27
+ # pass
28
 
29
+
30
+ class GraphDiT(nn.Module):
31
+ def __init__(
32
+ self,
33
+ model_config_path,
34
+ data_info_path,
35
+ model_dtype,
36
+ ):
37
+ super().__init__()
38
+ dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
39
+
40
+ input_dims = data_info.input_dims
41
+ output_dims = data_info.output_dims
42
+ nodes_dist = data_info.nodes_dist
43
+ active_index = data_info.active_index
44
+
45
+ self.model_config = dm_cfg
46
+ self.data_info = data_info
47
+ self.T = dm_cfg.diffusion_steps
48
+ self.Xdim = input_dims["X"]
49
+ self.Edim = input_dims["E"]
50
+ self.ydim = input_dims["y"]
51
+ self.Xdim_output = output_dims["X"]
52
+ self.Edim_output = output_dims["E"]
53
+ self.ydim_output = output_dims["y"]
54
+ self.node_dist = nodes_dist
55
+ self.active_index = active_index
56
+ self.max_n_nodes = data_info.max_n_nodes
57
+ self.atom_decoder = data_info.atom_decoder
58
+ self.hidden_size = dm_cfg.hidden_size
59
+ self.mol_visualizer = MolecularVisualization(self.atom_decoder)
60
+
61
+ self.denoiser = Transformer(
62
+ max_n_nodes=self.max_n_nodes,
63
+ hidden_size=dm_cfg.hidden_size,
64
+ depth=dm_cfg.depth,
65
+ num_heads=dm_cfg.num_heads,
66
+ mlp_ratio=dm_cfg.mlp_ratio,
67
+ drop_condition=dm_cfg.drop_condition,
68
+ Xdim=self.Xdim,
69
+ Edim=self.Edim,
70
+ ydim=self.ydim,
71
+ )
72
+
73
+ self.model_dtype = model_dtype
74
+ self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
75
+ dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
76
+ )
77
+ x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
78
+ data_info.node_types.to(self.model_dtype)
79
+ )
80
+ e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
81
+ data_info.edge_types.to(self.model_dtype)
82
+ )
83
+ x_marginals = x_marginals / x_marginals.sum()
84
+ e_marginals = e_marginals / e_marginals.sum()
85
+
86
+ xe_conditions = data_info.transition_E.to(self.model_dtype)
87
+ xe_conditions = xe_conditions[self.active_index][:, self.active_index]
88
+
89
+ xe_conditions = xe_conditions.sum(dim=1)
90
+ ex_conditions = xe_conditions.t()
91
+ xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
92
+ ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
93
+
94
+ self.transition_model = utils.MarginalTransition(
95
+ x_marginals=x_marginals,
96
+ e_marginals=e_marginals,
97
+ xe_conditions=xe_conditions,
98
+ ex_conditions=ex_conditions,
99
+ y_classes=self.ydim_output,
100
+ n_nodes=self.max_n_nodes,
101
+ )
102
+ self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
103
+
104
+ def init_model(self, model_dir):
105
+ model_file = os.path.join(model_dir, 'model.pt')
106
+ if os.path.exists(model_file):
107
+ self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
108
+ else:
109
+ raise FileNotFoundError(f"Model file not found: {model_file}")
110
+
111
+ def disable_grads(self):
112
+ self.denoiser.disable_grads()
113
+
114
+ def forward(
115
+ self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
116
+ ):
117
+ raise ValueError('Not Implement')
118
+
119
+ def _forward(self, noisy_data, unconditioned=False):
120
+ noisy_x, noisy_e, properties = (
121
+ noisy_data["X_t"].to(self.model_dtype),
122
+ noisy_data["E_t"].to(self.model_dtype),
123
+ noisy_data["y_t"].to(self.model_dtype).clone(),
124
+ )
125
+ node_mask, timestep = (
126
+ noisy_data["node_mask"],
127
+ noisy_data["t"],
128
+ )
129
 
130
+ pred = self.denoiser(
131
+ noisy_x,
132
+ noisy_e,
133
+ node_mask,
134
+ properties,
135
+ timestep,
136
+ unconditioned=unconditioned,
137
+ )
138
+ return pred
139
+
140
+ def apply_noise(self, X, E, y, node_mask):
141
+ """Sample noise and apply it to the data."""
142
+
143
+ # Sample a timestep t.
144
+ # When evaluating, the loss for t=0 is computed separately
145
+ lowest_t = 0 if self.training else 1
146
+ t_int = torch.randint(
147
+ lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
148
+ ).to(
149
+ self.model_dtype
150
+ ) # (bs, 1)
151
+ s_int = t_int - 1
152
+
153
+ t_float = t_int / self.T
154
+ s_float = s_int / self.T
155
+
156
+ # beta_t and alpha_s_bar are used for denoising/loss computation
157
+ beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
158
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
159
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
160
+
161
+ Qtb = self.transition_model.get_Qt_bar(
162
+ alpha_t_bar, X.device
163
+ ) # (bs, dx_in, dx_out), (bs, de_in, de_out)
164
+
165
+ bs, n, d = X.shape
166
+ X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
167
+ prob_all = X_all @ Qtb.X
168
+ probX = prob_all[:, :, : self.Xdim_output]
169
+ probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)
170
+
171
+ sampled_t = utils.sample_discrete_features(
172
+ probX=probX, probE=probE, node_mask=node_mask
173
+ )
174
+
175
+ X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
176
+ E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
177
+ assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
178
+
179
+ y_t = y
180
+ z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)
181
+
182
+ noisy_data = {
183
+ "t_int": t_int,
184
+ "t": t_float,
185
+ "beta_t": beta_t,
186
+ "alpha_s_bar": alpha_s_bar,
187
+ "alpha_t_bar": alpha_t_bar,
188
+ "X_t": z_t.X,
189
+ "E_t": z_t.E,
190
+ "y_t": z_t.y,
191
+ "node_mask": node_mask,
192
+ }
193
+ return noisy_data
194
+
195
+ @torch.no_grad()
196
+ def generate(
197
+ self,
198
+ properties,
199
+ guide_scale=1.,
200
+ num_nodes=None,
201
+ number_chain_steps=50,
202
+ ):
203
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
204
+ properties = [float('nan') if x is None else x for x in properties]
205
+ properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
206
+ batch_size = properties.size(0)
207
+ assert batch_size == 1
208
+ if num_nodes is None:
209
+ num_nodes = self.node_dist.sample_n(batch_size, device)
210
+ else:
211
+ num_nodes = torch.LongTensor([num_nodes]).to(device)
212
+
213
+ arange = (
214
+ torch.arange(self.max_n_nodes, device=device)
215
+ .unsqueeze(0)
216
+ .expand(batch_size, -1)
217
+ )
218
+ node_mask = arange < num_nodes.unsqueeze(1)
219
+
220
+ z_T = utils.sample_discrete_feature_noise(
221
+ limit_dist=self.limit_dist, node_mask=node_mask
222
+ )
223
+ X, E = z_T.X, z_T.E
224
+
225
+ assert (E == torch.transpose(E, 1, 2)).all()
226
+
227
+ if number_chain_steps > 0:
228
+ chain_X_size = torch.Size((number_chain_steps, X.size(1)))
229
+ chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
230
+ chain_X = torch.zeros(chain_X_size)
231
+ chain_E = torch.zeros(chain_E_size)
232
+
233
+ # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
234
+ y = properties
235
+ for s_int in reversed(range(0, self.T)):
236
+ s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
237
+ t_array = s_array + 1
238
+ s_norm = s_array / self.T
239
+ t_norm = t_array / self.T
240
+
241
+ # Sample z_s
242
+ sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
243
+ s_norm, t_norm, X, E, y, node_mask, guide_scale, device
244
+ )
245
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
246
 
247
+ if number_chain_steps > 0:
248
+ # Save the first keep_chain graphs
249
+ write_index = (s_int * number_chain_steps) // self.T
250
+ chain_X[write_index] = discrete_sampled_s.X[:1]
251
+ chain_E[write_index] = discrete_sampled_s.E[:1]
252
+
253
+ # Sample
254
+ sampled_s = sampled_s.mask(node_mask, collapse=True)
255
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
256
+
257
+ molecule_list = []
258
+ n = num_nodes[0]
259
+ atom_types = X[0, :n].cpu()
260
+ edge_types = E[0, :n, :n].cpu()
261
+ molecule_list.append([atom_types, edge_types])
262
+ smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]
263
+
264
+ # Visualize Chains
265
+ if number_chain_steps > 0:
266
+ final_X_chain = X[:1]
267
+ final_E_chain = E[:1]
268
+
269
+ chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E
270
+ chain_E[0] = final_E_chain
271
+
272
+ chain_X = utils.reverse_tensor(chain_X)
273
+ chain_E = utils.reverse_tensor(chain_E)
274
+
275
+ # Repeat last frame to see final sample better
276
+ chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
277
+ chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
278
+ mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
279
+ else:
280
+ mol_img_list = []
281
+
282
+ return smiles, mol_img_list
283
+
284
+ def check_valid(self, smiles):
285
+ return check_valid(smiles)
286
 
287
+ def sample_p_zs_given_zt(
288
+ self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
289
+ ):
290
+ """Samples from zs ~ p(zs | zt). Only used during sampling.
291
+ if last_step, return the graph prediction as well"""
292
+ bs, n, _ = X_t.shape
293
+ beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
294
+ alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
295
+ alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
296
+
297
+ # Neural net predictions
298
+ noisy_data = {
299
+ "X_t": X_t,
300
+ "E_t": E_t,
301
+ "y_t": properties,
302
+ "t": t,
303
+ "node_mask": node_mask,
304
+ }
305
+
306
+ def get_prob(noisy_data, unconditioned=False):
307
+ pred = self._forward(noisy_data, unconditioned=unconditioned)
308
+
309
+ # Normalize predictions
310
+ pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
311
+ pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
312
+
313
+ # Retrieve transitions matrix
314
+ Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
315
+ Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
316
+ Qt = self.transition_model.get_Qt(beta_t, device)
317
+
318
+ Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
319
+ predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)
320
+
321
+ unnormalized_probX_all = utils.reverse_diffusion(
322
+ predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
323
+ )
324
+
325
+ unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
326
+ unnormalized_prob_E = unnormalized_probX_all[
327
+ :, :, self.Xdim_output :
328
+ ].reshape(bs, n * n, -1)
329
+
330
+ unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
331
+ unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
332
+
333
+ prob_X = unnormalized_prob_X / torch.sum(
334
+ unnormalized_prob_X, dim=-1, keepdim=True
335
+ ) # bs, n, d_t-1
336
+ prob_E = unnormalized_prob_E / torch.sum(
337
+ unnormalized_prob_E, dim=-1, keepdim=True
338
+ ) # bs, n, d_t-1
339
+ prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
340
+
341
+ return prob_X, prob_E
342
+
343
+ prob_X, prob_E = get_prob(noisy_data)
344
+
345
+ ### Guidance
346
+ if guide_scale != 1:
347
+ uncon_prob_X, uncon_prob_E = get_prob(
348
+ noisy_data, unconditioned=True
349
+ )
350
+ prob_X = (
351
+ uncon_prob_X
352
+ * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
353
+ )
354
+ prob_E = (
355
+ uncon_prob_E
356
+ * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
357
+ )
358
+ prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
359
+ prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)
360
+
361
+ # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
362
+ # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()
363
+
364
+ sampled_s = utils.sample_discrete_features(
365
+ prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
366
+ )
367
+
368
+ X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
369
+ E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)
370
+
371
+ assert (E_s == torch.transpose(E_s, 1, 2)).all()
372
+ assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
373
+
374
+ out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
375
+ out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
376
+
377
+ return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
378
+ node_mask, collapse=True
379
+ ).type_as(properties)
380