File size: 1,804 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from collections import OrderedDict

import torch
from mmengine.fileio import load
from mmengine.runner import save_checkpoint


def convert(src: str, dst: str, prefix: str = 'd2_model') -> None:
    """Convert Detectron2 checkpoint to MMDetection style.

    Args:
        src (str): The Detectron2 checkpoint path, should endswith `pkl`.
        dst (str): The MMDetection checkpoint path.
        prefix (str): The prefix of MMDetection model, defaults to 'd2_model'.
    """
    # load arch_settings
    assert src.endswith('pkl'), \
        'the source Detectron2 checkpoint should endswith `pkl`.'
    d2_model = load(src, encoding='latin1').get('model')
    assert d2_model is not None

    # convert to mmdet style
    dst_state_dict = OrderedDict()
    for name, value in d2_model.items():
        if not isinstance(value, torch.Tensor):
            value = torch.from_numpy(value)
        dst_state_dict[f'{prefix}.{name}'] = value

    mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
    save_checkpoint(mmdet_model, dst)
    print(f'Convert Detectron2 model {src} to MMDetection model {dst}')


def main():
    parser = argparse.ArgumentParser(
        description='Convert Detectron2 checkpoint to MMDetection style')
    parser.add_argument('src', help='Detectron2 model path')
    parser.add_argument('dst', help='MMDetectron model save path')
    parser.add_argument(
        '--prefix', default='d2_model', type=str, help='prefix of the model')
    args = parser.parse_args()
    convert(args.src, args.dst, args.prefix)


if __name__ == '__main__':
    main()