Spaces:
Paused
Paused
| import os | |
| import json | |
| import random | |
| import numpy as np | |
| from SmilesPE.pretokenizer import atomwise_tokenizer | |
| PAD = '<pad>' | |
| SOS = '<sos>' | |
| EOS = '<eos>' | |
| UNK = '<unk>' | |
| MASK = '<mask>' | |
| PAD_ID = 0 | |
| SOS_ID = 1 | |
| EOS_ID = 2 | |
| UNK_ID = 3 | |
| MASK_ID = 4 | |
| class Tokenizer(object): | |
| def __init__(self, path=None): | |
| self.stoi = {} | |
| self.itos = {} | |
| if path: | |
| self.load(path) | |
| def __len__(self): | |
| return len(self.stoi) | |
| def output_constraint(self): | |
| return False | |
| def save(self, path): | |
| with open(path, 'w') as f: | |
| json.dump(self.stoi, f) | |
| def load(self, path): | |
| with open(path) as f: | |
| self.stoi = json.load(f) | |
| self.itos = {item[1]: item[0] for item in self.stoi.items()} | |
| def fit_on_texts(self, texts): | |
| vocab = set() | |
| for text in texts: | |
| vocab.update(text.split(' ')) | |
| vocab = [PAD, SOS, EOS, UNK] + list(vocab) | |
| for i, s in enumerate(vocab): | |
| self.stoi[s] = i | |
| self.itos = {item[1]: item[0] for item in self.stoi.items()} | |
| assert self.stoi[PAD] == PAD_ID | |
| assert self.stoi[SOS] == SOS_ID | |
| assert self.stoi[EOS] == EOS_ID | |
| assert self.stoi[UNK] == UNK_ID | |
| def text_to_sequence(self, text, tokenized=True): | |
| sequence = [] | |
| sequence.append(self.stoi['<sos>']) | |
| if tokenized: | |
| tokens = text.split(' ') | |
| else: | |
| tokens = atomwise_tokenizer(text) | |
| for s in tokens: | |
| if s not in self.stoi: | |
| s = '<unk>' | |
| sequence.append(self.stoi[s]) | |
| sequence.append(self.stoi['<eos>']) | |
| return sequence | |
| def texts_to_sequences(self, texts): | |
| sequences = [] | |
| for text in texts: | |
| sequence = self.text_to_sequence(text) | |
| sequences.append(sequence) | |
| return sequences | |
| def sequence_to_text(self, sequence): | |
| return ''.join(list(map(lambda i: self.itos[i], sequence))) | |
| def sequences_to_texts(self, sequences): | |
| texts = [] | |
| for sequence in sequences: | |
| text = self.sequence_to_text(sequence) | |
| texts.append(text) | |
| return texts | |
| def predict_caption(self, sequence): | |
| caption = '' | |
| for i in sequence: | |
| if i == self.stoi['<eos>'] or i == self.stoi['<pad>']: | |
| break | |
| caption += self.itos[i] | |
| return caption | |
| def predict_captions(self, sequences): | |
| captions = [] | |
| for sequence in sequences: | |
| caption = self.predict_caption(sequence) | |
| captions.append(caption) | |
| return captions | |
| def sequence_to_smiles(self, sequence): | |
| return {'smiles': self.predict_caption(sequence)} | |
| class NodeTokenizer(Tokenizer): | |
| def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False): | |
| super().__init__(path) | |
| self.maxx = input_size # height | |
| self.maxy = input_size # width | |
| self.sep_xy = sep_xy | |
| self.special_tokens = [PAD, SOS, EOS, UNK, MASK] | |
| self.continuous_coords = continuous_coords | |
| self.debug = debug | |
| def __len__(self): | |
| if self.sep_xy: | |
| return self.offset + self.maxx + self.maxy | |
| else: | |
| return self.offset + max(self.maxx, self.maxy) | |
| def offset(self): | |
| return len(self.stoi) | |
| def output_constraint(self): | |
| return not self.continuous_coords | |
| def len_symbols(self): | |
| return len(self.stoi) | |
| def fit_atom_symbols(self, atoms): | |
| vocab = self.special_tokens + list(set(atoms)) | |
| for i, s in enumerate(vocab): | |
| self.stoi[s] = i | |
| assert self.stoi[PAD] == PAD_ID | |
| assert self.stoi[SOS] == SOS_ID | |
| assert self.stoi[EOS] == EOS_ID | |
| assert self.stoi[UNK] == UNK_ID | |
| assert self.stoi[MASK] == MASK_ID | |
| self.itos = {item[1]: item[0] for item in self.stoi.items()} | |
| def is_x(self, x): | |
| return self.offset <= x < self.offset + self.maxx | |
| def is_y(self, y): | |
| if self.sep_xy: | |
| return self.offset + self.maxx <= y | |
| return self.offset <= y | |
| def is_symbol(self, s): | |
| return len(self.special_tokens) <= s < self.offset or s == UNK_ID | |
| def is_atom(self, id): | |
| if self.is_symbol(id): | |
| return self.is_atom_token(self.itos[id]) | |
| return False | |
| def is_atom_token(self, token): | |
| return token.isalpha() or token.startswith("[") or token == '*' or token == UNK | |
| def x_to_id(self, x): | |
| return self.offset + round(x * (self.maxx - 1)) | |
| def y_to_id(self, y): | |
| if self.sep_xy: | |
| return self.offset + self.maxx + round(y * (self.maxy - 1)) | |
| return self.offset + round(y * (self.maxy - 1)) | |
| def id_to_x(self, id): | |
| return (id - self.offset) / (self.maxx - 1) | |
| def id_to_y(self, id): | |
| if self.sep_xy: | |
| return (id - self.offset - self.maxx) / (self.maxy - 1) | |
| return (id - self.offset) / (self.maxy - 1) | |
| def get_output_mask(self, id): | |
| mask = [False] * len(self) | |
| if self.continuous_coords: | |
| return mask | |
| if self.is_atom(id): | |
| return [True] * self.offset + [False] * self.maxx + [True] * self.maxy | |
| if self.is_x(id): | |
| return [True] * (self.offset + self.maxx) + [False] * self.maxy | |
| if self.is_y(id): | |
| return [False] * self.offset + [True] * (self.maxx + self.maxy) | |
| return mask | |
| def symbol_to_id(self, symbol): | |
| if symbol not in self.stoi: | |
| return UNK_ID | |
| return self.stoi[symbol] | |
| def symbols_to_labels(self, symbols): | |
| labels = [] | |
| for symbol in symbols: | |
| labels.append(self.symbol_to_id(symbol)) | |
| return labels | |
| def labels_to_symbols(self, labels): | |
| symbols = [] | |
| for label in labels: | |
| symbols.append(self.itos[label]) | |
| return symbols | |
| def nodes_to_grid(self, nodes): | |
| coords, symbols = nodes['coords'], nodes['symbols'] | |
| grid = np.zeros((self.maxx, self.maxy), dtype=int) | |
| for [x, y], symbol in zip(coords, symbols): | |
| x = round(x * (self.maxx - 1)) | |
| y = round(y * (self.maxy - 1)) | |
| grid[x][y] = self.symbol_to_id(symbol) | |
| return grid | |
| def grid_to_nodes(self, grid): | |
| coords, symbols, indices = [], [], [] | |
| for i in range(self.maxx): | |
| for j in range(self.maxy): | |
| if grid[i][j] != 0: | |
| x = i / (self.maxx - 1) | |
| y = j / (self.maxy - 1) | |
| coords.append([x, y]) | |
| symbols.append(self.itos[grid[i][j]]) | |
| indices.append([i, j]) | |
| return {'coords': coords, 'symbols': symbols, 'indices': indices} | |
| def nodes_to_sequence(self, nodes): | |
| coords, symbols = nodes['coords'], nodes['symbols'] | |
| labels = [SOS_ID] | |
| for (x, y), symbol in zip(coords, symbols): | |
| assert 0 <= x <= 1 | |
| assert 0 <= y <= 1 | |
| labels.append(self.x_to_id(x)) | |
| labels.append(self.y_to_id(y)) | |
| labels.append(self.symbol_to_id(symbol)) | |
| labels.append(EOS_ID) | |
| return labels | |
| def sequence_to_nodes(self, sequence): | |
| coords, symbols = [], [] | |
| i = 0 | |
| if sequence[0] == SOS_ID: | |
| i += 1 | |
| while i + 2 < len(sequence): | |
| if sequence[i] == EOS_ID: | |
| break | |
| if self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]): | |
| x = self.id_to_x(sequence[i]) | |
| y = self.id_to_y(sequence[i+1]) | |
| symbol = self.itos[sequence[i+2]] | |
| coords.append([x, y]) | |
| symbols.append(symbol) | |
| i += 3 | |
| return {'coords': coords, 'symbols': symbols} | |
| def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False): | |
| tokens = atomwise_tokenizer(smiles) | |
| labels = [SOS_ID] | |
| indices = [] | |
| atom_idx = -1 | |
| for token in tokens: | |
| if atom_only and not self.is_atom_token(token): | |
| continue | |
| if token in self.stoi: | |
| labels.append(self.stoi[token]) | |
| else: | |
| if self.debug: | |
| print(f'{token} not in vocab') | |
| labels.append(UNK_ID) | |
| if self.is_atom_token(token): | |
| atom_idx += 1 | |
| if not self.continuous_coords: | |
| if mask_ratio > 0 and random.random() < mask_ratio: | |
| labels.append(MASK_ID) | |
| labels.append(MASK_ID) | |
| elif coords is not None: | |
| if atom_idx < len(coords): | |
| x, y = coords[atom_idx] | |
| assert 0 <= x <= 1 | |
| assert 0 <= y <= 1 | |
| else: | |
| x = random.random() | |
| y = random.random() | |
| labels.append(self.x_to_id(x)) | |
| labels.append(self.y_to_id(y)) | |
| indices.append(len(labels) - 1) | |
| labels.append(EOS_ID) | |
| return labels, indices | |
| def sequence_to_smiles(self, sequence): | |
| has_coords = not self.continuous_coords | |
| smiles = '' | |
| coords, symbols, indices = [], [], [] | |
| for i, label in enumerate(sequence): | |
| if label == EOS_ID or label == PAD_ID: | |
| break | |
| if self.is_x(label) or self.is_y(label): | |
| continue | |
| token = self.itos[label] | |
| smiles += token | |
| if self.is_atom_token(token): | |
| if has_coords: | |
| if i+3 < len(sequence) and self.is_x(sequence[i+1]) and self.is_y(sequence[i+2]): | |
| x = self.id_to_x(sequence[i+1]) | |
| y = self.id_to_y(sequence[i+2]) | |
| coords.append([x, y]) | |
| symbols.append(token) | |
| indices.append(i+3) | |
| else: | |
| if i+1 < len(sequence): | |
| symbols.append(token) | |
| indices.append(i+1) | |
| results = {'smiles': smiles, 'symbols': symbols, 'indices': indices} | |
| if has_coords: | |
| results['coords'] = coords | |
| return results | |
| class CharTokenizer(NodeTokenizer): | |
| def __init__(self, input_size=100, path=None, sep_xy=False, continuous_coords=False, debug=False): | |
| super().__init__(input_size, path, sep_xy, continuous_coords, debug) | |
| def fit_on_texts(self, texts): | |
| vocab = set() | |
| for text in texts: | |
| vocab.update(list(text)) | |
| if ' ' in vocab: | |
| vocab.remove(' ') | |
| vocab = [PAD, SOS, EOS, UNK] + list(vocab) | |
| for i, s in enumerate(vocab): | |
| self.stoi[s] = i | |
| self.itos = {item[1]: item[0] for item in self.stoi.items()} | |
| assert self.stoi[PAD] == PAD_ID | |
| assert self.stoi[SOS] == SOS_ID | |
| assert self.stoi[EOS] == EOS_ID | |
| assert self.stoi[UNK] == UNK_ID | |
| def text_to_sequence(self, text, tokenized=True): | |
| sequence = [] | |
| sequence.append(self.stoi['<sos>']) | |
| if tokenized: | |
| tokens = text.split(' ') | |
| assert all(len(s) == 1 for s in tokens) | |
| else: | |
| tokens = list(text) | |
| for s in tokens: | |
| if s not in self.stoi: | |
| s = '<unk>' | |
| sequence.append(self.stoi[s]) | |
| sequence.append(self.stoi['<eos>']) | |
| return sequence | |
| def fit_atom_symbols(self, atoms): | |
| atoms = list(set(atoms)) | |
| chars = [] | |
| for atom in atoms: | |
| chars.extend(list(atom)) | |
| vocab = self.special_tokens + chars | |
| for i, s in enumerate(vocab): | |
| self.stoi[s] = i | |
| assert self.stoi[PAD] == PAD_ID | |
| assert self.stoi[SOS] == SOS_ID | |
| assert self.stoi[EOS] == EOS_ID | |
| assert self.stoi[UNK] == UNK_ID | |
| assert self.stoi[MASK] == MASK_ID | |
| self.itos = {item[1]: item[0] for item in self.stoi.items()} | |
| def get_output_mask(self, id): | |
| ''' TO FIX ''' | |
| mask = [False] * len(self) | |
| if self.continuous_coords: | |
| return mask | |
| if self.is_x(id): | |
| return [True] * (self.offset + self.maxx) + [False] * self.maxy | |
| if self.is_y(id): | |
| return [False] * self.offset + [True] * (self.maxx + self.maxy) | |
| return mask | |
| def nodes_to_sequence(self, nodes): | |
| coords, symbols = nodes['coords'], nodes['symbols'] | |
| labels = [SOS_ID] | |
| for (x, y), symbol in zip(coords, symbols): | |
| assert 0 <= x <= 1 | |
| assert 0 <= y <= 1 | |
| labels.append(self.x_to_id(x)) | |
| labels.append(self.y_to_id(y)) | |
| for char in symbol: | |
| labels.append(self.symbol_to_id(char)) | |
| labels.append(EOS_ID) | |
| return labels | |
| def sequence_to_nodes(self, sequence): | |
| coords, symbols = [], [] | |
| i = 0 | |
| if sequence[0] == SOS_ID: | |
| i += 1 | |
| while i < len(sequence): | |
| if sequence[i] == EOS_ID: | |
| break | |
| if i+2 < len(sequence) and self.is_x(sequence[i]) and self.is_y(sequence[i+1]) and self.is_symbol(sequence[i+2]): | |
| x = self.id_to_x(sequence[i]) | |
| y = self.id_to_y(sequence[i+1]) | |
| for j in range(i+2, len(sequence)): | |
| if not self.is_symbol(sequence[j]): | |
| break | |
| symbol = ''.join(self.itos(sequence[k]) for k in range(i+2, j)) | |
| coords.append([x, y]) | |
| symbols.append(symbol) | |
| i = j | |
| else: | |
| i += 1 | |
| return {'coords': coords, 'symbols': symbols} | |
| def smiles_to_sequence(self, smiles, coords=None, mask_ratio=0, atom_only=False): | |
| tokens = atomwise_tokenizer(smiles) | |
| labels = [SOS_ID] | |
| indices = [] | |
| atom_idx = -1 | |
| for token in tokens: | |
| if atom_only and not self.is_atom_token(token): | |
| continue | |
| for c in token: | |
| if c in self.stoi: | |
| labels.append(self.stoi[c]) | |
| else: | |
| if self.debug: | |
| print(f'{c} not in vocab') | |
| labels.append(UNK_ID) | |
| if self.is_atom_token(token): | |
| atom_idx += 1 | |
| if not self.continuous_coords: | |
| if mask_ratio > 0 and random.random() < mask_ratio: | |
| labels.append(MASK_ID) | |
| labels.append(MASK_ID) | |
| elif coords is not None: | |
| if atom_idx < len(coords): | |
| x, y = coords[atom_idx] | |
| assert 0 <= x <= 1 | |
| assert 0 <= y <= 1 | |
| else: | |
| x = random.random() | |
| y = random.random() | |
| labels.append(self.x_to_id(x)) | |
| labels.append(self.y_to_id(y)) | |
| indices.append(len(labels) - 1) | |
| labels.append(EOS_ID) | |
| return labels, indices | |
| def sequence_to_smiles(self, sequence): | |
| has_coords = not self.continuous_coords | |
| smiles = '' | |
| coords, symbols, indices = [], [], [] | |
| i = 0 | |
| while i < len(sequence): | |
| label = sequence[i] | |
| if label == EOS_ID or label == PAD_ID: | |
| break | |
| if self.is_x(label) or self.is_y(label): | |
| i += 1 | |
| continue | |
| if not self.is_atom(label): | |
| smiles += self.itos[label] | |
| i += 1 | |
| continue | |
| if self.itos[label] == '[': | |
| j = i + 1 | |
| while j < len(sequence): | |
| if not self.is_symbol(sequence[j]): | |
| break | |
| if self.itos[sequence[j]] == ']': | |
| j += 1 | |
| break | |
| j += 1 | |
| else: | |
| if i+1 < len(sequence) and (self.itos[label] == 'C' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'l' \ | |
| or self.itos[label] == 'B' and self.is_symbol(sequence[i+1]) and self.itos[sequence[i+1]] == 'r'): | |
| j = i+2 | |
| else: | |
| j = i+1 | |
| token = ''.join(self.itos[sequence[k]] for k in range(i, j)) | |
| smiles += token | |
| if has_coords: | |
| if j+2 < len(sequence) and self.is_x(sequence[j]) and self.is_y(sequence[j+1]): | |
| x = self.id_to_x(sequence[j]) | |
| y = self.id_to_y(sequence[j+1]) | |
| coords.append([x, y]) | |
| symbols.append(token) | |
| indices.append(j+2) | |
| i = j+2 | |
| else: | |
| i = j | |
| else: | |
| if j < len(sequence): | |
| symbols.append(token) | |
| indices.append(j) | |
| i = j | |
| results = {'smiles': smiles, 'symbols': symbols, 'indices': indices} | |
| if has_coords: | |
| results['coords'] = coords | |
| return results | |
| def get_tokenizer(args): | |
| tokenizer = {} | |
| for format_ in args.formats: | |
| if format_ == 'atomtok': | |
| if args.vocab_file is None: | |
| args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json') | |
| tokenizer['atomtok'] = Tokenizer(args.vocab_file) | |
| elif format_ == "atomtok_coords": | |
| if args.vocab_file is None: | |
| args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_uspto.json') | |
| tokenizer["atomtok_coords"] = NodeTokenizer(args.coord_bins, args.vocab_file, args.sep_xy, | |
| continuous_coords=args.continuous_coords) | |
| elif format_ == "chartok_coords": | |
| if args.vocab_file is None: | |
| args.vocab_file = os.path.join(os.path.dirname(__file__), 'vocab/vocab_chars.json') | |
| tokenizer["chartok_coords"] = CharTokenizer(args.coord_bins, args.vocab_file, args.sep_xy, | |
| continuous_coords=args.continuous_coords) | |
| return tokenizer | |