Spaces:
Running
Running
Dependency-based repeat instead of using regions.
Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
|
@@ -246,6 +246,34 @@ class ModelConfig:
|
|
| 246 |
}
|
| 247 |
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
|
| 250 |
"""Builds the model described in the workspace."""
|
| 251 |
catalog = ops.CATALOGS[ENV]
|
|
@@ -259,6 +287,7 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
| 259 |
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
| 260 |
[optimizer] = optimizers
|
| 261 |
dependencies = {n.id: [] for n in ws.nodes}
|
|
|
|
| 262 |
in_edges = {}
|
| 263 |
out_edges = {}
|
| 264 |
repeats = []
|
|
@@ -266,6 +295,7 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
| 266 |
if nodes[e.target].data.title == "Repeat":
|
| 267 |
repeats.append(e.target)
|
| 268 |
dependencies[e.target].append(e.source)
|
|
|
|
| 269 |
in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
|
| 270 |
(e.source, e.sourceHandle)
|
| 271 |
)
|
|
@@ -351,7 +381,7 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
| 351 |
for out in out_edges.get(node_id, []):
|
| 352 |
i = _to_id(node_id, out)
|
| 353 |
outputs[out] = i
|
| 354 |
-
if
|
| 355 |
if "loss" in regions[node_id]:
|
| 356 |
made_in_loss.add(i)
|
| 357 |
else:
|
|
@@ -374,31 +404,23 @@ def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> Mod
|
|
| 374 |
regions[node_id].add(("repeat", node_id.removeprefix("START ")))
|
| 375 |
else:
|
| 376 |
repeat_id = node_id.removeprefix("END ")
|
|
|
|
| 377 |
print(f"repeat {repeat_id} ending")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
regions[node_id].remove(("repeat", repeat_id))
|
| 379 |
-
for n in
|
| 380 |
-
|
| 381 |
-
if ("repeat", repeat_id) in r:
|
| 382 |
-
print(f"repeating {n}")
|
| 383 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
| 384 |
pass
|
| 385 |
case _:
|
| 386 |
-
|
| 387 |
-
for i in op.inputs.keys():
|
| 388 |
-
id = getattr(inputs, i)
|
| 389 |
-
op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
|
| 390 |
-
if op.func != ops.no_op:
|
| 391 |
-
layer = op.func(*op_inputs, **p)
|
| 392 |
-
else:
|
| 393 |
-
layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
|
| 394 |
-
input_ids = ", ".join(i.id for i in op_inputs)
|
| 395 |
-
output_ids = []
|
| 396 |
-
for o, shape in zip(op.outputs.keys(), layer.shapes):
|
| 397 |
-
id = getattr(outputs, o)
|
| 398 |
-
output_ids.append(id)
|
| 399 |
-
sizes[id] = shape
|
| 400 |
-
output_ids = ", ".join(output_ids)
|
| 401 |
-
ls.append((layer.module, f"{input_ids} -> {output_ids}"))
|
| 402 |
cfg["model_inputs"] = list(used_in_model - made_in_model)
|
| 403 |
cfg["model_outputs"] = list(made_in_model & used_in_loss)
|
| 404 |
cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
|
|
|
|
| 246 |
}
|
| 247 |
|
| 248 |
|
| 249 |
+
def _add_op(op, params, inputs, outputs, sizes, layers):
|
| 250 |
+
op_inputs = []
|
| 251 |
+
for i in op.inputs.keys():
|
| 252 |
+
id = getattr(inputs, i)
|
| 253 |
+
op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
|
| 254 |
+
if op.func != ops.no_op:
|
| 255 |
+
layer = op.func(*op_inputs, **params)
|
| 256 |
+
else:
|
| 257 |
+
layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
|
| 258 |
+
input_ids = ", ".join(i.id for i in op_inputs)
|
| 259 |
+
output_ids = []
|
| 260 |
+
for o, shape in zip(op.outputs.keys(), layer.shapes):
|
| 261 |
+
id = getattr(outputs, o)
|
| 262 |
+
output_ids.append(id)
|
| 263 |
+
sizes[id] = shape
|
| 264 |
+
output_ids = ", ".join(output_ids)
|
| 265 |
+
layers.append((layer.module, f"{input_ids} -> {output_ids}"))
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _all_dependencies(node: str, dependencies: dict[str, list[str]]) -> set[str]:
|
| 269 |
+
"""Returns all dependencies of a node."""
|
| 270 |
+
deps = set()
|
| 271 |
+
for dep in dependencies[node]:
|
| 272 |
+
deps.add(dep)
|
| 273 |
+
deps.update(_all_dependencies(dep, dependencies))
|
| 274 |
+
return deps
|
| 275 |
+
|
| 276 |
+
|
| 277 |
def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
|
| 278 |
"""Builds the model described in the workspace."""
|
| 279 |
catalog = ops.CATALOGS[ENV]
|
|
|
|
| 287 |
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
| 288 |
[optimizer] = optimizers
|
| 289 |
dependencies = {n.id: [] for n in ws.nodes}
|
| 290 |
+
inv_dependencies = {n.id: [] for n in ws.nodes}
|
| 291 |
in_edges = {}
|
| 292 |
out_edges = {}
|
| 293 |
repeats = []
|
|
|
|
| 295 |
if nodes[e.target].data.title == "Repeat":
|
| 296 |
repeats.append(e.target)
|
| 297 |
dependencies[e.target].append(e.source)
|
| 298 |
+
inv_dependencies[e.source].append(e.target)
|
| 299 |
in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
|
| 300 |
(e.source, e.sourceHandle)
|
| 301 |
)
|
|
|
|
| 381 |
for out in out_edges.get(node_id, []):
|
| 382 |
i = _to_id(node_id, out)
|
| 383 |
outputs[out] = i
|
| 384 |
+
if not t.startswith("Input:"): # The outputs of inputs are not "made" by us.
|
| 385 |
if "loss" in regions[node_id]:
|
| 386 |
made_in_loss.add(i)
|
| 387 |
else:
|
|
|
|
| 404 |
regions[node_id].add(("repeat", node_id.removeprefix("START ")))
|
| 405 |
else:
|
| 406 |
repeat_id = node_id.removeprefix("END ")
|
| 407 |
+
start_id = f"START {repeat_id}"
|
| 408 |
print(f"repeat {repeat_id} ending")
|
| 409 |
+
after_start = _all_dependencies(start_id, inv_dependencies)
|
| 410 |
+
after_end = _all_dependencies(node_id, inv_dependencies)
|
| 411 |
+
before_end = _all_dependencies(node_id, dependencies)
|
| 412 |
+
affected_nodes = after_start - after_end
|
| 413 |
+
repeated_nodes = after_start & before_end
|
| 414 |
+
assert affected_nodes == repeated_nodes, (
|
| 415 |
+
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
| 416 |
+
)
|
| 417 |
regions[node_id].remove(("repeat", repeat_id))
|
| 418 |
+
for n in repeated_nodes:
|
| 419 |
+
print(f"repeating {n}")
|
|
|
|
|
|
|
| 420 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
| 421 |
pass
|
| 422 |
case _:
|
| 423 |
+
_add_op(op, p, inputs, outputs, sizes, ls)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
cfg["model_inputs"] = list(used_in_model - made_in_model)
|
| 425 |
cfg["model_outputs"] = list(made_in_model & used_in_loss)
|
| 426 |
cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
|