|
import os
|
|
import sys
|
|
import time
|
|
import __main__
|
|
import traceback
|
|
import mlflow
|
|
|
|
import colbert.utils.distributed as distributed
|
|
|
|
from contextlib import contextmanager
|
|
from colbert.utils.logging import Logger
|
|
from colbert.utils.utils import timestamp, create_directory, print_message
|
|
|
|
|
|
class _RunManager():
|
|
def __init__(self):
|
|
self.experiments_root = None
|
|
self.experiment = None
|
|
self.path = None
|
|
self.script = self._get_script_name()
|
|
self.name = self._generate_default_run_name()
|
|
self.original_name = self.name
|
|
self.exit_status = 'FINISHED'
|
|
|
|
self._logger = None
|
|
self.start_time = time.time()
|
|
|
|
def init(self, rank, root, experiment, name):
|
|
assert '/' not in experiment, experiment
|
|
assert '/' not in name, name
|
|
|
|
self.experiments_root = os.path.abspath(root)
|
|
self.experiment = experiment
|
|
self.name = name
|
|
self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)
|
|
|
|
if rank < 1:
|
|
if os.path.exists(self.path):
|
|
print('\n\n')
|
|
print_message("It seems that ", self.path, " already exists.")
|
|
print_message("Do you want to overwrite it? \t yes/no \n")
|
|
|
|
|
|
|
|
response = input()
|
|
if response.strip() != 'yes':
|
|
assert not os.path.exists(self.path), self.path
|
|
else:
|
|
create_directory(self.path)
|
|
|
|
distributed.barrier(rank)
|
|
|
|
self._logger = Logger(rank, self)
|
|
self._log_args = self._logger._log_args
|
|
self.warn = self._logger.warn
|
|
self.info = self._logger.info
|
|
self.info_all = self._logger.info_all
|
|
self.log_metric = self._logger.log_metric
|
|
self.log_new_artifact = self._logger.log_new_artifact
|
|
|
|
def _generate_default_run_name(self):
|
|
return timestamp()
|
|
|
|
def _get_script_name(self):
|
|
return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'
|
|
|
|
@contextmanager
|
|
def context(self, consider_failed_if_interrupted=True):
|
|
try:
|
|
yield
|
|
|
|
except KeyboardInterrupt as ex:
|
|
print('\n\nInterrupted\n\n')
|
|
self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
|
|
self._logger._log_all_artifacts()
|
|
|
|
if consider_failed_if_interrupted:
|
|
self.exit_status = 'KILLED'
|
|
|
|
sys.exit(128 + 2)
|
|
|
|
except Exception as ex:
|
|
self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
|
|
self._logger._log_all_artifacts()
|
|
|
|
self.exit_status = 'FAILED'
|
|
|
|
raise ex
|
|
|
|
finally:
|
|
total_seconds = str(time.time() - self.start_time) + '\n'
|
|
original_name = str(self.original_name)
|
|
name = str(self.name)
|
|
|
|
self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
|
|
self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
|
|
self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)
|
|
|
|
self._logger._log_all_artifacts()
|
|
|
|
mlflow.end_run(status=self.exit_status)
|
|
|
|
|
|
Run = _RunManager()
|
|
|