RxnIM / rxn /reaction /tokenizer.py
CYF200127's picture
Upload 116 files
5e9bd47 verified
raw
history blame
17.5 kB
import json
import copy
import random
import numpy as np
PAD = '<pad>'
SOS = '<sos>'
EOS = '<eos>'
UNK = '<unk>'
MASK = '<mask>'
Rxn = '[Rxn]' # Reaction
Rct = '[Rct]' # Reactant
Prd = '[Prd]' # Product
Cnd = '[Cnd]' # Condition
Idt = '[Idt]' # Identifier
Mol = '[Mol]' # Molecule
Txt = '[Txt]' # Text
Sup = '[Sup]' # Supplement
Noise = '[Nos]'
class ReactionTokenizer(object):
def __init__(self, input_size=100, sep_xy=True, pix2seq=False):
self.stoi = {}
self.itos = {}
self.pix2seq = pix2seq
self.maxx = input_size # height
self.maxy = input_size # width
self.sep_xy = sep_xy
self.special_tokens = [PAD, SOS, EOS, UNK, MASK]
self.tokens = [Rxn, Rct, Prd, Cnd, Idt, Mol, Txt, Sup, Noise]
self.fit_tokens(self.tokens)
def __len__(self):
if self.pix2seq:
return 2094
if self.sep_xy:
return self.offset + self.maxx + self.maxy
else:
return self.offset + max(self.maxx, self.maxy)
@property
def max_len(self):
return 256
@property
def PAD_ID(self):
return self.stoi[PAD]
@property
def SOS_ID(self):
return self.stoi[SOS]
@property
def EOS_ID(self):
return self.stoi[EOS]
@property
def UNK_ID(self):
return self.stoi[UNK]
@property
def NOISE_ID(self):
return self.stoi[Noise]
@property
def offset(self):
return 0 if self.pix2seq else len(self.stoi)
@property
def output_constraint(self):
return True
def fit_tokens(self, tokens):
vocab = self.special_tokens + tokens
if self.pix2seq:
for i, s in enumerate(vocab):
self.stoi[s] = 2001 + i
self.stoi[EOS] = len(self) - 2
# self.stoi[Noise] = len(self) - 1
else:
for i, s in enumerate(vocab):
self.stoi[s] = i
self.itos = {item[1]: item[0] for item in self.stoi.items()}
self.bbox_category_to_token = {1: Mol, 2: Txt, 3: Idt, 4: Sup}
self.token_to_bbox_category = {item[1]: item[0] for item in self.bbox_category_to_token.items()}
def is_x(self, x):
return 0 <= x - self.offset < self.maxx
def is_y(self, y):
if self.sep_xy:
return self.maxx <= y - self.offset < self.maxx + self.maxy
return 0 <= y - self.offset < self.maxy
def x_to_id(self, x):
if x < -0.001 or x > 1.001:
print(x)
else:
x = min(max(x, 0), 1)
assert 0 <= x <= 1
return self.offset + round(x * (self.maxx - 1))
def y_to_id(self, y):
if y < -0.001 or y > 1.001:
print(y)
else:
y = min(max(y, 0), 1)
assert 0 <= y <= 1
if self.sep_xy:
return self.offset + self.maxx + round(y * (self.maxy - 1))
return self.offset + round(y * (self.maxy - 1))
def id_to_x(self, id, scale=1):
if not self.is_x(id):
return -1
return (id - self.offset) / (self.maxx - 1) / scale
def id_to_y(self, id, scale=1):
if not self.is_y(id):
return -1
if self.sep_xy:
return (id - self.offset - self.maxx) / (self.maxy - 1) * scale
return (id - self.offset) / (self.maxy - 1) / scale
def update_state(self, state, idx):
if state is None:
new_state = (Rxn, 'e')
else:
if state[1] == 'x1':
new_state = (state[0], 'y1')
elif state[1] == 'y1':
new_state = (state[0], 'x2')
elif state[1] == 'x2':
new_state = (state[0], 'y2')
elif state[1] == 'y2':
new_state = (state[0], 'c')
elif state[1] == 'c':
if self.is_x(idx):
new_state = (state[0], 'x1')
else:
new_state = (state[0], 'e')
else:
if state[0] == Rct:
if self.is_x(idx):
new_state = (Cnd, 'x1')
else:
new_state = (Cnd, 'e')
elif state[0] == Cnd:
new_state = (Prd, 'x1')
elif state[0] == Prd:
new_state = (Rxn, 'e')
elif state[0] == Rxn:
if self.is_x(idx):
new_state = (Rct, 'x1')
else:
new_state = (EOS, 'e')
else:
new_state = (EOS, 'e')
return new_state
def output_mask(self, state):
# mask: True means forbidden
mask = np.array([True] * len(self))
if state[1] in ['y1', 'c']:
mask[self.offset:self.offset+self.maxx] = False
if state[1] in ['x1', 'x2']:
if self.sep_xy:
mask[self.offset+self.maxx:self.offset+self.maxx+self.maxy] = False
else:
mask[self.offset:self.offset+self.maxy] = False
if state[1] == 'y2':
for token in [Idt, Mol, Txt, Sup]:
mask[self.stoi[token]] = False
if state[1] == 'c':
mask[self.stoi[state[0]]] = False
if state[1] == 'e':
if state[0] in [Rct, Cnd, Rxn]:
mask[self.offset:self.offset + self.maxx] = False
if state[0] == Rct:
mask[self.stoi[Cnd]] = False
if state[0] == Prd:
mask[self.stoi[Rxn]] = False
mask[self.stoi[Noise]] = False
if state[0] in [Rxn, EOS]:
mask[self.EOS_ID] = False
return mask
def update_states_and_masks(self, states, ids):
new_states = [self.update_state(state, idx) for state, idx in zip(states, ids)]
masks = np.array([self.output_mask(state) for state in new_states])
return new_states, masks
def bbox_to_sequence(self, bbox, category):
sequence = []
x1, y1, x2, y2 = bbox
if x1 >= x2 or y1 >= y2:
return []
sequence.append(self.x_to_id(x1))
sequence.append(self.y_to_id(y1))
sequence.append(self.x_to_id(x2))
sequence.append(self.y_to_id(y2))
if category in self.bbox_category_to_token:
sequence.append(self.stoi[self.bbox_category_to_token[category]])
else:
sequence.append(self.stoi[Noise])
return sequence
def sequence_to_bbox(self, sequence, scale=[1, 1]):
if len(sequence) < 5:
return None
x1, y1 = self.id_to_x(sequence[0], scale[0]), self.id_to_y(sequence[1], scale[1])
x2, y2 = self.id_to_x(sequence[2], scale[0]), self.id_to_y(sequence[3], scale[1])
if x1 == -1 or y1 == -1 or x2 == -1 or y2 == -1 or x1 >= x2 or y1 >= y2 or sequence[4] not in self.itos:
return None
category = self.itos[sequence[4]]
if category not in [Mol, Txt, Idt, Sup]:
return None
return {'category': category, 'bbox': (x1, y1, x2, y2), 'category_id': self.token_to_bbox_category[category]}
def perturb_reaction(self, reaction, boxes):
reaction = copy.deepcopy(reaction)
options = []
options.append(0) # Option 0: add
if not(len(reaction['reactants']) == 1 and len(reaction['conditions']) == 0 and len(reaction['products']) == 1):
options.append(1) # Option 1: delete
options.append(2) # Option 2: move
choice = random.choice(options)
if choice == 0:
key = random.choice(['reactants', 'conditions', 'products'])
# TODO: insert to a random position
# We simply add a random box, which may be a duplicate box in this reaction
reaction[key].append(random.randrange(len(boxes)))
if choice == 1 or choice == 2:
options = []
for key, val in [('reactants', 1), ('conditions', 0), ('products', 1)]:
if len(reaction[key]) > val:
options.append(key)
key = random.choice(options)
idx = random.randrange(len(reaction[key]))
del_box = reaction[key][idx]
reaction[key] = reaction[key][:idx] + reaction[key][idx+1:]
if choice == 2:
options = ['reactants', 'conditions', 'products']
options.remove(key)
newkey = random.choice(options)
reaction[newkey].append(del_box)
return reaction
def augment_reaction(self, reactions, data):
area, boxes, labels = data['area'], data['boxes'], data['labels']
nonempty_boxes = [i for i in range(len(area)) if area[i] > 0]
if len(nonempty_boxes) == 0:
return None
if len(reactions) == 0 or random.randrange(100) < 20:
num_reactants = random.randint(1, 3)
num_conditions = random.randint(0, 3)
num_products = random.randint(1, 3)
reaction = {
'reactants': random.choices(nonempty_boxes, k=num_reactants),
'conditions': random.choices(nonempty_boxes, k=num_conditions),
'products': random.choices(nonempty_boxes, k=num_products)
}
else:
assert len(reactions) > 0
reaction = self.perturb_reaction(random.choice(reactions), boxes)
return reaction
def reaction_to_sequence(self, reaction, data, shuffle_bbox=False):
reaction = copy.deepcopy(reaction)
area, boxes, labels = data['area'], data['boxes'], data['labels']
# If reactants or products are empty (because of image cropping), skip the reaction
if all([area[i] == 0 for i in reaction['reactants']]) or all([area[i] == 0 for i in reaction['products']]):
return []
if shuffle_bbox:
random.shuffle(reaction['reactants'])
random.shuffle(reaction['conditions'])
random.shuffle(reaction['products'])
sequence = []
for idx in reaction['reactants']:
if area[idx] == 0:
continue
sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item())
sequence.append(self.stoi[Rct])
for idx in reaction['conditions']:
if area[idx] == 0:
continue
sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item())
sequence.append(self.stoi[Cnd])
for idx in reaction['products']:
if area[idx] == 0:
continue
sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item())
sequence.append(self.stoi[Prd])
sequence.append(self.stoi[Rxn])
return sequence
def data_to_sequence(self, data, rand_order=False, shuffle_bbox=False, add_noise=False, mix_noise=False):
sequence = [self.SOS_ID]
sequence_out = [self.SOS_ID]
reactions = copy.deepcopy(data['reactions'])
reactions_seqs = []
for reaction in reactions:
seq = self.reaction_to_sequence(reaction, data, shuffle_bbox=shuffle_bbox)
reactions_seqs.append([seq, seq])
noise_seqs = []
if add_noise:
total_len = sum(len(seq) for seq, seq_out in reactions_seqs)
while total_len < self.max_len:
reaction = self.augment_reaction(reactions, data)
if reaction is None:
break
seq = self.reaction_to_sequence(reaction, data)
if len(seq) == 0:
continue
if mix_noise:
seq[-1] = self.NOISE_ID
seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID]
else:
seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID]
noise_seqs.append([seq, seq_out])
total_len += len(seq)
if rand_order:
random.shuffle(reactions_seqs)
reactions_seqs += noise_seqs
if mix_noise:
random.shuffle(reactions_seqs)
for seq, seq_out in reactions_seqs:
sequence += seq
sequence_out += seq_out
sequence.append(self.EOS_ID)
sequence_out.append(self.EOS_ID)
return sequence, sequence_out
def sequence_to_data(self, sequence, scores=None, scale=None):
reactions = []
i = 0
cur_reaction = {'reactants': [], 'conditions': [], 'products': []}
flag = 'reactants'
if len(sequence) > 0 and sequence[0] == self.SOS_ID:
i += 1
while i < len(sequence):
if sequence[i] == self.EOS_ID:
break
if sequence[i] in self.itos:
if self.itos[sequence[i]] in [Rxn, Noise]:
cur_reaction['label'] = self.itos[sequence[i]]
if len(cur_reaction['reactants']) > 0 and len(cur_reaction['products']) > 0:
reactions.append(cur_reaction)
cur_reaction = {'reactants': [], 'conditions': [], 'products': []}
flag = 'reactants'
elif self.itos[sequence[i]] == Rct:
flag = 'conditions'
elif self.itos[sequence[i]] == Cnd:
flag = 'products'
elif self.itos[sequence[i]] == Prd:
flag = None
elif i+5 <= len(sequence) and flag is not None:
bbox = self.sequence_to_bbox(sequence[i:i+5], scale)
if bbox is not None:
cur_reaction[flag].append(bbox)
i += 4
i += 1
return reactions
def sequence_to_tokens(self, sequence):
return [self.itos[x] if x in self.itos else x for x in sequence]
class BboxTokenizer(ReactionTokenizer):
def __init__(self, input_size=100, sep_xy=True, pix2seq=False):
super(BboxTokenizer, self).__init__(input_size, sep_xy, pix2seq)
@property
def max_len(self):
return 500
@property
def output_constraint(self):
return False
def random_category(self):
return random.choice(list(self.bbox_category_to_token.keys()))
# return random.choice([random.choice(list(self.bbox_category_to_token.keys())), self.NOISE_ID])
def random_bbox(self):
_x1, _y1, _x2, _y2 = random.random(), random.random(), random.random(), random.random()
x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2)
category = self.random_category()
return [x1, y1, x2, y2], category
def jitter_bbox(self, bbox, ratio=0.2):
x1, y1, x2, y2 = bbox
w, h = x2 - x1, y2 - y1
_x1 = x1 + random.uniform(-w*ratio, w*ratio)
_y1 = y1 + random.uniform(-h*ratio, h*ratio)
_x2 = x2 + random.uniform(-w * ratio, w * ratio)
_y2 = y2 + random.uniform(-h * ratio, h * ratio)
x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2)
category = self.random_category()
return np.clip([x1, y1, x2, y2], 0, 1), category
def augment_box(self, bboxes):
if len(bboxes) == 0:
return self.random_bbox()
if random.random() < 0.5:
return self.random_bbox()
else:
return self.jitter_bbox(random.choice(bboxes))
def data_to_sequence(self, data, add_noise=False, rand_order=False):
sequence = [self.SOS_ID]
sequence_out = [self.SOS_ID]
if rand_order:
perm = np.random.permutation(len(data['boxes']))
boxes = data['boxes'][perm].tolist()
labels = data['labels'][perm].tolist()
else:
boxes = data['boxes'].tolist()
labels = data['labels'].tolist()
for bbox, category in zip(boxes, labels):
seq = self.bbox_to_sequence(bbox, category)
sequence += seq
# sequence[-1] = self.random_category()
sequence_out += seq
if add_noise:
while len(sequence) < self.max_len:
bbox, category = self.augment_box(boxes)
sequence += self.bbox_to_sequence(bbox, category)
sequence_out += [self.PAD_ID] * 4 + [self.NOISE_ID]
sequence.append(self.EOS_ID)
sequence_out.append(self.EOS_ID)
return sequence, sequence_out
def sequence_to_data(self, sequence, scores=None, scale=None):
bboxes = []
i = 0
if len(sequence) > 0 and sequence[0] == self.SOS_ID:
i += 1
while i < len(sequence):
if sequence[i] == self.EOS_ID:
break
if i+4 < len(sequence):
bbox = self.sequence_to_bbox(sequence[i:i+5], scale)
if bbox is not None:
if scores is not None:
bbox['score'] = scores[i + 4]
bboxes.append(bbox)
i += 4
i += 1
return bboxes
def get_tokenizer(args):
tokenizer = {}
if args.pix2seq:
args.coord_bins = 2000
args.sep_xy = False
format = args.format
if format == 'reaction':
tokenizer[format] = ReactionTokenizer(args.coord_bins, args.sep_xy, args.pix2seq)
if format == 'bbox':
tokenizer[format] = BboxTokenizer(args.coord_bins, args.sep_xy, args.pix2seq)
return tokenizer