orhir commited on
Commit
775a5db
·
verified ·
1 Parent(s): e602703

Update EdgeCape/models/detectors/EdgeCape.py

Browse files
EdgeCape/models/detectors/EdgeCape.py CHANGED
@@ -63,7 +63,14 @@ class EdgeCape(BasePose):
63
  self.keypoint_head_module.init_weights()
64
 
65
  def forward(self,
66
- input_str,
 
 
 
 
 
 
 
67
  **kwargs):
68
  """Calls either forward_train or forward_test depending on whether
69
  return_loss=True. Note this setting will change the expected inputs.
@@ -73,29 +80,6 @@ class EdgeCape(BasePose):
73
  the outer list indicating test time augmentations.
74
  """
75
 
76
- str_dict = json.loads(input_str)
77
- str_dict["img_s"] = [torch.tensor(str_dict["img_s"], dtype=torch.float32).cuda()]
78
- str_dict["img_q"] = torch.tensor(str_dict["img_q"], dtype=torch.float32).cuda()
79
- str_dict["target_weight_s"] = torch.tensor(str_dict["target_weight_s"], dtype=torch.float32).cuda()
80
- str_dict["target_s"] = torch.tensor(str_dict["target_s"], dtype=torch.float32).cuda()
81
-
82
- str_dict['img_metas'][0]['sample_joints_3d'][0] = torch.tensor(str_dict['img_metas'][0]['sample_joints_3d'][0])
83
- str_dict['img_metas'][0]['query_joints_3d'] = torch.tensor(str_dict['img_metas'][0]['query_joints_3d'])
84
- str_dict['img_metas'][0]['sample_center'][0] = torch.tensor(str_dict['img_metas'][0]['sample_center'][0])
85
- str_dict['img_metas'][0]['query_center'] = torch.tensor(str_dict['img_metas'][0]['query_center'])
86
- str_dict['img_metas'][0]['sample_scale'][0] = torch.tensor(str_dict['img_metas'][0]['sample_scale'][0])
87
- str_dict['img_metas'][0]['query_scale'] = torch.tensor(str_dict['img_metas'][0]['query_scale'])
88
-
89
- img_s = str_dict["img_s"]
90
- img_q = str_dict["img_q"]
91
- target_s = str_dict["target_s"]
92
- target_weight_s = str_dict["target_weight_s"]
93
- target_q = str_dict["target_q"]
94
- target_weight_q = str_dict["target_weight_q"]
95
- return_loss = str_dict["return_loss"]
96
- img_metas = str_dict["img_metas"]
97
- kwargs = {}
98
-
99
  if return_loss:
100
  return self.forward_train(img_s, target_s, target_weight_s, img_q,
101
  target_q, target_weight_q, img_metas,
 
63
  self.keypoint_head_module.init_weights()
64
 
65
  def forward(self,
66
+ img_s,
67
+ img_q,
68
+ target_s=None,
69
+ target_weight_s=None,
70
+ target_q=None,
71
+ target_weight_q=None,
72
+ img_metas=None,
73
+ return_loss=True,
74
  **kwargs):
75
  """Calls either forward_train or forward_test depending on whether
76
  return_loss=True. Note this setting will change the expected inputs.
 
80
  the outer list indicating test time augmentations.
81
  """
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if return_loss:
84
  return self.forward_train(img_s, target_s, target_weight_s, img_q,
85
  target_q, target_weight_q, img_metas,