Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py | |
import atexit | |
import functools | |
import logging | |
import sys | |
import uuid | |
from typing import Any, Dict, Optional, Union | |
from hydra.utils import instantiate | |
from iopath.common.file_io import g_pathmgr | |
from numpy import ndarray | |
from torch import Tensor | |
from torch.utils.tensorboard import SummaryWriter | |
from training.utils.train_utils import get_machine_local_and_dist_rank, makedir | |
Scalar = Union[Tensor, ndarray, int, float] | |
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): | |
makedir(log_dir) | |
summary_writer_method = SummaryWriter | |
return TensorBoardLogger( | |
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs | |
) | |
class TensorBoardWriterWrapper: | |
""" | |
A wrapper around a SummaryWriter object. | |
""" | |
def __init__( | |
self, | |
path: str, | |
*args: Any, | |
filename_suffix: str = None, | |
summary_writer_method: Any = SummaryWriter, | |
**kwargs: Any, | |
) -> None: | |
"""Create a new TensorBoard logger. | |
On construction, the logger creates a new events file that logs | |
will be written to. If the environment variable `RANK` is defined, | |
logger will only log if RANK = 0. | |
NOTE: If using the logger with distributed training: | |
- This logger can call collective operations | |
- Logs will be written on rank 0 only | |
- Logger must be constructed synchronously *after* initializing distributed process group. | |
Args: | |
path (str): path to write logs to | |
*args, **kwargs: Extra arguments to pass to SummaryWriter | |
""" | |
self._writer: Optional[SummaryWriter] = None | |
_, self._rank = get_machine_local_and_dist_rank() | |
self._path: str = path | |
if self._rank == 0: | |
logging.info( | |
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" | |
) | |
self._writer = summary_writer_method( | |
log_dir=path, | |
*args, | |
filename_suffix=filename_suffix or str(uuid.uuid4()), | |
**kwargs, | |
) | |
else: | |
logging.debug( | |
f"Not logging meters on this host because env RANK: {self._rank} != 0" | |
) | |
atexit.register(self.close) | |
def writer(self) -> Optional[SummaryWriter]: | |
return self._writer | |
def path(self) -> str: | |
return self._path | |
def flush(self) -> None: | |
"""Writes pending logs to disk.""" | |
if not self._writer: | |
return | |
self._writer.flush() | |
def close(self) -> None: | |
"""Close writer, flushing pending logs to disk. | |
Logs cannot be written after `close` is called. | |
""" | |
if not self._writer: | |
return | |
self._writer.close() | |
self._writer = None | |
class TensorBoardLogger(TensorBoardWriterWrapper): | |
""" | |
A simple logger for TensorBoard. | |
""" | |
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: | |
"""Add multiple scalar values to TensorBoard. | |
Args: | |
payload (dict): dictionary of tag name and scalar value | |
step (int, Optional): step value to record | |
""" | |
if not self._writer: | |
return | |
for k, v in payload.items(): | |
self.log(k, v, step) | |
def log(self, name: str, data: Scalar, step: int) -> None: | |
"""Add scalar data to TensorBoard. | |
Args: | |
name (string): tag name used to group scalars | |
data (float/int/Tensor): scalar data to log | |
step (int, optional): step value to record | |
""" | |
if not self._writer: | |
return | |
self._writer.add_scalar(name, data, global_step=step, new_style=True) | |
def log_hparams( | |
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] | |
) -> None: | |
"""Add hyperparameter data to TensorBoard. | |
Args: | |
hparams (dict): dictionary of hyperparameter names and corresponding values | |
meters (dict): dictionary of name of meter and corersponding values | |
""" | |
if not self._writer: | |
return | |
self._writer.add_hparams(hparams, meters) | |
class Logger: | |
""" | |
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. | |
""" | |
def __init__(self, logging_conf): | |
# allow turning off TensorBoard with "should_log: false" in config | |
tb_config = logging_conf.tensorboard_writer | |
tb_should_log = tb_config and tb_config.pop("should_log", True) | |
self.tb_logger = instantiate(tb_config) if tb_should_log else None | |
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: | |
if self.tb_logger: | |
self.tb_logger.log_dict(payload, step) | |
def log(self, name: str, data: Scalar, step: int) -> None: | |
if self.tb_logger: | |
self.tb_logger.log(name, data, step) | |
def log_hparams( | |
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] | |
) -> None: | |
if self.tb_logger: | |
self.tb_logger.log_hparams(hparams, meters) | |
# cache the opened file object, so that different calls to `setup_logger` | |
# with the same file name can safely write to the same file. | |
def _cached_log_stream(filename): | |
# we tune the buffering value so that the logs are updated | |
# frequently. | |
log_buffer_kb = 10 * 1024 # 10KB | |
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) | |
atexit.register(io.close) | |
return io | |
def setup_logging( | |
name, | |
output_dir=None, | |
rank=0, | |
log_level_primary="INFO", | |
log_level_secondary="ERROR", | |
): | |
""" | |
Setup various logging streams: stdout and file handlers. | |
For file handlers, we only setup for the master gpu. | |
""" | |
# get the filename if we want to log to the file as well | |
log_filename = None | |
if output_dir: | |
makedir(output_dir) | |
if rank == 0: | |
log_filename = f"{output_dir}/log.txt" | |
logger = logging.getLogger(name) | |
logger.setLevel(log_level_primary) | |
# create formatter | |
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" | |
formatter = logging.Formatter(FORMAT) | |
# Cleanup any existing handlers | |
for h in logger.handlers: | |
logger.removeHandler(h) | |
logger.root.handlers = [] | |
# setup the console handler | |
console_handler = logging.StreamHandler(sys.stdout) | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
if rank == 0: | |
console_handler.setLevel(log_level_primary) | |
else: | |
console_handler.setLevel(log_level_secondary) | |
# we log to file as well if user wants | |
if log_filename and rank == 0: | |
file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) | |
file_handler.setLevel(log_level_primary) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
logging.root = logger | |
def shutdown_logging(): | |
""" | |
After training is done, we ensure to shut down all the logger streams. | |
""" | |
logging.info("Shutting down loggers...") | |
handlers = logging.root.handlers | |
for handler in handlers: | |
handler.close() | |