import torch import selfies as sf import numpy as np import pandas as pd from rdkit import Chem from transformers import AutoTokenizer, AutoModel import gc from torch.utils.data import DataLoader, Dataset from multiprocessing import Pool, cpu_count from tqdm import tqdm import os os.environ["TOKENIZERS_PARALLELISM"] = "false" class SELFIESDataset(Dataset): def __init__(self, selfies_list): self.selfies = selfies_list def __len__(self): return len(self.selfies) def __getitem__(self, idx): return self.selfies[idx] class SELFIES(torch.nn.Module): def __init__(self): super().__init__() self.model = None self.tokenizer = None self.invalid = [] def smiles_to_selfies(self, smiles): try: return sf.encoder(smiles.strip()).replace('][', '] [') except: try: smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.strip())) return sf.encoder(smiles).replace('][', '] [') except: return None def get_selfies(self, smiles_list): with Pool(cpu_count()) as pool: selfies = list(pool.map(self.smiles_to_selfies, smiles_list)) self.invalid = [i for i, s in enumerate(selfies) if s is None] selfies = [s if s is not None else '[nop]' for s in selfies] return selfies @torch.no_grad() def get_embedding_batch(self, selfies_batch): encodings = self.tokenizer( selfies_batch, return_tensors='pt', max_length=128, truncation=True, padding='max_length' ) encodings = {k: v.to(self.model.device) for k, v in encodings.items()} outputs = self.model.encoder( input_ids=encodings['input_ids'], attention_mask=encodings['attention_mask'] ) model_output = outputs.last_hidden_state input_mask_expanded = encodings['attention_mask'].unsqueeze(-1).expand(model_output.size()).float() sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) pooled_output = sum_embeddings / sum_mask return pooled_output.cpu().numpy() def load(self, checkpoint=None): self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted") self.model.eval() def encode(self, smiles_list=[], use_gpu=False, return_tensor=False, batch_size=128, num_workers=4): selfies = self.get_selfies(smiles_list) dataset = SELFIESDataset(selfies) device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu") self.model.to(device) loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) embeddings = [] for batch in tqdm(loader, desc="Encoding"): emb = self.get_embedding_batch(batch) embeddings.append(emb) del emb gc.collect() emb = np.vstack(embeddings) for idx in self.invalid: emb[idx] = np.nan print(f"Cannot encode {smiles_list[idx]} to selfies. Embedding replaced by NaN.") return torch.tensor(emb) if return_tensor else pd.DataFrame(emb)