Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmocr.models.textrecog.preprocessor import (BasePreprocessor, | |
| TPSPreprocessor) | |
| def test_tps_preprocessor(): | |
| with pytest.raises(AssertionError): | |
| TPSPreprocessor(num_fiducial=-1) | |
| with pytest.raises(AssertionError): | |
| TPSPreprocessor(img_size=32) | |
| with pytest.raises(AssertionError): | |
| TPSPreprocessor(rectified_img_size=100) | |
| with pytest.raises(AssertionError): | |
| TPSPreprocessor(num_img_channel='bgr') | |
| tps_preprocessor = TPSPreprocessor( | |
| num_fiducial=20, | |
| img_size=(32, 100), | |
| rectified_img_size=(32, 100), | |
| num_img_channel=1) | |
| tps_preprocessor.init_weights() | |
| tps_preprocessor.train() | |
| batch_img = torch.randn(1, 1, 32, 100) | |
| processed = tps_preprocessor(batch_img) | |
| assert processed.shape == torch.Size([1, 1, 32, 100]) | |
| def test_base_preprocessor(): | |
| preprocessor = BasePreprocessor() | |
| preprocessor.init_weights() | |
| preprocessor.train() | |
| batch_img = torch.randn(1, 1, 32, 100) | |
| processed = preprocessor(batch_img) | |
| assert processed.shape == torch.Size([1, 1, 32, 100]) | |