Update graph_decoder/diffusion_model.py
Browse files- graph_decoder/diffusion_model.py +67 -66
graph_decoder/diffusion_model.py
CHANGED
@@ -19,72 +19,73 @@ class GraphDiT(nn.Module):
|
|
19 |
model_dtype,
|
20 |
):
|
21 |
super().__init__()
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
self.
|
32 |
-
self.
|
33 |
-
self.
|
34 |
-
self.
|
35 |
-
self.
|
36 |
-
self.
|
37 |
-
self.
|
38 |
-
self.
|
39 |
-
self.
|
40 |
-
self.
|
41 |
-
self.
|
42 |
-
self.
|
43 |
-
self.
|
44 |
-
self.
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
self.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
)
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
xe_conditions =
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
def init_model(self, model_dir):
|
90 |
model_file = os.path.join(model_dir, 'model.pt')
|
|
|
19 |
model_dtype,
|
20 |
):
|
21 |
super().__init__()
|
22 |
+
pass
|
23 |
+
|
24 |
+
# dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
|
25 |
+
|
26 |
+
# input_dims = data_info.input_dims
|
27 |
+
# output_dims = data_info.output_dims
|
28 |
+
# nodes_dist = data_info.nodes_dist
|
29 |
+
# active_index = data_info.active_index
|
30 |
+
|
31 |
+
# self.model_config = dm_cfg
|
32 |
+
# self.data_info = data_info
|
33 |
+
# self.T = dm_cfg.diffusion_steps
|
34 |
+
# self.Xdim = input_dims["X"]
|
35 |
+
# self.Edim = input_dims["E"]
|
36 |
+
# self.ydim = input_dims["y"]
|
37 |
+
# self.Xdim_output = output_dims["X"]
|
38 |
+
# self.Edim_output = output_dims["E"]
|
39 |
+
# self.ydim_output = output_dims["y"]
|
40 |
+
# self.node_dist = nodes_dist
|
41 |
+
# self.active_index = active_index
|
42 |
+
# self.max_n_nodes = data_info.max_n_nodes
|
43 |
+
# self.atom_decoder = data_info.atom_decoder
|
44 |
+
# self.hidden_size = dm_cfg.hidden_size
|
45 |
+
# self.mol_visualizer = MolecularVisualization(self.atom_decoder)
|
46 |
+
|
47 |
+
# self.denoiser = Transformer(
|
48 |
+
# max_n_nodes=self.max_n_nodes,
|
49 |
+
# hidden_size=dm_cfg.hidden_size,
|
50 |
+
# depth=dm_cfg.depth,
|
51 |
+
# num_heads=dm_cfg.num_heads,
|
52 |
+
# mlp_ratio=dm_cfg.mlp_ratio,
|
53 |
+
# drop_condition=dm_cfg.drop_condition,
|
54 |
+
# Xdim=self.Xdim,
|
55 |
+
# Edim=self.Edim,
|
56 |
+
# ydim=self.ydim,
|
57 |
+
# )
|
58 |
+
|
59 |
+
# self.model_dtype = model_dtype
|
60 |
+
# self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
|
61 |
+
# dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
|
62 |
+
# )
|
63 |
+
# x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
|
64 |
+
# data_info.node_types.to(self.model_dtype)
|
65 |
+
# )
|
66 |
+
# e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
|
67 |
+
# data_info.edge_types.to(self.model_dtype)
|
68 |
+
# )
|
69 |
+
# x_marginals = x_marginals / x_marginals.sum()
|
70 |
+
# e_marginals = e_marginals / e_marginals.sum()
|
71 |
+
|
72 |
+
# xe_conditions = data_info.transition_E.to(self.model_dtype)
|
73 |
+
# xe_conditions = xe_conditions[self.active_index][:, self.active_index]
|
74 |
+
|
75 |
+
# xe_conditions = xe_conditions.sum(dim=1)
|
76 |
+
# ex_conditions = xe_conditions.t()
|
77 |
+
# xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
|
78 |
+
# ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
|
79 |
+
|
80 |
+
# self.transition_model = utils.MarginalTransition(
|
81 |
+
# x_marginals=x_marginals,
|
82 |
+
# e_marginals=e_marginals,
|
83 |
+
# xe_conditions=xe_conditions,
|
84 |
+
# ex_conditions=ex_conditions,
|
85 |
+
# y_classes=self.ydim_output,
|
86 |
+
# n_nodes=self.max_n_nodes,
|
87 |
+
# )
|
88 |
+
# self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
89 |
|
90 |
def init_model(self, model_dir):
|
91 |
model_file = os.path.join(model_dir, 'model.pt')
|