Spaces:
Running
Running
Make new optimizer when model is copied.
Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
|
@@ -189,10 +189,14 @@ class ModelConfig:
|
|
| 189 |
model_outputs: list[str]
|
| 190 |
loss_inputs: list[str]
|
| 191 |
loss: torch.nn.Module
|
| 192 |
-
|
|
|
|
| 193 |
source_workspace: str | None = None
|
| 194 |
trained: bool = False
|
| 195 |
|
|
|
|
|
|
|
|
|
|
| 196 |
def num_parameters(self) -> int:
|
| 197 |
return sum(p.numel() for p in self.model.parameters())
|
| 198 |
|
|
@@ -222,10 +226,20 @@ class ModelConfig:
|
|
| 222 |
self.optimizer.step()
|
| 223 |
return loss.item()
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
def copy(self):
|
| 226 |
"""Returns a copy of the model."""
|
| 227 |
-
c = dataclasses.replace(
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
return c
|
| 230 |
|
| 231 |
def metadata(self):
|
|
@@ -451,9 +465,7 @@ class ModelBuilder:
|
|
| 451 |
assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
|
| 452 |
# Create optimizer.
|
| 453 |
op = self.catalog["Optimizer"]
|
| 454 |
-
|
| 455 |
-
o = getattr(torch.optim, p["type"].name)
|
| 456 |
-
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
| 457 |
return ModelConfig(**cfg)
|
| 458 |
|
| 459 |
|
|
|
|
| 189 |
model_outputs: list[str]
|
| 190 |
loss_inputs: list[str]
|
| 191 |
loss: torch.nn.Module
|
| 192 |
+
optimizer_parameters: dict[str, any]
|
| 193 |
+
optimizer: torch.optim.Optimizer | None = None
|
| 194 |
source_workspace: str | None = None
|
| 195 |
trained: bool = False
|
| 196 |
|
| 197 |
+
def __post_init__(self):
|
| 198 |
+
self._make_optimizer()
|
| 199 |
+
|
| 200 |
def num_parameters(self) -> int:
|
| 201 |
return sum(p.numel() for p in self.model.parameters())
|
| 202 |
|
|
|
|
| 226 |
self.optimizer.step()
|
| 227 |
return loss.item()
|
| 228 |
|
| 229 |
+
def _make_optimizer(self):
|
| 230 |
+
# We need to make a new optimizer when the model is copied. (It's tied to its parameters.)
|
| 231 |
+
p = self.optimizer_parameters
|
| 232 |
+
o = getattr(torch.optim, p["type"].name)
|
| 233 |
+
self.optimizer = o(self.model.parameters(), lr=p["lr"])
|
| 234 |
+
|
| 235 |
def copy(self):
|
| 236 |
"""Returns a copy of the model."""
|
| 237 |
+
c = dataclasses.replace(
|
| 238 |
+
self,
|
| 239 |
+
model=copy.deepcopy(self.model),
|
| 240 |
+
)
|
| 241 |
+
c._make_optimizer()
|
| 242 |
+
c.optimizer.load_state_dict(self.optimizer.state_dict())
|
| 243 |
return c
|
| 244 |
|
| 245 |
def metadata(self):
|
|
|
|
| 465 |
assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
|
| 466 |
# Create optimizer.
|
| 467 |
op = self.catalog["Optimizer"]
|
| 468 |
+
cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
|
|
|
|
|
|
|
| 469 |
return ModelConfig(**cfg)
|
| 470 |
|
| 471 |
|