liuganghuggingface commited on
Commit
341f250
·
verified ·
1 Parent(s): 6348fcc

Update graph_decoder/diffusion_utils.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_utils.py +128 -128
graph_decoder/diffusion_utils.py CHANGED
@@ -1,131 +1,131 @@
1
- import os
2
- import json
3
- import yaml
4
-
5
- import torch
6
- import numpy as np
7
- from torch.nn import functional as F
8
- from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops
9
- from types import SimpleNamespace
10
-
11
- def dict_to_namespace(d):
12
- return SimpleNamespace(
13
- **{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in d.items()}
14
- )
15
-
16
- class DataInfos:
17
- def __init__(self, meta_filename="data.meta.json"):
18
- self.all_targets = ['CH4', 'CO2', 'H2', 'N2', 'O2']
19
- self.task_type = "gas_permeability"
20
- if os.path.exists(meta_filename):
21
- with open(meta_filename, "r") as f:
22
- meta_dict = json.load(f)
23
- else:
24
- raise FileNotFoundError(f"Meta file {meta_filename} not found.")
25
-
26
- self.active_atoms = meta_dict["active_atoms"]
27
- self.max_n_nodes = meta_dict["max_node"]
28
- self.original_max_n_nodes = meta_dict["max_node"]
29
- self.n_nodes = torch.Tensor(meta_dict["n_atoms_per_mol_dist"])
30
- self.edge_types = torch.Tensor(meta_dict["bond_type_dist"])
31
- self.transition_E = torch.Tensor(meta_dict["transition_E"])
32
-
33
- self.atom_decoder = meta_dict["active_atoms"]
34
- node_types = torch.Tensor(meta_dict["atom_type_dist"])
35
- active_index = (node_types > 0).nonzero().squeeze()
36
- self.node_types = torch.Tensor(meta_dict["atom_type_dist"])[active_index]
37
- self.nodes_dist = DistributionNodes(self.n_nodes)
38
- self.active_index = active_index
39
-
40
- val_len = 3 * self.original_max_n_nodes - 2
41
- meta_val = torch.Tensor(meta_dict["valencies"])
42
- self.valency_distribution = torch.zeros(val_len)
43
- val_len = min(val_len, len(meta_val))
44
- self.valency_distribution[:val_len] = meta_val[:val_len]
45
- ## for all
46
- self.input_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
47
- self.output_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
48
- # self.input_dims = {"X": 11, "E": 5, "y": 5}
49
- # self.output_dims = {"X": 11, "E": 5, "y": 5}
50
-
51
- def load_config(config_path, data_meta_info_path):
52
- if not os.path.exists(config_path):
53
- raise FileNotFoundError(f"Configuration file not found: {config_path}")
54
-
55
- if not os.path.exists(data_meta_info_path):
56
- raise FileNotFoundError(f"Data meta info file not found: {data_meta_info_path}")
57
-
58
- with open(config_path, "r") as file:
59
- cfg_dict = yaml.safe_load(file)
60
-
61
- cfg = dict_to_namespace(cfg_dict)
62
-
63
- data_info = DataInfos(data_meta_info_path)
64
- return cfg, data_info
65
-
66
-
67
- #### graph utils
68
- class PlaceHolder:
69
- def __init__(self, X, E, y):
70
- self.X = X
71
- self.E = E
72
- self.y = y
73
-
74
- def type_as(self, x: torch.Tensor, categorical: bool = False):
75
- """Changes the device and dtype of X, E, y."""
76
- self.X = self.X.type_as(x)
77
- self.E = self.E.type_as(x)
78
- if categorical:
79
- self.y = self.y.type_as(x)
80
- return self
81
-
82
- def mask(self, node_mask, collapse=False):
83
- x_mask = node_mask.unsqueeze(-1) # bs, n, 1
84
- e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
85
- e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
86
-
87
- if collapse:
88
- self.X = torch.argmax(self.X, dim=-1)
89
- self.E = torch.argmax(self.E, dim=-1)
90
-
91
- self.X[node_mask == 0] = -1
92
- self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1
93
- else:
94
- self.X = self.X * x_mask
95
- self.E = self.E * e_mask1 * e_mask2
96
- assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
97
- return self
98
-
99
-
100
- def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None):
101
- X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
102
- # node_mask = node_mask.float()
103
- edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
104
- if max_num_nodes is None:
105
- max_num_nodes = X.size(1)
106
- E = to_dense_adj(
107
- edge_index=edge_index,
108
- batch=batch,
109
- edge_attr=edge_attr,
110
- max_num_nodes=max_num_nodes,
111
- )
112
- E = encode_no_edge(E)
113
- return PlaceHolder(X=X, E=E, y=None), node_mask
114
-
115
-
116
- def encode_no_edge(E):
117
- assert len(E.shape) == 4
118
- if E.shape[-1] == 0:
119
- return E
120
- no_edge = torch.sum(E, dim=3) == 0
121
- first_elt = E[:, :, :, 0]
122
- first_elt[no_edge] = 1
123
- E[:, :, :, 0] = first_elt
124
- diag = (
125
- torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
126
- )
127
- E[diag] = 0
128
- return E
129
 
130
 
131
  # #### diffusion utils
 
