Spaces:
Running
Running
File size: 15,091 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
# pyre-strict
from typing import List
import torch
from . import config, ir, scheduler
from .dependencies import WeakDep
from .utils import tuple_sorted
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
def sink_waits(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of
communication overlap.
"""
new_order = []
cur_waits = set()
for snode in snodes:
if isinstance(snode.node, ir.Wait):
cur_waits.add(snode)
else:
for wait in tuple_sorted(cur_waits):
if snode in wait.node_users:
new_order.append(wait)
cur_waits.remove(wait)
new_order.append(snode)
new_order.extend(tuple_sorted(cur_waits))
return new_order
def raise_comms(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Greedily moves comms as early as possible (i.e. until we reach an input).
Optimal in terms of communication overlap.
TODO: We might want to adjust this in the future to account for memory limitations.
e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible,
which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
"""
new_order_reversed: List["scheduler.BaseSchedulerNode"] = []
cur_comms: List["scheduler.BaseSchedulerNode"] = []
for snode in reversed(snodes):
if isinstance(snode.node, ir.CollectiveKernel):
cur_comms.append(snode)
else:
for comm in cur_comms:
assert len(comm.inverse_users) > 0
while len(cur_comms) > 0 and any(
snode in comm.inverse_users for comm in cur_comms
):
comm = cur_comms.pop(0)
new_order_reversed.append(comm)
new_order_reversed.append(snode)
assert len(cur_comms) <= 1
new_order_reversed.extend(tuple_sorted(cur_comms))
return new_order_reversed[::-1]
def get_ancestors(node):
ancestors = set()
cur_nodes = [node]
while len(cur_nodes) > 0:
new_nodes = []
for node in cur_nodes:
for inp in node.inverse_users:
if inp not in ancestors:
ancestors.add(inp)
new_nodes.append(inp)
cur_nodes = new_nodes
return ancestors
def get_descendants(node):
descendants = set()
cur_nodes = [node]
while len(cur_nodes) > 0:
new_nodes = []
for node in cur_nodes:
for inp in node.node_users:
if inp not in descendants:
descendants.add(inp)
new_nodes.append(inp)
cur_nodes = new_nodes
return descendants
def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]):
"""
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
(might not be the same ordering as the eager mode program).
TODO: Come up with a better approach
"""
comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)]
for i in range(1, len(comm_nodes)):
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name()))
def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None:
assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes)
def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float:
"""
Returns estimated op runtime in nanoseconds (ns)
"""
if config.estimate_op_runtime == "default":
runtime = snode.get_estimated_runtime()
else:
assert callable(config.estimate_op_runtime)
runtime = config.estimate_op_runtime(snode)
return runtime
def reorder_compute_for_overlap(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Decides a global ordering of all compute and communication nodes,
assuming that we already have a global ordering of communication nodes.
Overall scheduling procedure is:
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
We prioritize compute nodes that are needed sooner.
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
Step 4: We schedule comm N + 1.
Repeat this for subsequent comm nodes.
"""
final_order = []
comm_nodes = []
for snode in snodes:
if isinstance(snode.node, ir.CollectiveKernel):
comm_nodes.append(snode)
if len(comm_nodes) == 0:
# if there is no comm nodes, return the current order
return snodes
comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
comm_descendants = {node: get_descendants(node) for node in comm_nodes}
indeg = dict.fromkeys(snodes, 0)
for snode in snodes:
for user in snode.node_users:
if user in indeg:
indeg[user] += 1
ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0}
unscheduled_nodes = set()
unscheduled_nodes = set(snodes)
def schedule_node(snode):
"""
Schedule a single node.
"""
assert snode in unscheduled_nodes
assert snode in ready_to_schedule_nodes
ready_to_schedule_nodes.remove(snode)
unscheduled_nodes.remove(snode)
final_order.append(snode)
for user in tuple_sorted(snode.node_users):
if user in indeg:
indeg[user] -= 1
if indeg[user] == 0:
ready_to_schedule_nodes.add(user)
def schedule_nodes(snodes):
"""
Schedules all nodes in `snodes` in an arbitrary topologically valid order.
"""
all_nodes = set(snodes)
assert all(node in unscheduled_nodes for node in all_nodes)
while len(all_nodes) > 0:
# NOTE: since model graph is always a DAG and does not have circular dependency inside,
# there should be at least one node that is a "free node" (i.e. indeg == 0),
# hence infinite loop is not possible. But we check here just to be safe.
progress = False
for node in tuple_sorted(all_nodes):
if node in ready_to_schedule_nodes:
schedule_node(node)
all_nodes.remove(node)
progress = True
if not progress:
raise Exception(
"Unable to find a free node (indeg == 0). This is an impossible state to reach. "
"Please report a bug to PyTorch."
)
# First, schedule all compute nodes that are required by first comm node,
# as well as the first comm node itself.
assert len(comm_nodes) > 0
schedule_nodes(
list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]],
)
rolled_over_compute_cost = 0
for idx in range(1, len(comm_ancestors)):
# Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule
# all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`,
# to run at the same time with comm `idx-1`.
needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & (
comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]]
)
assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes)
total_compute_runtime_cost = rolled_over_compute_cost + sum(
[
estimate_op_runtime(node)
for node in needed_by_next_comm_and_ready_compute_nodes
]
)
prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1])
schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes))
# Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done.
# Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`.
# We prioritize compute nodes that are needed sooner.
step1_runtime_cost = total_compute_runtime_cost
if step1_runtime_cost >= prev_comm_runtime_cost:
pass
else:
# Find all ready to schedule compute nodes that do not depend on comm `idx-1`.
ready_to_schedule_compute_nodes = tuple_sorted(
ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]]
)
assert_no_comm_nodes(ready_to_schedule_compute_nodes)
def earliest_comm_descendant(node):
for idx in range(len(comm_nodes)):
if node in comm_ancestors[comm_nodes[idx]]:
return idx
return len(comm_nodes)
# Prioritize compute nodes that are needed sooner.
ready_to_schedule_compute_nodes = sorted(
ready_to_schedule_compute_nodes, key=earliest_comm_descendant
)
for snode in ready_to_schedule_compute_nodes:
if total_compute_runtime_cost >= prev_comm_runtime_cost:
# If accumulated compute runtime cost is greater than comm `idx-1` runtime cost,
# it means we have maximized overlap for comm `idx-1`, and hence we stop looking
# for more compute to schedule.
break
compute_runtime_cost = estimate_op_runtime(snode)
# If we're not able to leverage more than half of this
# node's compute to overlap, we skip it.
# TODO: Smarter heuristics here
if (
prev_comm_runtime_cost - total_compute_runtime_cost
) <= compute_runtime_cost / 2:
continue
schedule_node(snode)
total_compute_runtime_cost += compute_runtime_cost
rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost
# Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`.
needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]]
schedule_nodes(list(needed_by_next_comm_nodes))
# Step 4: We schedule comm `idx`.
schedule_nodes([comm_nodes[idx]])
is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0
# The idea here is that if there are no compute nodes from Step 3
# (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes
# in Step 2 to overlap with the next comm, since they're not required to finish
# before the next comm starts.
if is_prev_comm_blocking_next_comm:
rolled_over_compute_cost = 0
else:
rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment]
schedule_nodes(unscheduled_nodes)
return final_order
def node_summary(snode):
detail = ""
if isinstance(snode.node, ir.ExternKernelOut):
detail = f" ({snode.node.python_kernel_name})"
out_tensor_info = ""
if (
hasattr(snode.node, "layout")
and hasattr(snode.node.layout, "size")
and hasattr(snode.node.layout, "stride")
):
out_tensor_info = (
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
)
node_name = ""
if hasattr(snode.node, "name"):
node_name = snode.node.name
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
def visualize_overlap(order):
total_est_runtime: float = 0.0
cur_comm_node = None
for snode in order:
if cur_comm_node is None:
if isinstance(snode.node, ir.CollectiveKernel):
total_est_runtime += estimate_op_runtime(snode)
cur_comm_node = snode.node
elif isinstance(snode.node, ir.Wait):
raise Exception(
"Wait is not expected when there is no collective running"
)
else: # exposed compute op
total_est_runtime += estimate_op_runtime(snode)
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
else: # cur_comm_node is not None
if isinstance(snode.node, ir.CollectiveKernel):
raise Exception(
"Found two collectives running at the same time. "
"`visualize_overlap` needs to be updated to handle this case"
)
elif isinstance(snode.node, ir.Wait): # end of this comm op
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
cur_comm_node = None
else: # overlapped compute op
overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
overlap_log.debug(
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
)
def reorder_compute_and_comm_for_overlap(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
order = snodes
for p in config.reorder_for_compute_comm_overlap_passes:
if isinstance(p, str) and p in globals():
p = globals()[p] # it is a builtin pass
if torch.distributed.get_rank() == 0:
overlap_log.debug(
f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
)
try:
visualize_overlap(order)
except Exception as e:
overlap_log.debug(str(e))
order = p(order) # type: ignore[operator]
if torch.distributed.get_rank() == 0:
overlap_log.debug(
f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
)
try:
visualize_overlap(order)
except Exception as e:
overlap_log.debug(str(e))
return order
|