orhir commited on
Commit
058cc76
·
verified ·
1 Parent(s): 5305d1f

Update EdgeCape/models/detectors/EdgeCape.py

Browse files
EdgeCape/models/detectors/EdgeCape.py CHANGED
@@ -12,6 +12,7 @@ from mmpose.models.builder import POSENETS
12
  from mmpose.models.detectors.base import BasePose
13
  from EdgeCape.models.backbones.adapter import DPT
14
  from EdgeCape.models.backbones.dino import DINO
 
15
 
16
 
17
  @POSENETS.register_module()
@@ -62,14 +63,7 @@ class EdgeCape(BasePose):
62
  self.keypoint_head_module.init_weights()
63
 
64
  def forward(self,
65
- img_s,
66
- img_q,
67
- target_s=None,
68
- target_weight_s=None,
69
- target_q=None,
70
- target_weight_q=None,
71
- img_metas=None,
72
- return_loss=True,
73
  **kwargs):
74
  """Calls either forward_train or forward_test depending on whether
75
  return_loss=True. Note this setting will change the expected inputs.
@@ -78,6 +72,30 @@ class EdgeCape(BasePose):
78
  should be double nested (i.e. List[Tensor], List[List[dict]]), with
79
  the outer list indicating test time augmentations.
80
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if return_loss:
82
  return self.forward_train(img_s, target_s, target_weight_s, img_q,
83
  target_q, target_weight_q, img_metas,
 
12
  from mmpose.models.detectors.base import BasePose
13
  from EdgeCape.models.backbones.adapter import DPT
14
  from EdgeCape.models.backbones.dino import DINO
15
+ import json
16
 
17
 
18
  @POSENETS.register_module()
 
63
  self.keypoint_head_module.init_weights()
64
 
65
  def forward(self,
66
+ input_str,
 
 
 
 
 
 
 
67
  **kwargs):
68
  """Calls either forward_train or forward_test depending on whether
69
  return_loss=True. Note this setting will change the expected inputs.
 
72
  should be double nested (i.e. List[Tensor], List[List[dict]]), with
73
  the outer list indicating test time augmentations.
74
  """
75
+
76
+ str_dict = json.loads(input_str)
77
+ str_dict["img_s"] = [torch.tensor(str_dict["img_s"], dtype=torch.float32).cuda()]
78
+ str_dict["img_q"] = torch.tensor(str_dict["img_q"], dtype=torch.float32).cuda()
79
+ str_dict["target_weight_s"] = [torch.tensor(str_dict["target_weight_s"], dtype=torch.float32).cuda()]
80
+ str_dict["target_s"] = [torch.tensor(str_dict["target_s"], dtype=torch.float32).cuda()]
81
+
82
+ str_dict['img_metas'][0]['sample_joints_3d'][0] = torch.tensor(str_dict['img_metas'][0]['sample_joints_3d'][0])
83
+ str_dict['img_metas'][0]['query_joints_3d'] = torch.tensor(str_dict['img_metas'][0]['query_joints_3d'])
84
+ str_dict['img_metas'][0]['sample_center'][0] = torch.tensor(str_dict['img_metas'][0]['sample_center'][0])
85
+ str_dict['img_metas'][0]['query_center'] = torch.tensor(str_dict['img_metas'][0]['query_center'])
86
+ str_dict['img_metas'][0]['sample_scale'][0] = torch.tensor(str_dict['img_metas'][0]['sample_scale'][0])
87
+ str_dict['img_metas'][0]['query_scale'] = torch.tensor(str_dict['img_metas'][0]['query_scale'])
88
+
89
+ img_s = str_dict["img_s"]
90
+ img_q = str_dict["img_q"]
91
+ target_s = str_dict["target_s"]
92
+ target_weight_s = str_dict["target_weight_s"]
93
+ target_q = str_dict["target_q"]
94
+ target_weight_q = str_dict["target_weight_q"]
95
+ return_loss = str_dict["return_loss"]
96
+ img_metas = str_dict["img_metas"]
97
+ kwargs = {}
98
+
99
  if return_loss:
100
  return self.forward_train(img_s, target_s, target_weight_s, img_q,
101
  target_q, target_weight_q, img_metas,