Spaces:
Runtime error
Runtime error
| import torch | |
| from torch_geometric.data import Data | |
| from ogb.utils import smiles2graph | |
| from rdkit import Chem | |
| import random | |
| import os | |
| import json | |
| from rdkit import RDLogger | |
| RDLogger.DisableLog('rdApp.*') | |
| from .r_smiles import multi_process | |
| import multiprocessing | |
| def reformat_smiles(smiles, smiles_type='default'): | |
| if not smiles: | |
| return None | |
| if smiles_type == 'default': | |
| return smiles | |
| elif smiles_type=='canonical': | |
| mol = Chem.MolFromSmiles(smiles) | |
| return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False) | |
| elif smiles_type=='restricted': | |
| mol = Chem.MolFromSmiles(smiles) | |
| new_atom_order = list(range(mol.GetNumAtoms())) | |
| random.shuffle(new_atom_order) | |
| random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order) | |
| return Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False) | |
| elif smiles_type=='unrestricted': | |
| mol = Chem.MolFromSmiles(smiles) | |
| return Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False) | |
| elif smiles_type=='r_smiles': | |
| # the implementation of root-aligned smiles is in r_smiles.py | |
| return smiles | |
| else: | |
| raise NotImplementedError(f"smiles_type {smiles_type} not implemented") | |
| def json_read(path): | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| return data | |
| def json_write(path, data): | |
| with open(path, 'w') as f: | |
| json.dump(data, f, indent=4, ensure_ascii=False) | |
| def format_float_from_string(s): | |
| try: | |
| float_value = float(s) | |
| return f'{float_value:.2f}' | |
| except ValueError: | |
| return s | |
| def make_abstract(mol_dict, abstract_max_len=256, property_max_len=256): | |
| prompt = '' | |
| if 'abstract' in mol_dict: | |
| abstract_string = mol_dict['abstract'][:abstract_max_len] | |
| prompt += f'[Abstract] {abstract_string} ' | |
| property_string = '' | |
| property_dict = mol_dict['property'] if 'property' in mol_dict else {} | |
| for property_key in ['Experimental Properties', 'Computed Properties']: | |
| if not property_key in property_dict: | |
| continue | |
| for key, value in property_dict[property_key].items(): | |
| if isinstance(value, float): | |
| key_value_string = f'{key}: {value:.2f}; ' | |
| elif isinstance(value, str): | |
| float_value = format_float_from_string(value) | |
| key_value_string = f'{key}: {float_value}; ' | |
| else: | |
| key_value_string = f'{key}: {value}; ' | |
| if len(property_string+key_value_string) > property_max_len: | |
| break | |
| property_string += key_value_string | |
| if property_string: | |
| property_string = property_string[:property_max_len] | |
| prompt += f'[Properties] {property_string}. ' | |
| return prompt | |
| def smiles2data(smiles): | |
| graph = smiles2graph(smiles) | |
| x = torch.from_numpy(graph['node_feat']) | |
| edge_index = torch.from_numpy(graph['edge_index'], ) | |
| edge_attr = torch.from_numpy(graph['edge_feat']) | |
| data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) | |
| return data | |
| import re | |
| SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" | |
| CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") | |
| def _insert_split_marker(m: re.Match): | |
| """ | |
| Applies split marker based on a regex match of special tokens such as | |
| [START_DNA]. | |
| Parameters | |
| ---------- | |
| n : str | |
| Input text to split | |
| Returns | |
| ---------- | |
| str - the text with the split token added | |
| """ | |
| start_token, _, sequence, end_token = m.groups() | |
| sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) | |
| return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" | |
| def escape_custom_split_sequence(text): | |
| """ | |
| Applies custom splitting to the text for GALILEO's tokenization | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to split | |
| Returns | |
| ---------- | |
| str - the text with the split token added | |
| """ | |
| return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) | |
| def generate_rsmiles(reactants, products, augmentation=20): | |
| """ | |
| reactants: list of N, reactant smiles | |
| products: list of N, product smiles | |
| augmentation: int, number of augmentations | |
| return: list of N x augmentation | |
| """ | |
| data = [{ | |
| 'reactant': r.strip().replace(' ', ''), | |
| 'product': p.strip().replace(' ', ''), | |
| 'augmentation': augmentation, | |
| 'root_aligned': True, | |
| } for r, p in zip(reactants, products)] | |
| pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) | |
| results = pool.map(func=multi_process,iterable=data) | |
| product_smiles = [smi for r in results for smi in r['src_data']] | |
| reactant_smiles = [smi for r in results for smi in r['tgt_data']] | |
| return reactant_smiles, product_smiles |