Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import time | |
import random | |
import re | |
import string | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from .indigo import Indigo | |
from .indigo.renderer import IndigoRenderer | |
from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise | |
from .utils import FORMAT_INFO | |
from .tokenizer import PAD_ID | |
from .chemistry import get_num_atoms, normalize_nodes | |
from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS | |
cv2.setNumThreads(1) | |
INDIGO_HYGROGEN_PROB = 0.2 | |
INDIGO_FUNCTIONAL_GROUP_PROB = 0.8 | |
INDIGO_CONDENSED_PROB = 0.5 | |
INDIGO_RGROUP_PROB = 0.5 | |
INDIGO_COMMENT_PROB = 0.3 | |
INDIGO_DEARMOTIZE_PROB = 0.8 | |
INDIGO_COLOR_PROB = 0.2 | |
def get_transforms(input_size, augment=True, rotate=True, debug=False): | |
trans_list = [] | |
if augment and rotate: | |
trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255))) | |
trans_list.append(CropWhite(pad=5)) | |
if augment: | |
trans_list += [ | |
# NormalizedGridDistortion(num_steps=10, distort_limit=0.3), | |
A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5), | |
PadWhite(pad_ratio=0.4, p=0.2), | |
A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3), | |
A.Blur(), | |
A.GaussNoise(), | |
SaltAndPepperNoise(num_dots=20, p=0.5) | |
] | |
trans_list.append(A.Resize(input_size, input_size)) | |
if not debug: | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
trans_list += [ | |
A.ToGray(p=1), | |
A.Normalize(mean=mean, std=std), | |
ToTensorV2(), | |
] | |
return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False)) | |
def add_functional_group(indigo, mol, debug=False): | |
if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB: | |
return mol | |
# Delete functional group and add a pseudo atom with its abbrv | |
substitutions = [sub for sub in SUBSTITUTIONS] | |
random.shuffle(substitutions) | |
for sub in substitutions: | |
query = indigo.loadSmarts(sub.smarts) | |
matcher = indigo.substructureMatcher(mol) | |
matched_atoms_ids = set() | |
for match in matcher.iterateMatches(query): | |
if random.random() < sub.probability or debug: | |
atoms = [] | |
atoms_ids = set() | |
for item in query.iterateAtoms(): | |
atom = match.mapAtom(item) | |
atoms.append(atom) | |
atoms_ids.add(atom.index()) | |
if len(matched_atoms_ids.intersection(atoms_ids)) > 0: | |
continue | |
abbrv = random.choice(sub.abbrvs) | |
superatom = mol.addAtom(abbrv) | |
for atom in atoms: | |
for nei in atom.iterateNeighbors(): | |
if nei.index() not in atoms_ids: | |
if nei.symbol() == 'H': | |
# indigo won't match explicit hydrogen, so remove them explicitly | |
atoms_ids.add(nei.index()) | |
else: | |
superatom.addBond(nei, nei.bond().bondOrder()) | |
for id in atoms_ids: | |
mol.getAtom(id).remove() | |
matched_atoms_ids = matched_atoms_ids.union(atoms_ids) | |
return mol | |
def add_explicit_hydrogen(indigo, mol): | |
atoms = [] | |
for atom in mol.iterateAtoms(): | |
try: | |
hs = atom.countImplicitHydrogens() | |
if hs > 0: | |
atoms.append((atom, hs)) | |
except: | |
continue | |
if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB: | |
atom, hs = random.choice(atoms) | |
for i in range(hs): | |
h = mol.addAtom('H') | |
h.addBond(atom, 1) | |
return mol | |
def add_rgroup(indigo, mol, smiles): | |
atoms = [] | |
for atom in mol.iterateAtoms(): | |
try: | |
hs = atom.countImplicitHydrogens() | |
if hs > 0: | |
atoms.append(atom) | |
except: | |
continue | |
if len(atoms) > 0 and '*' not in smiles: | |
if random.random() < INDIGO_RGROUP_PROB: | |
atom_idx = random.choice(range(len(atoms))) | |
atom = atoms[atom_idx] | |
atoms.pop(atom_idx) | |
symbol = random.choice(RGROUP_SYMBOLS) | |
r = mol.addAtom(symbol) | |
r.addBond(atom, 1) | |
return mol | |
def get_rand_symb(): | |
symb = random.choice(ELEMENTS) | |
if random.random() < 0.1: | |
symb += random.choice(string.ascii_lowercase) | |
if random.random() < 0.1: | |
symb += random.choice(string.ascii_uppercase) | |
if random.random() < 0.1: | |
symb = f'({gen_rand_condensed()})' | |
return symb | |
def get_rand_num(): | |
if random.random() < 0.9: | |
if random.random() < 0.8: | |
return '' | |
else: | |
return str(random.randint(2, 9)) | |
else: | |
return '1' + str(random.randint(2, 9)) | |
def gen_rand_condensed(): | |
tokens = [] | |
for i in range(5): | |
if i >= 1 and random.random() < 0.8: | |
break | |
tokens.append(get_rand_symb()) | |
tokens.append(get_rand_num()) | |
return ''.join(tokens) | |
def add_rand_condensed(indigo, mol): | |
atoms = [] | |
for atom in mol.iterateAtoms(): | |
try: | |
hs = atom.countImplicitHydrogens() | |
if hs > 0: | |
atoms.append(atom) | |
except: | |
continue | |
if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB: | |
atom = random.choice(atoms) | |
symbol = gen_rand_condensed() | |
r = mol.addAtom(symbol) | |
r.addBond(atom, 1) | |
return mol | |
def generate_output_smiles(indigo, mol): | |
# TODO: if using mol.canonicalSmiles(), explicit H will be removed | |
smiles = mol.smiles() | |
mol = indigo.loadMolecule(smiles) | |
if '*' in smiles: | |
part_a, part_b = smiles.split(' ', maxsplit=1) | |
part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1] | |
symbols = [t for t in part_b.split(';') if len(t) > 0] | |
output = '' | |
cnt = 0 | |
for i, c in enumerate(part_a): | |
if c != '*': | |
output += c | |
else: | |
output += f'[{symbols[cnt]}]' | |
cnt += 1 | |
return mol, output | |
else: | |
if ' ' in smiles: | |
# special cases with extension | |
smiles = smiles.split(' ')[0] | |
return mol, smiles | |
def add_comment(indigo): | |
if random.random() < INDIGO_COMMENT_PROB: | |
indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters)) | |
indigo.setOption('render-comment-font-size', random.randint(40, 60)) | |
indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1])) | |
indigo.setOption('render-comment-position', random.choice(['top', 'bottom'])) | |
indigo.setOption('render-comment-offset', random.randint(2, 30)) | |
def add_color(indigo, mol): | |
if random.random() < INDIGO_COLOR_PROB: | |
indigo.setOption('render-coloring', True) | |
if random.random() < INDIGO_COLOR_PROB: | |
indigo.setOption('render-base-color', random.choice(list(COLORS.values()))) | |
if random.random() < INDIGO_COLOR_PROB: | |
if random.random() < 0.5: | |
indigo.setOption('render-highlight-color-enabled', True) | |
indigo.setOption('render-highlight-color', random.choice(list(COLORS.values()))) | |
if random.random() < 0.5: | |
indigo.setOption('render-highlight-thickness-enabled', True) | |
for atom in mol.iterateAtoms(): | |
if random.random() < 0.1: | |
atom.highlight() | |
return mol | |
def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False): | |
mol.layout() | |
coords, symbols = [], [] | |
index_map = {} | |
atoms = [atom for atom in mol.iterateAtoms()] | |
if shuffle_nodes: | |
random.shuffle(atoms) | |
for i, atom in enumerate(atoms): | |
if pseudo_coords: | |
x, y, z = atom.xyz() | |
else: | |
x, y = atom.coords() | |
coords.append([x, y]) | |
symbols.append(atom.symbol()) | |
index_map[atom.index()] = i | |
if pseudo_coords: | |
coords = normalize_nodes(np.array(coords)) | |
h, w, _ = image.shape | |
coords[:, 0] = coords[:, 0] * w | |
coords[:, 1] = coords[:, 1] * h | |
n = len(symbols) | |
edges = np.zeros((n, n), dtype=int) | |
for bond in mol.iterateBonds(): | |
s = index_map[bond.source().index()] | |
t = index_map[bond.destination().index()] | |
# 1/2/3/4 : single/double/triple/aromatic | |
edges[s, t] = bond.bondOrder() | |
edges[t, s] = bond.bondOrder() | |
if bond.bondStereo() in [5, 6]: | |
edges[s, t] = bond.bondStereo() | |
edges[t, s] = 11 - bond.bondStereo() | |
graph = { | |
'coords': coords, | |
'symbols': symbols, | |
'edges': edges, | |
'num_atoms': len(symbols) | |
} | |
return graph | |
def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False, | |
include_condensed=True, debug=False): | |
indigo = Indigo() | |
renderer = IndigoRenderer(indigo) | |
indigo.setOption('render-output-format', 'png') | |
indigo.setOption('render-background-color', '1,1,1') | |
indigo.setOption('render-stereo-style', 'none') | |
indigo.setOption('render-label-mode', 'hetero') | |
indigo.setOption('render-font-family', 'Arial') | |
if not default_option: | |
thickness = random.uniform(0.5, 2) # limit the sum of the following two parameters to be smaller than 4 | |
indigo.setOption('render-relative-thickness', thickness) | |
indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness)) | |
if random.random() < 0.5: | |
indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica'])) | |
indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero'])) | |
indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False])) | |
if random.random() < 0.1: | |
indigo.setOption('render-stereo-style', 'old') | |
if random.random() < 0.2: | |
indigo.setOption('render-atom-ids-visible', True) | |
try: | |
mol = indigo.loadMolecule(smiles) | |
if mol_augment: | |
if random.random() < INDIGO_DEARMOTIZE_PROB: | |
mol.dearomatize() | |
else: | |
mol.aromatize() | |
smiles = mol.canonicalSmiles() | |
add_comment(indigo) | |
mol = add_explicit_hydrogen(indigo, mol) | |
mol = add_rgroup(indigo, mol, smiles) | |
if include_condensed: | |
mol = add_rand_condensed(indigo, mol) | |
mol = add_functional_group(indigo, mol, debug) | |
mol = add_color(indigo, mol) | |
mol, smiles = generate_output_smiles(indigo, mol) | |
buf = renderer.renderToBuffer(mol) | |
img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) # decode buffer to image | |
# img = np.repeat(np.expand_dims(img, 2), 3, axis=2) # expand to RGB | |
graph = get_graph(mol, img, shuffle_nodes, pseudo_coords) | |
success = True | |
except Exception: | |
if debug: | |
raise Exception | |
img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) | |
graph = {} | |
success = False | |
return img, smiles, graph, success | |
class TrainDataset(Dataset): | |
def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False): | |
super().__init__() | |
self.df = df | |
self.args = args | |
self.tokenizer = tokenizer | |
if 'file_path' in df.columns: | |
self.file_paths = df['file_path'].values | |
if not self.file_paths[0].startswith(args.data_path): | |
self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']] | |
self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None | |
self.formats = args.formats | |
self.labelled = (split == 'train') | |
if self.labelled: | |
self.labels = {} | |
for format_ in self.formats: | |
if format_ in ['atomtok', 'inchi']: | |
field = FORMAT_INFO[format_]['name'] | |
if field in df.columns: | |
self.labels[format_] = df[field].values | |
self.transform = get_transforms(args.input_size, | |
augment=(self.labelled and args.augment)) | |
# self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)]) | |
self.dynamic_indigo = (dynamic_indigo and split == 'train') | |
if self.labelled and not dynamic_indigo and args.coords_file is not None: | |
if args.coords_file == 'aux_file': | |
self.coords_df = df | |
self.pseudo_coords = True | |
else: | |
self.coords_df = pd.read_csv(args.coords_file) | |
self.pseudo_coords = False | |
else: | |
self.coords_df = None | |
self.pseudo_coords = args.pseudo_coords | |
def __len__(self): | |
return len(self.df) | |
def image_transform(self, image, coords=[], renormalize=False): | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # .astype(np.float32) | |
augmented = self.transform(image=image, keypoints=coords) | |
image = augmented['image'] | |
if len(coords) > 0: | |
coords = np.array(augmented['keypoints']) | |
if renormalize: | |
coords = normalize_nodes(coords, flip_y=False) | |
else: | |
_, height, width = image.shape | |
coords[:, 0] = coords[:, 0] / width | |
coords[:, 1] = coords[:, 1] / height | |
coords = np.array(coords).clip(0, 1) | |
return image, coords | |
return image | |
def __getitem__(self, idx): | |
try: | |
return self.getitem(idx) | |
except Exception as e: | |
with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f: | |
f.write(str(e)) | |
raise e | |
def getitem(self, idx): | |
ref = {} | |
if self.dynamic_indigo: | |
begin = time.time() | |
image, smiles, graph, success = generate_indigo_image( | |
self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option, | |
shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords, | |
include_condensed=self.args.include_condensed) | |
# raw_image = image | |
end = time.time() | |
if idx < 30 and self.args.save_image: | |
path = os.path.join(self.args.save_path, 'images') | |
os.makedirs(path, exist_ok=True) | |
cv2.imwrite(os.path.join(path, f'{idx}.png'), image) | |
if not success: | |
return idx, None, {} | |
image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords) | |
graph['coords'] = coords | |
ref['time'] = end - begin | |
if 'atomtok' in self.formats: | |
max_len = FORMAT_INFO['atomtok']['max_len'] | |
label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False) | |
ref['atomtok'] = torch.LongTensor(label[:max_len]) | |
if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats: | |
ref['edges'] = torch.tensor(graph['edges']) | |
if 'atomtok_coords' in self.formats: | |
self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'], | |
mask_ratio=self.args.mask_ratio) | |
if 'chartok_coords' in self.formats: | |
self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'], | |
mask_ratio=self.args.mask_ratio) | |
return idx, image, ref | |
else: | |
file_path = self.file_paths[idx] | |
image = cv2.imread(file_path) | |
if image is None: | |
image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) | |
print(file_path, 'not found!') | |
if self.coords_df is not None: | |
h, w, _ = image.shape | |
coords = np.array(eval(self.coords_df.loc[idx, 'node_coords'])) | |
if self.pseudo_coords: | |
coords = normalize_nodes(coords) | |
coords[:, 0] = coords[:, 0] * w | |
coords[:, 1] = coords[:, 1] * h | |
image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords) | |
else: | |
image = self.image_transform(image) | |
coords = None | |
if self.labelled: | |
smiles = self.smiles[idx] | |
if 'atomtok' in self.formats: | |
max_len = FORMAT_INFO['atomtok']['max_len'] | |
label = self.tokenizer['atomtok'].text_to_sequence(smiles, False) | |
ref['atomtok'] = torch.LongTensor(label[:max_len]) | |
if 'atomtok_coords' in self.formats: | |
if coords is not None: | |
self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0) | |
else: | |
self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) | |
if 'chartok_coords' in self.formats: | |
if coords is not None: | |
self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0) | |
else: | |
self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) | |
if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats): | |
smiles = self.smiles[idx] | |
if 'atomtok_coords' in self.formats: | |
self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) | |
if 'chartok_coords' in self.formats: | |
self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) | |
return idx, image, ref | |
def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): | |
max_len = FORMAT_INFO['atomtok_coords']['max_len'] | |
tokenizer = self.tokenizer['atomtok_coords'] | |
if smiles is None or type(smiles) is not str: | |
smiles = "" | |
label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) | |
ref['atomtok_coords'] = torch.LongTensor(label[:max_len]) | |
indices = [i for i in indices if i < max_len] | |
ref['atom_indices'] = torch.LongTensor(indices) | |
if tokenizer.continuous_coords: | |
if coords is not None: | |
ref['coords'] = torch.tensor(coords) | |
else: | |
ref['coords'] = torch.ones(len(indices), 2) * -1. | |
if edges is not None: | |
ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] | |
else: | |
if 'edges' in self.df.columns: | |
edge_list = eval(self.df.loc[idx, 'edges']) | |
n = len(indices) | |
edges = torch.zeros((n, n), dtype=torch.long) | |
for u, v, t in edge_list: | |
if u < n and v < n: | |
if t <= 4: | |
edges[u, v] = t | |
edges[v, u] = t | |
else: | |
edges[u, v] = t | |
edges[v, u] = 11 - t | |
ref['edges'] = edges | |
else: | |
ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) | |
def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): | |
max_len = FORMAT_INFO['chartok_coords']['max_len'] | |
tokenizer = self.tokenizer['chartok_coords'] | |
if smiles is None or type(smiles) is not str: | |
smiles = "" | |
label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) | |
ref['chartok_coords'] = torch.LongTensor(label[:max_len]) | |
indices = [i for i in indices if i < max_len] | |
ref['atom_indices'] = torch.LongTensor(indices) | |
if tokenizer.continuous_coords: | |
if coords is not None: | |
ref['coords'] = torch.tensor(coords) | |
else: | |
ref['coords'] = torch.ones(len(indices), 2) * -1. | |
if edges is not None: | |
ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] | |
else: | |
if 'edges' in self.df.columns: | |
edge_list = eval(self.df.loc[idx, 'edges']) | |
n = len(indices) | |
edges = torch.zeros((n, n), dtype=torch.long) | |
for u, v, t in edge_list: | |
if u < n and v < n: | |
if t <= 4: | |
edges[u, v] = t | |
edges[v, u] = t | |
else: | |
edges[u, v] = t | |
edges[v, u] = 11 - t | |
ref['edges'] = edges | |
else: | |
ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) | |
class AuxTrainDataset(Dataset): | |
def __init__(self, args, train_df, aux_df, tokenizer): | |
super().__init__() | |
self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo) | |
self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False) | |
def __len__(self): | |
return len(self.train_dataset) + len(self.aux_dataset) | |
def __getitem__(self, idx): | |
if idx < len(self.train_dataset): | |
return self.train_dataset[idx] | |
else: | |
return self.aux_dataset[idx - len(self.train_dataset)] | |
def pad_images(imgs): | |
# B, C, H, W | |
max_shape = [0, 0] | |
for img in imgs: | |
for i in range(len(max_shape)): | |
max_shape[i] = max(max_shape[i], img.shape[-1 - i]) | |
stack = [] | |
for img in imgs: | |
pad = [] | |
for i in range(len(max_shape)): | |
pad = pad + [0, max_shape[i] - img.shape[-1 - i]] | |
stack.append(F.pad(img, pad, value=0)) | |
return torch.stack(stack) | |
def bms_collate(batch): | |
ids = [] | |
imgs = [] | |
batch = [ex for ex in batch if ex[1] is not None] | |
formats = list(batch[0][2].keys()) | |
seq_formats = [k for k in formats if | |
k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']] | |
refs = {key: [[], []] for key in seq_formats} | |
for ex in batch: | |
ids.append(ex[0]) | |
imgs.append(ex[1]) | |
ref = ex[2] | |
for key in seq_formats: | |
refs[key][0].append(ref[key]) | |
refs[key][1].append(torch.LongTensor([len(ref[key])])) | |
# Sequence | |
for key in seq_formats: | |
# this padding should work for atomtok_with_coords too, each of which has shape (length, 4) | |
refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID) | |
refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1) | |
# Time | |
# if 'time' in formats: | |
# refs['time'] = [ex[2]['time'] for ex in batch] | |
# Coords | |
if 'coords' in formats: | |
refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.) | |
# Edges | |
if 'edges' in formats: | |
edges_list = [ex[2]['edges'] for ex in batch] | |
max_len = max([len(edges) for edges in edges_list]) | |
refs['edges'] = torch.stack( | |
[F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list], | |
dim=0) | |
return ids, pad_images(imgs), refs | |