Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 2,166 Bytes
			
			| dded3ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import torch
from .posegnn.calculator import PosEGNNCalculator
import ase
from ase import Atoms
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download
from tqdm import tqdm
torch.set_float32_matmul_precision("high")
def smiles_to_atoms(smiles):
    mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
    AllChem.EmbedMolecule(mol)
    ase_atoms = ase.Atoms(
        numbers=[
            atom.GetAtomicNum() for atom in mol.GetAtoms()
        ],
        positions=mol.GetConformer().GetPositions()
    )
    return ase_atoms
class POSEGNN():
    def __init__(self, use_gpu=True):
        device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        self.device = device
        self.calculator = None
    def load(self, checkpoint=None):
        repo_id = "ibm-research/materials.pos-egnn"
        filename = "pytorch_model.bin"
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
        self.calculator = PosEGNNCalculator(model_path, device=self.device, compute_stress=False)
    def encode(self, smiles_list, return_tensor=False, batch_size=32):
        results = []
        # make batch-wise processing with progress bar
        for i in tqdm(range(0, len(smiles_list), batch_size), desc="Batch Encoding"):
            batch = smiles_list[i:i+batch_size]
            atoms_batch = []
            for smiles in batch:
                try:
                    atoms = smiles_to_atoms(smiles)
                    atoms.calc = self.calculator
                    atoms_batch.append(atoms)
                except Exception as e:
                    print(f"Skipping {smiles}: {e}")
            if atoms_batch:
                embeddings = [a.get_invariant_embeddings().mean(dim=0).cpu() for a in atoms_batch]
                batch_tensor = torch.stack(embeddings)
                results.append(batch_tensor)
        if not results:
            raise RuntimeError("No valid SMILES could be processed.")
        all_embeddings = torch.cat(results, dim=0)
        return all_embeddings if return_tensor else pd.DataFrame(all_embeddings.numpy())
 | 
