File size: 2,777 Bytes
de956c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15216c3
 
de956c8
15216c3
de956c8
 
 
 
 
 
15216c3
de956c8
 
 
15216c3
de956c8
 
 
 
 
 
 
 
15216c3
 
de956c8
 
 
15216c3
de956c8
 
 
15216c3
de956c8
15216c3
de956c8
 
 
 
 
15216c3
 
 
de956c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pytest
import os
import sys
import logging

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor

import torch


def test_protac_model():
    model = PROTAC_Model(hidden_dim=128)
    assert model.hidden_dim == 128
    assert model.smiles_emb_dim == 224
    assert model.poi_emb_dim == 1024
    assert model.e3_emb_dim == 1024
    assert model.cell_emb_dim == 768
    assert model.batch_size == 32
    assert model.learning_rate == 0.001
    assert model.dropout == 0.2
    assert model.join_embeddings == 'concat'
    assert model.train_dataset is None
    assert model.val_dataset is None
    assert model.test_dataset is None
    assert model.disabled_embeddings == []
    assert model.apply_scaling == False

def test_protac_predictor():
    predictor = PROTAC_Predictor(hidden_dim=128)
    assert predictor.hidden_dim == 128
    assert predictor.smiles_emb_dim == 224
    assert predictor.poi_emb_dim == 1024
    assert predictor.e3_emb_dim == 1024
    assert predictor.cell_emb_dim == 768
    assert predictor.join_embeddings == 'concat'
    assert predictor.disabled_embeddings == []

def test_load_model(caplog):
    caplog.set_level(logging.WARNING)

    model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'

    model = PROTAC_Model.load_from_checkpoint(
        model_filename,
        map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
    )
    # apply_scaling: true
    # batch_size: 8
    # cell_emb_dim: 768
    # disabled_embeddings: []
    # dropout: 0.11257777663560328
    # e3_emb_dim: 1024
    # hidden_dim: 768
    # join_embeddings: concat
    # learning_rate: 1.843233973932415e-05
    # poi_emb_dim: 1024
    # smiles_emb_dim: 224
    assert model.hidden_dim == 768
    assert model.smiles_emb_dim == 224
    assert model.poi_emb_dim == 1024
    assert model.e3_emb_dim == 1024
    assert model.cell_emb_dim == 768
    assert model.batch_size == 8
    assert model.learning_rate == 1.843233973932415e-05
    assert model.dropout == 0.11257777663560328
    assert model.join_embeddings == 'concat'
    assert model.disabled_embeddings == []
    assert model.apply_scaling == True
    print(model.scalers)


def test_checkpoint_file():
    model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
    checkpoint = torch.load(
        model_filename,
        map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
    )
    print(checkpoint.keys())
    print(checkpoint["hyper_parameters"])
    print([k for k, v in checkpoint["state_dict"].items()])
    import pickle

    print(pickle.loads(checkpoint['scalers']))

pytest.main()