ipd's picture
add pos-egnn
dded3ca
raw
history blame
2.17 kB
import torch
from .posegnn.calculator import PosEGNNCalculator
import ase
from ase import Atoms
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download
from tqdm import tqdm
torch.set_float32_matmul_precision("high")
def smiles_to_atoms(smiles):
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
AllChem.EmbedMolecule(mol)
ase_atoms = ase.Atoms(
numbers=[
atom.GetAtomicNum() for atom in mol.GetAtoms()
],
positions=mol.GetConformer().GetPositions()
)
return ase_atoms
class POSEGNN():
def __init__(self, use_gpu=True):
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.device = device
self.calculator = None
def load(self, checkpoint=None):
repo_id = "ibm-research/materials.pos-egnn"
filename = "pytorch_model.bin"
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
self.calculator = PosEGNNCalculator(model_path, device=self.device, compute_stress=False)
def encode(self, smiles_list, return_tensor=False, batch_size=32):
results = []
# make batch-wise processing with progress bar
for i in tqdm(range(0, len(smiles_list), batch_size), desc="Batch Encoding"):
batch = smiles_list[i:i+batch_size]
atoms_batch = []
for smiles in batch:
try:
atoms = smiles_to_atoms(smiles)
atoms.calc = self.calculator
atoms_batch.append(atoms)
except Exception as e:
print(f"Skipping {smiles}: {e}")
if atoms_batch:
embeddings = [a.get_invariant_embeddings().mean(dim=0).cpu() for a in atoms_batch]
batch_tensor = torch.stack(embeddings)
results.append(batch_tensor)
if not results:
raise RuntimeError("No valid SMILES could be processed.")
all_embeddings = torch.cat(results, dim=0)
return all_embeddings if return_tensor else pd.DataFrame(all_embeddings.numpy())