Spaces:
Running
Running
import torch | |
from .posegnn.calculator import PosEGNNCalculator | |
import ase | |
from ase import Atoms | |
from rdkit import Chem | |
from rdkit.Chem import AllChem | |
import pandas as pd | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from tqdm import tqdm | |
torch.set_float32_matmul_precision("high") | |
def smiles_to_atoms(smiles): | |
mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) | |
AllChem.EmbedMolecule(mol) | |
ase_atoms = ase.Atoms( | |
numbers=[ | |
atom.GetAtomicNum() for atom in mol.GetAtoms() | |
], | |
positions=mol.GetConformer().GetPositions() | |
) | |
return ase_atoms | |
class POSEGNN(): | |
def __init__(self, use_gpu=True): | |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" | |
self.device = device | |
self.calculator = None | |
def load(self, checkpoint=None): | |
repo_id = "ibm-research/materials.pos-egnn" | |
filename = "pytorch_model.bin" | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
self.calculator = PosEGNNCalculator(model_path, device=self.device, compute_stress=False) | |
def encode(self, smiles_list, return_tensor=False, batch_size=32): | |
results = [] | |
# make batch-wise processing with progress bar | |
for i in tqdm(range(0, len(smiles_list), batch_size), desc="Batch Encoding"): | |
batch = smiles_list[i:i+batch_size] | |
atoms_batch = [] | |
for smiles in batch: | |
try: | |
atoms = smiles_to_atoms(smiles) | |
atoms.calc = self.calculator | |
atoms_batch.append(atoms) | |
except Exception as e: | |
print(f"Skipping {smiles}: {e}") | |
if atoms_batch: | |
embeddings = [a.get_invariant_embeddings().mean(dim=0).cpu() for a in atoms_batch] | |
batch_tensor = torch.stack(embeddings) | |
results.append(batch_tensor) | |
if not results: | |
raise RuntimeError("No valid SMILES could be processed.") | |
all_embeddings = torch.cat(results, dim=0) | |
return all_embeddings if return_tensor else pd.DataFrame(all_embeddings.numpy()) | |