File size: 2,166 Bytes
dded3ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from .posegnn.calculator import PosEGNNCalculator
import ase
from ase import Atoms
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download
from tqdm import tqdm

torch.set_float32_matmul_precision("high")

def smiles_to_atoms(smiles):
    mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
    AllChem.EmbedMolecule(mol)
    ase_atoms = ase.Atoms(
        numbers=[
            atom.GetAtomicNum() for atom in mol.GetAtoms()
        ],
        positions=mol.GetConformer().GetPositions()
    )
    return ase_atoms

class POSEGNN():
    def __init__(self, use_gpu=True):
        device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        self.device = device
        self.calculator = None

    def load(self, checkpoint=None):
        repo_id = "ibm-research/materials.pos-egnn"
        filename = "pytorch_model.bin"
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
        self.calculator = PosEGNNCalculator(model_path, device=self.device, compute_stress=False)

    def encode(self, smiles_list, return_tensor=False, batch_size=32):
        results = []

        # make batch-wise processing with progress bar
        for i in tqdm(range(0, len(smiles_list), batch_size), desc="Batch Encoding"):
            batch = smiles_list[i:i+batch_size]
            atoms_batch = []

            for smiles in batch:
                try:
                    atoms = smiles_to_atoms(smiles)
                    atoms.calc = self.calculator
                    atoms_batch.append(atoms)
                except Exception as e:
                    print(f"Skipping {smiles}: {e}")

            if atoms_batch:
                embeddings = [a.get_invariant_embeddings().mean(dim=0).cpu() for a in atoms_batch]
                batch_tensor = torch.stack(embeddings)
                results.append(batch_tensor)

        if not results:
            raise RuntimeError("No valid SMILES could be processed.")

        all_embeddings = torch.cat(results, dim=0)
        return all_embeddings if return_tensor else pd.DataFrame(all_embeddings.numpy())