update
Browse files- train/main_branches.py +6 -8
train/main_branches.py
CHANGED
|
@@ -14,6 +14,7 @@ from torchcfm.optimal_transport import OTPlanSampler
|
|
| 14 |
from branchsbm.branchsbm import BranchSBM
|
| 15 |
from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
|
| 16 |
from branchsbm.branch_interpolant_train import BranchInterpolantTrain
|
|
|
|
| 17 |
|
| 18 |
from dataloaders.trajectory_data import TemporalDataModule
|
| 19 |
from dataloaders.mouse_data import WeightedBranchedCellDataModule
|
|
@@ -24,15 +25,13 @@ from dataloaders.trametinib_single import TrametinibSingleBranchDataModule
|
|
| 24 |
from dataloaders.lidar_data import WeightedBranchedLidarDataModule
|
| 25 |
from dataloaders.lidar_data_single import LidarSingleDataModule
|
| 26 |
|
| 27 |
-
from networks.
|
| 28 |
-
from networks.
|
| 29 |
-
from networks.
|
| 30 |
-
from networks.unet_base import UNetModelWrapper as UNetModel
|
| 31 |
-
from networks.geopath_networks.unet import GeoPathUNet
|
| 32 |
from utils import set_seed
|
| 33 |
|
| 34 |
from train.parsers import parse_args
|
| 35 |
-
from
|
| 36 |
from train.train_utils import (
|
| 37 |
load_config,
|
| 38 |
merge_config,
|
|
@@ -40,9 +39,8 @@ from train.train_utils import (
|
|
| 40 |
dataset_name2datapath,
|
| 41 |
create_callbacks,
|
| 42 |
)
|
| 43 |
-
from
|
| 44 |
import torch.nn as nn
|
| 45 |
-
from flow_matchers.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
|
| 46 |
|
| 47 |
def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
|
| 48 |
set_seed(seed)
|
|
|
|
| 14 |
from branchsbm.branchsbm import BranchSBM
|
| 15 |
from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
|
| 16 |
from branchsbm.branch_interpolant_train import BranchInterpolantTrain
|
| 17 |
+
from branchsbm.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
|
| 18 |
|
| 19 |
from dataloaders.trajectory_data import TemporalDataModule
|
| 20 |
from dataloaders.mouse_data import WeightedBranchedCellDataModule
|
|
|
|
| 25 |
from dataloaders.lidar_data import WeightedBranchedLidarDataModule
|
| 26 |
from dataloaders.lidar_data_single import LidarSingleDataModule
|
| 27 |
|
| 28 |
+
from networks.flow_mlp import VelocityNet
|
| 29 |
+
from networks.growth_mlp import GrowthNet
|
| 30 |
+
from networks.interpolant_mlp import GeoPathMLP
|
|
|
|
|
|
|
| 31 |
from utils import set_seed
|
| 32 |
|
| 33 |
from train.parsers import parse_args
|
| 34 |
+
from branchsbm.ema import EMA
|
| 35 |
from train.train_utils import (
|
| 36 |
load_config,
|
| 37 |
merge_config,
|
|
|
|
| 39 |
dataset_name2datapath,
|
| 40 |
create_callbacks,
|
| 41 |
)
|
| 42 |
+
from state_costs.metric_factory import DataManifoldMetric
|
| 43 |
import torch.nn as nn
|
|
|
|
| 44 |
|
| 45 |
def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
|
| 46 |
set_seed(seed)
|