Spaces:
Running
Running
File size: 8,744 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import bisect
import dataclasses
import dis
import sys
from typing import Any, Set, Union
TERMINAL_OPCODES = {
dis.opmap["RETURN_VALUE"],
dis.opmap["JUMP_FORWARD"],
dis.opmap["RAISE_VARARGS"],
# TODO(jansel): double check exception handling
}
if sys.version_info >= (3, 9):
TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
if sys.version_info >= (3, 11):
TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"])
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
else:
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
HASLOCAL = set(dis.haslocal)
HASFREE = set(dis.hasfree)
stack_effect = dis.stack_effect
def get_indexof(insts):
"""
Get a mapping from instruction memory address to index in instruction list.
Additionally checks that each instruction only appears once in the list.
"""
indexof = {}
for i, inst in enumerate(insts):
assert inst not in indexof
indexof[inst] = i
return indexof
def remove_dead_code(instructions):
"""Dead code elimination"""
indexof = get_indexof(instructions)
live_code = set()
def find_live_code(start):
for i in range(start, len(instructions)):
if i in live_code:
return
live_code.add(i)
inst = instructions[i]
if inst.exn_tab_entry:
find_live_code(indexof[inst.exn_tab_entry.target])
if inst.opcode in JUMP_OPCODES:
find_live_code(indexof[inst.target])
if inst.opcode in TERMINAL_OPCODES:
return
find_live_code(0)
# change exception table entries if start/end instructions are dead
# assumes that exception table entries have been propagated,
# e.g. with bytecode_transformation.propagate_inst_exn_table_entries,
# and that instructions with an exn_tab_entry lies within its start/end.
if sys.version_info >= (3, 11):
live_idx = sorted(live_code)
for i, inst in enumerate(instructions):
if i in live_code and inst.exn_tab_entry:
# find leftmost live instruction >= start
start_idx = bisect.bisect_left(
live_idx, indexof[inst.exn_tab_entry.start]
)
assert start_idx < len(live_idx)
# find rightmost live instruction <= end
end_idx = (
bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1
)
assert end_idx >= 0
assert live_idx[start_idx] <= i <= live_idx[end_idx]
inst.exn_tab_entry.start = instructions[live_idx[start_idx]]
inst.exn_tab_entry.end = instructions[live_idx[end_idx]]
return [inst for i, inst in enumerate(instructions) if i in live_code]
def remove_pointless_jumps(instructions):
"""Eliminate jumps to the next instruction"""
pointless_jumps = {
id(a)
for a, b in zip(instructions, instructions[1:])
if a.opname == "JUMP_ABSOLUTE" and a.target is b
}
return [inst for inst in instructions if id(inst) not in pointless_jumps]
def propagate_line_nums(instructions):
"""Ensure every instruction has line number set in case some are removed"""
cur_line_no = None
def populate_line_num(inst):
nonlocal cur_line_no
if inst.starts_line:
cur_line_no = inst.starts_line
inst.starts_line = cur_line_no
for inst in instructions:
populate_line_num(inst)
def remove_extra_line_nums(instructions):
"""Remove extra starts line properties before packing bytecode"""
cur_line_no = None
def remove_line_num(inst):
nonlocal cur_line_no
if inst.starts_line is None:
return
elif inst.starts_line == cur_line_no:
inst.starts_line = None
else:
cur_line_no = inst.starts_line
for inst in instructions:
remove_line_num(inst)
@dataclasses.dataclass
class ReadsWrites:
reads: Set[Any]
writes: Set[Any]
visited: Set[Any]
def livevars_analysis(instructions, instruction):
indexof = get_indexof(instructions)
must = ReadsWrites(set(), set(), set())
may = ReadsWrites(set(), set(), set())
def walk(state, start):
if start in state.visited:
return
state.visited.add(start)
for i in range(start, len(instructions)):
inst = instructions[i]
if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
if "LOAD" in inst.opname or "DELETE" in inst.opname:
if inst.argval not in must.writes:
state.reads.add(inst.argval)
elif "STORE" in inst.opname:
state.writes.add(inst.argval)
elif inst.opname == "MAKE_CELL":
pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")
if inst.exn_tab_entry:
walk(may, indexof[inst.exn_tab_entry.target])
if inst.opcode in JUMP_OPCODES:
walk(may, indexof[inst.target])
state = may
if inst.opcode in TERMINAL_OPCODES:
return
walk(must, indexof[instruction])
return must.reads | may.reads
@dataclasses.dataclass
class FixedPointBox:
value: bool = True
@dataclasses.dataclass
class StackSize:
low: Union[int, float]
high: Union[int, float]
fixed_point: FixedPointBox
def zero(self):
self.low = 0
self.high = 0
self.fixed_point.value = False
def offset_of(self, other, n):
prior = (self.low, self.high)
self.low = min(self.low, other.low + n)
self.high = max(self.high, other.high + n)
if (self.low, self.high) != prior:
self.fixed_point.value = False
def exn_tab_jump(self, depth):
prior = (self.low, self.high)
self.low = min(self.low, depth)
self.high = max(self.high, depth)
if (self.low, self.high) != prior:
self.fixed_point.value = False
def stacksize_analysis(instructions) -> Union[int, float]:
assert instructions
fixed_point = FixedPointBox()
stack_sizes = {
inst: StackSize(float("inf"), float("-inf"), fixed_point)
for inst in instructions
}
stack_sizes[instructions[0]].zero()
for _ in range(100):
if fixed_point.value:
break
fixed_point.value = True
for inst, next_inst in zip(instructions, instructions[1:] + [None]):
stack_size = stack_sizes[inst]
# CALL_FINALLY in Python 3.8 is handled differently when determining stack depth.
# See https://github.com/python/cpython/blob/3.8/Python/compile.c#L5450.
# Essentially, the stack effect of CALL_FINALLY is computed with jump=True,
# but the resulting stack depth is propagated to the next instruction, not the
# jump target.
is_call_finally = (
sys.version_info < (3, 9) and inst.opcode == dis.opmap["CALL_FINALLY"]
)
if inst.opcode not in TERMINAL_OPCODES:
assert next_inst is not None, f"missing next inst: {inst}"
stack_sizes[next_inst].offset_of(
stack_size,
stack_effect(inst.opcode, inst.arg, jump=is_call_finally),
)
if inst.opcode in JUMP_OPCODES and not is_call_finally:
stack_sizes[inst.target].offset_of(
stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
)
if inst.exn_tab_entry:
# see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
# on why depth is computed this way.
depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1
stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth)
if False:
for inst in instructions:
stack_size = stack_sizes[inst]
print(stack_size.low, stack_size.high, inst)
low = min([x.low for x in stack_sizes.values()])
high = max([x.high for x in stack_sizes.values()])
assert fixed_point.value, "failed to reach fixed point"
assert low >= 0
return high
|