File size: 4,349 Bytes
7575eee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

"""Helpers for managing the run/training loop."""

import datetime
import json
import os
import pprint
import time
import types

from typing import Any

from . import submit


class RunContext(object):
    """Helper class for managing the run/training loop.

    The context will hide the implementation details of a basic run/training loop.
    It will set things up properly, tell if run should be stopped, and then cleans up.
    User should call update periodically and use should_stop to determine if run should be stopped.

    Args:
        submit_config: The SubmitConfig that is used for the current run.
        config_module: The whole config module that is used for the current run.
        max_epoch: Optional cached value for the max_epoch variable used in update.
    """

    def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
        self.submit_config = submit_config
        self.should_stop_flag = False
        self.has_closed = False
        self.start_time = time.time()
        self.last_update_time = time.time()
        self.last_update_interval = 0.0
        self.max_epoch = max_epoch

        # pretty print the all the relevant content of the config module to a text file
        if config_module is not None:
            with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
                filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
                pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)

        # write out details about the run to a text file
        self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
        with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
            pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)

    def __enter__(self) -> "RunContext":
        return self

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        self.close()

    def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
        """Do general housekeeping and keep the state of the context up-to-date.
        Should be called often enough but not in a tight loop."""
        assert not self.has_closed

        self.last_update_interval = time.time() - self.last_update_time
        self.last_update_time = time.time()

        if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
            self.should_stop_flag = True

        max_epoch_val = self.max_epoch if max_epoch is None else max_epoch

    def should_stop(self) -> bool:
        """Tell whether a stopping condition has been triggered one way or another."""
        return self.should_stop_flag

    def get_time_since_start(self) -> float:
        """How much time has passed since the creation of the context."""
        return time.time() - self.start_time

    def get_time_since_last_update(self) -> float:
        """How much time has passed since the last call to update."""
        return time.time() - self.last_update_time

    def get_last_update_interval(self) -> float:
        """How much time passed between the previous two calls to update."""
        return self.last_update_interval

    def close(self) -> None:
        """Close the context and clean up.
        Should only be called once."""
        if not self.has_closed:
            # update the run.txt with stopping time
            self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
            with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
                pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)

            self.has_closed = True