Commit
·
27fdc8e
1
Parent(s):
263ef5a
Upload model
Browse files- modeling_mxm.py +30 -30
modeling_mxm.py
CHANGED
|
@@ -18,7 +18,7 @@ from torch_scatter import scatter
|
|
| 18 |
from torch_geometric.nn import global_add_pool, radius
|
| 19 |
from torch_sparse import SparseTensor
|
| 20 |
|
| 21 |
-
from mxm_model.configuration_mxm import MXMConfig
|
| 22 |
|
| 23 |
from tqdm import tqdm
|
| 24 |
import numpy as np
|
|
@@ -911,35 +911,35 @@ class Local_MP(torch.nn.Module):
|
|
| 911 |
return h, y
|
| 912 |
|
| 913 |
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
|
| 944 |
|
| 945 |
|
|
|
|
| 18 |
from torch_geometric.nn import global_add_pool, radius
|
| 19 |
from torch_sparse import SparseTensor
|
| 20 |
|
| 21 |
+
# from mxm_model.configuration_mxm import MXMConfig
|
| 22 |
|
| 23 |
from tqdm import tqdm
|
| 24 |
import numpy as np
|
|
|
|
| 911 |
return h, y
|
| 912 |
|
| 913 |
|
| 914 |
+
class MXMConfig(PretrainedConfig):
|
| 915 |
+
model_type = "mxm"
|
| 916 |
+
|
| 917 |
+
def __init__(
|
| 918 |
+
self,
|
| 919 |
+
dim: int=128,
|
| 920 |
+
n_layer: int=6,
|
| 921 |
+
cutoff: float=5.0,
|
| 922 |
+
num_spherical: int=7,
|
| 923 |
+
num_radial: int=6,
|
| 924 |
+
envelope_exponent: int=5,
|
| 925 |
+
|
| 926 |
+
smiles: List[str] = None,
|
| 927 |
+
processor_class: str = "SmilesProcessor",
|
| 928 |
+
**kwargs,
|
| 929 |
+
):
|
| 930 |
+
|
| 931 |
+
self.dim = dim # the dimension of input feature
|
| 932 |
+
self.n_layer = n_layer # the number of GCN layers
|
| 933 |
+
self.cutoff = cutoff # the cutoff distance for neighbor searching
|
| 934 |
+
self.num_spherical = num_spherical # the number of spherical harmonics
|
| 935 |
+
self.num_radial = num_radial # the number of radial basis
|
| 936 |
+
self.envelope_exponent = envelope_exponent # the envelope exponent
|
| 937 |
+
|
| 938 |
+
self.smiles = smiles # process smiles
|
| 939 |
+
self.processor_class = processor_class
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
super().__init__(**kwargs)
|
| 943 |
|
| 944 |
|
| 945 |
|