Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from typing import List, Tuple | |
| import tf_keras as keras | |
| import numpy as np | |
| from dataloader_iam import Batch | |
| import tensorflow.compat.v1 as tf | |
| tf.compat.v1.disable_v2_behavior | |
| # Disable eager mode | |
| tf.compat.v1.disable_eager_execution | |
| class DecoderType: | |
| """ | |
| CTC decoder types. | |
| """ | |
| BestPath = 0 | |
| BeamSearch = 1 | |
| WordBeamSearch = 2 | |
| class Model: | |
| """ | |
| Minimalistic TF model for HTR. | |
| """ | |
| def __init__(self, | |
| char_list: List[str], | |
| model_dir: str, | |
| decoder_type: str = DecoderType.BestPath, | |
| must_restore: bool = False, | |
| dump: bool = False) -> None: | |
| """ | |
| Init model: add CNN, RNN and CTC and initialize TF. | |
| """ | |
| self.dump = dump | |
| self.char_list = char_list | |
| self.decoder_type = decoder_type | |
| self.must_restore = must_restore | |
| self.snap_ID = 0 | |
| self.model_dir = model_dir | |
| tf.compat.v1.disable_eager_execution() | |
| # Whether to use normalization over a batch or a population | |
| self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') | |
| # input image batch | |
| self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None)) | |
| # setup CNN, RNN and CTC | |
| self.setup_cnn() | |
| self.setup_rnn() | |
| self.setup_ctc() | |
| # setup optimizer to train NN | |
| self.batches_trained = 0 | |
| self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) | |
| with tf.control_dependencies(self.update_ops): | |
| self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss) | |
| # initialize TF | |
| self.sess, self.saver = self.setup_tf() | |
| def setup_cnn(self) -> None: | |
| """ | |
| Create CNN layers. | |
| """ | |
| cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3) | |
| # list of parameters for the layers | |
| kernel_vals = [5, 5, 3, 3, 3] | |
| feature_vals = [1, 32, 64, 128, 128, 256] | |
| stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)] | |
| num_layers = len(stride_vals) | |
| # create layers | |
| pool = cnn_in4d # input to first CNN layer | |
| for i in range(num_layers): | |
| kernel = tf.Variable( | |
| tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]], | |
| stddev=0.1)) | |
| conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1)) | |
| conv_norm = tf.keras.layers.BatchNormalization()(conv, training=self.is_train) | |
| relu = tf.nn.relu(conv_norm) | |
| pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1), | |
| strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID') | |
| self.cnn_out_4d = pool | |
| def setup_rnn(self) -> None: | |
| """ | |
| Create RNN layers. | |
| """ | |
| rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2]) | |
| # basic cells which is used to build RNN | |
| num_hidden = 256 | |
| cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in | |
| range(2)] # 2 layers | |
| # stack basic cells | |
| stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) | |
| # bidirectional RNN | |
| # BxTxF -> BxTx2H | |
| (fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d, | |
| dtype=rnn_in3d.dtype) | |
| # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H | |
| concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) | |
| # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC | |
| kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1)) | |
| self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), | |
| axis=[2]) | |
| def setup_ctc(self) -> None: | |
| """ | |
| Create CTC loss and decoder. | |
| """ | |
| # BxTxC -> TxBxC | |
| self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2]) | |
| # ground truth text as sparse tensor | |
| self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]), | |
| tf.compat.v1.placeholder(tf.int32, [None]), | |
| tf.compat.v1.placeholder(tf.int64, [2])) | |
| # calc loss for batch | |
| self.seq_len = tf.compat.v1.placeholder(tf.int32, [None]) | |
| self.loss = tf.reduce_mean( | |
| input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc, | |
| sequence_length=self.seq_len, | |
| ctc_merge_repeated=True)) | |
| # calc loss for each element to compute label probability | |
| self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32, | |
| shape=[None, None, len(self.char_list) + 1]) | |
| self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input, | |
| sequence_length=self.seq_len, ctc_merge_repeated=True) | |
| # best path decoding or beam search decoding | |
| if self.decoder_type == DecoderType.BestPath: | |
| self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len) | |
| elif self.decoder_type == DecoderType.BeamSearch: | |
| self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len, | |
| beam_width=50) | |
| # word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch) | |
| elif self.decoder_type == DecoderType.WordBeamSearch: | |
| # prepare information about language (dictionary, characters in dataset, characters forming words) | |
| chars = ''.join(self.char_list) | |
| word_chars = open('../model/wordCharList.txt').read().splitlines()[0] | |
| corpus = open('../data/corpus.txt').read() | |
| # decode using the "Words" mode of word beam search | |
| from word_beam_search import WordBeamSearch | |
| self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), | |
| word_chars.encode('utf8')) | |
| # the input to the decoder must have softmax already applied | |
| self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2) | |
| def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]: | |
| """ | |
| Initialize TF. | |
| """ | |
| print('Python: ' + sys.version) | |
| print('Tensorflow: ' + tf.__version__) | |
| sess = tf.compat.v1.Session() # TF session | |
| saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file | |
| latest_snapshot = tf.train.latest_checkpoint(self.model_dir ) # is there a saved model? | |
| # if model must be restored (for inference), there must be a snapshot | |
| if self.must_restore and not latest_snapshot: | |
| raise Exception('No saved model found in: ' + model_dir) | |
| # load saved model if available | |
| if latest_snapshot: | |
| print('Init with stored values from ' + latest_snapshot) | |
| saver.restore(sess, latest_snapshot) | |
| else: | |
| print('Init with new values') | |
| sess.run(tf.compat.v1.global_variables_initializer()) | |
| return sess, saver | |
| def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]: | |
| """ | |
| Put ground truth texts into sparse tensor for ctc_loss. | |
| """ | |
| indices = [] | |
| values = [] | |
| shape = [len(texts), 0] # last entry must be max(labelList[i]) | |
| # go over all texts | |
| for batchElement, text in enumerate(texts): | |
| # convert to string of label (i.e. class-ids) | |
| label_str = [self.char_list.index(c) for c in text] | |
| # sparse tensor must have size of max. label-string | |
| if len(label_str) > shape[1]: | |
| shape[1] = len(label_str) | |
| # put each label into sparse tensor | |
| for i, label in enumerate(label_str): | |
| indices.append([batchElement, i]) | |
| values.append(label) | |
| return indices, values, shape | |
| def decoder_output_to_text(self, ctc_output: tuple, batch_size: int) -> List[str]: | |
| """ | |
| Extract texts from output of CTC decoder. | |
| """ | |
| # word beam search: already contains label strings | |
| if self.decoder_type == DecoderType.WordBeamSearch: | |
| label_strs = ctc_output | |
| # TF decoders: label strings are contained in sparse tensor | |
| else: | |
| # ctc returns tuple, first element is SparseTensor | |
| decoded = ctc_output[0][0] | |
| # contains string of labels for each batch element | |
| label_strs = [[] for _ in range(batch_size)] | |
| # go over all indices and save mapping: batch -> values | |
| for (idx, idx2d) in enumerate(decoded.indices): | |
| label = decoded.values[idx] | |
| batch_element = idx2d[0] # index according to [b,t] | |
| label_strs[batch_element].append(label) | |
| # map labels to chars for all batch elements | |
| return [''.join([self.char_list[c] for c in labelStr]) for labelStr in label_strs] | |
| def train_batch(self, batch: Batch) -> float: | |
| """ | |
| Feed a batch into the NN to train it. | |
| """ | |
| num_batch_elements = len(batch.imgs) | |
| max_text_len = batch.imgs[0].shape[0] // 4 | |
| sparse = self.to_sparse(batch.gt_texts) | |
| eval_list = [self.optimizer, self.loss] | |
| feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse, | |
| self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True} | |
| _, loss_val = self.sess.run(eval_list, feed_dict) | |
| self.batches_trained += 1 | |
| return loss_val | |
| def dump_nn_output(rnn_output: np.ndarray) -> None: | |
| """ | |
| Dump the output of the NN to CSV file(s). | |
| """ | |
| dump_dir = '../dump/' | |
| if not os.path.isdir(dump_dir): | |
| os.mkdir(dump_dir) | |
| # iterate over all batch elements and create a CSV file for each one | |
| max_t, max_b, max_c = rnn_output.shape | |
| for b in range(max_b): | |
| csv = '' | |
| for t in range(max_t): | |
| for c in range(max_c): | |
| csv += str(rnn_output[t, b, c]) + ';' | |
| csv += '\n' | |
| fn = dump_dir + 'rnnOutput_' + str(b) + '.csv' | |
| print('Write dump of NN to file: ' + fn) | |
| with open(fn, 'w') as f: | |
| f.write(csv) | |
| def infer_batch(self, batch: Batch, calc_probability: bool = False, probability_of_gt: bool = False): | |
| """ | |
| Feed a batch into the NN to recognize the texts. | |
| """ | |
| # decode, optionally save RNN output | |
| num_batch_elements = len(batch.imgs) | |
| # put tensors to be evaluated into list | |
| eval_list = [] | |
| if self.decoder_type == DecoderType.WordBeamSearch: | |
| eval_list.append(self.wbs_input) | |
| else: | |
| eval_list.append(self.decoder) | |
| if self.dump or calc_probability: | |
| eval_list.append(self.ctc_in_3d_tbc) | |
| # sequence length depends on input image size (model downsizes width by 4) | |
| max_text_len = batch.imgs[0].shape[0] // 4 | |
| # dict containing all tensor fed into the model | |
| feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements, | |
| self.is_train: False} | |
| # evaluate model | |
| eval_res = self.sess.run(eval_list, feed_dict) | |
| # TF decoders: decoding already done in TF graph | |
| if self.decoder_type != DecoderType.WordBeamSearch: | |
| decoded = eval_res[0] | |
| # word beam search decoder: decoding is done in C++ function compute() | |
| else: | |
| decoded = self.decoder.compute(eval_res[0]) | |
| # map labels (numbers) to character string | |
| texts = self.decoder_output_to_text(decoded, num_batch_elements) | |
| # feed RNN output and recognized text into CTC loss to compute labeling probability | |
| probs = None | |
| if calc_probability: | |
| sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts) | |
| ctc_input = eval_res[1] | |
| eval_list = self.loss_per_element | |
| feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse, | |
| self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False} | |
| loss_vals = self.sess.run(eval_list, feed_dict) | |
| probs = np.exp(-loss_vals) | |
| # dump the output of the NN to CSV file(s) | |
| if self.dump: | |
| self.dump_nn_output(eval_res[1]) | |
| return texts, probs | |
| def save(self) -> None: | |
| """ | |
| Save model to file. | |
| """ | |
| self.snap_ID += 1 | |
| self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID) | |