Spaces:
Running
Running
# pyre-strict | |
from typing import List | |
import torch | |
from . import config, ir, scheduler | |
from .dependencies import WeakDep | |
from .utils import tuple_sorted | |
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") | |
def sink_waits( | |
snodes: List["scheduler.BaseSchedulerNode"], | |
) -> List["scheduler.BaseSchedulerNode"]: | |
""" | |
Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of | |
communication overlap. | |
""" | |
new_order = [] | |
cur_waits = set() | |
for snode in snodes: | |
if isinstance(snode.node, ir.Wait): | |
cur_waits.add(snode) | |
else: | |
for wait in tuple_sorted(cur_waits): | |
if snode in wait.node_users: | |
new_order.append(wait) | |
cur_waits.remove(wait) | |
new_order.append(snode) | |
new_order.extend(tuple_sorted(cur_waits)) | |
return new_order | |
def raise_comms( | |
snodes: List["scheduler.BaseSchedulerNode"], | |
) -> List["scheduler.BaseSchedulerNode"]: | |
""" | |
Greedily moves comms as early as possible (i.e. until we reach an input). | |
Optimal in terms of communication overlap. | |
TODO: We might want to adjust this in the future to account for memory limitations. | |
e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible, | |
which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP, | |
or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way. | |
""" | |
new_order_reversed: List["scheduler.BaseSchedulerNode"] = [] | |
cur_comms: List["scheduler.BaseSchedulerNode"] = [] | |
for snode in reversed(snodes): | |
if isinstance(snode.node, ir.CollectiveKernel): | |
cur_comms.append(snode) | |
else: | |
for comm in cur_comms: | |
assert len(comm.inverse_users) > 0 | |
while len(cur_comms) > 0 and any( | |
snode in comm.inverse_users for comm in cur_comms | |
): | |
comm = cur_comms.pop(0) | |
new_order_reversed.append(comm) | |
new_order_reversed.append(snode) | |
assert len(cur_comms) <= 1 | |
new_order_reversed.extend(tuple_sorted(cur_comms)) | |
return new_order_reversed[::-1] | |
def get_ancestors(node): | |
ancestors = set() | |
cur_nodes = [node] | |
while len(cur_nodes) > 0: | |
new_nodes = [] | |
for node in cur_nodes: | |
for inp in node.inverse_users: | |
if inp not in ancestors: | |
ancestors.add(inp) | |
new_nodes.append(inp) | |
cur_nodes = new_nodes | |
return ancestors | |
def get_descendants(node): | |
descendants = set() | |
cur_nodes = [node] | |
while len(cur_nodes) > 0: | |
new_nodes = [] | |
for node in cur_nodes: | |
for inp in node.node_users: | |
if inp not in descendants: | |
descendants.add(inp) | |
new_nodes.append(inp) | |
cur_nodes = new_nodes | |
return descendants | |
def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]): | |
""" | |
Decide global ordering of comms, by just enforcing the ordering that's in the input graph | |
(might not be the same ordering as the eager mode program). | |
TODO: Come up with a better approach | |
""" | |
comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)] | |
for i in range(1, len(comm_nodes)): | |
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm | |
comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name())) | |
def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None: | |
assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes) | |
def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float: | |
""" | |
Returns estimated op runtime in nanoseconds (ns) | |
""" | |
if config.estimate_op_runtime == "default": | |
runtime = snode.get_estimated_runtime() | |
else: | |
assert callable(config.estimate_op_runtime) | |
runtime = config.estimate_op_runtime(snode) | |
return runtime | |
def reorder_compute_for_overlap( | |
snodes: List["scheduler.BaseSchedulerNode"], | |
) -> List["scheduler.BaseSchedulerNode"]: | |
""" | |
Decides a global ordering of all compute and communication nodes, | |
assuming that we already have a global ordering of communication nodes. | |
Overall scheduling procedure is: | |
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes | |
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. | |
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. | |
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. | |
We prioritize compute nodes that are needed sooner. | |
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. | |
Step 4: We schedule comm N + 1. | |
Repeat this for subsequent comm nodes. | |
""" | |
final_order = [] | |
comm_nodes = [] | |
for snode in snodes: | |
if isinstance(snode.node, ir.CollectiveKernel): | |
comm_nodes.append(snode) | |
if len(comm_nodes) == 0: | |
# if there is no comm nodes, return the current order | |
return snodes | |
comm_ancestors = {node: get_ancestors(node) for node in comm_nodes} | |
comm_descendants = {node: get_descendants(node) for node in comm_nodes} | |
indeg = dict.fromkeys(snodes, 0) | |
for snode in snodes: | |
for user in snode.node_users: | |
if user in indeg: | |
indeg[user] += 1 | |
ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0} | |
unscheduled_nodes = set() | |
unscheduled_nodes = set(snodes) | |
def schedule_node(snode): | |
""" | |
Schedule a single node. | |
""" | |
assert snode in unscheduled_nodes | |
assert snode in ready_to_schedule_nodes | |
ready_to_schedule_nodes.remove(snode) | |
unscheduled_nodes.remove(snode) | |
final_order.append(snode) | |
for user in tuple_sorted(snode.node_users): | |
if user in indeg: | |
indeg[user] -= 1 | |
if indeg[user] == 0: | |
ready_to_schedule_nodes.add(user) | |
def schedule_nodes(snodes): | |
""" | |
Schedules all nodes in `snodes` in an arbitrary topologically valid order. | |
""" | |
all_nodes = set(snodes) | |
assert all(node in unscheduled_nodes for node in all_nodes) | |
while len(all_nodes) > 0: | |
# NOTE: since model graph is always a DAG and does not have circular dependency inside, | |
# there should be at least one node that is a "free node" (i.e. indeg == 0), | |
# hence infinite loop is not possible. But we check here just to be safe. | |
progress = False | |
for node in tuple_sorted(all_nodes): | |
if node in ready_to_schedule_nodes: | |
schedule_node(node) | |
all_nodes.remove(node) | |
progress = True | |
if not progress: | |
raise Exception( | |
"Unable to find a free node (indeg == 0). This is an impossible state to reach. " | |
"Please report a bug to PyTorch." | |
) | |
# First, schedule all compute nodes that are required by first comm node, | |
# as well as the first comm node itself. | |
assert len(comm_nodes) > 0 | |
schedule_nodes( | |
list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]], | |
) | |
rolled_over_compute_cost = 0 | |
for idx in range(1, len(comm_ancestors)): | |
# Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule | |
# all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`, | |
# to run at the same time with comm `idx-1`. | |
needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & ( | |
comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]] | |
) | |
assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes) | |
total_compute_runtime_cost = rolled_over_compute_cost + sum( | |
[ | |
estimate_op_runtime(node) | |
for node in needed_by_next_comm_and_ready_compute_nodes | |
] | |
) | |
prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1]) | |
schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes)) | |
# Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done. | |
# Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`. | |
# We prioritize compute nodes that are needed sooner. | |
step1_runtime_cost = total_compute_runtime_cost | |
if step1_runtime_cost >= prev_comm_runtime_cost: | |
pass | |
else: | |
# Find all ready to schedule compute nodes that do not depend on comm `idx-1`. | |
ready_to_schedule_compute_nodes = tuple_sorted( | |
ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]] | |
) | |
assert_no_comm_nodes(ready_to_schedule_compute_nodes) | |
def earliest_comm_descendant(node): | |
for idx in range(len(comm_nodes)): | |
if node in comm_ancestors[comm_nodes[idx]]: | |
return idx | |
return len(comm_nodes) | |
# Prioritize compute nodes that are needed sooner. | |
ready_to_schedule_compute_nodes = sorted( | |
ready_to_schedule_compute_nodes, key=earliest_comm_descendant | |
) | |
for snode in ready_to_schedule_compute_nodes: | |
if total_compute_runtime_cost >= prev_comm_runtime_cost: | |
# If accumulated compute runtime cost is greater than comm `idx-1` runtime cost, | |
# it means we have maximized overlap for comm `idx-1`, and hence we stop looking | |
# for more compute to schedule. | |
break | |
compute_runtime_cost = estimate_op_runtime(snode) | |
# If we're not able to leverage more than half of this | |
# node's compute to overlap, we skip it. | |
# TODO: Smarter heuristics here | |
if ( | |
prev_comm_runtime_cost - total_compute_runtime_cost | |
) <= compute_runtime_cost / 2: | |
continue | |
schedule_node(snode) | |
total_compute_runtime_cost += compute_runtime_cost | |
rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost | |
# Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`. | |
needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]] | |
schedule_nodes(list(needed_by_next_comm_nodes)) | |
# Step 4: We schedule comm `idx`. | |
schedule_nodes([comm_nodes[idx]]) | |
is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0 | |
# The idea here is that if there are no compute nodes from Step 3 | |
# (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes | |
# in Step 2 to overlap with the next comm, since they're not required to finish | |
# before the next comm starts. | |
if is_prev_comm_blocking_next_comm: | |
rolled_over_compute_cost = 0 | |
else: | |
rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment] | |
schedule_nodes(unscheduled_nodes) | |
return final_order | |
def node_summary(snode): | |
detail = "" | |
if isinstance(snode.node, ir.ExternKernelOut): | |
detail = f" ({snode.node.python_kernel_name})" | |
out_tensor_info = "" | |
if ( | |
hasattr(snode.node, "layout") | |
and hasattr(snode.node.layout, "size") | |
and hasattr(snode.node.layout, "stride") | |
): | |
out_tensor_info = ( | |
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" | |
) | |
node_name = "" | |
if hasattr(snode.node, "name"): | |
node_name = snode.node.name | |
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" | |
def visualize_overlap(order): | |
total_est_runtime: float = 0.0 | |
cur_comm_node = None | |
for snode in order: | |
if cur_comm_node is None: | |
if isinstance(snode.node, ir.CollectiveKernel): | |
total_est_runtime += estimate_op_runtime(snode) | |
cur_comm_node = snode.node | |
elif isinstance(snode.node, ir.Wait): | |
raise Exception( | |
"Wait is not expected when there is no collective running" | |
) | |
else: # exposed compute op | |
total_est_runtime += estimate_op_runtime(snode) | |
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 | |
else: # cur_comm_node is not None | |
if isinstance(snode.node, ir.CollectiveKernel): | |
raise Exception( | |
"Found two collectives running at the same time. " | |
"`visualize_overlap` needs to be updated to handle this case" | |
) | |
elif isinstance(snode.node, ir.Wait): # end of this comm op | |
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 | |
cur_comm_node = None | |
else: # overlapped compute op | |
overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 | |
overlap_log.debug( | |
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 | |
) | |
def reorder_compute_and_comm_for_overlap( | |
snodes: List["scheduler.BaseSchedulerNode"], | |
) -> List["scheduler.BaseSchedulerNode"]: | |
order = snodes | |
for p in config.reorder_for_compute_comm_overlap_passes: | |
if isinstance(p, str) and p in globals(): | |
p = globals()[p] # it is a builtin pass | |
if torch.distributed.get_rank() == 0: | |
overlap_log.debug( | |
f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 | |
) | |
try: | |
visualize_overlap(order) | |
except Exception as e: | |
overlap_log.debug(str(e)) | |
order = p(order) # type: ignore[operator] | |
if torch.distributed.get_rank() == 0: | |
overlap_log.debug( | |
f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 | |
) | |
try: | |
visualize_overlap(order) | |
except Exception as e: | |
overlap_log.debug(str(e)) | |
return order | |