lolcats / src /utils /logging.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame
3.99 kB
"""
Logging utilities to make terminal slightly more delightful
"""
import rich.syntax
import rich.tree
from omegaconf import OmegaConf, DictConfig, ListConfig
def _format_arg(arg_name: str, cutoff=2) -> str:
if arg_name is None:
return arg_name
arg_name = str(arg_name)
# Hardcode to handle backslash
name_splits = arg_name.split('/')
if len(name_splits) > 1:
return name_splits[-1]
# Abbreviate based on underscore
name_splits = arg_name.split('_')
if len(name_splits) > 1:
return ''.join([s[0] for s in name_splits])
else:
return arg_name[:cutoff]
def print_header(x: str) -> None:
"""
Print a header with a line above and below
"""
print('-' * len(x))
print(x)
print('-' * len(x))
def print_args(args, return_dict=False, verbose=True):
"""
Print the arguments passed to the script
"""
attributes = [a for a in dir(args) if a[0] != '_']
arg_dict = {} # switched to ewr
if verbose:
print('ARGPARSE ARGS')
for ix, attr in enumerate(attributes):
fancy = '└─' if ix == len(attributes) - 1 else 'β”œβ”€'
if verbose:
print(f'{fancy} {attr}: {getattr(args, attr)}')
arg_dict[attr] = getattr(args, attr)
if return_dict:
return arg_dict
def update_description_metrics(description: str, metrics: dict):
"""
Set the numbers that show up on progress bars
"""
for split in metrics:
if split != 'test': # No look
for metric_name, metric in metrics[split].items():
description += f' | {split}/{metric_name}: {metric:.3f}'
return description
# Control how tqdm progress bar looks
def type_of_script():
try:
ipy_str = str(type(get_ipython()))
if 'zmqshell' in ipy_str:
return 'jupyter'
if 'terminal' in ipy_str:
return 'ipython'
except:
return 'terminal'
# Progress bar
def update_pbar_display(metrics, batch_ix, pbar, prefix, batch_size, accum_iter=1):
description = f'└── {prefix} batch {int(batch_ix)}/{len(pbar)} [batch size: {batch_size} - grad. accum. over {accum_iter} batch(es)]'
for metric_name, metric in metrics.items():
if metric_name == 'correct':
description += f' | {metric_name} (acc. %): {int(metric):>5d}/{int(metrics["total"])} = {metric / metrics["total"] * 100:.3f}%'
elif metric_name == 'acc':
description += f' | {metric_name}: {metric:.3f}'
elif metric_name in ['perplexity']: # , 'bpc']:
description += f' | {metric_name}: {Decimal(metric):.3E}'
elif metric_name != 'total':
description += f' | {metric_name}: {metric / metrics["total"]:.3f}'
pbar.set_description(description)
def print_config(config: DictConfig,
resolve: bool = True,
name: str = 'CONFIG') -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
fields (Sequence[str], optional): Determines which main fields from config will
be printed and in what order.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
style = "bright" # "dim"
tree = rich.tree.Tree(name, style=style, guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
elif isinstance(config_section, ListConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)