File size: 3,383 Bytes
6747ba1
d194709
6747ba1
d194709
6747ba1
 
d194709
 
 
 
6747ba1
d194709
 
6747ba1
 
d194709
 
 
 
 
 
 
 
 
 
 
6747ba1
 
 
 
 
 
d194709
 
 
 
6747ba1
d194709
 
6747ba1
d194709
6747ba1
d194709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6747ba1
 
d194709
6747ba1
 
d194709
6747ba1
d194709
6747ba1
d194709
6747ba1
 
d194709
6747ba1
d194709
 
 
6747ba1
d194709
 
6747ba1
d194709
6747ba1
d194709
 
 
 
 
 
 
 
6747ba1
 
 
d194709
6747ba1
d194709
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import selfies as sf
import numpy as np
import pandas as pd
from rdkit import Chem
from transformers import AutoTokenizer, AutoModel
import gc
from torch.utils.data import DataLoader, Dataset
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class SELFIESDataset(Dataset):
    def __init__(self, selfies_list):
        self.selfies = selfies_list

    def __len__(self):
        return len(self.selfies)

    def __getitem__(self, idx):
        return self.selfies[idx]

class SELFIES(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = None
        self.tokenizer = None
        self.invalid = []

    def smiles_to_selfies(self, smiles):
        try:
            return sf.encoder(smiles.strip()).replace('][', '] [')
        except:
            try:
                smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.strip()))
                return sf.encoder(smiles).replace('][', '] [')
            except:
                return None

    def get_selfies(self, smiles_list):
        with Pool(cpu_count()) as pool:
            selfies = list(pool.map(self.smiles_to_selfies, smiles_list))

        self.invalid = [i for i, s in enumerate(selfies) if s is None]
        selfies = [s if s is not None else '[nop]' for s in selfies]
        return selfies

    @torch.no_grad()
    def get_embedding_batch(self, selfies_batch):
        encodings = self.tokenizer(
            selfies_batch,
            return_tensors='pt',
            max_length=128,
            truncation=True,
            padding='max_length'
        )
        encodings = {k: v.to(self.model.device) for k, v in encodings.items()}

        outputs = self.model.encoder(
            input_ids=encodings['input_ids'],
            attention_mask=encodings['attention_mask']
        )

        model_output = outputs.last_hidden_state
        input_mask_expanded = encodings['attention_mask'].unsqueeze(-1).expand(model_output.size()).float()
        sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        pooled_output = sum_embeddings / sum_mask

        return pooled_output.cpu().numpy()

    def load(self, checkpoint=None):
        self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
        self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
        self.model.eval()

    def encode(self, smiles_list=[], use_gpu=False, return_tensor=False, batch_size=128, num_workers=4):
        selfies = self.get_selfies(smiles_list)
        dataset = SELFIESDataset(selfies)

        device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
        self.model.to(device)

        loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)

        embeddings = []
        for batch in tqdm(loader, desc="Encoding"):
            emb = self.get_embedding_batch(batch)
            embeddings.append(emb)
            del emb
            gc.collect()

        emb = np.vstack(embeddings)

        for idx in self.invalid:
            emb[idx] = np.nan
            print(f"Cannot encode {smiles_list[idx]} to selfies. Embedding replaced by NaN.")

        return torch.tensor(emb) if return_tensor else pd.DataFrame(emb)