d-delaurier commited on
Commit
2b7e528
·
1 Parent(s): 2ee6fc9

Create tensorflow_train.py

Browse files
Files changed (1) hide show
  1. tensorflow_train.py +448 -0
tensorflow_train.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+
8
+ os.environ["USE_TF"] = "1"
9
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
10
+
11
+ import datetime
12
+ import hashlib
13
+ import multiprocessing as mp
14
+ import time
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import tensorflow as tf
19
+ from tensorflow.keras import mixed_precision
20
+ from tqdm.auto import tqdm
21
+
22
+ from doctr.models import login_to_hub, push_to_hf_hub
23
+
24
+ gpu_devices = tf.config.experimental.list_physical_devices("GPU")
25
+ if any(gpu_devices):
26
+ tf.config.experimental.set_memory_growth(gpu_devices[0], True)
27
+
28
+ from doctr import transforms as T
29
+ from doctr.datasets import VOCABS, DataLoader, RecognitionDataset, WordGenerator
30
+ from doctr.models import recognition
31
+ from doctr.utils.metrics import TextMatch
32
+ from utils import EarlyStopper, plot_recorder, plot_samples
33
+
34
+
35
+ def record_lr(
36
+ model: tf.keras.Model,
37
+ train_loader: DataLoader,
38
+ batch_transforms,
39
+ optimizer,
40
+ start_lr: float = 1e-7,
41
+ end_lr: float = 1,
42
+ num_it: int = 100,
43
+ amp: bool = False,
44
+ ):
45
+ """Gridsearch the optimal learning rate for the training.
46
+ Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py
47
+ """
48
+ if num_it > len(train_loader):
49
+ raise ValueError("the value of `num_it` needs to be lower than the number of available batches")
50
+
51
+ # Update param groups & LR
52
+ gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
53
+ optimizer.learning_rate = start_lr
54
+
55
+ lr_recorder = [start_lr * gamma**idx for idx in range(num_it)]
56
+ loss_recorder = []
57
+
58
+ for batch_idx, (images, targets) in enumerate(train_loader):
59
+ images = batch_transforms(images)
60
+
61
+ # Forward, Backward & update
62
+ with tf.GradientTape() as tape:
63
+ train_loss = model(images, targets, training=True)["loss"]
64
+ grads = tape.gradient(train_loss, model.trainable_weights)
65
+
66
+ if amp:
67
+ grads = optimizer.get_unscaled_gradients(grads)
68
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
69
+
70
+ optimizer.learning_rate = optimizer.learning_rate * gamma
71
+
72
+ # Record
73
+ train_loss = train_loss.numpy()
74
+ if np.any(np.isnan(train_loss)):
75
+ if batch_idx == 0:
76
+ raise ValueError("loss value is NaN or inf.")
77
+ else:
78
+ break
79
+ loss_recorder.append(train_loss.mean())
80
+ # Stop after the number of iterations
81
+ if batch_idx + 1 == num_it:
82
+ break
83
+
84
+ return lr_recorder[: len(loss_recorder)], loss_recorder
85
+
86
+
87
+ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False):
88
+ train_iter = iter(train_loader)
89
+ # Iterate over the batches of the dataset
90
+ pbar = tqdm(train_iter, position=1)
91
+ for images, targets in pbar:
92
+ images = batch_transforms(images)
93
+
94
+ with tf.GradientTape() as tape:
95
+ train_loss = model(images, targets, training=True)["loss"]
96
+ grads = tape.gradient(train_loss, model.trainable_weights)
97
+ if amp:
98
+ grads = optimizer.get_unscaled_gradients(grads)
99
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
100
+
101
+ pbar.set_description(f"Training loss: {train_loss.numpy().mean():.6}")
102
+
103
+
104
+ def evaluate(model, val_loader, batch_transforms, val_metric):
105
+ # Reset val metric
106
+ val_metric.reset()
107
+ # Validation loop
108
+ val_loss, batch_cnt = 0, 0
109
+ val_iter = iter(val_loader)
110
+ for images, targets in tqdm(val_iter):
111
+ images = batch_transforms(images)
112
+ out = model(images, targets, return_preds=True, training=False)
113
+ # Compute metric
114
+ if len(out["preds"]):
115
+ words, _ = zip(*out["preds"])
116
+ else:
117
+ words = []
118
+ val_metric.update(targets, words)
119
+
120
+ val_loss += out["loss"].numpy().mean()
121
+ batch_cnt += 1
122
+
123
+ val_loss /= batch_cnt
124
+ result = val_metric.summary()
125
+ return val_loss, result["raw"], result["unicase"]
126
+
127
+
128
+ def main(args):
129
+ print(args)
130
+
131
+ if args.push_to_hub:
132
+ login_to_hub()
133
+
134
+ if not isinstance(args.workers, int):
135
+ args.workers = min(16, mp.cpu_count())
136
+
137
+ vocab = VOCABS[args.vocab]
138
+ fonts = args.font.split(",")
139
+
140
+ # AMP
141
+ if args.amp:
142
+ mixed_precision.set_global_policy("mixed_float16")
143
+
144
+ st = time.time()
145
+
146
+ if isinstance(args.val_path, str):
147
+ with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
148
+ val_hash = hashlib.sha256(f.read()).hexdigest()
149
+
150
+ # Load val data generator
151
+ val_set = RecognitionDataset(
152
+ img_folder=os.path.join(args.val_path, "images"),
153
+ labels_path=os.path.join(args.val_path, "labels.json"),
154
+ img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
155
+ )
156
+ else:
157
+ val_hash = None
158
+ # Load synthetic data generator
159
+ val_set = WordGenerator(
160
+ vocab=vocab,
161
+ min_chars=args.min_chars,
162
+ max_chars=args.max_chars,
163
+ num_samples=args.val_samples * len(vocab),
164
+ font_family=fonts,
165
+ img_transforms=T.Compose([
166
+ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
167
+ # Ensure we have a 90% split of white-background images
168
+ T.RandomApply(T.ColorInversion(), 0.9),
169
+ ]),
170
+ )
171
+
172
+ val_loader = DataLoader(
173
+ val_set,
174
+ batch_size=args.batch_size,
175
+ shuffle=False,
176
+ drop_last=False,
177
+ num_workers=args.workers,
178
+ )
179
+ print(
180
+ f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
181
+ f"{val_loader.num_batches} batches)"
182
+ )
183
+
184
+ # Load doctr model
185
+ model = recognition.__dict__[args.arch](
186
+ pretrained=args.pretrained,
187
+ input_shape=(args.input_size, 4 * args.input_size, 3),
188
+ vocab=vocab,
189
+ )
190
+ # Resume weights
191
+ if isinstance(args.resume, str):
192
+ model.load_weights(args.resume)
193
+
194
+ # Metrics
195
+ val_metric = TextMatch()
196
+
197
+ batch_transforms = T.Compose([
198
+ T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)),
199
+ ])
200
+
201
+ if args.test_only:
202
+ print("Running evaluation")
203
+ val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric)
204
+ print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
205
+ return
206
+
207
+ st = time.time()
208
+
209
+ if isinstance(args.train_path, str):
210
+ # Load train data generator
211
+ base_path = Path(args.train_path)
212
+ parts = (
213
+ [base_path]
214
+ if base_path.joinpath("labels.json").is_file()
215
+ else [base_path.joinpath(sub) for sub in os.listdir(base_path)]
216
+ )
217
+ with open(parts[0].joinpath("labels.json"), "rb") as f:
218
+ train_hash = hashlib.sha256(f.read()).hexdigest()
219
+
220
+ train_set = RecognitionDataset(
221
+ parts[0].joinpath("images"),
222
+ parts[0].joinpath("labels.json"),
223
+ img_transforms=T.Compose([
224
+ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
225
+ # Augmentations
226
+ T.RandomApply(T.ColorInversion(), 0.1),
227
+ T.RandomApply(T.ToGray(num_output_channels=3), 0.1),
228
+ T.RandomJpegQuality(60),
229
+ T.RandomSaturation(0.3),
230
+ T.RandomContrast(0.3),
231
+ T.RandomBrightness(0.3),
232
+ T.RandomApply(T.RandomShadow(), 0.4),
233
+ T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
234
+ T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3),
235
+ ]),
236
+ )
237
+ if len(parts) > 1:
238
+ for subfolder in parts[1:]:
239
+ train_set.merge_dataset(
240
+ RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))
241
+ )
242
+ else:
243
+ train_hash = None
244
+ # Load synthetic data generator
245
+ train_set = WordGenerator(
246
+ vocab=vocab,
247
+ min_chars=args.min_chars,
248
+ max_chars=args.max_chars,
249
+ num_samples=args.train_samples * len(vocab),
250
+ font_family=fonts,
251
+ img_transforms=T.Compose([
252
+ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
253
+ # Ensure we have a 90% split of white-background images
254
+ T.RandomApply(T.ColorInversion(), 0.9),
255
+ T.RandomApply(T.ToGray(num_output_channels=3), 0.1),
256
+ T.RandomJpegQuality(60),
257
+ T.RandomSaturation(0.3),
258
+ T.RandomContrast(0.3),
259
+ T.RandomBrightness(0.3),
260
+ T.RandomApply(T.RandomShadow(), 0.4),
261
+ T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
262
+ T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3),
263
+ ]),
264
+ )
265
+
266
+ train_loader = DataLoader(
267
+ train_set,
268
+ batch_size=args.batch_size,
269
+ shuffle=True,
270
+ drop_last=True,
271
+ num_workers=args.workers,
272
+ )
273
+ print(
274
+ f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
275
+ f"{train_loader.num_batches} batches)"
276
+ )
277
+
278
+ if args.show_samples:
279
+ x, target = next(iter(train_loader))
280
+ plot_samples(x, target)
281
+ return
282
+
283
+ # Optimizer
284
+ scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
285
+ args.lr,
286
+ decay_steps=args.epochs * len(train_loader),
287
+ decay_rate=1 / (25e4), # final lr as a fraction of initial lr
288
+ staircase=False,
289
+ name="ExponentialDecay",
290
+ )
291
+ optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5)
292
+ if args.amp:
293
+ optimizer = mixed_precision.LossScaleOptimizer(optimizer)
294
+ # LR Finder
295
+ if args.find_lr:
296
+ lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp)
297
+ plot_recorder(lrs, losses)
298
+ return
299
+
300
+ # Tensorboard to monitor training
301
+ current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
302
+ exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name
303
+
304
+ config = {
305
+ "learning_rate": args.lr,
306
+ "epochs": args.epochs,
307
+ "batch_size": args.batch_size,
308
+ "architecture": args.arch,
309
+ "input_size": args.input_size,
310
+ "optimizer": optimizer.name,
311
+ "framework": "tensorflow",
312
+ "scheduler": scheduler.name,
313
+ "vocab": args.vocab,
314
+ "train_hash": train_hash,
315
+ "val_hash": val_hash,
316
+ "pretrained": args.pretrained,
317
+ }
318
+
319
+ # W&B
320
+ if args.wb:
321
+ import wandb
322
+
323
+ run = wandb.init(
324
+ name=exp_name,
325
+ project="text-recognition",
326
+ config=config,
327
+ )
328
+
329
+ # ClearML
330
+ if args.clearml:
331
+ from clearml import Task
332
+
333
+ task = Task.init(project_name="docTR/text-recognition", task_name=exp_name, reuse_last_task_id=False)
334
+ task.upload_artifact("config", config)
335
+
336
+ # Backbone freezing
337
+ if args.freeze_backbone:
338
+ for layer in model.feat_extractor.layers:
339
+ layer.trainable = False
340
+
341
+ min_loss = np.inf
342
+ if args.early_stop:
343
+ early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta)
344
+ # Training loop
345
+ for epoch in range(args.epochs):
346
+ fit_one_epoch(model, train_loader, batch_transforms, optimizer, args.amp)
347
+
348
+ # Validation loop at the end of each epoch
349
+ val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric)
350
+ if val_loss < min_loss:
351
+ print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
352
+ model.save_weights(f"./{exp_name}/weights")
353
+ min_loss = val_loss
354
+ print(
355
+ f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
356
+ f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})"
357
+ )
358
+ # W&B
359
+ if args.wb:
360
+ wandb.log({
361
+ "val_loss": val_loss,
362
+ "exact_match": exact_match,
363
+ "partial_match": partial_match,
364
+ })
365
+
366
+ # ClearML
367
+ if args.clearml:
368
+ from clearml import Logger
369
+
370
+ logger = Logger.current_logger()
371
+ logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch)
372
+ logger.report_scalar(title="Exact Match", series="exact_match", value=exact_match, iteration=epoch)
373
+ logger.report_scalar(title="Partial Match", series="partial_match", value=partial_match, iteration=epoch)
374
+ if args.early_stop and early_stopper.early_stop(val_loss):
375
+ print("Training halted early due to reaching patience limit.")
376
+ break
377
+ if args.wb:
378
+ run.finish()
379
+
380
+ if args.push_to_hub:
381
+ push_to_hf_hub(model, exp_name, task="recognition", run_config=args)
382
+
383
+
384
+ def parse_args():
385
+ import argparse
386
+
387
+ parser = argparse.ArgumentParser(
388
+ description="DocTR training script for text recognition (TensorFlow)",
389
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
390
+ )
391
+
392
+ parser.add_argument("arch", type=str, help="text-recognition model to train")
393
+ parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)")
394
+ parser.add_argument("--val_path", type=str, default=None, help="path to val data folder")
395
+ parser.add_argument(
396
+ "--train-samples",
397
+ type=int,
398
+ default=1000,
399
+ help="Multiplied by the vocab length gets you the number of synthetic training samples that will be used.",
400
+ )
401
+ parser.add_argument(
402
+ "--val-samples",
403
+ type=int,
404
+ default=20,
405
+ help="Multiplied by the vocab length gets you the number of synthetic validation samples that will be used.",
406
+ )
407
+ parser.add_argument(
408
+ "--font", type=str, default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", help="Font family to be used"
409
+ )
410
+ parser.add_argument("--min-chars", type=int, default=1, help="Minimum number of characters per synthetic sample")
411
+ parser.add_argument("--max-chars", type=int, default=12, help="Maximum number of characters per synthetic sample")
412
+ parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
413
+ parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
414
+ parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training")
415
+ parser.add_argument("--input_size", type=int, default=32, help="input size H for the model, W = 4*H")
416
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)")
417
+ parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading")
418
+ parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
419
+ parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
420
+ parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
421
+ parser.add_argument(
422
+ "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
423
+ )
424
+ parser.add_argument(
425
+ "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
426
+ )
427
+ parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases")
428
+ parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML")
429
+ parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub")
430
+ parser.add_argument(
431
+ "--pretrained",
432
+ dest="pretrained",
433
+ action="store_true",
434
+ help="Load pretrained parameters before starting the training",
435
+ )
436
+ parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
437
+ parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
438
+ parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")
439
+ parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping")
440
+ parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping")
441
+ args = parser.parse_args()
442
+
443
+ return args
444
+
445
+
446
+ if __name__ == "__main__":
447
+ args = parse_args()
448
+ main(args)