Huhujingjing commited on
Commit
c330f05
·
1 Parent(s): 3717306

Update modeling_transmxm.py

Browse files
Files changed (1) hide show
  1. modeling_transmxm.py +30 -30
modeling_transmxm.py CHANGED
@@ -23,7 +23,7 @@ from torch_scatter import scatter
23
  from torch_geometric.nn import global_add_pool, radius
24
  from torch_sparse import SparseTensor
25
 
26
- from transmxm_model.configuration_transmxm import TransmxmConfig
27
 
28
  from tqdm import tqdm
29
  import numpy as np
@@ -1171,35 +1171,35 @@ class Local_MP(torch.nn.Module):
1171
  return h, y
1172
 
1173
 
1174
- # class MXMConfig(PretrainedConfig):
1175
- # model_type = "gcn"
1176
- #
1177
- # def __init__(
1178
- # self,
1179
- # dim: int=128,
1180
- # n_layer: int=6,
1181
- # cutoff: float=5.0,
1182
- # num_spherical: int=7,
1183
- # num_radial: int=6,
1184
- # envelope_exponent: int=5,
1185
- #
1186
- # smiles: List[str] = None,
1187
- # processor_class: str = "SmilesProcessor",
1188
- # **kwargs,
1189
- # ):
1190
- #
1191
- # self.dim = dim # the dimension of input feature
1192
- # self.n_layer = n_layer # the number of GCN layers
1193
- # self.cutoff = cutoff # the cutoff distance for neighbor searching
1194
- # self.num_spherical = num_spherical # the number of spherical harmonics
1195
- # self.num_radial = num_radial # the number of radial basis
1196
- # self.envelope_exponent = envelope_exponent # the envelope exponent
1197
- #
1198
- # self.smiles = smiles # process smiles
1199
- # self.processor_class = processor_class
1200
- #
1201
- #
1202
- # super().__init__(**kwargs)
1203
 
1204
 
1205
 
 
23
  from torch_geometric.nn import global_add_pool, radius
24
  from torch_sparse import SparseTensor
25
 
26
+ # from transmxm_model.configuration_transmxm import TransmxmConfig
27
 
28
  from tqdm import tqdm
29
  import numpy as np
 
1171
  return h, y
1172
 
1173
 
1174
+ class TransmxmConfig(PretrainedConfig):
1175
+ model_type = "transmxm"
1176
+
1177
+ def __init__(
1178
+ self,
1179
+ dim: int=128,
1180
+ n_layer: int=6,
1181
+ cutoff: float=5.0,
1182
+ num_spherical: int=7,
1183
+ num_radial: int=6,
1184
+ envelope_exponent: int=5,
1185
+
1186
+ smiles: List[str] = None,
1187
+ processor_class: str = "SmilesProcessor",
1188
+ **kwargs,
1189
+ ):
1190
+
1191
+ self.dim = dim # the dimension of input feature
1192
+ self.n_layer = n_layer # the number of GCN layers
1193
+ self.cutoff = cutoff # the cutoff distance for neighbor searching
1194
+ self.num_spherical = num_spherical # the number of spherical harmonics
1195
+ self.num_radial = num_radial # the number of radial basis
1196
+ self.envelope_exponent = envelope_exponent # the envelope exponent
1197
+
1198
+ self.smiles = smiles # process smiles
1199
+ self.processor_class = processor_class
1200
+
1201
+
1202
+ super().__init__(**kwargs)
1203
 
1204
 
1205