Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from numpy.testing import assert_array_equal | |
| from mmocr.apis.utils import tensor2grayimgs | |
| def test_tensor2grayimgs(): | |
| # test tensor obj | |
| with pytest.raises(AssertionError): | |
| tensor = np.random.rand(2, 3, 3) | |
| tensor2grayimgs(tensor) | |
| # test tensor ndim | |
| with pytest.raises(AssertionError): | |
| tensor = torch.randn(2, 3, 3) | |
| tensor2grayimgs(tensor) | |
| # test tensor dim-1 | |
| with pytest.raises(AssertionError): | |
| tensor = torch.randn(2, 3, 5, 5) | |
| tensor2grayimgs(tensor) | |
| # test mean length | |
| with pytest.raises(AssertionError): | |
| tensor = torch.randn(2, 1, 5, 5) | |
| tensor2grayimgs(tensor, mean=(1, 1, 1)) | |
| # test std length | |
| with pytest.raises(AssertionError): | |
| tensor = torch.randn(2, 1, 5, 5) | |
| tensor2grayimgs(tensor, std=(1, 1, 1)) | |
| tensor = torch.randn(2, 1, 5, 5) | |
| gts = [t.squeeze(0).cpu().numpy().astype(np.uint8) for t in tensor] | |
| outputs = tensor2grayimgs(tensor, mean=(0, ), std=(1, )) | |
| for gt, output in zip(gts, outputs): | |
| assert_array_equal(gt, output) | |