Spaces:
Running
on
Zero
Running
on
Zero
Update models/interactive_model.py
Browse files
models/interactive_model.py
CHANGED
@@ -10,12 +10,12 @@ from torchvision.transforms import functional as TF
|
|
10 |
import torch
|
11 |
|
12 |
class NamedCurves(nn.Module):
|
13 |
-
def __init__(self, configs: dict):
|
14 |
super().__init__()
|
15 |
self.model_configs = configs
|
16 |
|
17 |
self.backbone = Backbone(**configs['backbone']['params'])
|
18 |
-
self.color_naming = ColorNaming(num_categories=configs['color_naming']['num_categories'])
|
19 |
self.bcpe = BCPE(**configs['bezier_control_points_estimator']['params'])
|
20 |
self.local_fusion = LocalFusion(**configs['local_fusion']['params'])
|
21 |
|
|
|
10 |
import torch
|
11 |
|
12 |
class NamedCurves(nn.Module):
|
13 |
+
def __init__(self, configs: dict, device="cuda"):
|
14 |
super().__init__()
|
15 |
self.model_configs = configs
|
16 |
|
17 |
self.backbone = Backbone(**configs['backbone']['params'])
|
18 |
+
self.color_naming = ColorNaming(num_categories=configs['color_naming']['num_categories'], device=device)
|
19 |
self.bcpe = BCPE(**configs['bezier_control_points_estimator']['params'])
|
20 |
self.local_fusion = LocalFusion(**configs['local_fusion']['params'])
|
21 |
|