Spaces:
Runtime error
Runtime error
Update graph_decoder/diffusion_model.py
Browse files- graph_decoder/diffusion_model.py +67 -66
graph_decoder/diffusion_model.py
CHANGED
|
@@ -19,72 +19,73 @@ class GraphDiT(nn.Module):
|
|
| 19 |
model_dtype,
|
| 20 |
):
|
| 21 |
super().__init__()
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
self.
|
| 32 |
-
self.
|
| 33 |
-
self.
|
| 34 |
-
self.
|
| 35 |
-
self.
|
| 36 |
-
self.
|
| 37 |
-
self.
|
| 38 |
-
self.
|
| 39 |
-
self.
|
| 40 |
-
self.
|
| 41 |
-
self.
|
| 42 |
-
self.
|
| 43 |
-
self.
|
| 44 |
-
self.
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
self.
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
xe_conditions =
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
def init_model(self, model_dir):
|
| 90 |
model_file = os.path.join(model_dir, 'model.pt')
|
|
|
|
| 19 |
model_dtype,
|
| 20 |
):
|
| 21 |
super().__init__()
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
# dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
|
| 25 |
+
|
| 26 |
+
# input_dims = data_info.input_dims
|
| 27 |
+
# output_dims = data_info.output_dims
|
| 28 |
+
# nodes_dist = data_info.nodes_dist
|
| 29 |
+
# active_index = data_info.active_index
|
| 30 |
+
|
| 31 |
+
# self.model_config = dm_cfg
|
| 32 |
+
# self.data_info = data_info
|
| 33 |
+
# self.T = dm_cfg.diffusion_steps
|
| 34 |
+
# self.Xdim = input_dims["X"]
|
| 35 |
+
# self.Edim = input_dims["E"]
|
| 36 |
+
# self.ydim = input_dims["y"]
|
| 37 |
+
# self.Xdim_output = output_dims["X"]
|
| 38 |
+
# self.Edim_output = output_dims["E"]
|
| 39 |
+
# self.ydim_output = output_dims["y"]
|
| 40 |
+
# self.node_dist = nodes_dist
|
| 41 |
+
# self.active_index = active_index
|
| 42 |
+
# self.max_n_nodes = data_info.max_n_nodes
|
| 43 |
+
# self.atom_decoder = data_info.atom_decoder
|
| 44 |
+
# self.hidden_size = dm_cfg.hidden_size
|
| 45 |
+
# self.mol_visualizer = MolecularVisualization(self.atom_decoder)
|
| 46 |
+
|
| 47 |
+
# self.denoiser = Transformer(
|
| 48 |
+
# max_n_nodes=self.max_n_nodes,
|
| 49 |
+
# hidden_size=dm_cfg.hidden_size,
|
| 50 |
+
# depth=dm_cfg.depth,
|
| 51 |
+
# num_heads=dm_cfg.num_heads,
|
| 52 |
+
# mlp_ratio=dm_cfg.mlp_ratio,
|
| 53 |
+
# drop_condition=dm_cfg.drop_condition,
|
| 54 |
+
# Xdim=self.Xdim,
|
| 55 |
+
# Edim=self.Edim,
|
| 56 |
+
# ydim=self.ydim,
|
| 57 |
+
# )
|
| 58 |
+
|
| 59 |
+
# self.model_dtype = model_dtype
|
| 60 |
+
# self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
|
| 61 |
+
# dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
|
| 62 |
+
# )
|
| 63 |
+
# x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
|
| 64 |
+
# data_info.node_types.to(self.model_dtype)
|
| 65 |
+
# )
|
| 66 |
+
# e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
|
| 67 |
+
# data_info.edge_types.to(self.model_dtype)
|
| 68 |
+
# )
|
| 69 |
+
# x_marginals = x_marginals / x_marginals.sum()
|
| 70 |
+
# e_marginals = e_marginals / e_marginals.sum()
|
| 71 |
+
|
| 72 |
+
# xe_conditions = data_info.transition_E.to(self.model_dtype)
|
| 73 |
+
# xe_conditions = xe_conditions[self.active_index][:, self.active_index]
|
| 74 |
+
|
| 75 |
+
# xe_conditions = xe_conditions.sum(dim=1)
|
| 76 |
+
# ex_conditions = xe_conditions.t()
|
| 77 |
+
# xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
|
| 78 |
+
# ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
|
| 79 |
+
|
| 80 |
+
# self.transition_model = utils.MarginalTransition(
|
| 81 |
+
# x_marginals=x_marginals,
|
| 82 |
+
# e_marginals=e_marginals,
|
| 83 |
+
# xe_conditions=xe_conditions,
|
| 84 |
+
# ex_conditions=ex_conditions,
|
| 85 |
+
# y_classes=self.ydim_output,
|
| 86 |
+
# n_nodes=self.max_n_nodes,
|
| 87 |
+
# )
|
| 88 |
+
# self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
| 89 |
|
| 90 |
def init_model(self, model_dir):
|
| 91 |
model_file = os.path.join(model_dir, 'model.pt')
|