Commit
·
c330f05
1
Parent(s):
3717306
Update modeling_transmxm.py
Browse files- modeling_transmxm.py +30 -30
modeling_transmxm.py
CHANGED
@@ -23,7 +23,7 @@ from torch_scatter import scatter
|
|
23 |
from torch_geometric.nn import global_add_pool, radius
|
24 |
from torch_sparse import SparseTensor
|
25 |
|
26 |
-
from transmxm_model.configuration_transmxm import TransmxmConfig
|
27 |
|
28 |
from tqdm import tqdm
|
29 |
import numpy as np
|
@@ -1171,35 +1171,35 @@ class Local_MP(torch.nn.Module):
|
|
1171 |
return h, y
|
1172 |
|
1173 |
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
|
1204 |
|
1205 |
|
|
|
23 |
from torch_geometric.nn import global_add_pool, radius
|
24 |
from torch_sparse import SparseTensor
|
25 |
|
26 |
+
# from transmxm_model.configuration_transmxm import TransmxmConfig
|
27 |
|
28 |
from tqdm import tqdm
|
29 |
import numpy as np
|
|
|
1171 |
return h, y
|
1172 |
|
1173 |
|
1174 |
+
class TransmxmConfig(PretrainedConfig):
|
1175 |
+
model_type = "transmxm"
|
1176 |
+
|
1177 |
+
def __init__(
|
1178 |
+
self,
|
1179 |
+
dim: int=128,
|
1180 |
+
n_layer: int=6,
|
1181 |
+
cutoff: float=5.0,
|
1182 |
+
num_spherical: int=7,
|
1183 |
+
num_radial: int=6,
|
1184 |
+
envelope_exponent: int=5,
|
1185 |
+
|
1186 |
+
smiles: List[str] = None,
|
1187 |
+
processor_class: str = "SmilesProcessor",
|
1188 |
+
**kwargs,
|
1189 |
+
):
|
1190 |
+
|
1191 |
+
self.dim = dim # the dimension of input feature
|
1192 |
+
self.n_layer = n_layer # the number of GCN layers
|
1193 |
+
self.cutoff = cutoff # the cutoff distance for neighbor searching
|
1194 |
+
self.num_spherical = num_spherical # the number of spherical harmonics
|
1195 |
+
self.num_radial = num_radial # the number of radial basis
|
1196 |
+
self.envelope_exponent = envelope_exponent # the envelope exponent
|
1197 |
+
|
1198 |
+
self.smiles = smiles # process smiles
|
1199 |
+
self.processor_class = processor_class
|
1200 |
+
|
1201 |
+
|
1202 |
+
super().__init__(**kwargs)
|
1203 |
|
1204 |
|
1205 |
|