liuganghuggingface commited on
Commit
50e54f4
·
verified ·
1 Parent(s): 991b396

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +67 -70
graph_decoder/diffusion_model.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import os
3
  import yaml
4
  import json
@@ -20,73 +19,72 @@ class GraphDiT(nn.Module):
20
  model_dtype,
21
  ):
22
  super().__init__()
23
- pass
24
-
25
- # dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
26
-
27
- # input_dims = data_info.input_dims
28
- # output_dims = data_info.output_dims
29
- # nodes_dist = data_info.nodes_dist
30
- # active_index = data_info.active_index
31
-
32
- # self.model_config = dm_cfg
33
- # self.data_info = data_info
34
- # self.T = dm_cfg.diffusion_steps
35
- # self.Xdim = input_dims["X"]
36
- # self.Edim = input_dims["E"]
37
- # self.ydim = input_dims["y"]
38
- # self.Xdim_output = output_dims["X"]
39
- # self.Edim_output = output_dims["E"]
40
- # self.ydim_output = output_dims["y"]
41
- # self.node_dist = nodes_dist
42
- # self.active_index = active_index
43
- # self.max_n_nodes = data_info.max_n_nodes
44
- # self.atom_decoder = data_info.atom_decoder
45
- # self.hidden_size = dm_cfg.hidden_size
46
- # self.mol_visualizer = MolecularVisualization(self.atom_decoder)
47
-
48
- # self.denoiser = Transformer(
49
- # max_n_nodes=self.max_n_nodes,
50
- # hidden_size=dm_cfg.hidden_size,
51
- # depth=dm_cfg.depth,
52
- # num_heads=dm_cfg.num_heads,
53
- # mlp_ratio=dm_cfg.mlp_ratio,
54
- # drop_condition=dm_cfg.drop_condition,
55
- # Xdim=self.Xdim,
56
- # Edim=self.Edim,
57
- # ydim=self.ydim,
58
- # )
59
-
60
- # self.model_dtype = model_dtype
61
- # self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
62
- # dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
63
- # )
64
- # x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
65
- # data_info.node_types.to(self.model_dtype)
66
- # )
67
- # e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
68
- # data_info.edge_types.to(self.model_dtype)
69
- # )
70
- # x_marginals = x_marginals / x_marginals.sum()
71
- # e_marginals = e_marginals / e_marginals.sum()
72
-
73
- # xe_conditions = data_info.transition_E.to(self.model_dtype)
74
- # xe_conditions = xe_conditions[self.active_index][:, self.active_index]
75
-
76
- # xe_conditions = xe_conditions.sum(dim=1)
77
- # ex_conditions = xe_conditions.t()
78
- # xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
79
- # ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
80
-
81
- # self.transition_model = utils.MarginalTransition(
82
- # x_marginals=x_marginals,
83
- # e_marginals=e_marginals,
84
- # xe_conditions=xe_conditions,
85
- # ex_conditions=ex_conditions,
86
- # y_classes=self.ydim_output,
87
- # n_nodes=self.max_n_nodes,
88
- # )
89
- # self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
90
 
91
  def init_model(self, model_dir):
92
  model_file = os.path.join(model_dir, 'model.pt')
@@ -179,8 +177,7 @@ class GraphDiT(nn.Module):
179
  }
180
  return noisy_data
181
 
182
- # @torch.no_grad()
183
- @spaces.GPU
184
  def generate(
185
  self,
186
  properties,
 
 
1
  import os
2
  import yaml
3
  import json
 
19
  model_dtype,
20
  ):
21
  super().__init__()
22
+
23
+ dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
24
+
25
+ input_dims = data_info.input_dims
26
+ output_dims = data_info.output_dims
27
+ nodes_dist = data_info.nodes_dist
28
+ active_index = data_info.active_index
29
+
30
+ self.model_config = dm_cfg
31
+ self.data_info = data_info
32
+ self.T = dm_cfg.diffusion_steps
33
+ self.Xdim = input_dims["X"]
34
+ self.Edim = input_dims["E"]
35
+ self.ydim = input_dims["y"]
36
+ self.Xdim_output = output_dims["X"]
37
+ self.Edim_output = output_dims["E"]
38
+ self.ydim_output = output_dims["y"]
39
+ self.node_dist = nodes_dist
40
+ self.active_index = active_index
41
+ self.max_n_nodes = data_info.max_n_nodes
42
+ self.atom_decoder = data_info.atom_decoder
43
+ self.hidden_size = dm_cfg.hidden_size
44
+ self.mol_visualizer = MolecularVisualization(self.atom_decoder)
45
+
46
+ self.denoiser = Transformer(
47
+ max_n_nodes=self.max_n_nodes,
48
+ hidden_size=dm_cfg.hidden_size,
49
+ depth=dm_cfg.depth,
50
+ num_heads=dm_cfg.num_heads,
51
+ mlp_ratio=dm_cfg.mlp_ratio,
52
+ drop_condition=dm_cfg.drop_condition,
53
+ Xdim=self.Xdim,
54
+ Edim=self.Edim,
55
+ ydim=self.ydim,
56
+ )
57
+
58
+ self.model_dtype = model_dtype
59
+ self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
60
+ dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
61
+ )
62
+ x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
63
+ data_info.node_types.to(self.model_dtype)
64
+ )
65
+ e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
66
+ data_info.edge_types.to(self.model_dtype)
67
+ )
68
+ x_marginals = x_marginals / x_marginals.sum()
69
+ e_marginals = e_marginals / e_marginals.sum()
70
+
71
+ xe_conditions = data_info.transition_E.to(self.model_dtype)
72
+ xe_conditions = xe_conditions[self.active_index][:, self.active_index]
73
+
74
+ xe_conditions = xe_conditions.sum(dim=1)
75
+ ex_conditions = xe_conditions.t()
76
+ xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
77
+ ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
78
+
79
+ self.transition_model = utils.MarginalTransition(
80
+ x_marginals=x_marginals,
81
+ e_marginals=e_marginals,
82
+ xe_conditions=xe_conditions,
83
+ ex_conditions=ex_conditions,
84
+ y_classes=self.ydim_output,
85
+ n_nodes=self.max_n_nodes,
86
+ )
87
+ self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
 
88
 
89
  def init_model(self, model_dir):
90
  model_file = os.path.join(model_dir, 'model.pt')
 
177
  }
178
  return noisy_data
179
 
180
+ @torch.no_grad()
 
181
  def generate(
182
  self,
183
  properties,