sophtang commited on
Commit
a03ffb8
·
verified ·
1 Parent(s): 5a87d8d
Files changed (1) hide show
  1. train/main_branches.py +6 -8
train/main_branches.py CHANGED
@@ -14,6 +14,7 @@ from torchcfm.optimal_transport import OTPlanSampler
14
  from branchsbm.branchsbm import BranchSBM
15
  from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
16
  from branchsbm.branch_interpolant_train import BranchInterpolantTrain
 
17
 
18
  from dataloaders.trajectory_data import TemporalDataModule
19
  from dataloaders.mouse_data import WeightedBranchedCellDataModule
@@ -24,15 +25,13 @@ from dataloaders.trametinib_single import TrametinibSingleBranchDataModule
24
  from dataloaders.lidar_data import WeightedBranchedLidarDataModule
25
  from dataloaders.lidar_data_single import LidarSingleDataModule
26
 
27
- from networks.flow_networks.mlp import VelocityNet
28
- from networks.growth_networks.mlp import GrowthNet
29
- from networks.geopath_networks.mlp import GeoPathMLP
30
- from networks.unet_base import UNetModelWrapper as UNetModel
31
- from networks.geopath_networks.unet import GeoPathUNet
32
  from utils import set_seed
33
 
34
  from train.parsers import parse_args
35
- from flow_matchers.ema import EMA
36
  from train.train_utils import (
37
  load_config,
38
  merge_config,
@@ -40,9 +39,8 @@ from train.train_utils import (
40
  dataset_name2datapath,
41
  create_callbacks,
42
  )
43
- from geo_metrics.metric_factory import DataManifoldMetric
44
  import torch.nn as nn
45
- from flow_matchers.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
46
 
47
  def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
48
  set_seed(seed)
 
14
  from branchsbm.branchsbm import BranchSBM
15
  from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar
16
  from branchsbm.branch_interpolant_train import BranchInterpolantTrain
17
+ from branchsbm.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar
18
 
19
  from dataloaders.trajectory_data import TemporalDataModule
20
  from dataloaders.mouse_data import WeightedBranchedCellDataModule
 
25
  from dataloaders.lidar_data import WeightedBranchedLidarDataModule
26
  from dataloaders.lidar_data_single import LidarSingleDataModule
27
 
28
+ from networks.flow_mlp import VelocityNet
29
+ from networks.growth_mlp import GrowthNet
30
+ from networks.interpolant_mlp import GeoPathMLP
 
 
31
  from utils import set_seed
32
 
33
  from train.parsers import parse_args
34
+ from branchsbm.ema import EMA
35
  from train.train_utils import (
36
  load_config,
37
  merge_config,
 
39
  dataset_name2datapath,
40
  create_callbacks,
41
  )
42
+ from state_costs.metric_factory import DataManifoldMetric
43
  import torch.nn as nn
 
44
 
45
  def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None:
46
  set_seed(seed)