ipd commited on
Commit
d194709
·
verified ·
1 Parent(s): 4e2f180

Update models/selfies_model/load.py

Browse files
Files changed (1) hide show
  1. models/selfies_model/load.py +70 -64
models/selfies_model/load.py CHANGED
@@ -1,96 +1,102 @@
1
- import os
2
- import sys
3
  import torch
4
- import selfies as sf # selfies>=2.1.1
5
- import pickle
6
- import pandas as pd
7
  import numpy as np
8
- from datasets import Dataset
9
  from rdkit import Chem
10
  from transformers import AutoTokenizer, AutoModel
 
 
 
 
11
 
 
 
12
 
13
- class SELFIES(torch.nn.Module):
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(self):
16
  super().__init__()
17
  self.model = None
18
  self.tokenizer = None
19
  self.invalid = []
20
 
21
- def get_selfies(self, smiles_list):
22
- self.invalid = []
23
- spaced_selfies_batch = []
24
- for i, smiles in enumerate(smiles_list):
25
  try:
26
- selfies = sf.encoder(smiles.rstrip())
 
27
  except:
28
- try:
29
- smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.rstrip()))
30
- selfies = sf.encoder(smiles)
31
- except:
32
- selfies = "[]"
33
- self.invalid.append(i)
34
-
35
- spaced_selfies_batch.append(selfies.replace('][', '] ['))
36
-
37
- return spaced_selfies_batch
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def get_embedding(self, selfies):
41
- encoding = self.tokenizer(selfies["selfies"], return_tensors='pt', max_length=128, truncation=True, padding='max_length')
42
- input_ids = encoding['input_ids']
43
- attention_mask = encoding['attention_mask']
44
- outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
45
  model_output = outputs.last_hidden_state
46
-
47
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
48
  sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
49
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
50
- model_output = sum_embeddings / sum_mask
51
-
52
- encoding["embedding"] = model_output
53
 
54
- return encoding
55
-
56
-
57
- def load(self, checkpoint="bart-2908.pickle"):
58
- """
59
- inputs :
60
- checkpoint (pickle object)
61
- """
62
 
 
63
  self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
64
  self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
65
- """if os.path.isfile(checkpoint):
66
- with open(checkpoint, "rb") as input_file:
67
- self.model, self.tokenizer = pickle.load(input_file)
68
- for p in sys.path:
69
- file = p + "/" + checkpoint
70
- if os.path.isfile(file):
71
- with open(file, "rb") as input_file:
72
- self.model, self.tokenizer = pickle.load(input_file)"""
73
 
 
 
 
74
 
 
 
75
 
 
76
 
77
- # TODO: remove `use_gpu` argument in validation pipeline
78
- def encode(self, smiles_list=[], use_gpu=False, return_tensor=False):
79
- """
80
- inputs :
81
- checkpoint (pickle object)
82
- :return: embedding
83
- """
84
- selfies = self.get_selfies(smiles_list)
85
- selfies_df = pd.DataFrame(selfies,columns=["selfies"])
86
- data = Dataset.from_pandas(selfies_df)
87
- embedding = data.map(self.get_embedding, batched=True, num_proc=1, batch_size=128)
88
- emb = np.asarray(embedding["embedding"].copy())
89
 
90
  for idx in self.invalid:
91
  emb[idx] = np.nan
92
- print("Cannot encode {0} to selfies and embedding replaced by NaN".format(smiles_list[idx]))
93
 
94
- if return_tensor:
95
- return torch.tensor(emb)
96
- return pd.DataFrame(emb)
 
 
 
1
  import torch
2
+ import selfies as sf
 
 
3
  import numpy as np
4
+ import pandas as pd
5
  from rdkit import Chem
6
  from transformers import AutoTokenizer, AutoModel
7
+ import gc
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from multiprocessing import Pool, cpu_count
10
+ from tqdm import tqdm
11
 
12
+ import os
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
 
 
15
 
16
+ class SELFIESDataset(Dataset):
17
+ def __init__(self, selfies_list):
18
+ self.selfies = selfies_list
19
+
20
+ def __len__(self):
21
+ return len(self.selfies)
22
+
23
+ def __getitem__(self, idx):
24
+ return self.selfies[idx]
25
+
26
+ class SELFIES(torch.nn.Module):
27
  def __init__(self):
28
  super().__init__()
29
  self.model = None
30
  self.tokenizer = None
31
  self.invalid = []
32
 
33
+ def smiles_to_selfies(self, smiles):
34
+ try:
35
+ return sf.encoder(smiles.strip()).replace('][', '] [')
36
+ except:
37
  try:
38
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.strip()))
39
+ return sf.encoder(smiles).replace('][', '] [')
40
  except:
41
+ return None
 
 
 
 
 
 
 
 
 
42
 
43
+ def get_selfies(self, smiles_list):
44
+ with Pool(cpu_count()) as pool:
45
+ selfies = list(pool.map(self.smiles_to_selfies, smiles_list))
46
+
47
+ self.invalid = [i for i, s in enumerate(selfies) if s is None]
48
+ selfies = [s if s is not None else '[nop]' for s in selfies]
49
+ return selfies
50
+
51
+ @torch.no_grad()
52
+ def get_embedding_batch(self, selfies_batch):
53
+ encodings = self.tokenizer(
54
+ selfies_batch,
55
+ return_tensors='pt',
56
+ max_length=128,
57
+ truncation=True,
58
+ padding='max_length'
59
+ )
60
+ encodings = {k: v.to(self.model.device) for k, v in encodings.items()}
61
+
62
+ outputs = self.model.encoder(
63
+ input_ids=encodings['input_ids'],
64
+ attention_mask=encodings['attention_mask']
65
+ )
66
 
 
 
 
 
 
67
  model_output = outputs.last_hidden_state
68
+ input_mask_expanded = encodings['attention_mask'].unsqueeze(-1).expand(model_output.size()).float()
 
69
  sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
70
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
71
+ pooled_output = sum_embeddings / sum_mask
 
 
72
 
73
+ return pooled_output.cpu().numpy()
 
 
 
 
 
 
 
74
 
75
+ def load(self, checkpoint=None):
76
  self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
77
  self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
78
+ self.model.eval()
 
 
 
 
 
 
 
79
 
80
+ def encode(self, smiles_list=[], use_gpu=False, return_tensor=False, batch_size=128, num_workers=4):
81
+ selfies = self.get_selfies(smiles_list)
82
+ dataset = SELFIESDataset(selfies)
83
 
84
+ device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
85
+ self.model.to(device)
86
 
87
+ loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
88
 
89
+ embeddings = []
90
+ for batch in tqdm(loader, desc="Encoding"):
91
+ emb = self.get_embedding_batch(batch)
92
+ embeddings.append(emb)
93
+ del emb
94
+ gc.collect()
95
+
96
+ emb = np.vstack(embeddings)
 
 
 
 
97
 
98
  for idx in self.invalid:
99
  emb[idx] = np.nan
100
+ print(f"Cannot encode {smiles_list[idx]} to selfies. Embedding replaced by NaN.")
101
 
102
+ return torch.tensor(emb) if return_tensor else pd.DataFrame(emb)