Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +6 -4
trainer.py
CHANGED
|
@@ -732,11 +732,12 @@ class Trainer(object):
|
|
| 732 |
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
|
| 733 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 734 |
|
| 735 |
-
|
|
|
|
| 736 |
if self.submodel == "NoTarget":
|
| 737 |
-
|
| 738 |
else:
|
| 739 |
-
|
| 740 |
|
| 741 |
if self.submodel == "RL":
|
| 742 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
|
@@ -916,7 +917,8 @@ class Trainer(object):
|
|
| 916 |
"Runtime (seconds)": round(et, 2),
|
| 917 |
"Validity": f"{fraction_valid(metric_calc_dr)*100:.2f}%",
|
| 918 |
"Uniqueness": f"{fraction_unique(metric_calc_dr)*100:.2f}%",
|
| 919 |
-
"Novelty": f"{novelty(metric_calc_dr,
|
|
|
|
| 920 |
}
|
| 921 |
# print("Validity: ", fraction_valid(metric_calc_dr), "\n")
|
| 922 |
# print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")
|
|
|
|
| 732 |
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
|
| 733 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 734 |
|
| 735 |
+
|
| 736 |
+
smiles_test = [line for line in open("data/chembl_test.smi", 'r').read().splitlines()]
|
| 737 |
if self.submodel == "NoTarget":
|
| 738 |
+
smiles_train = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 739 |
else:
|
| 740 |
+
smiles_train = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
| 741 |
|
| 742 |
if self.submodel == "RL":
|
| 743 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
|
|
|
| 917 |
"Runtime (seconds)": round(et, 2),
|
| 918 |
"Validity": f"{fraction_valid(metric_calc_dr)*100:.2f}%",
|
| 919 |
"Uniqueness": f"{fraction_unique(metric_calc_dr)*100:.2f}%",
|
| 920 |
+
"Novelty Train": f"{novelty(metric_calc_dr, smiles_train)*100:.2f}%",
|
| 921 |
+
"Novelty Test": f"{novelty(metric_calc_dr, smiles_test)*100:.2f}%"
|
| 922 |
}
|
| 923 |
# print("Validity: ", fraction_valid(metric_calc_dr), "\n")
|
| 924 |
# print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")
|