Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmcv.cnn.bricks import ConvModule | |
| from mmocr.utils import revert_sync_batchnorm | |
| def test_revert_sync_batchnorm(): | |
| conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu') | |
| conv_syncbn.train() | |
| x = torch.randn(1, 3, 10, 10) | |
| # Will raise an ValueError saying SyncBN does not run on CPU | |
| with pytest.raises(ValueError): | |
| y = conv_syncbn(x) | |
| conv_bn = revert_sync_batchnorm(conv_syncbn) | |
| y = conv_bn(x) | |
| assert y.shape == (1, 8, 9, 9) | |
| assert conv_bn.training == conv_syncbn.training | |
| conv_syncbn.eval() | |
| conv_bn = revert_sync_batchnorm(conv_syncbn) | |
| assert conv_bn.training == conv_syncbn.training | |