Spaces:
Sleeping
Sleeping
from types import TracebackType | |
from typing import List, Optional | |
import tempfile | |
import traceback | |
import contextlib | |
import inspect | |
import os.path | |
# This file contains utilities for ensuring dynamically compile()'d | |
# code fragments display their line numbers in backtraces. | |
# | |
# The constraints: | |
# | |
# - We don't have control over the user exception printer (in particular, | |
# we cannot assume the linecache trick will work, c.f. | |
# https://stackoverflow.com/q/50515651/23845 ) | |
# | |
# - We don't want to create temporary files every time we compile() | |
# some code; file creation should happen lazily only at exception | |
# time. Arguably, you *should* be willing to write out your | |
# generated Python code to file system, but in some situations | |
# (esp. library code) it would violate user expectation to write | |
# to the file system, so we try to avoid it. In particular, we'd | |
# like to keep the files around, so users can open up the files | |
# mentioned in the trace; if the file is invisible, we want to | |
# avoid clogging up the filesystem. | |
# | |
# If this is not a constraint for you, there is a substantially simpler | |
# way to implement the functionality in this PR: instead of using | |
# eval/exec directly, just always write a Python file to filesystem | |
# and compile that. | |
# | |
# - You have control over a context where the compiled code will get | |
# executed, so that we can interpose while the stack is unwinding | |
# (otherwise, we have no way to interpose on the exception printing | |
# process.) | |
# | |
# There are two things you have to do to make use of the utilities here: | |
# | |
# - When you compile your source code, you must save its string source | |
# in its f_globals under the magic name "__compile_source__" | |
# | |
# - Before running the compiled code, enter the | |
# report_compile_source_on_error() context manager. | |
def report_compile_source_on_error(): | |
try: | |
yield | |
except Exception as exc: | |
tb = exc.__traceback__ | |
# Walk the traceback, looking for frames that have | |
# source attached | |
stack = [] | |
while tb is not None: | |
filename = tb.tb_frame.f_code.co_filename | |
source = tb.tb_frame.f_globals.get("__compile_source__") | |
if filename == "<string>" and source is not None: | |
# What black magic are we doing here? Intuitively, what | |
# we would like to do is overwrite the co_filename on any | |
# frames that were generated from exec/eval so that they | |
# point to a temporary file that has the actual line | |
# information, so Python's default error printer can print | |
# useful line information on it. | |
# | |
# Writing out the temporary file is easy. But overwriting | |
# co_filename is not! You can't modify the code object | |
# associated with a frame. You can, however, reconstruct | |
# a traceback with entirely new frames from scratch, so that's | |
# what we do. But there's another problem, which is how to | |
# make the frame? | |
# | |
# The black magic is we make a frankenstein frame and code | |
# object which resembles the original frame/code enough so | |
# that it will print properly under traceback and the default | |
# error printer, but IT IS NOT THE ORIGINAL FRAME (you | |
# couldn't, e.g., execute its code with different variables | |
# and expect it to work.) | |
# Don't delete the temporary file so the user can inspect it | |
# TODO: This creates a temporary file for every frame, but we | |
# technically only need one per distinct __compile_source__ | |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f: | |
f.write(source) | |
# Create a frame. Python doesn't let you construct | |
# FrameType directly, so just make one with compile | |
frame = tb.tb_frame | |
code = compile('__inspect_currentframe()', f.name, 'eval') | |
code = code.replace(co_name=frame.f_code.co_name) | |
# Python 3.11 only | |
if hasattr(frame.f_code, 'co_linetable'): | |
# We can't copy ALL of the metadata over, because you | |
# can cause Python to segfault this way. What exactly | |
# do we need? We need enough information for | |
# traceback to be able to print the exception | |
# correctly. Code reading Lib/traceback.py reveals | |
# that traceback calls code.co_positions() in order to | |
# get the augmented line/col numbers. Objects/codeobject.c, | |
# specifically _PyCode_InitAddressRange, reveals that | |
# this iterator is initialized from co_linetable and | |
# co_firstfileno. So copy these we must! | |
code = code.replace( # type: ignore[call-arg] | |
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined] | |
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined] | |
) | |
fake_frame = eval( | |
code, | |
frame.f_globals, | |
{ | |
**frame.f_locals, | |
'__inspect_currentframe': inspect.currentframe | |
} | |
) | |
fake_tb = TracebackType( | |
None, fake_frame, tb.tb_lasti, tb.tb_lineno | |
) | |
stack.append(fake_tb) | |
else: | |
stack.append(tb) | |
tb = tb.tb_next | |
# Reconstruct the linked list | |
tb_next = None | |
for tb in reversed(stack): | |
tb.tb_next = tb_next | |
tb_next = tb | |
raise exc.with_traceback(tb_next) # noqa: TRY200 | |
def shorten_filename(fn, *, base=None): | |
"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" | |
if base is None: | |
base = os.path.dirname(os.path.dirname(__file__)) | |
# Truncate torch/foo.py to foo.py | |
try: | |
prefix = os.path.commonpath([fn, base]) | |
except ValueError: | |
return fn | |
else: | |
return fn[len(prefix) + 1:] | |
def format_frame(frame, *, base=None, line=False): | |
""" | |
Format a FrameSummary in a short way, without printing full absolute path or code. | |
The idea is the result fits on a single line. | |
""" | |
extra_line = "" | |
if line: | |
extra_line = f"{frame.line} # " | |
return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}" | |
def format_traceback_short(tb): | |
"""Format a TracebackType in a short way, printing only the inner-most frame.""" | |
return format_frame(traceback.extract_tb(tb)[-1]) | |
class CapturedTraceback: | |
__slots__ = ['tb', 'skip'] | |
def __init__(self, tb, skip=0): | |
self.tb = tb | |
self.skip = skip | |
def cleanup(self): | |
self.tb = None | |
def summary(self): | |
import torch._C._profiler | |
if self.tb is None: | |
# TODO: Maybe indicate that the traceback was elided? | |
return traceback.StackSummary() | |
return _extract_symbolized_tb( | |
torch._C._profiler.symbolize_tracebacks([self.tb])[0], | |
self.skip | |
) | |
def __getstate__(self): | |
return (None, { | |
'tb': None, # TB is not pickleable | |
'skip': self.skip, | |
}) | |
def extract(*, script=False, cpp=False, skip=0): | |
""" | |
Like traceback.extract_stack(), but faster (approximately 20x faster); it | |
is fast enough that you can unconditionally log stacks this way as part of | |
normal execution. It returns a torch._C._profiler.CapturedTraceback | |
object that must be formatted specially with format_captured_tb. | |
By default, this only reports Python backtraces (like extract_stack). You | |
can set the script/cpp kwargs to also turn on TorchScript/C++ trace | |
reporting. | |
""" | |
import torch._C._profiler | |
if script or cpp: | |
assert skip == 0, "skip with script/cpp NYI" | |
return CapturedTraceback( | |
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), | |
# Elide extract() frame if we don't have script/cpp frames. If | |
# we do have those frames, it doesn't work so force zero. | |
0 if script or cpp else skip + 1 | |
) | |
def format(self): | |
""" | |
Formats a single torch._C._profiler.CapturedTraceback into a list of | |
strings equivalent to the output of traceback.format_list. Note that if | |
pass it CapturedTraceback with C++ traces, it is better not to use this | |
function and use the batch formatting API format_captured_tbs to amortize | |
the cost of symbolization | |
""" | |
return traceback.format_list(self.summary()) | |
def format_all(tbs): | |
""" | |
Bulk version of CapturedTraceback.format. Returns a list of list of strings. | |
""" | |
import torch._C._profiler | |
# Directly populate tracebacks that already have cached summaries | |
rs: List[Optional[List[str]]] = [] | |
delayed_idxs = [] | |
for i, tb in enumerate(tbs): | |
if tb.tb is None: | |
rs.append([]) | |
else: | |
rs.append(None) | |
delayed_idxs.append(i) | |
stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) | |
for i, stb in zip(delayed_idxs, stbs): | |
rs[i] = traceback.format_list(tbs[i].summary()) | |
return rs | |
def _extract_symbolized_tb(tb, skip): | |
""" | |
Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of | |
pre-processed stack trace entries. | |
""" | |
stack = traceback.StackSummary() | |
for f in reversed(tb[skip:]): | |
stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name'])) | |
return stack | |