Commit
·
3e2d1c7
1
Parent(s):
940d605
Upload model
Browse files- modeling_mxm.py +6 -6
modeling_mxm.py
CHANGED
|
@@ -949,7 +949,7 @@ class MXMModel(PreTrainedModel):
|
|
| 949 |
def __init__(self, config):
|
| 950 |
super().__init__(config)
|
| 951 |
|
| 952 |
-
self.
|
| 953 |
dim=config.dim,
|
| 954 |
n_layer=config.n_layer,
|
| 955 |
cutoff=config.cutoff,
|
|
@@ -961,14 +961,14 @@ class MXMModel(PreTrainedModel):
|
|
| 961 |
smiles=config.smiles,
|
| 962 |
)
|
| 963 |
|
| 964 |
-
self.
|
| 965 |
self.dataset = None
|
| 966 |
self.output = None
|
| 967 |
self.data_loader = None
|
| 968 |
self.pred_data = None
|
| 969 |
|
| 970 |
def forward(self, tensor):
|
| 971 |
-
return self.
|
| 972 |
|
| 973 |
def SmilesProcessor(self, smiles):
|
| 974 |
return self.process.get_data(smiles)
|
|
@@ -982,8 +982,8 @@ class MXMModel(PreTrainedModel):
|
|
| 982 |
drop_last = kwargs.pop('drop_last', False)
|
| 983 |
num_workers = kwargs.pop('num_workers', 0)
|
| 984 |
|
| 985 |
-
self.
|
| 986 |
-
self.
|
| 987 |
|
| 988 |
self.dataset = self.process.get_data(smiles)
|
| 989 |
self.output = ""
|
|
@@ -1004,7 +1004,7 @@ class MXMModel(PreTrainedModel):
|
|
| 1004 |
batch = batch.to(device)
|
| 1005 |
with torch.no_grad():
|
| 1006 |
self.pred_data['smiles'] += batch['smiles']
|
| 1007 |
-
self.pred_data['pred'] += self.
|
| 1008 |
|
| 1009 |
pred = torch.tensor(self.pred_data['pred']).reshape(-1)
|
| 1010 |
if device == 'cuda':
|
|
|
|
| 949 |
def __init__(self, config):
|
| 950 |
super().__init__(config)
|
| 951 |
|
| 952 |
+
self.backbone = MXMNet(
|
| 953 |
dim=config.dim,
|
| 954 |
n_layer=config.n_layer,
|
| 955 |
cutoff=config.cutoff,
|
|
|
|
| 961 |
smiles=config.smiles,
|
| 962 |
)
|
| 963 |
|
| 964 |
+
self.model = None
|
| 965 |
self.dataset = None
|
| 966 |
self.output = None
|
| 967 |
self.data_loader = None
|
| 968 |
self.pred_data = None
|
| 969 |
|
| 970 |
def forward(self, tensor):
|
| 971 |
+
return self.backbone.forward_features(tensor)
|
| 972 |
|
| 973 |
def SmilesProcessor(self, smiles):
|
| 974 |
return self.process.get_data(smiles)
|
|
|
|
| 982 |
drop_last = kwargs.pop('drop_last', False)
|
| 983 |
num_workers = kwargs.pop('num_workers', 0)
|
| 984 |
|
| 985 |
+
self.model = AutoModel.from_pretrained("Huhujingjing/custom-mxm", trust_remote_code=True).to(device)
|
| 986 |
+
self.model.eval()
|
| 987 |
|
| 988 |
self.dataset = self.process.get_data(smiles)
|
| 989 |
self.output = ""
|
|
|
|
| 1004 |
batch = batch.to(device)
|
| 1005 |
with torch.no_grad():
|
| 1006 |
self.pred_data['smiles'] += batch['smiles']
|
| 1007 |
+
self.pred_data['pred'] += self.model(batch).cpu().tolist()
|
| 1008 |
|
| 1009 |
pred = torch.tensor(self.pred_data['pred']).reshape(-1)
|
| 1010 |
if device == 'cuda':
|