|
import random |
|
import warnings |
|
from importlib.util import find_spec |
|
from typing import Callable |
|
|
|
import numpy as np |
|
import torch |
|
from omegaconf import DictConfig |
|
|
|
from .logger import RankedLogger |
|
from .rich_utils import enforce_tags, print_config_tree |
|
|
|
log = RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
|
def extras(cfg: DictConfig) -> None: |
|
"""Applies optional utilities before the task is started. |
|
|
|
Utilities: |
|
- Ignoring python warnings |
|
- Setting tags from command line |
|
- Rich config printing |
|
""" |
|
|
|
|
|
if not cfg.get("extras"): |
|
log.warning("Extras config not found! <cfg.extras=null>") |
|
return |
|
|
|
|
|
if cfg.extras.get("ignore_warnings"): |
|
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
if cfg.extras.get("enforce_tags"): |
|
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") |
|
enforce_tags(cfg, save_to_file=True) |
|
|
|
|
|
if cfg.extras.get("print_config"): |
|
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") |
|
print_config_tree(cfg, resolve=True, save_to_file=True) |
|
|
|
|
|
def task_wrapper(task_func: Callable) -> Callable: |
|
"""Optional decorator that controls the failure behavior when executing the task function. |
|
|
|
This wrapper can be used to: |
|
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure) |
|
- save the exception to a `.log` file |
|
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) |
|
- etc. (adjust depending on your needs) |
|
|
|
Example: |
|
``` |
|
@utils.task_wrapper |
|
def train(cfg: DictConfig) -> Tuple[dict, dict]: |
|
|
|
... |
|
|
|
return metric_dict, object_dict |
|
``` |
|
""" |
|
|
|
def wrap(cfg: DictConfig): |
|
|
|
try: |
|
metric_dict, object_dict = task_func(cfg=cfg) |
|
|
|
|
|
except Exception as ex: |
|
|
|
log.exception("") |
|
|
|
|
|
|
|
|
|
|
|
raise ex |
|
|
|
|
|
finally: |
|
|
|
log.info(f"Output dir: {cfg.paths.run_dir}") |
|
|
|
|
|
if find_spec("wandb"): |
|
import wandb |
|
|
|
if wandb.run: |
|
log.info("Closing wandb!") |
|
wandb.finish() |
|
|
|
return metric_dict, object_dict |
|
|
|
return wrap |
|
|
|
|
|
def get_metric_value(metric_dict: dict, metric_name: str) -> float: |
|
"""Safely retrieves value of the metric logged in LightningModule.""" |
|
|
|
if not metric_name: |
|
log.info("Metric name is None! Skipping metric value retrieval...") |
|
return None |
|
|
|
if metric_name not in metric_dict: |
|
raise Exception( |
|
f"Metric value not found! <metric_name={metric_name}>\n" |
|
"Make sure metric name logged in LightningModule is correct!\n" |
|
"Make sure `optimized_metric` name in `hparams_search` config is correct!" |
|
) |
|
|
|
metric_value = metric_dict[metric_name].item() |
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") |
|
|
|
return metric_value |
|
|
|
|
|
def set_seed(seed: int): |
|
if seed < 0: |
|
seed = -seed |
|
if seed > (1 << 31): |
|
seed = 1 << 31 |
|
|
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
if torch.backends.cudnn.is_available(): |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|