Update graph_decoder/diffusion_model.py
Browse files- graph_decoder/diffusion_model.py +67 -70
graph_decoder/diffusion_model.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import spaces
|
2 |
import os
|
3 |
import yaml
|
4 |
import json
|
@@ -20,73 +19,72 @@ class GraphDiT(nn.Module):
|
|
20 |
model_dtype,
|
21 |
):
|
22 |
super().__init__()
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
# self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
90 |
|
91 |
def init_model(self, model_dir):
|
92 |
model_file = os.path.join(model_dir, 'model.pt')
|
@@ -179,8 +177,7 @@ class GraphDiT(nn.Module):
|
|
179 |
}
|
180 |
return noisy_data
|
181 |
|
182 |
-
|
183 |
-
@spaces.GPU
|
184 |
def generate(
|
185 |
self,
|
186 |
properties,
|
|
|
|
|
1 |
import os
|
2 |
import yaml
|
3 |
import json
|
|
|
19 |
model_dtype,
|
20 |
):
|
21 |
super().__init__()
|
22 |
+
|
23 |
+
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
|
24 |
+
|
25 |
+
input_dims = data_info.input_dims
|
26 |
+
output_dims = data_info.output_dims
|
27 |
+
nodes_dist = data_info.nodes_dist
|
28 |
+
active_index = data_info.active_index
|
29 |
+
|
30 |
+
self.model_config = dm_cfg
|
31 |
+
self.data_info = data_info
|
32 |
+
self.T = dm_cfg.diffusion_steps
|
33 |
+
self.Xdim = input_dims["X"]
|
34 |
+
self.Edim = input_dims["E"]
|
35 |
+
self.ydim = input_dims["y"]
|
36 |
+
self.Xdim_output = output_dims["X"]
|
37 |
+
self.Edim_output = output_dims["E"]
|
38 |
+
self.ydim_output = output_dims["y"]
|
39 |
+
self.node_dist = nodes_dist
|
40 |
+
self.active_index = active_index
|
41 |
+
self.max_n_nodes = data_info.max_n_nodes
|
42 |
+
self.atom_decoder = data_info.atom_decoder
|
43 |
+
self.hidden_size = dm_cfg.hidden_size
|
44 |
+
self.mol_visualizer = MolecularVisualization(self.atom_decoder)
|
45 |
+
|
46 |
+
self.denoiser = Transformer(
|
47 |
+
max_n_nodes=self.max_n_nodes,
|
48 |
+
hidden_size=dm_cfg.hidden_size,
|
49 |
+
depth=dm_cfg.depth,
|
50 |
+
num_heads=dm_cfg.num_heads,
|
51 |
+
mlp_ratio=dm_cfg.mlp_ratio,
|
52 |
+
drop_condition=dm_cfg.drop_condition,
|
53 |
+
Xdim=self.Xdim,
|
54 |
+
Edim=self.Edim,
|
55 |
+
ydim=self.ydim,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.model_dtype = model_dtype
|
59 |
+
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
|
60 |
+
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
|
61 |
+
)
|
62 |
+
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
|
63 |
+
data_info.node_types.to(self.model_dtype)
|
64 |
+
)
|
65 |
+
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
|
66 |
+
data_info.edge_types.to(self.model_dtype)
|
67 |
+
)
|
68 |
+
x_marginals = x_marginals / x_marginals.sum()
|
69 |
+
e_marginals = e_marginals / e_marginals.sum()
|
70 |
+
|
71 |
+
xe_conditions = data_info.transition_E.to(self.model_dtype)
|
72 |
+
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
|
73 |
+
|
74 |
+
xe_conditions = xe_conditions.sum(dim=1)
|
75 |
+
ex_conditions = xe_conditions.t()
|
76 |
+
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
|
77 |
+
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
|
78 |
+
|
79 |
+
self.transition_model = utils.MarginalTransition(
|
80 |
+
x_marginals=x_marginals,
|
81 |
+
e_marginals=e_marginals,
|
82 |
+
xe_conditions=xe_conditions,
|
83 |
+
ex_conditions=ex_conditions,
|
84 |
+
y_classes=self.ydim_output,
|
85 |
+
n_nodes=self.max_n_nodes,
|
86 |
+
)
|
87 |
+
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
|
|
88 |
|
89 |
def init_model(self, model_dir):
|
90 |
model_file = os.path.join(model_dir, 'model.pt')
|
|
|
177 |
}
|
178 |
return noisy_data
|
179 |
|
180 |
+
@torch.no_grad()
|
|
|
181 |
def generate(
|
182 |
self,
|
183 |
properties,
|