Spaces:
Running
Running
File size: 2,166 Bytes
dded3ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
from .posegnn.calculator import PosEGNNCalculator
import ase
from ase import Atoms
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download
from tqdm import tqdm
torch.set_float32_matmul_precision("high")
def smiles_to_atoms(smiles):
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
AllChem.EmbedMolecule(mol)
ase_atoms = ase.Atoms(
numbers=[
atom.GetAtomicNum() for atom in mol.GetAtoms()
],
positions=mol.GetConformer().GetPositions()
)
return ase_atoms
class POSEGNN():
def __init__(self, use_gpu=True):
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.device = device
self.calculator = None
def load(self, checkpoint=None):
repo_id = "ibm-research/materials.pos-egnn"
filename = "pytorch_model.bin"
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
self.calculator = PosEGNNCalculator(model_path, device=self.device, compute_stress=False)
def encode(self, smiles_list, return_tensor=False, batch_size=32):
results = []
# make batch-wise processing with progress bar
for i in tqdm(range(0, len(smiles_list), batch_size), desc="Batch Encoding"):
batch = smiles_list[i:i+batch_size]
atoms_batch = []
for smiles in batch:
try:
atoms = smiles_to_atoms(smiles)
atoms.calc = self.calculator
atoms_batch.append(atoms)
except Exception as e:
print(f"Skipping {smiles}: {e}")
if atoms_batch:
embeddings = [a.get_invariant_embeddings().mean(dim=0).cpu() for a in atoms_batch]
batch_tensor = torch.stack(embeddings)
results.append(batch_tensor)
if not results:
raise RuntimeError("No valid SMILES could be processed.")
all_embeddings = torch.cat(results, dim=0)
return all_embeddings if return_tensor else pd.DataFrame(all_embeddings.numpy())
|