1
+ # import os
2
+ # import json
3
+ # import yaml
4
+
5
+ # import torch
6
+ # import numpy as np
7
+ # from torch.nn import functional as F
8
+ # from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops
9
+ # from types import SimpleNamespace
10
+
11
+ # def dict_to_namespace(d):
12
+ # return SimpleNamespace(
13
+ # **{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in d.items()}
14
+ # )
15
+
16
+ # class DataInfos:
17
+ # def __init__(self, meta_filename="data.meta.json"):
18
+ # self.all_targets = ['CH4', 'CO2', 'H2', 'N2', 'O2']
19
+ # self.task_type = "gas_permeability"
20
+ # if os.path.exists(meta_filename):
21
+ # with open(meta_filename, "r") as f:
22
+ # meta_dict = json.load(f)
23
+ # else:
24
+ # raise FileNotFoundError(f"Meta file {meta_filename} not found.")
25
+
26
+ # self.active_atoms = meta_dict["active_atoms"]
27
+ # self.max_n_nodes = meta_dict["max_node"]
28
+ # self.original_max_n_nodes = meta_dict["max_node"]
29
+ # self.n_nodes = torch.Tensor(meta_dict["n_atoms_per_mol_dist"])
30
+ # self.edge_types = torch.Tensor(meta_dict["bond_type_dist"])
31
+ # self.transition_E = torch.Tensor(meta_dict["transition_E"])
32
+
33
+ # self.atom_decoder = meta_dict["active_atoms"]
34
+ # node_types = torch.Tensor(meta_dict["atom_type_dist"])
35
+ # active_index = (node_types > 0).nonzero().squeeze()
36
+ # self.node_types = torch.Tensor(meta_dict["atom_type_dist"])[active_index]
37
+ # self.nodes_dist = DistributionNodes(self.n_nodes)
38
+ # self.active_index = active_index
39
+
40
+ # val_len = 3 * self.original_max_n_nodes - 2
41
+ # meta_val = torch.Tensor(meta_dict["valencies"])
42
+ # self.valency_distribution = torch.zeros(val_len)
43
+ # val_len = min(val_len, len(meta_val))
44
+ # self.valency_distribution[:val_len] = meta_val[:val_len]
45
+ # ## for all
46
+ # self.input_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
47
+ # self.output_dims = {"X": len(self.active_atoms), "E": 5, "y": 5}
48
+ # # self.input_dims = {"X": 11, "E": 5, "y": 5}
49
+ # # self.output_dims = {"X": 11, "E": 5, "y": 5}
50
+
51
+ # def load_config(config_path, data_meta_info_path):
52
+ # if not os.path.exists(config_path):
53
+ # raise FileNotFoundError(f"Configuration file not found: {config_path}")
54
+
55
+ # if not os.path.exists(data_meta_info_path):
56
+ # raise FileNotFoundError(f"Data meta info file not found: {data_meta_info_path}")
57
+
58
+ # with open(config_path, "r") as file:
59
+ # cfg_dict = yaml.safe_load(file)
60
+
61
+ # cfg = dict_to_namespace(cfg_dict)
62
+
63
+ # data_info = DataInfos(data_meta_info_path)
64
+ # return cfg, data_info
65
+
66
+
67
+ # #### graph utils
68
+ # class PlaceHolder:
69
+ # def __init__(self, X, E, y):
70
+ # self.X = X
71
+ # self.E = E
72
+ # self.y = y
73
+
74
+ # def type_as(self, x: torch.Tensor, categorical: bool = False):
75
+ # """Changes the device and dtype of X, E, y."""
76
+ # self.X = self.X.type_as(x)
77
+ # self.E = self.E.type_as(x)
78
+ # if categorical:
79
+ # self.y = self.y.type_as(x)
80
+ # return self
81
+
82
+ # def mask(self, node_mask, collapse=False):
83
+ # x_mask = node_mask.unsqueeze(-1) # bs, n, 1
84
+ # e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
85
+ # e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
86
+
87
+ # if collapse:
88
+ # self.X = torch.argmax(self.X, dim=-1)
89
+ # self.E = torch.argmax(self.E, dim=-1)
90
+
91
+ # self.X[node_mask == 0] = -1
92
+ # self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1
93
+ # else:
94
+ # self.X = self.X * x_mask
95
+ # self.E = self.E * e_mask1 * e_mask2
96
+ # assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
97
+ # return self
98
+
99
+
100
+ # def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None):
101
+ # X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
102
+ # # node_mask = node_mask.float()
103
+ # edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
104
+ # if max_num_nodes is None:
105
+ # max_num_nodes = X.size(1)
106
+ # E = to_dense_adj(
107
+ # edge_index=edge_index,
108
+ # batch=batch,
109
+ # edge_attr=edge_attr,
110
+ # max_num_nodes=max_num_nodes,
111
+ # )
112
+ # E = encode_no_edge(E)
113
+ # return PlaceHolder(X=X, E=E, y=None), node_mask
114
+
115
+
116
+ # def encode_no_edge(E):
117
+ # assert len(E.shape) == 4
118
+ # if E.shape[-1] == 0:
119
+ # return E
120
+ # no_edge = torch.sum(E, dim=3) == 0
121
+ # first_elt = E[:, :, :, 0]
122
+ # first_elt[no_edge] = 1
123
+ # E[:, :, :, 0] = first_elt
124
+ # diag = (
125
+ # torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
126
+ # )
127
+ # E[diag] = 0
128
+ # return E
129
 
130
 
131
  # #### diffusion utils