File size: 523 Bytes
e8861c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torchvision
def format_target(targets):
'''
Args:
targets (List[Dict]),
Return:
tensor (Tensor), [im_id, label, bbox,]
'''
outputs = []
for i, tgt in enumerate(targets):
boxes = torchvision.ops.box_convert(tgt['boxes'], in_fmt='xyxy', out_fmt='cxcywh')
labels = tgt['labels'].reshape(-1, 1)
im_ids = torch.ones_like(labels) * i
outputs.append(torch.cat([im_ids, labels, boxes], dim=1))
return torch.cat(outputs, dim=0)
|