Spaces:
Running
Running
from enum import Enum | |
from typing import NamedTuple, Dict, List, Set | |
from torch.fx.node import Node, map_arg | |
class Partition: | |
"""Partition class contains all the information about an individual partition. | |
It also provides necessary methods for manipulation the partition. | |
""" | |
def __init__(self, partition_id: int) -> None: | |
self.nodes: Set[Node] = set() | |
self.partition_id = partition_id | |
self.parents: Set[Partition] = set() | |
self.children: Set[Partition] = set() | |
self.bfs_level: int = -1 | |
self.used_mem_bytes: int = 0 | |
self.logical_device_ids: List[int] = [] | |
def __str__(self): | |
return str(self.partition_id) | |
def recalculate_mem_size(self): | |
self.used_mem_bytes = 0 | |
for node in self.nodes: | |
self.used_mem_bytes += get_extra_size_of(node, self.nodes) | |
def add_node(self, node): | |
input_nodes: Dict[Node, None] = {} | |
map_arg(node.args, input_nodes.setdefault) | |
map_arg(node.kwargs, input_nodes.setdefault) | |
# Add current node's input nodes if they are placeholder or constants | |
for n in input_nodes: | |
if n.op in {"placeholder", "get_attr"}: | |
self.nodes.add(n) | |
self.nodes.add(node) | |
self.recalculate_mem_size() | |
def remove_node(self, node): | |
# Remove a node only if the node is in the partition | |
if node in self.nodes: | |
self.nodes.remove(node) | |
# Collect the node's input nodes | |
input_nodes: Dict[Node, None] = {} | |
map_arg(node.args, input_nodes.setdefault) | |
map_arg(node.kwargs, input_nodes.setdefault) | |
# Check if an input node is a placeholder or get_attr, | |
# and this input node is not used by some other nodes in this partition, | |
# the remove this input node | |
for input_node in input_nodes: | |
if all( | |
n not in self.nodes for n in input_node.users | |
) and input_node.op in {"placeholder", "get_attr"}: | |
self.nodes.remove(input_node) | |
self.recalculate_mem_size() | |
class Device(NamedTuple): | |
name: str | |
available_mem_bytes: int | |
logical_id: int | |
class NodeLatency(NamedTuple): | |
# Latency due to the memory bandwidth | |
mem_latency_sec: float | |
# Latency due to the computation | |
computer_latency_sec: float | |
class PartitionLatency(NamedTuple): | |
# Sum of all nodes' memory latency on the critical path | |
mem_latency_sec: float | |
# Sum of all nodes' compute latency on the critical path | |
computer_latency_sec: float | |
# Latency of the critical path | |
overall_latency_sec: float | |
class PartitionMode(Enum): | |
size_based = 0 | |
sparse_nn = 1 | |
cost_aware = 2 | |
kl_based = 3 | |
aot_based = 4 | |
class PartitionerConfig(NamedTuple): | |
devices: List[Device] | |
mode: PartitionMode = PartitionMode.size_based | |
transfer_rate_bytes_per_sec: float = 0.0 | |
node_to_latency_mapping: Dict[Node, NodeLatency] = {} | |
node_to_partition_mapping: Dict[Node, int] = {} | |
partition_to_logical_device_mapping: Dict[int, List[int]] = {} | |
# Saturate host by replicating partitions to the remaining idle devices. | |
saturate_host: bool = False | |
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: | |
"""Given a node and a set of nodes, | |
this function return the extra size that needed | |
if this node is included in this set. | |
""" | |
# Find all its input nodes | |
input_nodes: Dict[Node, None] = {} | |
map_arg(node.args, input_nodes.setdefault) | |
map_arg(node.kwargs, input_nodes.setdefault) | |
# Calculate total size of related nodes | |
total_size_of_input_nodes = 0 | |
for n in input_nodes: | |
# Make sure this node hasn't been in this set yet | |
if n not in nodes: | |
size_bytes = getattr(n, "size_bytes", None) | |
if size_bytes: | |
total_size_of_input_nodes += size_bytes.output_size | |
else: | |
raise RuntimeError("node has no size_bytes attr") | |
# Don't forget the op node itself | |
size_bytes = getattr(node, "size_bytes", None) | |
if size_bytes: | |
total_size_of_input_nodes += size_bytes.total_size | |
else: | |
raise RuntimeError("node has no size_bytes attr") | |
return total_size_of_input_nodes | |
def get_latency_of_one_partition( | |
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] | |
) -> PartitionLatency: | |
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition""" | |
def get_top_nodes(partition: Partition) -> List[Node]: | |
"""Given a partition, return a list of nodes on the top bfs level""" | |
top_nodes: List[Node] = [] | |
for node in partition.nodes: | |
# Skip placeholder and get_attr nodes | |
if node.op in {"placeholder", "get_attr"}: | |
continue | |
input_nodes: Dict[Node, None] = {} | |
map_arg(node.args, input_nodes.setdefault) | |
map_arg(node.kwargs, input_nodes.setdefault) | |
# If a node has no input nodes in this partition, | |
# or its input nodes in this partition are placeholders and get_attrs | |
# this node is on the top bfs level in this partition | |
if not any( | |
n in partition.nodes and n.op not in {"placeholder", "get_attr"} | |
for n in input_nodes | |
): | |
top_nodes.append(node) | |
return top_nodes | |
def dfs_helper(node: Node, partition_latency) -> PartitionLatency: | |
"""Given a top node of a partition, this function returns | |
the latency of the critical path in the partition | |
""" | |
node_latency = node_to_latency_mapping[node] | |
# Calculate the current overall latency of the partition | |
overall_latency_sec = partition_latency.overall_latency_sec + max( | |
node_latency.computer_latency_sec, node_latency.mem_latency_sec | |
) | |
# Update the mem latency of this path | |
mem_latency_sec = ( | |
partition_latency.mem_latency_sec + node_latency.mem_latency_sec | |
) | |
# Update the compute latency of this path | |
computer_latency_sec = ( | |
partition_latency.computer_latency_sec + node_latency.computer_latency_sec | |
) | |
# Get all users of this node that are in this partition | |
users = set(node.users).intersection(partition.nodes) | |
if users: | |
max_latency = PartitionLatency( | |
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 | |
) | |
for n in users: | |
# Get new partition latency recursively | |
new_partition_latency = dfs_helper( | |
n, | |
PartitionLatency( | |
mem_latency_sec, computer_latency_sec, overall_latency_sec | |
), | |
) | |
if ( | |
new_partition_latency.overall_latency_sec | |
> max_latency.overall_latency_sec | |
): | |
max_latency = new_partition_latency | |
return max_latency | |
# If there is no user, the node is at bottom of the partition | |
return PartitionLatency( | |
mem_latency_sec, computer_latency_sec, overall_latency_sec | |
) | |
# Main part starts | |
# Get all top level nodes of this partition | |
top_nodes = get_top_nodes(partition) | |
critical_path_latency = PartitionLatency( | |
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 | |
) | |
# Go through all top nodes and find the largest latency (critical pass latency) | |
for node in top_nodes: | |
partition_latency = dfs_helper( | |
node, | |
PartitionLatency( | |
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 | |
), | |
) | |
if ( | |
partition_latency.overall_latency_sec | |
> critical_path_latency.overall_latency_sec | |
): | |
critical_path_latency = partition_latency | |
return critical_path_latency | |
def get_partition_to_latency_mapping( | |
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] | |
) -> Dict[Partition, PartitionLatency]: | |
"""Given all the partitions and node_to_latency_mapping dictionary, | |
return a mapping dictionary of each partition to its overall latency | |
""" | |
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} | |
# Go through each partition and get its latency | |
for partition in partitions: | |
partition_latency = get_latency_of_one_partition( | |
partition, node_to_latency_mapping | |
) | |
partition_to_latency_mapping[partition] = partition_latency | |
return partition_to_latency_mapping | |
def get_comm_latency_between( | |
parent_partition: Partition, | |
child_partition: Partition, | |
transfer_rate_bytes_per_sec: float, | |
): | |
"""Given two partitions (parent and child), | |
calculate the communication latency between the two. | |
""" | |
# If two partitions are on the same device, the comm latency is 0. | |
if ( | |
parent_partition.logical_device_ids != [] | |
and child_partition.logical_device_ids != [] | |
and parent_partition.logical_device_ids == child_partition.logical_device_ids | |
): | |
return 0.0 | |
# Keep tracking the communication size between parent and child | |
comm_size = 0 | |
# Keep tracking all the counted node | |
visited_nodes = set() | |
# Go through all nodes in the child partition | |
# If a node has input nodes from the parent partition, | |
# the output size of those input nodes will be counted | |
# and added to comm_size | |
for node in child_partition.nodes: | |
input_nodes: Dict[Node, None] = {} | |
map_arg(node.args, input_nodes.setdefault) | |
map_arg(node.kwargs, input_nodes.setdefault) | |
for n in input_nodes: | |
if n in parent_partition.nodes and n not in visited_nodes: | |
size_bytes = getattr(n, "size_bytes", None) | |
if size_bytes is not None: | |
comm_size += size_bytes.output_size | |
visited_nodes.add(n) | |
return comm_size / transfer_rate_bytes_per_sec | |
def get_latency_of_partitioned_graph( | |
partitions: List[Partition], | |
partition_to_latency_mapping: Dict[Partition, PartitionLatency], | |
transfer_rate_bytes_per_sec: float, | |
): | |
"""Given all partitions in a graph, find the critical path among all partitions | |
and return its latency as the latency of the whole graph | |
""" | |
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: | |
"""This function helps to recursively get the latency of a path of partitions""" | |
# Update latency by adding current partition's latency | |
latency_so_far_sec += partition_to_latency_mapping[ | |
partition | |
].overall_latency_sec | |
children = partition.children | |
if partition.children: | |
max_latency_sec = 0.0 | |
for child in partition.children: | |
# Calculate latency between | |
comm_latency_sec = get_comm_latency_between( | |
partition, child, transfer_rate_bytes_per_sec | |
) | |
new_latency_sec = dfs_helper( | |
child, latency_so_far_sec + comm_latency_sec | |
) | |
if new_latency_sec > max_latency_sec: | |
max_latency_sec = new_latency_sec | |
return max_latency_sec | |
return latency_so_far_sec | |
def get_top_partitions(partitions: List[Partition]) -> List[Partition]: | |
"""This function is to return all the partitions without parents | |
as the starting points of all the paths | |
""" | |
top_partitions = [] | |
for partition in partitions: | |
# If a partition has no parents, then it is a top partition | |
if len(partition.parents) == 0: | |
top_partitions.append(partition) | |
return top_partitions | |
top_partitions = get_top_partitions(partitions) | |
critical_path_latency_sec = 0.0 | |
for partition in top_partitions: | |
latency_sec = dfs_helper(partition, 0.0) | |
if latency_sec > critical_path_latency_sec: | |
critical_path_latency_sec = latency_sec | |
return critical_path_latency_sec | |