|
import json |
|
import copy |
|
import random |
|
import numpy as np |
|
|
|
|
|
PAD = '<pad>' |
|
SOS = '<sos>' |
|
EOS = '<eos>' |
|
UNK = '<unk>' |
|
MASK = '<mask>' |
|
|
|
Rxn = '[Rxn]' |
|
Rct = '[Rct]' |
|
Prd = '[Prd]' |
|
Cnd = '[Cnd]' |
|
Idt = '[Idt]' |
|
Mol = '[Mol]' |
|
Txt = '[Txt]' |
|
Sup = '[Sup]' |
|
Noise = '[Nos]' |
|
|
|
|
|
class ReactionTokenizer(object): |
|
|
|
def __init__(self, input_size=100, sep_xy=True, pix2seq=False): |
|
self.stoi = {} |
|
self.itos = {} |
|
self.pix2seq = pix2seq |
|
self.maxx = input_size |
|
self.maxy = input_size |
|
self.sep_xy = sep_xy |
|
self.special_tokens = [PAD, SOS, EOS, UNK, MASK] |
|
self.tokens = [Rxn, Rct, Prd, Cnd, Idt, Mol, Txt, Sup, Noise] |
|
self.fit_tokens(self.tokens) |
|
|
|
def __len__(self): |
|
if self.pix2seq: |
|
return 2094 |
|
if self.sep_xy: |
|
return self.offset + self.maxx + self.maxy |
|
else: |
|
return self.offset + max(self.maxx, self.maxy) |
|
|
|
@property |
|
def max_len(self): |
|
return 256 |
|
|
|
@property |
|
def PAD_ID(self): |
|
return self.stoi[PAD] |
|
|
|
@property |
|
def SOS_ID(self): |
|
return self.stoi[SOS] |
|
|
|
@property |
|
def EOS_ID(self): |
|
return self.stoi[EOS] |
|
|
|
@property |
|
def UNK_ID(self): |
|
return self.stoi[UNK] |
|
|
|
@property |
|
def NOISE_ID(self): |
|
return self.stoi[Noise] |
|
|
|
@property |
|
def offset(self): |
|
return 0 if self.pix2seq else len(self.stoi) |
|
|
|
@property |
|
def output_constraint(self): |
|
return True |
|
|
|
def fit_tokens(self, tokens): |
|
vocab = self.special_tokens + tokens |
|
if self.pix2seq: |
|
for i, s in enumerate(vocab): |
|
self.stoi[s] = 2001 + i |
|
self.stoi[EOS] = len(self) - 2 |
|
|
|
else: |
|
for i, s in enumerate(vocab): |
|
self.stoi[s] = i |
|
self.itos = {item[1]: item[0] for item in self.stoi.items()} |
|
self.bbox_category_to_token = {1: Mol, 2: Txt, 3: Idt, 4: Sup} |
|
self.token_to_bbox_category = {item[1]: item[0] for item in self.bbox_category_to_token.items()} |
|
|
|
def is_x(self, x): |
|
return 0 <= x - self.offset < self.maxx |
|
|
|
def is_y(self, y): |
|
if self.sep_xy: |
|
return self.maxx <= y - self.offset < self.maxx + self.maxy |
|
return 0 <= y - self.offset < self.maxy |
|
|
|
def x_to_id(self, x): |
|
if x < -0.001 or x > 1.001: |
|
print(x) |
|
else: |
|
x = min(max(x, 0), 1) |
|
assert 0 <= x <= 1 |
|
return self.offset + round(x * (self.maxx - 1)) |
|
|
|
def y_to_id(self, y): |
|
if y < -0.001 or y > 1.001: |
|
print(y) |
|
else: |
|
y = min(max(y, 0), 1) |
|
assert 0 <= y <= 1 |
|
if self.sep_xy: |
|
return self.offset + self.maxx + round(y * (self.maxy - 1)) |
|
return self.offset + round(y * (self.maxy - 1)) |
|
|
|
def id_to_x(self, id, scale=1): |
|
if not self.is_x(id): |
|
return -1 |
|
return (id - self.offset) / (self.maxx - 1) / scale |
|
|
|
def id_to_y(self, id, scale=1): |
|
if not self.is_y(id): |
|
return -1 |
|
if self.sep_xy: |
|
return (id - self.offset - self.maxx) / (self.maxy - 1) * scale |
|
return (id - self.offset) / (self.maxy - 1) / scale |
|
|
|
def update_state(self, state, idx): |
|
if state is None: |
|
new_state = (Rxn, 'e') |
|
else: |
|
if state[1] == 'x1': |
|
new_state = (state[0], 'y1') |
|
elif state[1] == 'y1': |
|
new_state = (state[0], 'x2') |
|
elif state[1] == 'x2': |
|
new_state = (state[0], 'y2') |
|
elif state[1] == 'y2': |
|
new_state = (state[0], 'c') |
|
elif state[1] == 'c': |
|
if self.is_x(idx): |
|
new_state = (state[0], 'x1') |
|
else: |
|
new_state = (state[0], 'e') |
|
else: |
|
if state[0] == Rct: |
|
if self.is_x(idx): |
|
new_state = (Cnd, 'x1') |
|
else: |
|
new_state = (Cnd, 'e') |
|
elif state[0] == Cnd: |
|
new_state = (Prd, 'x1') |
|
elif state[0] == Prd: |
|
new_state = (Rxn, 'e') |
|
elif state[0] == Rxn: |
|
if self.is_x(idx): |
|
new_state = (Rct, 'x1') |
|
else: |
|
new_state = (EOS, 'e') |
|
else: |
|
new_state = (EOS, 'e') |
|
return new_state |
|
|
|
def output_mask(self, state): |
|
|
|
mask = np.array([True] * len(self)) |
|
if state[1] in ['y1', 'c']: |
|
mask[self.offset:self.offset+self.maxx] = False |
|
if state[1] in ['x1', 'x2']: |
|
if self.sep_xy: |
|
mask[self.offset+self.maxx:self.offset+self.maxx+self.maxy] = False |
|
else: |
|
mask[self.offset:self.offset+self.maxy] = False |
|
if state[1] == 'y2': |
|
for token in [Idt, Mol, Txt, Sup]: |
|
mask[self.stoi[token]] = False |
|
if state[1] == 'c': |
|
mask[self.stoi[state[0]]] = False |
|
if state[1] == 'e': |
|
if state[0] in [Rct, Cnd, Rxn]: |
|
mask[self.offset:self.offset + self.maxx] = False |
|
if state[0] == Rct: |
|
mask[self.stoi[Cnd]] = False |
|
if state[0] == Prd: |
|
mask[self.stoi[Rxn]] = False |
|
mask[self.stoi[Noise]] = False |
|
if state[0] in [Rxn, EOS]: |
|
mask[self.EOS_ID] = False |
|
return mask |
|
|
|
def update_states_and_masks(self, states, ids): |
|
new_states = [self.update_state(state, idx) for state, idx in zip(states, ids)] |
|
masks = np.array([self.output_mask(state) for state in new_states]) |
|
return new_states, masks |
|
|
|
def bbox_to_sequence(self, bbox, category): |
|
sequence = [] |
|
x1, y1, x2, y2 = bbox |
|
if x1 >= x2 or y1 >= y2: |
|
return [] |
|
sequence.append(self.x_to_id(x1)) |
|
sequence.append(self.y_to_id(y1)) |
|
sequence.append(self.x_to_id(x2)) |
|
sequence.append(self.y_to_id(y2)) |
|
if category in self.bbox_category_to_token: |
|
sequence.append(self.stoi[self.bbox_category_to_token[category]]) |
|
else: |
|
sequence.append(self.stoi[Noise]) |
|
return sequence |
|
|
|
def sequence_to_bbox(self, sequence, scale=[1, 1]): |
|
if len(sequence) < 5: |
|
return None |
|
x1, y1 = self.id_to_x(sequence[0], scale[0]), self.id_to_y(sequence[1], scale[1]) |
|
x2, y2 = self.id_to_x(sequence[2], scale[0]), self.id_to_y(sequence[3], scale[1]) |
|
if x1 == -1 or y1 == -1 or x2 == -1 or y2 == -1 or x1 >= x2 or y1 >= y2 or sequence[4] not in self.itos: |
|
return None |
|
category = self.itos[sequence[4]] |
|
if category not in [Mol, Txt, Idt, Sup]: |
|
return None |
|
return {'category': category, 'bbox': (x1, y1, x2, y2), 'category_id': self.token_to_bbox_category[category]} |
|
|
|
def perturb_reaction(self, reaction, boxes): |
|
reaction = copy.deepcopy(reaction) |
|
options = [] |
|
options.append(0) |
|
if not(len(reaction['reactants']) == 1 and len(reaction['conditions']) == 0 and len(reaction['products']) == 1): |
|
options.append(1) |
|
options.append(2) |
|
choice = random.choice(options) |
|
if choice == 0: |
|
key = random.choice(['reactants', 'conditions', 'products']) |
|
|
|
|
|
reaction[key].append(random.randrange(len(boxes))) |
|
if choice == 1 or choice == 2: |
|
options = [] |
|
for key, val in [('reactants', 1), ('conditions', 0), ('products', 1)]: |
|
if len(reaction[key]) > val: |
|
options.append(key) |
|
key = random.choice(options) |
|
idx = random.randrange(len(reaction[key])) |
|
del_box = reaction[key][idx] |
|
reaction[key] = reaction[key][:idx] + reaction[key][idx+1:] |
|
if choice == 2: |
|
options = ['reactants', 'conditions', 'products'] |
|
options.remove(key) |
|
newkey = random.choice(options) |
|
reaction[newkey].append(del_box) |
|
return reaction |
|
|
|
def augment_reaction(self, reactions, data): |
|
area, boxes, labels = data['area'], data['boxes'], data['labels'] |
|
nonempty_boxes = [i for i in range(len(area)) if area[i] > 0] |
|
if len(nonempty_boxes) == 0: |
|
return None |
|
if len(reactions) == 0 or random.randrange(100) < 20: |
|
num_reactants = random.randint(1, 3) |
|
num_conditions = random.randint(0, 3) |
|
num_products = random.randint(1, 3) |
|
reaction = { |
|
'reactants': random.choices(nonempty_boxes, k=num_reactants), |
|
'conditions': random.choices(nonempty_boxes, k=num_conditions), |
|
'products': random.choices(nonempty_boxes, k=num_products) |
|
} |
|
else: |
|
assert len(reactions) > 0 |
|
reaction = self.perturb_reaction(random.choice(reactions), boxes) |
|
return reaction |
|
|
|
def reaction_to_sequence(self, reaction, data, shuffle_bbox=False): |
|
reaction = copy.deepcopy(reaction) |
|
area, boxes, labels = data['area'], data['boxes'], data['labels'] |
|
|
|
if all([area[i] == 0 for i in reaction['reactants']]) or all([area[i] == 0 for i in reaction['products']]): |
|
return [] |
|
if shuffle_bbox: |
|
random.shuffle(reaction['reactants']) |
|
random.shuffle(reaction['conditions']) |
|
random.shuffle(reaction['products']) |
|
sequence = [] |
|
for idx in reaction['reactants']: |
|
if area[idx] == 0: |
|
continue |
|
sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item()) |
|
sequence.append(self.stoi[Rct]) |
|
for idx in reaction['conditions']: |
|
if area[idx] == 0: |
|
continue |
|
sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item()) |
|
sequence.append(self.stoi[Cnd]) |
|
for idx in reaction['products']: |
|
if area[idx] == 0: |
|
continue |
|
sequence += self.bbox_to_sequence(boxes[idx].tolist(), labels[idx].item()) |
|
sequence.append(self.stoi[Prd]) |
|
sequence.append(self.stoi[Rxn]) |
|
return sequence |
|
|
|
def data_to_sequence(self, data, rand_order=False, shuffle_bbox=False, add_noise=False, mix_noise=False): |
|
sequence = [self.SOS_ID] |
|
sequence_out = [self.SOS_ID] |
|
reactions = copy.deepcopy(data['reactions']) |
|
reactions_seqs = [] |
|
for reaction in reactions: |
|
seq = self.reaction_to_sequence(reaction, data, shuffle_bbox=shuffle_bbox) |
|
reactions_seqs.append([seq, seq]) |
|
noise_seqs = [] |
|
if add_noise: |
|
total_len = sum(len(seq) for seq, seq_out in reactions_seqs) |
|
while total_len < self.max_len: |
|
reaction = self.augment_reaction(reactions, data) |
|
if reaction is None: |
|
break |
|
seq = self.reaction_to_sequence(reaction, data) |
|
if len(seq) == 0: |
|
continue |
|
if mix_noise: |
|
seq[-1] = self.NOISE_ID |
|
seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID] |
|
else: |
|
seq_out = [self.PAD_ID] * (len(seq) - 1) + [self.NOISE_ID] |
|
noise_seqs.append([seq, seq_out]) |
|
total_len += len(seq) |
|
if rand_order: |
|
random.shuffle(reactions_seqs) |
|
reactions_seqs += noise_seqs |
|
if mix_noise: |
|
random.shuffle(reactions_seqs) |
|
for seq, seq_out in reactions_seqs: |
|
sequence += seq |
|
sequence_out += seq_out |
|
sequence.append(self.EOS_ID) |
|
sequence_out.append(self.EOS_ID) |
|
return sequence, sequence_out |
|
|
|
def sequence_to_data(self, sequence, scores=None, scale=None): |
|
reactions = [] |
|
i = 0 |
|
cur_reaction = {'reactants': [], 'conditions': [], 'products': []} |
|
flag = 'reactants' |
|
if len(sequence) > 0 and sequence[0] == self.SOS_ID: |
|
i += 1 |
|
while i < len(sequence): |
|
if sequence[i] == self.EOS_ID: |
|
break |
|
if sequence[i] in self.itos: |
|
if self.itos[sequence[i]] in [Rxn, Noise]: |
|
cur_reaction['label'] = self.itos[sequence[i]] |
|
if len(cur_reaction['reactants']) > 0 and len(cur_reaction['products']) > 0: |
|
reactions.append(cur_reaction) |
|
cur_reaction = {'reactants': [], 'conditions': [], 'products': []} |
|
flag = 'reactants' |
|
elif self.itos[sequence[i]] == Rct: |
|
flag = 'conditions' |
|
elif self.itos[sequence[i]] == Cnd: |
|
flag = 'products' |
|
elif self.itos[sequence[i]] == Prd: |
|
flag = None |
|
elif i+5 <= len(sequence) and flag is not None: |
|
bbox = self.sequence_to_bbox(sequence[i:i+5], scale) |
|
if bbox is not None: |
|
cur_reaction[flag].append(bbox) |
|
i += 4 |
|
i += 1 |
|
return reactions |
|
|
|
def sequence_to_tokens(self, sequence): |
|
return [self.itos[x] if x in self.itos else x for x in sequence] |
|
|
|
|
|
class BboxTokenizer(ReactionTokenizer): |
|
|
|
def __init__(self, input_size=100, sep_xy=True, pix2seq=False): |
|
super(BboxTokenizer, self).__init__(input_size, sep_xy, pix2seq) |
|
|
|
@property |
|
def max_len(self): |
|
return 500 |
|
|
|
@property |
|
def output_constraint(self): |
|
return False |
|
|
|
def random_category(self): |
|
return random.choice(list(self.bbox_category_to_token.keys())) |
|
|
|
|
|
def random_bbox(self): |
|
_x1, _y1, _x2, _y2 = random.random(), random.random(), random.random(), random.random() |
|
x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2) |
|
category = self.random_category() |
|
return [x1, y1, x2, y2], category |
|
|
|
def jitter_bbox(self, bbox, ratio=0.2): |
|
x1, y1, x2, y2 = bbox |
|
w, h = x2 - x1, y2 - y1 |
|
_x1 = x1 + random.uniform(-w*ratio, w*ratio) |
|
_y1 = y1 + random.uniform(-h*ratio, h*ratio) |
|
_x2 = x2 + random.uniform(-w * ratio, w * ratio) |
|
_y2 = y2 + random.uniform(-h * ratio, h * ratio) |
|
x1, y1, x2, y2 = min(_x1, _x2), min(_y1, _y2), max(_x1, _x2), max(_y1, _y2) |
|
category = self.random_category() |
|
return np.clip([x1, y1, x2, y2], 0, 1), category |
|
|
|
def augment_box(self, bboxes): |
|
if len(bboxes) == 0: |
|
return self.random_bbox() |
|
if random.random() < 0.5: |
|
return self.random_bbox() |
|
else: |
|
return self.jitter_bbox(random.choice(bboxes)) |
|
|
|
def data_to_sequence(self, data, add_noise=False, rand_order=False): |
|
sequence = [self.SOS_ID] |
|
sequence_out = [self.SOS_ID] |
|
if rand_order: |
|
perm = np.random.permutation(len(data['boxes'])) |
|
boxes = data['boxes'][perm].tolist() |
|
labels = data['labels'][perm].tolist() |
|
else: |
|
boxes = data['boxes'].tolist() |
|
labels = data['labels'].tolist() |
|
for bbox, category in zip(boxes, labels): |
|
seq = self.bbox_to_sequence(bbox, category) |
|
sequence += seq |
|
|
|
sequence_out += seq |
|
if add_noise: |
|
while len(sequence) < self.max_len: |
|
bbox, category = self.augment_box(boxes) |
|
sequence += self.bbox_to_sequence(bbox, category) |
|
sequence_out += [self.PAD_ID] * 4 + [self.NOISE_ID] |
|
sequence.append(self.EOS_ID) |
|
sequence_out.append(self.EOS_ID) |
|
return sequence, sequence_out |
|
|
|
def sequence_to_data(self, sequence, scores=None, scale=None): |
|
bboxes = [] |
|
i = 0 |
|
if len(sequence) > 0 and sequence[0] == self.SOS_ID: |
|
i += 1 |
|
while i < len(sequence): |
|
if sequence[i] == self.EOS_ID: |
|
break |
|
if i+4 < len(sequence): |
|
bbox = self.sequence_to_bbox(sequence[i:i+5], scale) |
|
if bbox is not None: |
|
if scores is not None: |
|
bbox['score'] = scores[i + 4] |
|
bboxes.append(bbox) |
|
i += 4 |
|
i += 1 |
|
return bboxes |
|
|
|
|
|
def get_tokenizer(args): |
|
tokenizer = {} |
|
if args.pix2seq: |
|
args.coord_bins = 2000 |
|
args.sep_xy = False |
|
format = args.format |
|
if format == 'reaction': |
|
tokenizer[format] = ReactionTokenizer(args.coord_bins, args.sep_xy, args.pix2seq) |
|
if format == 'bbox': |
|
tokenizer[format] = BboxTokenizer(args.coord_bins, args.sep_xy, args.pix2seq) |
|
return tokenizer |
|
|