Spaces:
Running
Running
File size: 12,687 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
from enum import Enum
from typing import NamedTuple, Dict, List, Set
from torch.fx.node import Node, map_arg
class Partition:
"""Partition class contains all the information about an individual partition.
It also provides necessary methods for manipulation the partition.
"""
def __init__(self, partition_id: int) -> None:
self.nodes: Set[Node] = set()
self.partition_id = partition_id
self.parents: Set[Partition] = set()
self.children: Set[Partition] = set()
self.bfs_level: int = -1
self.used_mem_bytes: int = 0
self.logical_device_ids: List[int] = []
def __str__(self):
return str(self.partition_id)
def recalculate_mem_size(self):
self.used_mem_bytes = 0
for node in self.nodes:
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
def add_node(self, node):
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Add current node's input nodes if they are placeholder or constants
for n in input_nodes:
if n.op in {"placeholder", "get_attr"}:
self.nodes.add(n)
self.nodes.add(node)
self.recalculate_mem_size()
def remove_node(self, node):
# Remove a node only if the node is in the partition
if node in self.nodes:
self.nodes.remove(node)
# Collect the node's input nodes
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Check if an input node is a placeholder or get_attr,
# and this input node is not used by some other nodes in this partition,
# the remove this input node
for input_node in input_nodes:
if all(
n not in self.nodes for n in input_node.users
) and input_node.op in {"placeholder", "get_attr"}:
self.nodes.remove(input_node)
self.recalculate_mem_size()
class Device(NamedTuple):
name: str
available_mem_bytes: int
logical_id: int
class NodeLatency(NamedTuple):
# Latency due to the memory bandwidth
mem_latency_sec: float
# Latency due to the computation
computer_latency_sec: float
class PartitionLatency(NamedTuple):
# Sum of all nodes' memory latency on the critical path
mem_latency_sec: float
# Sum of all nodes' compute latency on the critical path
computer_latency_sec: float
# Latency of the critical path
overall_latency_sec: float
class PartitionMode(Enum):
size_based = 0
sparse_nn = 1
cost_aware = 2
kl_based = 3
aot_based = 4
class PartitionerConfig(NamedTuple):
devices: List[Device]
mode: PartitionMode = PartitionMode.size_based
transfer_rate_bytes_per_sec: float = 0.0
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
node_to_partition_mapping: Dict[Node, int] = {}
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
# Saturate host by replicating partitions to the remaining idle devices.
saturate_host: bool = False
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
"""Given a node and a set of nodes,
this function return the extra size that needed
if this node is included in this set.
"""
# Find all its input nodes
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Calculate total size of related nodes
total_size_of_input_nodes = 0
for n in input_nodes:
# Make sure this node hasn't been in this set yet
if n not in nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes:
total_size_of_input_nodes += size_bytes.output_size
else:
raise RuntimeError("node has no size_bytes attr")
# Don't forget the op node itself
size_bytes = getattr(node, "size_bytes", None)
if size_bytes:
total_size_of_input_nodes += size_bytes.total_size
else:
raise RuntimeError("node has no size_bytes attr")
return total_size_of_input_nodes
def get_latency_of_one_partition(
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
) -> PartitionLatency:
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
def get_top_nodes(partition: Partition) -> List[Node]:
"""Given a partition, return a list of nodes on the top bfs level"""
top_nodes: List[Node] = []
for node in partition.nodes:
# Skip placeholder and get_attr nodes
if node.op in {"placeholder", "get_attr"}:
continue
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# If a node has no input nodes in this partition,
# or its input nodes in this partition are placeholders and get_attrs
# this node is on the top bfs level in this partition
if not any(
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
for n in input_nodes
):
top_nodes.append(node)
return top_nodes
def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
"""Given a top node of a partition, this function returns
the latency of the critical path in the partition
"""
node_latency = node_to_latency_mapping[node]
# Calculate the current overall latency of the partition
overall_latency_sec = partition_latency.overall_latency_sec + max(
node_latency.computer_latency_sec, node_latency.mem_latency_sec
)
# Update the mem latency of this path
mem_latency_sec = (
partition_latency.mem_latency_sec + node_latency.mem_latency_sec
)
# Update the compute latency of this path
computer_latency_sec = (
partition_latency.computer_latency_sec + node_latency.computer_latency_sec
)
# Get all users of this node that are in this partition
users = set(node.users).intersection(partition.nodes)
if users:
max_latency = PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
)
for n in users:
# Get new partition latency recursively
new_partition_latency = dfs_helper(
n,
PartitionLatency(
mem_latency_sec, computer_latency_sec, overall_latency_sec
),
)
if (
new_partition_latency.overall_latency_sec
> max_latency.overall_latency_sec
):
max_latency = new_partition_latency
return max_latency
# If there is no user, the node is at bottom of the partition
return PartitionLatency(
mem_latency_sec, computer_latency_sec, overall_latency_sec
)
# Main part starts
# Get all top level nodes of this partition
top_nodes = get_top_nodes(partition)
critical_path_latency = PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
)
# Go through all top nodes and find the largest latency (critical pass latency)
for node in top_nodes:
partition_latency = dfs_helper(
node,
PartitionLatency(
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
),
)
if (
partition_latency.overall_latency_sec
> critical_path_latency.overall_latency_sec
):
critical_path_latency = partition_latency
return critical_path_latency
def get_partition_to_latency_mapping(
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
) -> Dict[Partition, PartitionLatency]:
"""Given all the partitions and node_to_latency_mapping dictionary,
return a mapping dictionary of each partition to its overall latency
"""
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
# Go through each partition and get its latency
for partition in partitions:
partition_latency = get_latency_of_one_partition(
partition, node_to_latency_mapping
)
partition_to_latency_mapping[partition] = partition_latency
return partition_to_latency_mapping
def get_comm_latency_between(
parent_partition: Partition,
child_partition: Partition,
transfer_rate_bytes_per_sec: float,
):
"""Given two partitions (parent and child),
calculate the communication latency between the two.
"""
# If two partitions are on the same device, the comm latency is 0.
if (
parent_partition.logical_device_ids != []
and child_partition.logical_device_ids != []
and parent_partition.logical_device_ids == child_partition.logical_device_ids
):
return 0.0
# Keep tracking the communication size between parent and child
comm_size = 0
# Keep tracking all the counted node
visited_nodes = set()
# Go through all nodes in the child partition
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
for node in child_partition.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
for n in input_nodes:
if n in parent_partition.nodes and n not in visited_nodes:
size_bytes = getattr(n, "size_bytes", None)
if size_bytes is not None:
comm_size += size_bytes.output_size
visited_nodes.add(n)
return comm_size / transfer_rate_bytes_per_sec
def get_latency_of_partitioned_graph(
partitions: List[Partition],
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
transfer_rate_bytes_per_sec: float,
):
"""Given all partitions in a graph, find the critical path among all partitions
and return its latency as the latency of the whole graph
"""
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
"""This function helps to recursively get the latency of a path of partitions"""
# Update latency by adding current partition's latency
latency_so_far_sec += partition_to_latency_mapping[
partition
].overall_latency_sec
children = partition.children
if partition.children:
max_latency_sec = 0.0
for child in partition.children:
# Calculate latency between
comm_latency_sec = get_comm_latency_between(
partition, child, transfer_rate_bytes_per_sec
)
new_latency_sec = dfs_helper(
child, latency_so_far_sec + comm_latency_sec
)
if new_latency_sec > max_latency_sec:
max_latency_sec = new_latency_sec
return max_latency_sec
return latency_so_far_sec
def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
"""This function is to return all the partitions without parents
as the starting points of all the paths
"""
top_partitions = []
for partition in partitions:
# If a partition has no parents, then it is a top partition
if len(partition.parents) == 0:
top_partitions.append(partition)
return top_partitions
top_partitions = get_top_partitions(partitions)
critical_path_latency_sec = 0.0
for partition in top_partitions:
latency_sec = dfs_helper(partition, 0.0)
if latency_sec > critical_path_latency_sec:
critical_path_latency_sec = latency_sec
return critical_path_latency_sec
|