import copy import dataclasses import sys import types from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( create_call_function, create_call_method, create_dup_top, create_instruction, create_jump_absolute, Instruction, InstructionExnTabEntry, transform_code_object, unique_id, ) from .utils import ExactWeakKeyDictionary # taken from code.h in cpython CO_OPTIMIZED = 0x0001 CO_NEWLOCALS = 0x0002 CO_VARARGS = 0x0004 CO_VARKEYWORDS = 0x0008 CO_NESTED = 0x0010 CO_GENERATOR = 0x0020 CO_NOFREE = 0x0040 CO_COROUTINE = 0x0080 CO_ITERABLE_COROUTINE = 0x0100 CO_ASYNC_GENERATOR = 0x0200 @dataclasses.dataclass(frozen=True) class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): """ Codegen based off of: load args enter context try: (rest) finally: exit context """ load_args = [] if self.target_values: load_args = [ create_instruction("LOAD_CONST", argval=val) for val in self.target_values ] ctx_name = unique_id(f"___context_manager_{self.stack_index}") if ctx_name not in code_options["co_varnames"]: code_options["co_varnames"] += (ctx_name,) for name in ["__enter__", "__exit__"]: if name not in code_options["co_names"]: code_options["co_names"] += (name,) except_jump_target = create_instruction( "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" ) cleanup_complete_jump_target = create_instruction("NOP") setup_finally = [ *load_args, *create_call_function(len(load_args), True), create_instruction("STORE_FAST", argval=ctx_name), create_instruction("LOAD_FAST", argval=ctx_name), create_instruction("LOAD_METHOD", argval="__enter__"), *create_call_method(0), create_instruction("POP_TOP"), ] if sys.version_info < (3, 11): setup_finally.append( create_instruction("SETUP_FINALLY", target=except_jump_target) ) else: exn_tab_begin = create_instruction("NOP") exn_tab_end = create_instruction("NOP") exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( exn_tab_begin, exn_tab_end, except_jump_target, self.stack_index + 1, False, ) setup_finally.append(exn_tab_begin) def create_reset(): return [ create_instruction("LOAD_FAST", argval=ctx_name), create_instruction("LOAD_METHOD", argval="__exit__"), create_instruction("LOAD_CONST", argval=None), create_dup_top(), create_dup_top(), *create_call_method(3), create_instruction("POP_TOP"), ] if sys.version_info < (3, 9): epilogue = [ create_instruction("POP_BLOCK"), create_instruction("BEGIN_FINALLY"), except_jump_target, *create_reset(), create_instruction("END_FINALLY"), ] elif sys.version_info < (3, 11): epilogue = [ create_instruction("POP_BLOCK"), *create_reset(), create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), except_jump_target, *create_reset(), create_instruction("RERAISE"), cleanup_complete_jump_target, ] else: finally_exn_tab_end = create_instruction("RERAISE", arg=0) finally_exn_tab_target = create_instruction("COPY", arg=3) except_jump_target.exn_tab_entry = InstructionExnTabEntry( except_jump_target, finally_exn_tab_end, finally_exn_tab_target, self.stack_index + 2, True, ) epilogue = [ exn_tab_end, *create_reset(), create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), except_jump_target, # PUSH_EXC_INFO *create_reset(), finally_exn_tab_end, # RERAISE 0 finally_exn_tab_target, # COPY 3 create_instruction("POP_EXCEPT"), create_instruction("RERAISE", arg=1), cleanup_complete_jump_target, ] cleanup[:] = epilogue + cleanup return setup_finally def __call__(self, code_options, cleanup): """ Codegen based off of: with ctx(args): (rest) """ load_args = [] if self.target_values: load_args = [ create_instruction("LOAD_CONST", argval=val) for val in self.target_values ] if sys.version_info < (3, 9): with_cleanup_start = create_instruction("WITH_CLEANUP_START") begin_finally = create_instruction("BEGIN_FINALLY") cleanup[:] = [ create_instruction("POP_BLOCK"), begin_finally, with_cleanup_start, create_instruction("WITH_CLEANUP_FINISH"), create_instruction("END_FINALLY"), ] + cleanup return [ *load_args, create_instruction("CALL_FUNCTION", arg=len(load_args)), create_instruction("SETUP_WITH", target=with_cleanup_start), create_instruction("POP_TOP"), ], None elif sys.version_info < (3, 11): with_except_start = create_instruction("WITH_EXCEPT_START") pop_top_after_with_except_start = create_instruction("POP_TOP") cleanup_complete_jump_target = create_instruction("NOP") cleanup[:] = [ create_instruction("POP_BLOCK"), create_instruction("LOAD_CONST", argval=None), create_instruction("DUP_TOP"), create_instruction("DUP_TOP"), create_instruction("CALL_FUNCTION", arg=3), create_instruction("POP_TOP"), create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), with_except_start, create_instruction( "POP_JUMP_IF_TRUE", target=pop_top_after_with_except_start ), create_instruction("RERAISE"), pop_top_after_with_except_start, create_instruction("POP_TOP"), create_instruction("POP_TOP"), create_instruction("POP_EXCEPT"), create_instruction("POP_TOP"), cleanup_complete_jump_target, ] + cleanup return [ *load_args, create_instruction("CALL_FUNCTION", arg=len(load_args)), create_instruction("SETUP_WITH", target=with_except_start), create_instruction("POP_TOP"), ], None else: pop_top_after_with_except_start = create_instruction("POP_TOP") cleanup_complete_jump_target = create_instruction("NOP") def create_load_none(): return create_instruction("LOAD_CONST", argval=None) exn_tab_1_begin = create_instruction("POP_TOP") exn_tab_1_end = create_instruction("NOP") exn_tab_1_target = create_instruction("PUSH_EXC_INFO") exn_tab_2_end = create_instruction("RERAISE", arg=2) exn_tab_2_target = create_instruction("COPY", arg=3) exn_tab_1_begin.exn_tab_entry = InstructionExnTabEntry( exn_tab_1_begin, exn_tab_1_end, exn_tab_1_target, self.stack_index + 1, True, ) exn_tab_1_target.exn_tab_entry = InstructionExnTabEntry( exn_tab_1_target, exn_tab_2_end, exn_tab_2_target, self.stack_index + 3, True, ) pop_top_after_with_except_start.exn_tab_entry = InstructionExnTabEntry( pop_top_after_with_except_start, pop_top_after_with_except_start, exn_tab_2_target, self.stack_index + 3, True, ) cleanup[:] = [ exn_tab_1_end, create_load_none(), create_load_none(), create_load_none(), *create_call_function(2, False), create_instruction("POP_TOP"), create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), exn_tab_1_target, # PUSH_EXC_INFO create_instruction("WITH_EXCEPT_START"), create_instruction( "POP_JUMP_FORWARD_IF_TRUE", target=pop_top_after_with_except_start, ), exn_tab_2_end, # RERAISE 2 exn_tab_2_target, # COPY 3 create_instruction("POP_EXCEPT"), create_instruction("RERAISE", arg=1), pop_top_after_with_except_start, create_instruction("POP_EXCEPT"), create_instruction("POP_TOP"), create_instruction("POP_TOP"), cleanup_complete_jump_target, ] + cleanup return [ *load_args, *create_call_function(len(load_args), True), create_instruction("BEFORE_WITH"), exn_tab_1_begin, # POP_TOP ], exn_tab_1_target @dataclasses.dataclass class ResumeFunctionMetadata: code: types.CodeType instructions: List[Instruction] = dataclasses.field(default_factory=list) # Python 3.11+ fields # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists # of instructions of all exception table entries that have the same target. # map from PUSH_EXC_INFO's in the prefix to original block target offset prefix_block_target_offset_remap: List[int] = dataclasses.field( default_factory=list ) # map from new block target offsets to original block target offsets block_target_offset_remap: Optional[Dict[int, int]] = None def _filter_iter(l1, l2, cond): """ Two-pointer conditional filter. e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) returns the instructions with offsets in sorted_offsets """ it = iter(l2) res = [] try: cur = next(it) for val in l1: if cond(val, cur): res.append(val) cur = next(it) except StopIteration: pass return res class ContinueExecutionCache: cache = ExactWeakKeyDictionary() generated_code_metadata = ExactWeakKeyDictionary() @classmethod def lookup(cls, code, lineno, *key): if code not in cls.cache: cls.cache[code] = dict() key = tuple(key) if key not in cls.cache[code]: cls.cache[code][key] = cls.generate(code, lineno, *key) return cls.cache[code][key] @classmethod def generate( cls, code, lineno, offset: int, setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ nstack: int, argnames: Tuple[str], setup_fns: Tuple[ReenterWith], null_idxes: Tuple[int], ) -> types.CodeType: assert offset is not None assert not ( code.co_flags & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR) ) assert code.co_flags & CO_OPTIMIZED if code in ContinueExecutionCache.generated_code_metadata: return cls.generate_based_on_original_code_object( code, lineno, offset, setup_fn_target_offsets, nstack, argnames, setup_fns, null_idxes, ) is_py311_plus = sys.version_info >= (3, 11) meta = ResumeFunctionMetadata(code) def update(instructions: List[Instruction], code_options: Dict[str, Any]): meta.instructions = copy.deepcopy(instructions) args = [f"___stack{i}" for i in range(nstack)] args.extend(v for v in argnames if v not in args) freevars = tuple(code_options["co_cellvars"] or []) + tuple( code_options["co_freevars"] or [] ) code_options[ "co_name" ] = f"torch_dynamo_resume_in_{code_options['co_name']}_at_{lineno}" if is_py311_plus: qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1) if len(qualified_path) == 1: code_options["co_qualname"] = code_options["co_name"] else: assert len(qualified_path) == 2 module_name, co_name = qualified_path code_options[ "co_qualname" ] = f"{module_name}.torch_dynamo_resume_in_{co_name}_at_{lineno}" code_options["co_firstlineno"] = lineno code_options["co_cellvars"] = tuple() code_options["co_freevars"] = freevars code_options["co_argcount"] = len(args) code_options["co_posonlyargcount"] = 0 code_options["co_kwonlyargcount"] = 0 code_options["co_varnames"] = tuple( args + [v for v in code_options["co_varnames"] if v not in args] ) code_options["co_flags"] = code_options["co_flags"] & ~( CO_VARARGS | CO_VARKEYWORDS ) target = next(i for i in instructions if i.offset == offset) prefix = [] if is_py311_plus: if freevars: prefix.append( create_instruction("COPY_FREE_VARS", arg=len(freevars)) ) prefix.append(create_instruction("RESUME", arg=0)) cleanup: List[Instruction] = [] hooks = {fn.stack_index: fn for fn in setup_fns} hook_target_offsets = { fn.stack_index: setup_fn_target_offsets[i] for i, fn in enumerate(setup_fns) } offset_to_inst = {inst.offset: inst for inst in instructions} # map old hook targets to new targets generated by the hook old_hook_target_remap = {} null_idxes_i = 0 for i in range(nstack): while ( null_idxes_i < len(null_idxes) and null_idxes[null_idxes_i] == i + null_idxes_i ): prefix.append(create_instruction("PUSH_NULL")) null_idxes_i += 1 prefix.append(create_instruction("LOAD_FAST", argval=f"___stack{i}")) if i in hooks: hook = hooks.pop(i) hook_insts, exn_target = hook(code_options, cleanup) prefix.extend(hook_insts) if is_py311_plus: hook_target_offset = hook_target_offsets.pop(i) old_hook_target = offset_to_inst[hook_target_offset] meta.prefix_block_target_offset_remap.append(hook_target_offset) old_hook_target_remap[old_hook_target] = exn_target if is_py311_plus: # reverse the mapping since targets of later/nested contexts are inserted # into the mapping later, but show up earlier in the prefix. meta.prefix_block_target_offset_remap = list( reversed(meta.prefix_block_target_offset_remap) ) assert not hooks prefix.append(create_jump_absolute(target)) # because the line number table monotonically increases from co_firstlineno # remove starts_line for any instructions before the graph break instruction # this will ensure the instructions after the break have the correct line numbers for inst in instructions: if inst.offset == target.offset: break inst.starts_line = None if sys.version_info >= (3, 11): inst.positions = None if cleanup: prefix.extend(cleanup) prefix.extend(cls.unreachable_codes(code_options)) # remap original instructions' exception table entries if old_hook_target_remap: assert is_py311_plus for inst in instructions: if ( inst.exn_tab_entry and inst.exn_tab_entry.target in old_hook_target_remap ): inst.exn_tab_entry.target = old_hook_target_remap[ inst.exn_tab_entry.target ] # TODO(jansel): add dead code elimination here instructions[:] = prefix + instructions new_code = transform_code_object(code, update) ContinueExecutionCache.generated_code_metadata[new_code] = meta return new_code @staticmethod def unreachable_codes(code_options) -> List[Instruction]: """Codegen a `raise None` to make analysis work for unreachable code""" return [ create_instruction("LOAD_CONST", argval=None), create_instruction("RAISE_VARARGS", arg=1), ] @classmethod def generate_based_on_original_code_object( cls, code, lineno, offset: int, setup_fn_target_offsets: Tuple[int, ...], *args ): """ This handles the case of generating a resume into code generated to resume something else. We want to always generate starting from the original code object so that if control flow paths converge we only generated 1 resume function (rather than 2^n resume functions). """ meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[ code ] new_offset = None def find_new_offset( instructions: List[Instruction], code_options: Dict[str, Any] ): nonlocal new_offset (target,) = (i for i in instructions if i.offset == offset) # match the functions starting at the last instruction as we have added a prefix (new_target,) = ( i2 for i1, i2 in zip(reversed(instructions), reversed(meta.instructions)) if i1 is target ) assert target.opcode == new_target.opcode new_offset = new_target.offset transform_code_object(code, find_new_offset) if sys.version_info >= (3, 11): # setup_fn_target_offsets currently contains the target offset of # each setup_fn, based on `code`. When we codegen the resume function # based on the original code object, `meta.code`, the offsets in # setup_fn_target_offsets must be based on `meta.code` instead. if not meta.block_target_offset_remap: block_target_offset_remap = meta.block_target_offset_remap = {} def remap_block_offsets( instructions: List[Instruction], code_options: Dict[str, Any] ): # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, # so we can tell which block a prefix PUSH_EXC_INFO belongs to, # by counting. Then we can use meta.prefix_block-target_offset_remap # to determine where in the original code the PUSH_EXC_INFO offset # replaced. prefix_blocks: List[Instruction] = [] for inst in instructions: if len(prefix_blocks) == len( meta.prefix_block_target_offset_remap ): break if inst.opname == "PUSH_EXC_INFO": prefix_blocks.append(inst) # offsets into prefix for inst, o in zip( prefix_blocks, meta.prefix_block_target_offset_remap ): block_target_offset_remap[cast(int, inst.offset)] = o # old bytecode targets are after the prefix PUSH_EXC_INFO's old_start_offset = ( cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1 ) # offsets into old bytecode old_inst_offsets = sorted( n for n in setup_fn_target_offsets if n > old_start_offset ) targets = _filter_iter( instructions, old_inst_offsets, lambda inst, o: inst.offset == o ) new_targets = _filter_iter( zip(reversed(instructions), reversed(meta.instructions)), targets, lambda v1, v2: v1[0] is v2, ) for new, old in zip(new_targets, targets): block_target_offset_remap[old.offset] = new[1].offset transform_code_object(code, remap_block_offsets) # if offset is not in setup_fn_target_offsets, it is an error setup_fn_target_offsets = tuple( meta.block_target_offset_remap[n] for n in setup_fn_target_offsets ) return ContinueExecutionCache.lookup( meta.code, lineno, new_offset, setup_fn_target_offsets, *args ) """ # partially finished support for with statements def convert_locals_to_cells( instructions: List[Instruction], code_options: Dict[str, Any]): code_options["co_cellvars"] = tuple( var for var in code_options["co_varnames"] if var not in code_options["co_freevars"] and not var.startswith("___stack") ) cell_and_free = code_options["co_cellvars"] + code_options["co_freevars"] for inst in instructions: if str(inst.argval).startswith("___stack"): continue elif inst.opname == "LOAD_FAST": inst.opname = "LOAD_DEREF" elif inst.opname == "STORE_FAST": inst.opname = "STORE_DEREF" elif inst.opname == "DELETE_FAST": inst.opname = "DELETE_DEREF" else: continue inst.opcode = dis.opmap[inst.opname] assert inst.argval in cell_and_free, inst.argval inst.arg = cell_and_free.index(inst.argval) def patch_setup_with( instructions: List[Instruction], code_options: Dict[str, Any] ): nonlocal need_skip need_skip = True target_index = next( idx for idx, i in enumerate(instructions) if i.offset == offset ) assert instructions[target_index].opname == "SETUP_WITH" convert_locals_to_cells(instructions, code_options) stack_depth_before = nstack + stack_effect(instructions[target_index].opcode, instructions[target_index].arg) inside_with = [] inside_with_resume_at = None stack_depth = stack_depth_before idx = target_index + 1 for idx in range(idx, len(instructions)): inst = instructions[idx] if inst.opname == "BEGIN_FINALLY": inside_with_resume_at = inst break elif inst.target is not None: unimplemented("jump from with not supported") elif inst.opname in ("BEGIN_FINALLY", "WITH_CLEANUP_START", "WITH_CLEANUP_FINISH", "END_FINALLY", "POP_FINALLY", "POP_EXCEPT", "POP_BLOCK", "END_ASYNC_FOR"): unimplemented("block ops not supported") inside_with.append(inst) stack_depth += stack_effect(inst.opcode, inst.arg) assert inside_with_resume_at instructions = [ create_instruction("LOAD_FAST", f"___stack{i}") for i in range(nstack) ] + [ create_instruction("SETUP_WITH", target=instructions[target_index].target) ... call the function ... unpack_tuple ] + [ create_instruction("JUMP_ABSOLUTE", target=inside_with_resume_at) ] """