Spaces:
Sleeping
Sleeping
import numpy as np | |
from PIL import Image | |
import cv2 | |
import layoutparser as lp | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from rdkit.Chem import rdDepictor | |
rdDepictor.SetPreferCoordGen(True) | |
from rdkit.Chem.Draw import IPythonConsole | |
from rdkit.Chem import AllChem | |
import re | |
import copy | |
BOND_TO_INT = { | |
"": 0, | |
"single": 1, | |
"double": 2, | |
"triple": 3, | |
"aromatic": 4, | |
"solid wedge": 5, | |
"dashed wedge": 6 | |
} | |
RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', | |
'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', "R'", | |
'1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]'] | |
RGROUP_SYMBOLS = RGROUP_SYMBOLS + [f'[{i}]' for i in RGROUP_SYMBOLS] | |
RGROUP_SMILES = ['[1*]', '[2*]','[3*]', '[4*]','[5*]', '[6*]','[7*]', '[8*]','[9*]', '[10*]','[11*]', '[12*]','[a*]', '[b*]','[c*]', '[d*]','*', '[Rf]'] | |
def get_figures_from_pages(pages, pdfparser): | |
figures = [] | |
for i in range(len(pages)): | |
img = np.asarray(pages[i]) | |
layout = pdfparser.detect(img) | |
blocks = lp.Layout([b for b in layout if b.type == "Figure"]) | |
for block in blocks: | |
figure = Image.fromarray(block.crop_image(img)) | |
figures.append({ | |
'image': figure, | |
'page': i | |
}) | |
return figures | |
def clean_bbox_output(figures, bboxes): | |
results = [] | |
cropped = [] | |
references = [] | |
for i, output in enumerate(bboxes): | |
mol_bboxes = [elt['bbox'] for elt in output if elt['category'] == '[Mol]'] | |
mol_scores = [elt['score'] for elt in output if elt['category'] == '[Mol]'] | |
data = {} | |
results.append(data) | |
data['image'] = figures[i] | |
data['molecules'] = [] | |
for bbox, score in zip(mol_bboxes, mol_scores): | |
x1, y1, x2, y2 = bbox | |
height, width, _ = figures[i].shape | |
cropped_img = figures[i][int(y1*height):int(y2*height),int(x1*width):int(x2*width)] | |
cur_mol = { | |
'bbox': bbox, | |
'score': score, | |
'image': cropped_img, | |
#'info': None, | |
} | |
cropped.append(cropped_img) | |
data['molecules'].append(cur_mol) | |
references.append(cur_mol) | |
return results, cropped, references | |
def convert_to_pil(image): | |
if type(image) == np.ndarray: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(image) | |
return image | |
def convert_to_cv2(image): | |
if type(image) != np.ndarray: | |
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) | |
return image | |
def replace_rgroups_in_figure(figures, results, coref_results, molscribe, batch_size=16): | |
pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)') | |
for figure, result, corefs in zip(figures, results, coref_results): | |
r_groups = [] | |
seen_r_groups = set() | |
for bbox in corefs['bboxes']: | |
if bbox['category'] == '[Idt]': | |
for text in bbox['text']: | |
res = pattern.search(text) | |
if res is None: | |
continue | |
name = res.group('name') | |
group = res.group('group') | |
if (name, group) in seen_r_groups: | |
continue | |
seen_r_groups.add((name, group)) | |
r_groups.append({name: res.group('group')}) | |
if r_groups and result['reactions']: | |
seen_r_groups = set([pair[0] for pair in seen_r_groups]) | |
orig_reaction = result['reactions'][0] | |
graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size) | |
relevant_locs = {} | |
for i, graph in enumerate(graphs): | |
to_add = [] | |
for j, atom in enumerate(graph['chartok_coords']['symbols']): | |
if atom[1:-1] in seen_r_groups: | |
to_add.append((atom[1:-1], j)) | |
relevant_locs[i] = to_add | |
for r_group in r_groups: | |
reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_group, molscribe) | |
to_add ={ | |
'reactants': reaction['reactants'][:], | |
'conditions': orig_reaction['conditions'][:], | |
'products': reaction['products'][:] | |
} | |
result['reactions'].append(to_add) | |
return results | |
def process_tables(figures, results, molscribe, batch_size=16): | |
r_group_pattern = re.compile(r'^(\w+-)?(?P<group>[\w-]+)( \(\w+\))?$') | |
for figure, result in zip(figures, results): | |
result['page'] = figure['page'] | |
if figure['table']['content'] is not None: | |
content = figure['table']['content'] | |
if len(result['reactions']) > 1: | |
print("Warning: multiple reactions detected for table") | |
elif len(result['reactions']) == 0: | |
continue | |
orig_reaction = result['reactions'][0] | |
graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size) | |
relevant_locs = find_relevant_groups(graphs, content['columns']) | |
conditions_to_extend = [] | |
for row in content['rows']: | |
r_groups = {} | |
expanded_conditions = orig_reaction['conditions'][:] | |
replaced = False | |
for col, entry in zip(content['columns'], row): | |
if col['tag'] != 'alkyl group': | |
expanded_conditions.append({ | |
'category': '[Table]', | |
'text': entry['text'], | |
'tag': col['tag'], | |
'header': col['text'], | |
}) | |
else: | |
found = r_group_pattern.match(entry['text']) | |
if found is not None: | |
r_groups[col['text']] = found.group('group') | |
replaced = True | |
reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_groups, molscribe) | |
if replaced: | |
to_add = { | |
'reactants': reaction['reactants'][:], | |
'conditions': expanded_conditions, | |
'products': reaction['products'][:] | |
} | |
result['reactions'].append(to_add) | |
else: | |
conditions_to_extend.append(expanded_conditions) | |
orig_reaction['conditions'] = [orig_reaction['conditions']] | |
orig_reaction['conditions'].extend(conditions_to_extend) | |
return results | |
def get_atoms_and_bonds(image, reaction, molscribe, batch_size=16): | |
image = convert_to_cv2(image) | |
cropped_images = [] | |
results = [] | |
for key, molecules in reaction.items(): | |
for i, elt in enumerate(molecules): | |
if type(elt) != dict or elt['category'] != '[Mol]': | |
continue | |
x1, y1, x2, y2 = elt['bbox'] | |
height, width, _ = image.shape | |
cropped_images.append(image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)]) | |
to_add = { | |
'image': cropped_images[-1], | |
'chartok_coords': { | |
'coords': [], | |
'symbols': [], | |
}, | |
'edges': [], | |
'key': (key, i) | |
} | |
results.append(to_add) | |
outputs = molscribe.predict_images(cropped_images, return_atoms_bonds=True, batch_size=batch_size) | |
for mol, result in zip(outputs, results): | |
for atom in mol['atoms']: | |
result['chartok_coords']['coords'].append((atom['x'], atom['y'])) | |
result['chartok_coords']['symbols'].append(atom['atom_symbol']) | |
result['edges'] = [[0] * len(mol['atoms']) for _ in range(len(mol['atoms']))] | |
for bond in mol['bonds']: | |
i, j = bond['endpoint_atoms'] | |
result['edges'][i][j] = BOND_TO_INT[bond['bond_type']] | |
result['edges'][j][i] = BOND_TO_INT[bond['bond_type']] | |
return results | |
def find_relevant_groups(graphs, columns): | |
results = {} | |
r_groups = set([f"[{col['text']}]" for col in columns if col['tag'] == 'alkyl group']) | |
for i, graph in enumerate(graphs): | |
to_add = [] | |
for j, atom in enumerate(graph['chartok_coords']['symbols']): | |
if atom in r_groups: | |
to_add.append((atom[1:-1], j)) | |
results[i] = to_add | |
return results | |
def get_replaced_reaction(orig_reaction, graphs, relevant_locs, mappings, molscribe): | |
graph_copy = [] | |
for graph in graphs: | |
graph_copy.append({ | |
'image': graph['image'], | |
'chartok_coords': { | |
'coords': graph['chartok_coords']['coords'][:], | |
'symbols': graph['chartok_coords']['symbols'][:], | |
}, | |
'edges': graph['edges'][:], | |
'key': graph['key'], | |
}) | |
for graph_idx, atoms in relevant_locs.items(): | |
for atom, atom_idx in atoms: | |
if atom in mappings: | |
graph_copy[graph_idx]['chartok_coords']['symbols'][atom_idx] = mappings[atom] | |
reaction_copy = {} | |
def append_copy(copy_list, entity): | |
if entity['category'] == '[Mol]': | |
copy_list.append({ | |
k1: v1 for k1, v1 in entity.items() | |
}) | |
else: | |
copy_list.append(entity) | |
for k, v in orig_reaction.items(): | |
reaction_copy[k] = [] | |
for entity in v: | |
if type(entity) == list: | |
sub_list = [] | |
for e in entity: | |
append_copy(sub_list, e) | |
reaction_copy[k].append(sub_list) | |
else: | |
append_copy(reaction_copy[k], entity) | |
for graph in graph_copy: | |
output = molscribe.convert_graph_to_output([graph], [graph['image']]) | |
molecule = reaction_copy[graph['key'][0]][graph['key'][1]] | |
molecule['smiles'] = output[0]['smiles'] | |
molecule['molfile'] = output[0]['molfile'] | |
return reaction_copy | |
def get_sites(tar, ref, ref_site = False): | |
rdDepictor.Compute2DCoords(ref) | |
rdDepictor.Compute2DCoords(tar) | |
idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(tar, ref) | |
in_template = [i[1] for i in idx_pair] | |
sites = [] | |
for i in range(tar.GetNumAtoms()): | |
if i not in in_template: | |
for j in tar.GetAtomWithIdx(i).GetNeighbors(): | |
if j.GetIdx() in in_template and j.GetIdx() not in sites: | |
if ref_site: sites.append(idx_pair[in_template.index(j.GetIdx())][0]) | |
else: sites.append(idx_pair[in_template.index(j.GetIdx())][0]) | |
return sites | |
def get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = None): | |
# returns prod_mol_to_query which is the mapping of atom indices in prod_mol to the atom indices of the molecule represented by prod_smiles | |
prod_template_intermediate = Chem.MolToSmiles(prod_mol) | |
prod_template = prod_smiles | |
for r in RGROUP_SMILES: | |
if r!='*' and r!='(*)': | |
prod_template = prod_template.replace(r, '*') | |
prod_template_intermediate = prod_template_intermediate.replace(r, '*') | |
prod_template_intermediate_mol = Chem.MolFromSmiles(prod_template_intermediate) | |
prod_template_mol = Chem.MolFromSmiles(prod_template) | |
p = Chem.AdjustQueryParameters.NoAdjustments() | |
p.makeDummiesQueries = True | |
prod_template_mol_query = Chem.AdjustQueryProperties(prod_template_mol, p) | |
prod_template_intermediate_mol_query = Chem.AdjustQueryProperties(prod_template_intermediate_mol, p) | |
rdDepictor.Compute2DCoords(prod_mol) | |
rdDepictor.Compute2DCoords(prod_template_mol_query) | |
rdDepictor.Compute2DCoords(prod_template_intermediate_mol_query) | |
idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(prod_mol, prod_template_intermediate_mol_query) | |
intermdiate_to_prod_mol = {a:b for a,b in idx_pair} | |
prod_mol_to_intermediate = {b:a for a,b in idx_pair} | |
#idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query) | |
#intermediate_to_query = {a:b for a,b in idx_pair_2} | |
#query_to_intermediate = {b:a for a,b in idx_pair_2} | |
#prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate} | |
substructs = prod_template_mol_query.GetSubstructMatches(prod_template_intermediate_mol_query, uniquify = False) | |
#idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query) | |
for substruct in substructs: | |
intermediate_to_query = {a:b for a, b in enumerate(substruct)} | |
query_to_intermediate = {intermediate_to_query[i]: i for i in intermediate_to_query} | |
prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate} | |
good_map = True | |
for i in r_sites_reversed: | |
if prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[i]).GetSymbol() not in RGROUP_SMILES: | |
good_map = False | |
if good_map: | |
break | |
return prod_mol_to_query, prod_template_mol_query | |
def clean_corefs(coref_results_dict, idx): | |
label_pattern = rf'{re.escape(idx)}[a-zA-Z]+' | |
#unclean_pattern = re.escape(idx) + r'\d(?![\d% ])' | |
toreturn = {} | |
for prod in coref_results_dict: | |
has_good_label = False | |
for parsed in coref_results_dict[prod]: | |
if re.search(label_pattern, parsed): | |
has_good_label = True | |
if not has_good_label: | |
for parsed in coref_results_dict[prod]: | |
if idx+'1' in parsed: | |
coref_results_dict[prod].append(idx+'l') | |
elif idx+'0' in parsed: | |
coref_results_dict[prod].append(idx+'o') | |
elif idx+'5' in parsed: | |
coref_results_dict[prod].append(idx+'s') | |
elif idx+'9' in parsed: | |
coref_results_dict[prod].append(idx+'g') | |
def expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe): | |
name = res.group('name') | |
group = res.group('group') | |
#print(other_prod) | |
atoms = coref_smiles_to_graphs[other_prod]['atoms'] | |
bonds = coref_smiles_to_graphs[other_prod]['bonds'] | |
#print(atoms, bonds) | |
graph = { | |
'image': None, | |
'chartok_coords': { | |
'coords': [], | |
'symbols': [], | |
}, | |
'edges': [], | |
'key': None | |
} | |
for atom in atoms: | |
graph['chartok_coords']['coords'].append((atom['x'], atom['y'])) | |
graph['chartok_coords']['symbols'].append(atom['atom_symbol']) | |
graph['edges'] = [[0] * len(atoms) for _ in range(len(atoms))] | |
for bond in bonds: | |
i, j = bond['endpoint_atoms'] | |
graph['edges'][i][j] = BOND_TO_INT[bond['bond_type']] | |
graph['edges'][j][i] = BOND_TO_INT[bond['bond_type']] | |
for i, symbol in enumerate(graph['chartok_coords']['symbols']): | |
if symbol[1:-1] == name: | |
graph['chartok_coords']['symbols'][i] = group | |
#print(graph) | |
o = molscribe.convert_graph_to_output([graph], [graph['image']]) | |
return Chem.MolFromSmiles(o[0]['smiles']) | |
def get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn): | |
prod_template_mol_query, r_sites_reversed_new, h_sites, num_r_groups = query | |
# we get the substruct matches. note that we set uniquify to false since the order matters for our method | |
substructs = other_prod_mol.GetSubstructMatches(prod_template_mol_query, uniquify = False) | |
#for r in r_sites_reversed: | |
# print(prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[r]).GetSymbol()) | |
# for each substruct we create the mapping of the substruct onto the other_mol | |
# delete all the molecules in other_mol correspond to the substruct | |
# and check if they number of mol frags is equal to number of r groups | |
# we do this to make sure we have the correct substruct | |
if len(substructs) >= 1: | |
for substruct in substructs: | |
query_to_other = {a:b for a,b in enumerate(substruct)} | |
other_to_query = {query_to_other[i]:i for i in query_to_other} | |
editable = Chem.EditableMol(other_prod_mol) | |
r_site_correspondence = [] | |
for r in r_sites_reversed_new: | |
#get its id in substruct | |
substruct_id = query_to_other[r] | |
r_site_correspondence.append([substruct_id, r_sites_reversed_new[r]]) | |
for idx in tuple(sorted(substruct, reverse = True)): | |
if idx not in [query_to_other[i] for i in r_sites_reversed_new]: | |
editable.RemoveAtom(idx) | |
for r_site in r_site_correspondence: | |
if idx < r_site[0]: | |
r_site[0]-=1 | |
other_prod_removed = editable.GetMol() | |
if len(Chem.GetMolFrags(other_prod_removed, asMols = False)) == num_r_groups: | |
break | |
# need to compute the sites at which correspond to each r_site_reversed | |
r_site_correspondence.sort(key = lambda x: x[0]) | |
f = [] | |
ff = [] | |
frags = Chem.GetMolFrags(other_prod_removed, asMols = True, frags = f, fragsMolAtomMapping = ff) | |
# r_group_information maps r group name --> the fragment/molcule corresponding to the r group and the atom index it should be connected at | |
r_group_information = {} | |
#tosubtract = 0 | |
for idx, r_site in enumerate(r_site_correspondence): | |
r_group_information[r_site[1]]= (frags[f[r_site[0]]], ff[f[r_site[0]]].index(r_site[0])) | |
#tosubtract += len(ff[idx]) | |
for r_site in h_sites: | |
r_group_information[r_site] = (Chem.MolFromSmiles('[H]'), 0) | |
# now we modify all of the reactants according to the R groups we have found | |
# for every reactant we disconnect its r group symbol, and connect it to the r group | |
modify_reactants = copy.deepcopy(reactant_mols) | |
modified_reactant_smiles = [] | |
for reactant_idx in reactant_information: | |
if len(reactant_information[reactant_idx]) == 0: | |
modified_reactant_smiles.append(Chem.MolToSmiles(modify_reactants[reactant_idx])) | |
else: | |
combined = reactant_mols[reactant_idx] | |
if combined.GetNumAtoms() == 1: | |
r_group, _, _ = reactant_information[reactant_idx][0] | |
modified_reactant_smiles.append(Chem.MolToSmiles(r_group_information[r_group][0])) | |
else: | |
for r_group, r_index, connect_index in reactant_information[reactant_idx]: | |
combined = Chem.CombineMols(combined, r_group_information[r_group][0]) | |
editable = Chem.EditableMol(combined) | |
atomIdxAdder = reactant_mols[reactant_idx].GetNumAtoms() | |
for r_group, r_index, connect_index in reactant_information[reactant_idx]: | |
Chem.EditableMol.RemoveBond(editable, r_index, connect_index) | |
Chem.EditableMol.AddBond(editable, connect_index, atomIdxAdder + r_group_information[r_group][1], Chem.BondType.SINGLE) | |
atomIdxAdder += r_group_information[r_group][0].GetNumAtoms() | |
r_indices = [i[1] for i in reactant_information[reactant_idx]] | |
r_indices.sort(reverse = True) | |
for r_index in r_indices: | |
Chem.EditableMol.RemoveAtom(editable, r_index) | |
modified_reactant_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(editable.GetMol())))) | |
toreturn.append((modified_reactant_smiles, [Chem.MolToSmiles(other_prod_mol)], parsed)) | |
return True | |
else: | |
return False | |
def query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups): | |
subsets = generate_subsets(num_r_groups) | |
toreturn = [] | |
for subset in subsets: | |
r_sites_list = [[i, r_sites_reversed_new[i]] for i in r_sites_reversed_new] | |
r_sites_list.sort(key = lambda x: x[0]) | |
to_edit = Chem.EditableMol(prod_template_mol_query) | |
for entry in subset: | |
pos = r_sites_list[entry][0] | |
Chem.EditableMol.RemoveBond(to_edit, r_sites_list[entry][0], prod_template_mol_query.GetAtomWithIdx(r_sites_list[entry][0]).GetNeighbors()[0].GetIdx()) | |
for entry in subset: | |
pos = r_sites_list[entry][0] | |
Chem.EditableMol.RemoveAtom(to_edit, pos) | |
edited = to_edit.GetMol() | |
for entry in subset: | |
for i in range(entry + 1, num_r_groups): | |
r_sites_list[i][0]-=1 | |
new_r_sites = {} | |
new_h_sites = set() | |
for i in range(num_r_groups): | |
if i not in subset: | |
new_r_sites[r_sites_list[i][0]] = r_sites_list[i][1] | |
else: | |
new_h_sites.add(r_sites_list[i][1]) | |
toreturn.append((edited, new_r_sites, new_h_sites, num_r_groups - len(subset))) | |
return toreturn | |
def generate_subsets(n): | |
def backtrack(start, subset): | |
result.append(subset[:]) | |
for i in range(start, -1, -1): # Iterate in reverse order | |
subset.append(i) | |
backtrack(i - 1, subset) | |
subset.pop() | |
result = [] | |
backtrack(n - 1, []) | |
return sorted(result, key=lambda x: (-len(x), x), reverse=True) | |
def backout(results, coref_results, molscribe): | |
toreturn = [] | |
if not results or not results[0]['reactions'] or not coref_results: | |
return toreturn | |
try: | |
reactants = results[0]['reactions'][0]['reactants'] | |
products = [i['smiles'] for i in results[0]['reactions'][0]['products']] | |
coref_results_dict = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[1]]['text'] for coref in coref_results[0]['corefs']} | |
coref_smiles_to_graphs = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[0]] for coref in coref_results[0]['corefs']} | |
if len(products) == 1: | |
if products[0] not in coref_results_dict: | |
print("Warning: No Label Parsed") | |
return | |
product_labels = coref_results_dict[products[0]] | |
prod = products[0] | |
label_idx = product_labels[0] | |
''' | |
if len(product_labels) == 1: | |
# get the coreference label of the product molecule | |
label_idx = product_labels[0] | |
else: | |
print("Warning: Malformed Label Parsed.") | |
return | |
''' | |
else: | |
print("Warning: More than one product detected") | |
return | |
# format the regular expression for labels that correspond to the product label | |
numbers = re.findall(r'\d+', label_idx) | |
label_idx = numbers[0] if len(numbers) > 0 else "" | |
label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+' | |
prod_smiles = prod | |
prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile']) | |
# identify the atom indices of the R groups in the product tempalte | |
h_counter = 0 | |
r_sites = {} | |
for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']): | |
sym = atom['atom_symbol'] | |
if sym == '[H]': | |
h_counter += 1 | |
if sym[0] == '[': | |
sym = sym[1:-1] | |
if sym[0] == 'R' and sym[1:].isdigit(): | |
sym = sym[1:]+"*" | |
sym = f'[{sym}]' | |
if sym in RGROUP_SYMBOLS: | |
if sym not in r_sites: | |
r_sites[sym] = [idx-h_counter] | |
else: | |
r_sites[sym].append(idx-h_counter) | |
r_sites_reversed = {} | |
for sym in r_sites: | |
for pos in r_sites[sym]: | |
r_sites_reversed[pos] = sym | |
num_r_groups = len(r_sites_reversed) | |
#prepare the product template and get the associated mapping | |
prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed) | |
reactant_mols = [] | |
#--------------process the reactants----------------- | |
reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...] | |
for idx, reactant in enumerate(reactants): | |
reactant_information[idx] = [] | |
reactant_mols.append(Chem.MolFromSmiles(reactant['smiles'])) | |
has_r = False | |
r_sites_reactant = {} | |
h_counter = 0 | |
for a_idx, atom in enumerate(reactant['atoms']): | |
#go through all atoms and check if they are an R group, if so add it to reactant information | |
sym = atom['atom_symbol'] | |
if sym == '[H]': | |
h_counter += 1 | |
if sym[0] == '[': | |
sym = sym[1:-1] | |
if sym[0] == 'R' and sym[1:].isdigit(): | |
sym = sym[1:]+"*" | |
sym = f'[{sym}]' | |
if sym in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append([sym, -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant[sym] = a_idx-h_counter | |
elif sym == '[1*]' and '[7*]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[7*]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[7*]'] = a_idx-h_counter | |
elif sym == '[7*]' and '[1*]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[1*]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[1*]'] = a_idx-h_counter | |
elif sym == '[1*]' and '[Rf]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[Rf]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[Rf]'] = a_idx-h_counter | |
elif sym == '[Rf]' and '[1*]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[1*]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[1*]'] = a_idx-h_counter | |
r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant} | |
# if the reactant had r groups, we had to use the molecule generated from the MolBlock. | |
# but the molblock may have unexpanded elemeents that are not R groups | |
# so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles | |
# and adjust the indices of the r groups accordingly | |
if has_r: | |
#get the mapping | |
reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant) | |
#make the adjustment | |
for info in reactant_information[idx]: | |
info[1] = reactant_mol_to_query[info[1]] | |
info[2] = reactant_mol_to_query[info[2]] | |
reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles']) | |
#go through all the molecules in the coreference | |
clean_corefs(coref_results_dict, label_idx) | |
for other_prod in coref_results_dict: | |
#check if they match the product label regex | |
found_good_label = False | |
for parsed in coref_results_dict[other_prod]: | |
if re.search(label_pattern, parsed) and not found_good_label: | |
found_good_label = True | |
other_prod_mol = Chem.MolFromSmiles(other_prod) | |
if other_prod != prod_smiles and other_prod_mol is not None: | |
#check if there are R groups to be resolved in the target product | |
all_other_prod_mols = [] | |
r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)') | |
for parsed_labels in coref_results_dict[other_prod]: | |
res = r_group_sub_pattern.search(parsed_labels) | |
if res is not None: | |
all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels)) | |
if len(all_other_prod_mols) == 0: | |
if other_prod_mol is not None: | |
all_other_prod_mols.append((other_prod_mol, parsed)) | |
for other_prod_mol, parsed in all_other_prod_mols: | |
other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True) | |
for other_prod_frag in other_prod_frags: | |
substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False) | |
if len(substructs)>0: | |
other_prod_mol = other_prod_frag | |
break | |
r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed} | |
queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups) | |
matched = False | |
for query in queries: | |
if not matched: | |
try: | |
matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn) | |
except: | |
pass | |
except: | |
pass | |
return toreturn | |
def backout_without_coref(results, coref_results, coref_results_dict, coref_smiles_to_graphs, molscribe): | |
toreturn = [] | |
if not results or not results[0]['reactions'] or not coref_results: | |
return toreturn | |
try: | |
reactants = results[0]['reactions'][0]['reactants'] | |
products = [i['smiles'] for i in results[0]['reactions'][0]['products']] | |
coref_results_dict = coref_results_dict | |
coref_smiles_to_graphs = coref_smiles_to_graphs | |
if len(products) == 1: | |
if products[0] not in coref_results_dict: | |
print("Warning: No Label Parsed") | |
return | |
product_labels = coref_results_dict[products[0]] | |
prod = products[0] | |
label_idx = product_labels[0] | |
''' | |
if len(product_labels) == 1: | |
# get the coreference label of the product molecule | |
label_idx = product_labels[0] | |
else: | |
print("Warning: Malformed Label Parsed.") | |
return | |
''' | |
else: | |
print("Warning: More than one product detected") | |
return | |
# format the regular expression for labels that correspond to the product label | |
numbers = re.findall(r'\d+', label_idx) | |
label_idx = numbers[0] if len(numbers) > 0 else "" | |
label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+' | |
prod_smiles = prod | |
prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile']) | |
# identify the atom indices of the R groups in the product tempalte | |
h_counter = 0 | |
r_sites = {} | |
for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']): | |
sym = atom['atom_symbol'] | |
if sym == '[H]': | |
h_counter += 1 | |
if sym[0] == '[': | |
sym = sym[1:-1] | |
if sym[0] == 'R' and sym[1:].isdigit(): | |
sym = sym[1:]+"*" | |
sym = f'[{sym}]' | |
if sym in RGROUP_SYMBOLS: | |
if sym not in r_sites: | |
r_sites[sym] = [idx-h_counter] | |
else: | |
r_sites[sym].append(idx-h_counter) | |
r_sites_reversed = {} | |
for sym in r_sites: | |
for pos in r_sites[sym]: | |
r_sites_reversed[pos] = sym | |
num_r_groups = len(r_sites_reversed) | |
#prepare the product template and get the associated mapping | |
prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed) | |
reactant_mols = [] | |
#--------------process the reactants----------------- | |
reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...] | |
for idx, reactant in enumerate(reactants): | |
reactant_information[idx] = [] | |
reactant_mols.append(Chem.MolFromSmiles(reactant['smiles'])) | |
has_r = False | |
r_sites_reactant = {} | |
h_counter = 0 | |
for a_idx, atom in enumerate(reactant['atoms']): | |
#go through all atoms and check if they are an R group, if so add it to reactant information | |
sym = atom['atom_symbol'] | |
if sym == '[H]': | |
h_counter += 1 | |
if sym[0] == '[': | |
sym = sym[1:-1] | |
if sym[0] == 'R' and sym[1:].isdigit(): | |
sym = sym[1:]+"*" | |
sym = f'[{sym}]' | |
if sym in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append([sym, -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant[sym] = a_idx-h_counter | |
elif sym == '[1*]' and '[7*]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[7*]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[7*]'] = a_idx-h_counter | |
elif sym == '[7*]' and '[1*]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[1*]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[1*]'] = a_idx-h_counter | |
elif sym == '[1*]' and '[Rf]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[Rf]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[Rf]'] = a_idx-h_counter | |
elif sym == '[Rf]' and '[1*]' in r_sites: | |
if reactant_mols[-1].GetNumAtoms()==1: | |
reactant_information[idx].append(['[1*]', -1, -1]) | |
else: | |
has_r = True | |
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile']) | |
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]]) | |
r_sites_reactant['[1*]'] = a_idx-h_counter | |
r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant} | |
# if the reactant had r groups, we had to use the molecule generated from the MolBlock. | |
# but the molblock may have unexpanded elemeents that are not R groups | |
# so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles | |
# and adjust the indices of the r groups accordingly | |
if has_r: | |
#get the mapping | |
reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant) | |
#make the adjustment | |
for info in reactant_information[idx]: | |
info[1] = reactant_mol_to_query[info[1]] | |
info[2] = reactant_mol_to_query[info[2]] | |
reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles']) | |
#go through all the molecules in the coreference | |
clean_corefs(coref_results_dict, label_idx) | |
for other_prod in coref_results_dict: | |
#check if they match the product label regex | |
found_good_label = False | |
for parsed in coref_results_dict[other_prod]: | |
if re.search(label_pattern, parsed) and not found_good_label: | |
found_good_label = True | |
other_prod_mol = Chem.MolFromSmiles(other_prod) | |
if other_prod != prod_smiles and other_prod_mol is not None: | |
#check if there are R groups to be resolved in the target product | |
all_other_prod_mols = [] | |
r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)') | |
for parsed_labels in coref_results_dict[other_prod]: | |
res = r_group_sub_pattern.search(parsed_labels) | |
if res is not None: | |
all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels)) | |
if len(all_other_prod_mols) == 0: | |
if other_prod_mol is not None: | |
all_other_prod_mols.append((other_prod_mol, parsed)) | |
for other_prod_mol, parsed in all_other_prod_mols: | |
other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True) | |
for other_prod_frag in other_prod_frags: | |
substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False) | |
if len(substructs)>0: | |
other_prod_mol = other_prod_frag | |
break | |
r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed} | |
queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups) | |
matched = False | |
for query in queries: | |
if not matched: | |
try: | |
matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn) | |
except: | |
pass | |
except: | |
pass | |
return toreturn | |
def associate_corefs(results, results_coref): | |
coref_smiles = {} | |
idx_pattern = r'\b\d+[a-zA-Z]{0,2}\b' | |
for result_coref in results_coref: | |
bboxes, corefs = result_coref['bboxes'], result_coref['corefs'] | |
for coref in corefs: | |
mol, idt = coref[0], coref[1] | |
if len(bboxes[idt]['text']) > 0: | |
for text in bboxes[idt]['text']: | |
matches = re.findall(idx_pattern, text) | |
for match in matches: | |
coref_smiles[match] = bboxes[mol]['smiles'] | |
for page in results: | |
for reactions in page['reactions']: | |
for reaction in reactions['reactions']: | |
if 'Reactants' in reaction: | |
if isinstance(reaction['Reactants'], tuple): | |
if reaction['Reactants'][0] in coref_smiles: | |
reaction['Reactants'] = (f'{reaction["Reactants"][0]} ({coref_smiles[reaction["Reactants"][0]]})', reaction['Reactants'][1], reaction['Reactants'][2]) | |
else: | |
for idx, compound in enumerate(reaction['Reactants']): | |
if compound[0] in coref_smiles: | |
reaction['Reactants'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2]) | |
if 'Product' in reaction: | |
if isinstance(reaction['Product'], tuple): | |
if reaction['Product'][0] in coref_smiles: | |
reaction['Product'] = (f'{reaction["Product"][0]} ({coref_smiles[reaction["Product"][0]]})', reaction['Product'][1], reaction['Product'][2]) | |
else: | |
for idx, compound in enumerate(reaction['Product']): | |
if compound[0] in coref_smiles: | |
reaction['Product'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2]) | |
return results | |
def expand_reactions_with_backout(initial_results, results_coref, molscribe): | |
idx_pattern = r'^\d+[a-zA-Z]{0,2}$' | |
for reactions, result_coref in zip(initial_results, results_coref): | |
if not reactions['reactions']: | |
continue | |
try: | |
backout_results = backout([reactions], [result_coref], molscribe) | |
except Exception: | |
continue | |
conditions = reactions['reactions'][0]['conditions'] | |
idt_to_smiles = {} | |
if not backout_results: | |
continue | |
for reactants, products, idt in backout_results: | |
reactions['reactions'].append({ | |
'reactants': [{'category': '[Mol]', 'molfile': None, 'smiles': reactant} for reactant in reactants], | |
'conditions': conditions[:], | |
'products': [{'category': '[Mol]', 'molfile': None, 'smiles': product} for product in products] | |
}) | |
return initial_results | |