davidserra9 commited on
Commit
3a5ddf3
·
verified ·
1 Parent(s): c7626cc

Update models/interactive_model.py

Browse files
Files changed (1) hide show
  1. models/interactive_model.py +2 -2
models/interactive_model.py CHANGED
@@ -10,12 +10,12 @@ from torchvision.transforms import functional as TF
10
  import torch
11
 
12
  class NamedCurves(nn.Module):
13
- def __init__(self, configs: dict):
14
  super().__init__()
15
  self.model_configs = configs
16
 
17
  self.backbone = Backbone(**configs['backbone']['params'])
18
- self.color_naming = ColorNaming(num_categories=configs['color_naming']['num_categories'])
19
  self.bcpe = BCPE(**configs['bezier_control_points_estimator']['params'])
20
  self.local_fusion = LocalFusion(**configs['local_fusion']['params'])
21
 
 
10
  import torch
11
 
12
  class NamedCurves(nn.Module):
13
+ def __init__(self, configs: dict, device="cuda"):
14
  super().__init__()
15
  self.model_configs = configs
16
 
17
  self.backbone = Backbone(**configs['backbone']['params'])
18
+ self.color_naming = ColorNaming(num_categories=configs['color_naming']['num_categories'], device=device)
19
  self.bcpe = BCPE(**configs['bezier_control_points_estimator']['params'])
20
  self.local_fusion = LocalFusion(**configs['local_fusion']['params'])
21