Spaces:
No application file
No application file
File size: 1,294 Bytes
430de99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
from typing import Sequence
import torch
from detectron2.structures import ImageList
class TestImageList(unittest.TestCase):
def test_imagelist_padding_shape(self):
class TensorToImageList(torch.nn.Module):
def forward(self, tensors: Sequence[torch.Tensor]):
return ImageList.from_tensors(tensors, 4).tensor
func = torch.jit.trace(
TensorToImageList(), ([torch.ones((3, 10, 10), dtype=torch.float32)],)
)
ret = func([torch.ones((3, 15, 20), dtype=torch.float32)])
self.assertEqual(list(ret.shape), [1, 3, 16, 20], str(ret.shape))
func = torch.jit.trace(
TensorToImageList(),
(
[
torch.ones((3, 16, 10), dtype=torch.float32),
torch.ones((3, 13, 11), dtype=torch.float32),
],
),
)
ret = func(
[
torch.ones((3, 25, 20), dtype=torch.float32),
torch.ones((3, 10, 10), dtype=torch.float32),
]
)
# does not support calling with different #images
self.assertEqual(list(ret.shape), [2, 3, 28, 20], str(ret.shape))
|