File size: 1,683 Bytes
a711240
 
 
 
3c0c5aa
7200298
a711240
7200298
a711240
3c0c5aa
 
 
a711240
 
79fd7d0
3c0c5aa
 
7200298
 
 
eeb74de
7200298
 
 
 
 
a711240
7200298
 
 
 
 
 
 
a711240
5ad6755
7200298
a711240
9c03436
7200298
 
9c03436
79fd7d0
a711240
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Trains or fine-tunes a model for the task of monocular depth estimation
Receives 1 arguments from argparse:
  <data_path> - Path to the dataset which is split into 2 folders - train and test.
"""
import sys
import yaml
from fastai.vision.all import unet_learner, Path, resnet34, rmse, MSELossFlat
from custom_data_loading import create_data
from dagshub.fastai import DAGsHubLogger


if __name__ == "__main__":
    # Check if got all needed input for argparse
    if len(sys.argv) != 2:
        print("usage: %s <data_path>" % sys.argv[0], file=sys.stderr)
        sys.exit(0)

    with open(r"./src/code/params.yml") as f:
        params = yaml.safe_load(f)

    data = create_data(Path(sys.argv[1]))

    metrics = {'rmse': rmse}
    arch = {'resnet34': resnet34}
    loss = {'MSELossFlat': MSELossFlat()}

    learner = unet_learner(data,
                           arch.get(params['architecture']),
                           metrics=metrics.get(params['train_metric']),
                           wd=float(params['weight_decay']),
                           n_out=int(params['num_outs']),
                           loss_func=loss.get(params['loss_func']),
                           path=params['source_dir'],
                           model_dir=params['model_dir'],
                           cbs=DAGsHubLogger(
                               metrics_path="logs/train_metrics.csv",
                               hparams_path="logs/train_params.yml"))

    print("Training model...")
    learner.fine_tune(epochs=int(params['epochs']),
                      base_lr=float(params['learning_rate']))
    print("Saving model...")
    learner.save('model')
    print("Done!")