Spaces:
Running
Running
File size: 1,625 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import itertools
import logging
from torch.hub import _Faketqdm, tqdm
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True
# Return all loggers that torchdynamo/torchinductor is responsible for
def get_loggers():
return [
logging.getLogger("torch.fx.experimental.symbolic_shapes"),
logging.getLogger("torch._dynamo"),
logging.getLogger("torch._inductor"),
]
# Creates a logging function that logs a message with a step # prepended.
# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
# so that step numbers are initialized properly. e.g.:
# @functools.lru_cache(None)
# def _step_logger():
# return get_step_logger(logging.getLogger(...))
# def fn():
# _step_logger()(logging.INFO, "msg")
_step_counter = itertools.count(1)
# Update num_steps if more phases are added: Dynamo, AOT, Backend
# This is very inductor centric
# _inductor.utils.has_triton() gives a circular import error here
if not disable_progress:
try:
import triton # noqa: F401
num_steps = 3
except ImportError:
num_steps = 2
pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
def get_step_logger(logger):
if not disable_progress:
pbar.update(1)
if not isinstance(pbar, _Faketqdm):
pbar.set_postfix_str(f"{logger.name}")
step = next(_step_counter)
def log(level, msg, **kwargs):
logger.log(level, "Step %s: %s", step, msg, **kwargs)
return log
|